Skip to content

FqnToConfig doesn't handle module swaps #3490

@andrewor14

Description

@andrewor14

A user tried to use FqnToConfig with QATConfig, which uses module swap. However, FqnToConfig seems to only work with tensor subclasses, not module swaps:

if (
fqn_matches_fqn_config(module_fqn, config)
or _module_param_matches_fqn_config(module, module_fqn, config)
or ("_default" in config.fqn_to_config and _is_linear(module))
):
# this replaces inplace, so no need to reassign
_fqn_to_config_handler(module, module_fqn, config)

Looks like we do need to reassign for module swaps.

Minimal repro:

import torch
from torchao.quantization import quantize_, FqnToConfig, Int4WeightOnlyConfig
from torchao.quantization.qat import QATConfig
from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(4096, 4096)
    def forward(self, x):
        return self.l(x)

m = M()
print("BEFORE:", type(m.l), m.l)
base_config = NVFP4DynamicActivationNVFP4WeightConfig()
name_to_config = {"l": QATConfig(base_config, step="prepare")}
quantize_(m, FqnToConfig(name_to_config), filter_fn=None) # doesn't work
print("AFTER:", type(m.l), m.l)

Output: (module was not swapped)

BEFORE: <class 'torch.nn.modules.linear.Linear'> Linear(in_features=4096, out_features=4096, bias=True)
AFTER: <class 'torch.nn.modules.linear.Linear'> Linear(in_features=4096, out_features=4096, bias=True)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions