Skip to content

double not float KAN #8

@edmondja

Description

@edmondja

I have by running your demo notebook :

RuntimeError Traceback (most recent call last)
Cell In[25], line 10
7 X, y = make_classification(n_samples=5000, n_features=5, n_informative=3)
8 model = imodelsx.KANClassifier(hidden_layer_size=64, device='cpu',
9 regularize_activation=1.0, regularize_entropy=1.0)
---> 10 model.fit(X, y)
11 y_pred = model.predict(X)
12 print('Test acc', accuracy_score(y, y_pred))

File ~/opt/anaconda3/envs/py311/lib/python3.11/site-packages/imodelsx/kan/kan_sklearn.py:80, in KAN.fit(self, X, y, batch_size, lr, weight_decay, gamma)
78 x = x.view(-1, num_features).to(self.device)
79 optimizer.zero_grad()
---> 80 output = self.model(x).squeeze()
81 loss = criterion(output, labs.to(self.device).squeeze())
82 if isinstance(self, (KANGAMClassifier, KANGAMRegressor)):

File ~/opt/anaconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File ~/opt/anaconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File ~/opt/anaconda3/envs/py311/lib/python3.11/site-packages/imodelsx/kan/kan_modules.py:296, in KANModule.forward(self, x, update_grid)
294 if update_grid:
295 layer.update_grid(x)
--> 296 x = layer(x)
297 return x

File ~/opt/anaconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)

File ~/opt/anaconda3/envs/py311/lib/python3.11/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

File ~/opt/anaconda3/envs/py311/lib/python3.11/site-packages/imodelsx/kan/kan_modules.py:173, in KANLinearModule.forward(self, x)
170 def forward(self, x: torch.Tensor):
171 assert x.dim() == 2 and x.size(1) == self.in_features
--> 173 base_output = F.linear(self.base_activation(x), self.base_weight)
174 spline_output = F.linear(
175 self.b_splines(x).view(x.size(0), -1),
176 self.scaled_spline_weight.view(self.out_features, -1),
177 )
178 return base_output + spline_output

RuntimeError: expected m1 and m2 to have the same dtype, but got: float != double

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions