@@ -231,17 +231,18 @@ def __init__(self, params):
231231 # We should NOT modify the list itself. This would break the
232232 # hyperparameter algorithms.
233233 use_only_one_activation_type = False
234- if type (self .params .layer_activations ) == str :
235- use_only_one_activation_type = True
236- elif len (self .params .layer_activations ) > self .number_of_layers :
237- printout (
238- "Too many activation layers provided. The last" ,
239- str (
240- len (self .params .layer_activations ) - self .number_of_layers
241- ),
242- "activation function(s) will be ignored." ,
243- min_verbosity = 1 ,
244- )
234+
235+ if not isinstance (self .params .layer_activations , str ):
236+ if len (self .params .layer_activations ) > self .number_of_layers :
237+ printout (
238+ "Too many activation layers provided. The last" ,
239+ str (
240+ len (self .params .layer_activations )
241+ - self .number_of_layers
242+ ),
243+ "activation function(s) will be ignored." ,
244+ min_verbosity = 1 ,
245+ )
245246
246247 # Add the layers.
247248 # As this is a feedforward NN we always add linear layers, and then
@@ -256,22 +257,18 @@ def __init__(self, params):
256257 )
257258 )
258259 try :
259- if use_only_one_activation_type :
260- self .layers .append (
261- self ._activation_mappings [
262- self .params .layer_activations
263- ]()
260+ if isinstance (self .params .layer_activations , str ):
261+ self ._append_activation_function (
262+ self .params .layer_activations
264263 )
265264 else :
266- self .layers .append (
267- self ._activation_mappings [
268- self .params .layer_activations [i ]
269- ]()
265+ self ._append_activation_function (
266+ self .params .layer_activations [i ]
270267 )
271268 except KeyError :
272269 raise Exception ("Invalid activation type seleceted." )
273270 except IndexError :
274- # Layer without activation
271+ # No activation functions left to append at the end.
275272 pass
276273
277274 # Once everything is done, we can move the Network on the target
@@ -297,6 +294,24 @@ def forward(self, inputs):
297294 inputs = layer (inputs )
298295 return inputs
299296
297+ def _append_activation_function (self , activation_function ):
298+ """
299+ Append an activation function to the network.
300+
301+ Parameters
302+ ----------
303+ activation_function : str
304+ Activation function to be appended.
305+ """
306+ if activation_function is None :
307+ pass
308+ elif isinstance (activation_function , str ):
309+ self .layers .append (
310+ self ._activation_mappings [activation_function ]()
311+ )
312+ elif isinstance (activation_function , nn .Module ):
313+ self .layers .append (activation_function )
314+
300315
301316class LSTM (Network ):
302317 """Initialize this network as a LSTM network."""
0 commit comments