44import torch .nn as nn
55
66from ..datasets import SampleDataset
7- from ..processors import SequenceProcessor , TimeseriesProcessor
7+ from ..processors import SequenceProcessor , TimeseriesProcessor , TensorProcessor
88from .base_model import BaseModel
99
1010
1111class EmbeddingModel (BaseModel ):
1212 """
1313 EmbeddingModel is responsible for creating embedding layers for different types of input data.
14-
14+
1515 Attributes:
1616 dataset (SampleDataset): The dataset containing input processors.
1717 embedding_layers (nn.ModuleDict): A dictionary of embedding layers for each input field.
@@ -33,13 +33,27 @@ def __init__(self, dataset: SampleDataset, embedding_dim: int = 128):
3333 self .embedding_layers [field_name ] = nn .Embedding (
3434 num_embeddings = vocab_size ,
3535 embedding_dim = embedding_dim ,
36- padding_idx = 0
36+ padding_idx = 0 ,
3737 )
3838 elif isinstance (processor , TimeseriesProcessor ):
3939 self .embedding_layers [field_name ] = nn .Linear (
40- in_features = processor .size ,
41- out_features = embedding_dim
40+ in_features = processor .size , out_features = embedding_dim
4241 )
42+ elif isinstance (processor , TensorProcessor ):
43+ # For tensor processor, we need to determine the input size
44+ # from the first sample in the dataset
45+ sample_tensor = None
46+ for sample in dataset .samples :
47+ if field_name in sample :
48+ sample_tensor = processor .process (sample [field_name ])
49+ break
50+ if sample_tensor is not None :
51+ input_size = (
52+ sample_tensor .shape [- 1 ] if sample_tensor .dim () > 0 else 1
53+ )
54+ self .embedding_layers [field_name ] = nn .Linear (
55+ in_features = input_size , out_features = embedding_dim
56+ )
4357
4458 def forward (self , inputs : Dict [str , torch .Tensor ]) -> Dict [str , torch .Tensor ]:
4559 """
0 commit comments