Skip to content

Commit d615216

Browse files
Made no activation function in between hidden layers possible, allow for direct passing of torch modules
1 parent 63894ba commit d615216

File tree

2 files changed

+39
-21
lines changed

2 files changed

+39
-21
lines changed

mala/common/parameters.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,10 +350,13 @@ class ParametersNetwork(ParametersBase):
350350
activation function is used for all layers (including the output layer,
351351
i.e., an output activation is used!). Otherwise, the activation
352352
functions are added layer by layer.
353+
Note that no activation function is applied between input layer and
354+
first hidden layer!
353355
Currently supported activation functions are:
354356
355357
- Sigmoid
356358
- ReLU
359+
- None (no activation used)
357360
- LeakyReLU (default)
358361
359362
loss_function_type : string

mala/network/network.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -231,17 +231,18 @@ def __init__(self, params):
231231
# We should NOT modify the list itself. This would break the
232232
# hyperparameter algorithms.
233233
use_only_one_activation_type = False
234-
if type(self.params.layer_activations) == str:
235-
use_only_one_activation_type = True
236-
elif len(self.params.layer_activations) > self.number_of_layers:
237-
printout(
238-
"Too many activation layers provided. The last",
239-
str(
240-
len(self.params.layer_activations) - self.number_of_layers
241-
),
242-
"activation function(s) will be ignored.",
243-
min_verbosity=1,
244-
)
234+
235+
if not isinstance(self.params.layer_activations, str):
236+
if len(self.params.layer_activations) > self.number_of_layers:
237+
printout(
238+
"Too many activation layers provided. The last",
239+
str(
240+
len(self.params.layer_activations)
241+
- self.number_of_layers
242+
),
243+
"activation function(s) will be ignored.",
244+
min_verbosity=1,
245+
)
245246

246247
# Add the layers.
247248
# As this is a feedforward NN we always add linear layers, and then
@@ -256,22 +257,18 @@ def __init__(self, params):
256257
)
257258
)
258259
try:
259-
if use_only_one_activation_type:
260-
self.layers.append(
261-
self._activation_mappings[
262-
self.params.layer_activations
263-
]()
260+
if isinstance(self.params.layer_activations, str):
261+
self._append_activation_function(
262+
self.params.layer_activations
264263
)
265264
else:
266-
self.layers.append(
267-
self._activation_mappings[
268-
self.params.layer_activations[i]
269-
]()
265+
self._append_activation_function(
266+
self.params.layer_activations[i]
270267
)
271268
except KeyError:
272269
raise Exception("Invalid activation type seleceted.")
273270
except IndexError:
274-
# Layer without activation
271+
# No activation functions left to append at the end.
275272
pass
276273

277274
# Once everything is done, we can move the Network on the target
@@ -297,6 +294,24 @@ def forward(self, inputs):
297294
inputs = layer(inputs)
298295
return inputs
299296

297+
def _append_activation_function(self, activation_function):
298+
"""
299+
Append an activation function to the network.
300+
301+
Parameters
302+
----------
303+
activation_function : str
304+
Activation function to be appended.
305+
"""
306+
if activation_function is None:
307+
pass
308+
elif isinstance(activation_function, str):
309+
self.layers.append(
310+
self._activation_mappings[activation_function]()
311+
)
312+
elif isinstance(activation_function, nn.Module):
313+
self.layers.append(activation_function)
314+
300315

301316
class LSTM(Network):
302317
"""Initialize this network as a LSTM network."""

0 commit comments

Comments
 (0)