Skip to content

Commit cf03982

Browse files
committed
Update xarray api to output dump the results into the
1 parent d74c55f commit cf03982

32 files changed

Lines changed: 163 additions & 37 deletions

pynncml/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from pynncml import datasets
33
from pynncml import single_cml_methods as scm
44
from pynncml import multiple_cmls_methods as mcm
5-
from pynncml import single_sml_methods as ssm
6-
from pynncml import metrics
5+
from pynncml import sml_methods as ssm
6+
from pynncml.cml_nn_training import metrics
77
from pynncml import neural_networks
8-
from pynncml import training_helpers
8+
from pynncml import cml_nn_training
99
from pynncml import simulation
1010
from pynncml.plot_common import change_x_axis_time_format, plot_wet_dry_detection_mark
1111

pynncml/apis/xarray_processing/xarray_inference_engine.py

Lines changed: 0 additions & 15 deletions
This file was deleted.
File renamed without changes.

pynncml/cml_methods/apis/xarray_processing/__init__.py

Whitespace-only changes.

pynncml/apis/xarray_processing/wet_dry_methods.py renamed to pynncml/cml_methods/apis/xarray_processing/wet_dry_methods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pynncml.apis.xarray_processing.xarray_inference_engine import XarrayInferenceEngine
1+
from pynncml.cml_methods.apis.xarray_processing.xarray_inference_engine import XarrayInferenceEngine
22
from pynncml.neural_networks import DNNType
33
from pynncml.single_cml_methods.wet_dry import wet_dry_network,statistics_wet_dry
44

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
2+
from torch import nn
3+
import numpy as np
4+
import xarray as xr
5+
6+
from pynncml.cml_methods.base_cml_method import BaseCMLProcessingMethod
7+
from pynncml.cml_methods.results_data_structure import CMLResultsDataStructure
8+
from pynncml.datasets.xarray_processing import xarray2link
9+
from pynncml.multiple_cmls_methods import InferMultipleCMLs
10+
11+
12+
class XarrayInferenceEngine(nn.Module):
13+
def __init__(self,in_cml2rain_method:BaseCMLProcessingMethod,is_recurrent=True,is_attenuation=False, *args, **kwargs):
14+
super().__init__(*args, **kwargs)
15+
self.inference_engine = InferMultipleCMLs(in_cml2rain_method,is_recurrent,is_attenuation)
16+
17+
18+
def forward(self, x_xarray):
19+
link_set=xarray2link(x_xarray)
20+
results_data= self.inference_engine(link_set)
21+
22+
new_var_dims = ('time', 'sublink_id', 'cml_id')
23+
new_var_coords = {
24+
'time': x_xarray.time,
25+
'sublink_id': x_xarray.sublink_id,
26+
'cml_id': x_xarray.cml_id
27+
}
28+
for rname in CMLResultsDataStructure.results_types_list():
29+
x_xarray[rname] = xr.DataArray(
30+
np.full((x_xarray.sizes['time'], x_xarray.sizes['sublink_id'], x_xarray.sizes['cml_id']), np.nan),
31+
coords=new_var_coords,
32+
dims=new_var_dims
33+
)
34+
35+
for i, link in enumerate(link_set):
36+
results=self.inference_engine.cml2rain.convert_output_results(results_data[i])
37+
for rname in CMLResultsDataStructure.results_types_list():
38+
# Create a new DataArray with the correct dimensions and coordinates
39+
new_values = xr.DataArray(
40+
getattr(results,rname),
41+
coords=[x_xarray.time],
42+
dims=['time']
43+
)
44+
# Assign the new values to the dataset at the specified slice.
45+
# This will create a new variable if it doesn't exist.
46+
x_xarray[rname].loc[dict(sublink_id=link.sublink_id, cml_id=link.cml_id)] = new_values
47+
48+
return x_xarray

pynncml/neural_networks/base_neural_network.py renamed to pynncml/cml_methods/base_cml_method.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from torch import nn
22

3+
from pynncml.cml_methods.results_data_structure import CMLResultsDataStructure
34
from pynncml.datasets.alignment import AttenuationType
45

56

@@ -9,3 +10,6 @@ def __init__(self,input_data_type:AttenuationType,input_rate:int,output_rate:int
910
self.input_data_type = input_data_type
1011
self.input_rate = input_rate
1112
self.output_rate = output_rate
13+
14+
def convert_output_results(self,output_tensor)->CMLResultsDataStructure:
15+
raise NotImplemented
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import numpy as np
2+
from dataclasses import dataclass
3+
4+
5+
@dataclass
6+
class CMLResultsDataStructure:
7+
"""
8+
Attenuation data class
9+
:param wet_dry_detection: np.ndarray
10+
:param rain_estimation: np.ndarray
11+
"""
12+
wet_dry_detection: np.ndarray
13+
rain_estimation: np.ndarray
14+
15+
16+
@staticmethod
17+
def results_types_list():
18+
return ["wet_dry_detection", "rain_estimation"]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from pynncml.cml_nn_training.compute_data_normalization import compute_data_normalization

0 commit comments

Comments
 (0)