Skip to content

Commit 6ffe7b6

Browse files
Remove gate_activation to allow pickling
1 parent cd7b0af commit 6ffe7b6

File tree

1 file changed

+6
-14
lines changed

1 file changed

+6
-14
lines changed

sbi/neural_nets/net_builders/vector_field_nets.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

44
import math
5-
from typing import Callable, Literal, Optional, Sequence, Union
5+
from typing import Literal, Optional, Sequence, Union
66

77
import torch
88
import torch.nn as nn
@@ -281,7 +281,6 @@ def __init__(
281281
cond_dim: int,
282282
mlp_ratio: int = 1,
283283
activation: type[nn.Module] = nn.GELU,
284-
gate_activation: Callable = lambda x: (x + 1.0),
285284
):
286285
super().__init__()
287286

@@ -302,8 +301,6 @@ def __init__(
302301
nn.Linear(hidden_features * mlp_ratio, hidden_features),
303302
)
304303

305-
self.gate_activation = gate_activation
306-
307304
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
308305
"""
309306
Arguments:
@@ -315,7 +312,7 @@ def forward(self, x: Tensor, cond: Tensor) -> Tensor:
315312
"""
316313

317314
shift_, scale_, gate_ = self.ada_ln(cond).chunk(3, dim=-1)
318-
gate_ = self.gate_activation(gate_)
315+
gate_ = gate_ + 1.0 # Gate activation: `lambda x: x + 1`.
319316
y = (scale_ + 1) * x + shift_
320317
y = self.block(y)
321318
y = x + gate_ * y
@@ -665,7 +662,6 @@ def __init__(
665662
num_heads: int,
666663
mlp_ratio: int = 2,
667664
activation: type[nn.Module] = nn.GELU,
668-
gate_activation: Callable = lambda x: (x + 1.0),
669665
):
670666
"""Initialize dit transformer block.
671667
@@ -675,7 +671,6 @@ def __init__(
675671
num_heads: number of attention heads
676672
mlp_ratio: ratio for mlp hidden dimension
677673
activation: activation function
678-
gate_activation: activation function for the gate
679674
"""
680675
super().__init__()
681676

@@ -705,7 +700,6 @@ def __init__(
705700
activation(),
706701
nn.Linear(hidden_features * mlp_ratio, hidden_features),
707702
)
708-
self.gate_activation = gate_activation
709703

710704
# layer norms
711705
self.norm1 = nn.LayerNorm(hidden_features)
@@ -737,8 +731,8 @@ def forward(self, x: Tensor, cond: Tensor) -> Tensor:
737731
mlp_shift = mlp_shift.view(batch_size, 1, -1)
738732
mlp_gate = mlp_gate.view(batch_size, 1, -1)
739733

740-
attn_gate = self.gate_activation(attn_gate)
741-
mlp_gate = self.gate_activation(mlp_gate)
734+
attn_gate = attn_gate + 1.0 # Gate activation: `lambda x: x + 1`.
735+
mlp_gate = mlp_gate + 1.0 # Gate activation: `lambda x: x + 1`.
742736

743737
# attention with adaptive ln
744738
x_norm = self.norm1(x)
@@ -776,7 +770,6 @@ def __init__(
776770
num_heads: int,
777771
mlp_ratio: int = 4,
778772
activation: type[nn.Module] = nn.GELU,
779-
gate_activation: Callable = lambda x: (x + 1.0),
780773
):
781774
super().__init__()
782775

@@ -817,7 +810,6 @@ def __init__(
817810
self.norm1 = nn.LayerNorm(hidden_features)
818811
self.norm2 = nn.LayerNorm(hidden_features)
819812
self.norm3 = nn.LayerNorm(hidden_features)
820-
self.gate_activation = gate_activation
821813

822814
def forward(
823815
self, x: Tensor, cross_attention_condition: Tensor, time_condition: Tensor
@@ -846,8 +838,8 @@ def forward(
846838
mlp_shift = mlp_shift.unsqueeze(1)
847839
mlp_gate = mlp_gate.unsqueeze(1)
848840

849-
attn_gate = self.gate_activation(attn_gate)
850-
mlp_gate = self.gate_activation(mlp_gate)
841+
attn_gate = attn_gate + 1.0 # Gate activation: `lambda x: x + 1`.
842+
mlp_gate = mlp_gate + 1.0 # Gate activation: `lambda x: x + 1`.
851843

852844
# self-attention with adaptive ln
853845
x_norm = self.norm1(x)

0 commit comments

Comments
 (0)