fix
Some checks failed
Secret Leaks / trufflehog (push) Has been cancelled

This commit is contained in:
sayakpaul
2026-04-17 10:43:21 +05:30
parent 5be0434232
commit 78bbf2a3c2

View File

@@ -802,9 +802,9 @@ class TorchAoConfigMixin:
"""
TORCHAO_QUANT_TYPES = {
"int4wo": "Int4WeightOnlyConfig",
"int8wo": "Int8WeightOnlyConfig",
"int8dq": "Int8DynamicActivationInt8WeightConfig",
"int4wo": {"quant_type_name": "Int4WeightOnlyConfig"},
"int8wo": {"quant_type_name": "Int8WeightOnlyConfig"},
"int8dq": {"quant_type_name": "Int8DynamicActivationInt8WeightConfig"},
}
TORCHAO_EXPECTED_MEMORY_REDUCTIONS = {
@@ -814,12 +814,13 @@ class TorchAoConfigMixin:
}
@staticmethod
def _get_quant_config(config_name):
config_cls = getattr(_torchao_quantization, config_name)
return TorchAoConfig(config_cls())
def _get_quant_config(config_kwargs):
config_kwargs = config_kwargs.copy()
config_cls = getattr(_torchao_quantization, config_kwargs.pop("quant_type_name"))
return TorchAoConfig(config_cls(), **config_kwargs)
def _create_quantized_model(self, config_name, **extra_kwargs):
config = self._get_quant_config(config_name)
def _create_quantized_model(self, config_kwargs, **extra_kwargs):
config = self._get_quant_config(config_kwargs)
kwargs = getattr(self, "pretrained_model_kwargs", {}).copy()
kwargs["quantization_config"] = config
kwargs["device_map"] = str(torch_device)