22# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33
44import math
5- from typing import Callable , Literal , Optional , Sequence , Union
5+ from typing import Literal , Optional , Sequence , Union
66
77import torch
88import torch .nn as nn
@@ -281,7 +281,6 @@ def __init__(
281281 cond_dim : int ,
282282 mlp_ratio : int = 1 ,
283283 activation : type [nn .Module ] = nn .GELU ,
284- gate_activation : Callable = lambda x : (x + 1.0 ),
285284 ):
286285 super ().__init__ ()
287286
@@ -302,8 +301,6 @@ def __init__(
302301 nn .Linear (hidden_features * mlp_ratio , hidden_features ),
303302 )
304303
305- self .gate_activation = gate_activation
306-
307304 def forward (self , x : Tensor , cond : Tensor ) -> Tensor :
308305 """
309306 Arguments:
@@ -315,7 +312,7 @@ def forward(self, x: Tensor, cond: Tensor) -> Tensor:
315312 """
316313
317314 shift_ , scale_ , gate_ = self .ada_ln (cond ).chunk (3 , dim = - 1 )
318- gate_ = self . gate_activation ( gate_ )
315+ gate_ = gate_ + 1.0 # Gate activation: `lambda x: x + 1`.
319316 y = (scale_ + 1 ) * x + shift_
320317 y = self .block (y )
321318 y = x + gate_ * y
@@ -665,7 +662,6 @@ def __init__(
665662 num_heads : int ,
666663 mlp_ratio : int = 2 ,
667664 activation : type [nn .Module ] = nn .GELU ,
668- gate_activation : Callable = lambda x : (x + 1.0 ),
669665 ):
670666 """Initialize dit transformer block.
671667
@@ -675,7 +671,6 @@ def __init__(
675671 num_heads: number of attention heads
676672 mlp_ratio: ratio for mlp hidden dimension
677673 activation: activation function
678- gate_activation: activation function for the gate
679674 """
680675 super ().__init__ ()
681676
@@ -705,7 +700,6 @@ def __init__(
705700 activation (),
706701 nn .Linear (hidden_features * mlp_ratio , hidden_features ),
707702 )
708- self .gate_activation = gate_activation
709703
710704 # layer norms
711705 self .norm1 = nn .LayerNorm (hidden_features )
@@ -737,8 +731,8 @@ def forward(self, x: Tensor, cond: Tensor) -> Tensor:
737731 mlp_shift = mlp_shift .view (batch_size , 1 , - 1 )
738732 mlp_gate = mlp_gate .view (batch_size , 1 , - 1 )
739733
740- attn_gate = self . gate_activation ( attn_gate )
741- mlp_gate = self . gate_activation ( mlp_gate )
734+ attn_gate = attn_gate + 1.0 # Gate activation: `lambda x: x + 1`.
735+ mlp_gate = mlp_gate + 1.0 # Gate activation: `lambda x: x + 1`.
742736
743737 # attention with adaptive ln
744738 x_norm = self .norm1 (x )
@@ -776,7 +770,6 @@ def __init__(
776770 num_heads : int ,
777771 mlp_ratio : int = 4 ,
778772 activation : type [nn .Module ] = nn .GELU ,
779- gate_activation : Callable = lambda x : (x + 1.0 ),
780773 ):
781774 super ().__init__ ()
782775
@@ -817,7 +810,6 @@ def __init__(
817810 self .norm1 = nn .LayerNorm (hidden_features )
818811 self .norm2 = nn .LayerNorm (hidden_features )
819812 self .norm3 = nn .LayerNorm (hidden_features )
820- self .gate_activation = gate_activation
821813
822814 def forward (
823815 self , x : Tensor , cross_attention_condition : Tensor , time_condition : Tensor
@@ -846,8 +838,8 @@ def forward(
846838 mlp_shift = mlp_shift .unsqueeze (1 )
847839 mlp_gate = mlp_gate .unsqueeze (1 )
848840
849- attn_gate = self . gate_activation ( attn_gate )
850- mlp_gate = self . gate_activation ( mlp_gate )
841+ attn_gate = attn_gate + 1.0 # Gate activation: `lambda x: x + 1`.
842+ mlp_gate = mlp_gate + 1.0 # Gate activation: `lambda x: x + 1`.
851843
852844 # self-attention with adaptive ln
853845 x_norm = self .norm1 (x )
0 commit comments