-
Notifications
You must be signed in to change notification settings - Fork 386
Open
Description
A user tried to use FqnToConfig with QATConfig, which uses module swap. However, FqnToConfig seems to only work with tensor subclasses, not module swaps:
ao/torchao/quantization/quant_api.py
Lines 485 to 491 in ff6d9e2
| if ( | |
| fqn_matches_fqn_config(module_fqn, config) | |
| or _module_param_matches_fqn_config(module, module_fqn, config) | |
| or ("_default" in config.fqn_to_config and _is_linear(module)) | |
| ): | |
| # this replaces inplace, so no need to reassign | |
| _fqn_to_config_handler(module, module_fqn, config) |
Looks like we do need to reassign for module swaps.
Minimal repro:
import torch
from torchao.quantization import quantize_, FqnToConfig, Int4WeightOnlyConfig
from torchao.quantization.qat import QATConfig
from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.l = torch.nn.Linear(4096, 4096)
def forward(self, x):
return self.l(x)
m = M()
print("BEFORE:", type(m.l), m.l)
base_config = NVFP4DynamicActivationNVFP4WeightConfig()
name_to_config = {"l": QATConfig(base_config, step="prepare")}
quantize_(m, FqnToConfig(name_to_config), filter_fn=None) # doesn't work
print("AFTER:", type(m.l), m.l)
Output: (module was not swapped)
BEFORE: <class 'torch.nn.modules.linear.Linear'> Linear(in_features=4096, out_features=4096, bias=True)
AFTER: <class 'torch.nn.modules.linear.Linear'> Linear(in_features=4096, out_features=4096, bias=True)
Metadata
Metadata
Assignees
Labels
No labels