Skip to content

Commit 062173f

Browse files
author
kshitij-maths
committed
Improve parallel import robustness and include missing unit tests
1 parent 24fb01e commit 062173f

File tree

4 files changed

+87
-23
lines changed

4 files changed

+87
-23
lines changed

ezyrb/parallel/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@
1717

1818
from .reduction import Reduction
1919
from .pod import POD
20-
from .ae import AE
21-
from .ae_eddl import AE_EDDL
20+
try:
21+
from .ae import AE
22+
from .ae_eddl import AE_EDDL
23+
except ImportError:
24+
AE = None
25+
AE_EDDL = None
2226
from .approximation import Approximation
2327
from .rbf import RBF
2428
from .linear import Linear

ezyrb/parallel/pod.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,17 @@
1111
from numpy.linalg import eigh
1212
import numpy as np
1313

14-
from pycompss.api.task import task
15-
from pycompss.api.parameter import INOUT, IN
14+
try:
15+
from pycompss.api.task import task
16+
from pycompss.api.parameter import INOUT, IN
17+
except ImportError:
18+
# Fallback: Define a 'do-nothing' decorator and dummy constants
19+
def task(*args, **kwargs):
20+
return lambda f: f
21+
22+
INOUT = None
23+
IN = None
24+
1625
from .reduction import Reduction
1726

1827

ezyrb/parallel/reducedordermodel.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import copy
55
import pickle
66
import numpy as np
7-
from scipy.spatial.qhull import Delaunay
7+
from scipy.spatial import Delaunay
88
from sklearn.model_selection import KFold
99
from pycompss.api.api import compss_wait_on
1010

@@ -60,30 +60,32 @@ def fit(self, *args, **kwargs):
6060
:param \*args: additional parameters to pass to the `fit` method.
6161
:param \**kwargs: additional parameters to pass to the `fit` method.
6262
"""
63-
self.reduction.fit(self.database.snapshots.T)
64-
reduced_output = self.reduction.transform(
65-
self.database.snapshots.T, self.scaler_red
66-
)
63+
self.reduction.fit(self.database.snapshots_matrix.T)
64+
reduced_output = self.reduction.transform(self.database.snapshots_matrix.T).T
6765

6866
self.approximation.fit(
69-
self.database.parameters, reduced_output, *args, **kwargs
67+
self.database.parameters_matrix, reduced_output, *args, **kwargs
7068
)
7169

7270
return self
7371

7472
def predict(self, mu):
75-
"""
76-
Calculate predicted solution for given mu
73+
r"""
74+
Predict the solution for given parameters mu.
75+
76+
This method distributes the evaluation tasks across the
77+
available computational nodes using the PyCOMPSs framework.
78+
79+
:param numpy.ndarray mu: The parameters $\mu \in \mathbb{R}^d$ to evaluate.
80+
:return: The predicted snapshot $u(\mu)$.
7781
"""
7882
mu = np.atleast_2d(mu)
79-
if hasattr(self, "database") and self.database.scaler_parameters:
80-
mu = self.database.scaler_parameters.transform(mu)
81-
82-
predicted_red_sol = self.approximation.predict(mu, self.scaler_red)
83+
84+
predicted_red_sol = self.approximation.predict(mu)
8385

8486
predicted_sol = self.reduction.inverse_transform(
85-
predicted_red_sol, self.database
86-
)
87+
predicted_red_sol.T
88+
).T
8789

8890
return predicted_sol
8991

@@ -166,8 +168,8 @@ def kfold_cv_error(self, n_splits, *args, norm=np.linalg.norm, **kwargs):
166168
).fit(*args, **kwargs)
167169

168170
test = self.database[test_index]
169-
predicted_test.append(rom.predict(test.parameters))
170-
original_test.append(test.snapshots)
171+
predicted_test.append(rom.predict(test.parameters_matrix))
172+
original_test.append(test.snapshots_matrix)
171173

172174
predicted_test = compss_wait_on(predicted_test)
173175
for j in range(len(predicted_test)):
@@ -216,8 +218,8 @@ def loo_error(self, *args, norm=np.linalg.norm, **kwargs):
216218
copy.deepcopy(self.approximation),
217219
).fit(*args, **kwargs)
218220

219-
predicted_test.append(rom.predict(test_db.parameters))
220-
original_test.append(test_db.snapshots)
221+
predicted_test.append(rom.predict(test_db.parameters_matrix))
222+
original_test.append(test_db.snapshots_matrix)
221223

222224
predicted_test = compss_wait_on(predicted_test)
223225
for j in range(len(predicted_test)):
@@ -246,7 +248,7 @@ def optimal_mu(self, error=None, k=1):
246248
if error is None:
247249
error = self.loo_error()
248250

249-
mu = self.database.parameters
251+
mu = self.database.parameters_matrix
250252
tria = Delaunay(mu)
251253

252254
error_on_simplex = np.array(

tests/test_parallel.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import unittest
2+
import numpy as np
3+
import warnings
4+
5+
from ezyrb import Database, POD, RBF
6+
from ezyrb.parallel import ReducedOrderModel
7+
8+
class TestParallelROM(unittest.TestCase):
9+
def setUp(self):
10+
self.params = np.array([[1.0], [2.0], [3.0]])
11+
self.snapshots = np.array([[10.0], [20.0], [30.0]])
12+
self.db = Database(self.params, self.snapshots)
13+
14+
def test_initialization(self):
15+
rom = ReducedOrderModel(self.db, POD(), RBF())
16+
self.assertEqual(rom.database.parameters_matrix.shape, (3, 1))
17+
self.assertIsNotNone(rom.reduction)
18+
self.assertIsNotNone(rom.approximation)
19+
20+
def test_fit(self):
21+
rom = ReducedOrderModel(self.db, POD(), RBF())
22+
rom.fit()
23+
self.assertTrue(hasattr(rom.reduction, 'modes'))
24+
25+
def test_predict_scalar(self):
26+
rom = ReducedOrderModel(self.db, POD(), RBF())
27+
rom.fit()
28+
pred = rom.predict([1.5])
29+
self.assertTrue(pred is not None)
30+
31+
def test_predict_db(self):
32+
rom = ReducedOrderModel(self.db, POD(), RBF())
33+
rom.fit()
34+
pred = rom.predict(self.params)
35+
36+
if hasattr(pred, 'snapshots_matrix'):
37+
np.testing.assert_allclose(pred.snapshots_matrix, self.snapshots, rtol=1e-5)
38+
else:
39+
np.testing.assert_allclose(pred, self.snapshots, rtol=1e-5)
40+
41+
def test_wrong_dimensions(self):
42+
with self.assertRaises(Exception):
43+
bad_params = np.array([[1.0], [2.0]]) # Only 2 params
44+
bad_db = Database(bad_params, self.snapshots) # 3 snapshots
45+
rom = ReducedOrderModel(bad_db, POD(), RBF())
46+
rom.fit()
47+
48+
if __name__ == '__main__':
49+
unittest.main()

0 commit comments

Comments
 (0)