Skip to content

Commit dbd2878

Browse files
jhnwu3John Wu
andauthored
make MLP compatible with the new pyhealth 2.0 loaders (#543)
Co-authored-by: John Wu <[email protected]>
1 parent 7c75907 commit dbd2878

File tree

6 files changed

+393
-169
lines changed

6 files changed

+393
-169
lines changed

pyhealth/models/embedding.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
import torch.nn as nn
55

66
from ..datasets import SampleDataset
7-
from ..processors import SequenceProcessor, TimeseriesProcessor
7+
from ..processors import SequenceProcessor, TimeseriesProcessor, TensorProcessor
88
from .base_model import BaseModel
99

1010

1111
class EmbeddingModel(BaseModel):
1212
"""
1313
EmbeddingModel is responsible for creating embedding layers for different types of input data.
14-
14+
1515
Attributes:
1616
dataset (SampleDataset): The dataset containing input processors.
1717
embedding_layers (nn.ModuleDict): A dictionary of embedding layers for each input field.
@@ -33,13 +33,27 @@ def __init__(self, dataset: SampleDataset, embedding_dim: int = 128):
3333
self.embedding_layers[field_name] = nn.Embedding(
3434
num_embeddings=vocab_size,
3535
embedding_dim=embedding_dim,
36-
padding_idx=0
36+
padding_idx=0,
3737
)
3838
elif isinstance(processor, TimeseriesProcessor):
3939
self.embedding_layers[field_name] = nn.Linear(
40-
in_features=processor.size,
41-
out_features=embedding_dim
40+
in_features=processor.size, out_features=embedding_dim
4241
)
42+
elif isinstance(processor, TensorProcessor):
43+
# For tensor processor, we need to determine the input size
44+
# from the first sample in the dataset
45+
sample_tensor = None
46+
for sample in dataset.samples:
47+
if field_name in sample:
48+
sample_tensor = processor.process(sample[field_name])
49+
break
50+
if sample_tensor is not None:
51+
input_size = (
52+
sample_tensor.shape[-1] if sample_tensor.dim() > 0 else 1
53+
)
54+
self.embedding_layers[field_name] = nn.Linear(
55+
in_features=input_size, out_features=embedding_dim
56+
)
4357

4458
def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
4559
"""

0 commit comments

Comments
 (0)