mirror of
https://github.com/huggingface/diffusers.git
synced 2026-05-28 00:39:35 +08:00
This commit is contained in:
@@ -802,9 +802,9 @@ class TorchAoConfigMixin:
|
||||
"""
|
||||
|
||||
TORCHAO_QUANT_TYPES = {
|
||||
"int4wo": "Int4WeightOnlyConfig",
|
||||
"int8wo": "Int8WeightOnlyConfig",
|
||||
"int8dq": "Int8DynamicActivationInt8WeightConfig",
|
||||
"int4wo": {"quant_type_name": "Int4WeightOnlyConfig"},
|
||||
"int8wo": {"quant_type_name": "Int8WeightOnlyConfig"},
|
||||
"int8dq": {"quant_type_name": "Int8DynamicActivationInt8WeightConfig"},
|
||||
}
|
||||
|
||||
TORCHAO_EXPECTED_MEMORY_REDUCTIONS = {
|
||||
@@ -814,12 +814,13 @@ class TorchAoConfigMixin:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_quant_config(config_name):
|
||||
config_cls = getattr(_torchao_quantization, config_name)
|
||||
return TorchAoConfig(config_cls())
|
||||
def _get_quant_config(config_kwargs):
|
||||
config_kwargs = config_kwargs.copy()
|
||||
config_cls = getattr(_torchao_quantization, config_kwargs.pop("quant_type_name"))
|
||||
return TorchAoConfig(config_cls(), **config_kwargs)
|
||||
|
||||
def _create_quantized_model(self, config_name, **extra_kwargs):
|
||||
config = self._get_quant_config(config_name)
|
||||
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
|
||||
config = self._get_quant_config(config_kwargs)
|
||||
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
|
||||
kwargs["quantization_config"] = config
|
||||
kwargs["device_map"] = str(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user