Skip to content

Commit 53dd0ae

Browse files
authored
Merge pull request #2416 from NNPDF/fix_perhaps_nan_issue2322
Fix NaN accumulation in parallel runs
2 parents 604ea8e + 75b68e7 commit 53dd0ae

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

n3fit/src/n3fit/backends/keras_backend/MetaModel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
def _default_loss(y_true, y_pred): # pylint: disable=unused-argument
5151
"""Default loss to be used when the model is compiled with loss = Null
5252
(for instance if the prediction of the model is already the loss"""
53-
return ops.sum(y_pred)
53+
return ops.nansum(y_pred)
5454

5555

5656
class MetaModel(Model):
@@ -219,7 +219,7 @@ def losses_fun():
219219
# If we only have one dataset the output changes
220220
if len(out_names) == 2:
221221
predictions = [predictions]
222-
total_loss = ops.sum(predictions, axis=0)
222+
total_loss = ops.nansum(predictions, axis=0)
223223
ret = [total_loss] + predictions
224224
return dict(zip(out_names, ret))
225225

n3fit/src/n3fit/backends/keras_backend/operations.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
expand_dims,
4747
leaky_relu,
4848
reshape,
49+
nan_to_num,
4950
repeat,
5051
split,
5152
sum,
@@ -290,3 +291,8 @@ def tensor_splitter(ishape, split_sizes, axis=2, name="splitter"):
290291
lambda x: Kops.split(x, indices, axis=axis), output_shape=oshapes, name=name
291292
)
292293
return sp_layer
294+
295+
296+
def nansum(x, *args, **kwargs):
297+
"""Like np.nansum, returns the sum treating NaN as 0.0 (and inf as a very large number)."""
298+
return sum(nan_to_num(x), *args, **kwargs)

n3fit/src/n3fit/tests/test_backend.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
2-
This module tests the mathematical functions in the n3fit backend
3-
and ensures they do the same thing as their numpy counterparts
2+
This module tests the mathematical functions in the n3fit backend
3+
and ensures they do the same thing as their numpy counterparts
44
"""
55

66
import operator
@@ -153,3 +153,14 @@ def test_tensor_product():
153153

154154
def test_sum():
155155
numpy_check(op.sum, np.sum, mode='single')
156+
157+
158+
def test_nansum():
159+
"""Tests that sums with NaN in the arrays work as expected"""
160+
arr_nonan = np.array([1.0, 2.0])
161+
arr_nan = np.array([2.0, np.nan, 2.0])
162+
arr_axis_nan = np.array([[3.0, np.nan], [2.0, 6.0]])
163+
164+
np.testing.assert_allclose(op.nansum(arr_nonan), np.nansum(arr_nonan))
165+
np.testing.assert_allclose(op.nansum(arr_nan), np.nansum(arr_nan))
166+
np.testing.assert_allclose(op.nansum(arr_axis_nan, axis=0), np.nansum(arr_axis_nan, axis=0))

0 commit comments

Comments
 (0)