Skip to content

Commit 30beca2

Browse files
committed
Update code to fix negative att
1 parent cf03982 commit 30beca2

8 files changed

Lines changed: 150 additions & 59 deletions

File tree

examples/example_using_different_dataset.ipynb

Lines changed: 25 additions & 12 deletions
Large diffs are not rendered by default.

examples/rain_estimation_constant.ipynb

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
},
1212
{
1313
"cell_type": "code",
14-
"execution_count": 23,
15-
"outputs": [],
1614
"source": [
1715
"import pynncml as pnc \n",
1816
"from matplotlib import pyplot as plt"
@@ -22,8 +20,23 @@
2220
"pycharm": {
2321
"name": "#%%\n",
2422
"is_executing": false
23+
},
24+
"ExecuteTime": {
25+
"end_time": "2025-09-24T13:09:00.957505Z",
26+
"start_time": "2025-09-24T13:08:59.407782Z"
2527
}
26-
}
28+
},
29+
"outputs": [
30+
{
31+
"name": "stderr",
32+
"output_type": "stream",
33+
"text": [
34+
"/Users/haihab01/envs/research_base/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
35+
" from .autonotebook import tqdm as notebook_tqdm\n"
36+
]
37+
}
38+
],
39+
"execution_count": 1
2740
},
2841
{
2942
"cell_type": "markdown",
@@ -36,8 +49,6 @@
3649
},
3750
{
3851
"cell_type": "code",
39-
"execution_count": 24,
40-
"outputs": [],
4152
"source": [
4253
"open_cml_dataset = pnc.read_open_cml_dataset('../dataset/open_cml.p') # read OpenCML dataset"
4354
],
@@ -46,8 +57,26 @@
4657
"pycharm": {
4758
"name": "#%% \n",
4859
"is_executing": false
60+
},
61+
"ExecuteTime": {
62+
"end_time": "2025-09-24T13:09:01.043316Z",
63+
"start_time": "2025-09-24T13:09:00.971401Z"
4964
}
50-
}
65+
},
66+
"outputs": [
67+
{
68+
"ename": "AttributeError",
69+
"evalue": "module 'pynncml' has no attribute 'read_open_cml_dataset'",
70+
"output_type": "error",
71+
"traceback": [
72+
"\u001B[31m---------------------------------------------------------------------------\u001B[39m",
73+
"\u001B[31mAttributeError\u001B[39m Traceback (most recent call last)",
74+
"\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[2]\u001B[39m\u001B[32m, line 1\u001B[39m\n\u001B[32m----> \u001B[39m\u001B[32m1\u001B[39m open_cml_dataset = \u001B[43mpnc\u001B[49m\u001B[43m.\u001B[49m\u001B[43mread_open_cml_dataset\u001B[49m(\u001B[33m'\u001B[39m\u001B[33m../dataset/open_cml.p\u001B[39m\u001B[33m'\u001B[39m) \u001B[38;5;66;03m# read OpenCML dataset\u001B[39;00m\n",
75+
"\u001B[31mAttributeError\u001B[39m: module 'pynncml' has no attribute 'read_open_cml_dataset'"
76+
]
77+
}
78+
],
79+
"execution_count": 2
5180
},
5281
{
5382
"cell_type": "markdown",
@@ -205,4 +234,4 @@
205234
},
206235
"nbformat": 4,
207236
"nbformat_minor": 0
208-
}
237+
}

examples/rain_estimation_dynamic.ipynb

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
},
1010
{
1111
"cell_type": "code",
12-
"execution_count": 15,
1312
"metadata": {
1413
"collapsed": false,
1514
"jupyter": {
@@ -18,13 +17,27 @@
1817
"pycharm": {
1918
"is_executing": false,
2019
"name": "#%%\n"
20+
},
21+
"ExecuteTime": {
22+
"end_time": "2025-09-24T13:09:15.625122Z",
23+
"start_time": "2025-09-24T13:09:14.128129Z"
2124
}
2225
},
23-
"outputs": [],
2426
"source": [
2527
"import pynncml as pnc \n",
2628
"from matplotlib import pyplot as plt"
27-
]
29+
],
30+
"outputs": [
31+
{
32+
"name": "stderr",
33+
"output_type": "stream",
34+
"text": [
35+
"/Users/haihab01/envs/research_base/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
36+
" from .autonotebook import tqdm as notebook_tqdm\n"
37+
]
38+
}
39+
],
40+
"execution_count": 1
2841
},
2942
{
3043
"cell_type": "markdown",
@@ -35,7 +48,6 @@
3548
},
3649
{
3750
"cell_type": "code",
38-
"execution_count": 16,
3951
"metadata": {
4052
"collapsed": false,
4153
"jupyter": {
@@ -44,12 +56,29 @@
4456
"pycharm": {
4557
"is_executing": false,
4658
"name": "#%% \n"
59+
},
60+
"ExecuteTime": {
61+
"end_time": "2025-09-24T13:09:15.707062Z",
62+
"start_time": "2025-09-24T13:09:15.635194Z"
4763
}
4864
},
49-
"outputs": [],
5065
"source": [
5166
"open_cml_dataset = pnc.read_open_cml_dataset('../dataset/open_cml.p') # read OpenCML dataset"
52-
]
67+
],
68+
"outputs": [
69+
{
70+
"ename": "AttributeError",
71+
"evalue": "module 'pynncml' has no attribute 'read_open_cml_dataset'",
72+
"output_type": "error",
73+
"traceback": [
74+
"\u001B[31m---------------------------------------------------------------------------\u001B[39m",
75+
"\u001B[31mAttributeError\u001B[39m Traceback (most recent call last)",
76+
"\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[2]\u001B[39m\u001B[32m, line 1\u001B[39m\n\u001B[32m----> \u001B[39m\u001B[32m1\u001B[39m open_cml_dataset = \u001B[43mpnc\u001B[49m\u001B[43m.\u001B[49m\u001B[43mread_open_cml_dataset\u001B[49m(\u001B[33m'\u001B[39m\u001B[33m../dataset/open_cml.p\u001B[39m\u001B[33m'\u001B[39m) \u001B[38;5;66;03m# read OpenCML dataset\u001B[39;00m\n",
77+
"\u001B[31mAttributeError\u001B[39m: module 'pynncml' has no attribute 'read_open_cml_dataset'"
78+
]
79+
}
80+
],
81+
"execution_count": 2
5382
},
5483
{
5584
"cell_type": "markdown",

examples/wet_dry_classification.ipynb

Lines changed: 35 additions & 24 deletions
Large diffs are not rendered by default.

pynncml/cml_methods/apis/xarray_processing/xarray_inference_engine.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@
88
from pynncml.datasets.xarray_processing import xarray2link
99
from pynncml.multiple_cmls_methods import InferMultipleCMLs
1010

11+
def create_dataset_with_coords_only(original_ds):
12+
"""
13+
Creates a new, empty xarray Dataset with coordinates from a given dataset.
14+
"""
15+
# Simply create a new Dataset, passing the coords from the original.
16+
# The .coords property contains all the coordinate variables.
17+
new_ds = xr.Dataset(coords=original_ds.coords)
18+
return new_ds
19+
1120

1221
class XarrayInferenceEngine(nn.Module):
1322
def __init__(self,in_cml2rain_method:BaseCMLProcessingMethod,is_recurrent=True,is_attenuation=False, *args, **kwargs):
@@ -18,15 +27,15 @@ def __init__(self,in_cml2rain_method:BaseCMLProcessingMethod,is_recurrent=True,i
1827
def forward(self, x_xarray):
1928
link_set=xarray2link(x_xarray)
2029
results_data= self.inference_engine(link_set)
21-
30+
x_xarray_new=create_dataset_with_coords_only(x_xarray)
2231
new_var_dims = ('time', 'sublink_id', 'cml_id')
2332
new_var_coords = {
2433
'time': x_xarray.time,
2534
'sublink_id': x_xarray.sublink_id,
2635
'cml_id': x_xarray.cml_id
2736
}
2837
for rname in CMLResultsDataStructure.results_types_list():
29-
x_xarray[rname] = xr.DataArray(
38+
x_xarray_new[rname] = xr.DataArray(
3039
np.full((x_xarray.sizes['time'], x_xarray.sizes['sublink_id'], x_xarray.sizes['cml_id']), np.nan),
3140
coords=new_var_coords,
3241
dims=new_var_dims
@@ -43,6 +52,6 @@ def forward(self, x_xarray):
4352
)
4453
# Assign the new values to the dataset at the specified slice.
4554
# This will create a new variable if it doesn't exist.
46-
x_xarray[rname].loc[dict(sublink_id=link.sublink_id, cml_id=link.cml_id)] = new_values
55+
x_xarray_new[rname].loc[dict(sublink_id=link.sublink_id, cml_id=link.cml_id)] = new_values
4756

48-
return x_xarray
57+
return x_xarray_new

pynncml/datasets/link_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def attenuation(self) -> torch.Tensor:
392392
:return attenuation: torch.Tensor
393393
"""
394394
if self.has_tsl():
395-
return torch.tensor(-(self.link_tsl - self.link_rsl)).reshape(1, -1).float()
395+
return torch.tensor((self.link_tsl - self.link_rsl)).reshape(1, -1).float()
396396
else:
397397
return torch.tensor(-self.link_rsl).reshape(1, -1).float()
398398

tests/test_data_structure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_link_with_tsl(self):
7474
self.assertTrue(np.array_equal(l.time(), time.astype('datetime64[s]')))
7575
self.assertEqual(len(l.time()), TestDataStructure.n_samples)
7676
att = l.attenuation().numpy().flatten()
77-
self.assertTrue(np.round(np.sum(att + tsl - rsl) * 100) == 0)
77+
self.assertTrue(np.round(np.sum(att -tsl + rsl) * 100) == 0)
7878
l_min_max = l.create_min_max_link(10)
7979
self.assertTrue(len(l_min_max) == 10)
8080
self.assertEqual(len(l_min_max.attenuation().shape), 3)

tests/test_xarray_processing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22

33
import poligrain as plg
44

5-
from pynncml.apis.xarray_processing.wet_dry_methods import create_wet_dry_std
5+
from pynncml.cml_methods.apis.xarray_processing.wet_dry_methods import create_wet_dry_std
66

77

88
class TestOpenCML(unittest.TestCase):
99

10-
def test_poligrain_to_xarray(self):
10+
def test_poligrain_to_xarray_openmrg(self):
1111
(ds_rad,
1212
ds_cmls,
1313
ds_gauges_municp,
1414
ds_gauge_smhi) = plg.example_data.load_openmrg(data_dir="example_data", subset="8d")
15-
nn_base=create_wet_dry_std()
16-
nn_base(ds_cmls)
15+
nn_base=create_wet_dry_std(threshold=2.3,step=10)
16+
ds_cmls_wet_dry=nn_base(ds_cmls)
1717

0 commit comments

Comments
 (0)