1- import abc
21import enum
3- import pathlib
4- import typing
52
6- from fast_llm .config import Config , Field , FieldHint , check_field , config_class , skip_valid_if_none
7- from fast_llm .engine .distributed .config import PhaseType
8- from fast_llm .engine .schedule .config import BatchConfig
3+ from fast_llm .config import Config , Field , FieldHint , check_field , config_class
94from fast_llm .utils import Assert
105
11- if typing .TYPE_CHECKING :
12- from fast_llm .engine .distributed .distributed import Distributed
13-
14-
15- class DatasetSource (str , enum .Enum ):
16- """
17- An enum for the different ways to load datasets.
18- TODO: Reduce the diversity?
19- TODO: Is this specific to GPT data?
20- """
21-
22- list = "list"
23- file = "file"
24- sample = "sample"
25- random = "random"
26-
276
287class MultiprocessingContext (str , enum .Enum ):
298 # Fast but risk of segfaults due to interactions with triton
@@ -42,63 +21,6 @@ def _validate_path(value):
4221 return [value ] if isinstance (value , str ) else value
4322
4423
45- FIM_PREFIX = "<fim_prefix>"
46- FIM_MIDDLE = "<fim_middle>"
47- FIM_PAD = "<fim_pad>"
48- FIM_SUFFIX = "<fim_suffix>"
49-
50-
51- @config_class ()
52- class FimConfig (Config ):
53- """
54- Configuration for FIM.
55- """
56-
57- rate : float = Field (
58- default = 0.0 ,
59- desc = "FIM rate for each sample." ,
60- hint = FieldHint .core ,
61- valid = check_field (Assert .in_range_incl , 0 , 1 ),
62- )
63- max_middle_len : int | None = Field (
64- default = None ,
65- desc = "Maximum length of the middle segment in FIM." ,
66- hint = FieldHint .feature ,
67- valid = skip_valid_if_none (check_field (Assert .gt , 0 )),
68- )
69- split_sample : str | None = Field (
70- default = None ,
71- desc = "Split samples on this token and permute each fragment separately." ,
72- hint = FieldHint .feature ,
73- )
74- fragment_rate : float = Field (
75- default = 0.0 ,
76- desc = "FIM rate for each fragment when using fim_split_sample." ,
77- hint = FieldHint .feature ,
78- valid = check_field (Assert .in_range_incl , 0 , 1 ),
79- )
80- ignore_prefix : str | None = Field (
81- default = None ,
82- desc = "Do not apply FIM to fragments that start with this prefix." ,
83- hint = FieldHint .feature ,
84- )
85- spm_rate : float = Field (
86- default = 0.5 ,
87- desc = "TODO." ,
88- hint = FieldHint .feature ,
89- valid = check_field (Assert .in_range_incl , 0 , 1 ),
90- )
91- truncate_or_pad : bool = Field (
92- default = False ,
93- desc = "TODO." ,
94- hint = FieldHint .feature ,
95- )
96-
97- def _validate (self ):
98- super ()._validate ()
99- Assert .in_range_incl (self .rate , 0 , 1 )
100-
101-
10224TokenizerFromFile = "TokenizerFromFile"
10325
10426
@@ -120,122 +42,3 @@ class TokenizerConfig(Config):
12042 desc = "Path to the tokenizer file." ,
12143 hint = FieldHint .core ,
12244 )
123-
124-
125- @config_class
126- class SamplingConfig (Config ):
127- num_samples : int = Field (default = 1 , desc = "Number of samples to generate." )
128- seed : int = Field (default = 0 , desc = "Random seed." )
129- cache_directory : pathlib .Path | None = Field (default = None , desc = "Path to the sampling cache directory." )
130- verbose : bool = Field (default = True , desc = "Log sampling progress." )
131-
132-
133- @config_class ()
134- class DataConfig (Config ):
135- _abstract = True
136- _sampling_config_class : typing .ClassVar [type [SamplingConfig ]]
137-
138-
139- class Data (abc .ABC ):
140- # TODO: Improve interface
141- @abc .abstractmethod
142- def setup (self , distributed : "Distributed" , samples_per_phase : dict [PhaseType , int ]):
143- pass
144-
145- @abc .abstractmethod
146- def get_iterator (
147- self ,
148- batch_config : BatchConfig ,
149- phase : PhaseType ,
150- * ,
151- consumed_samples : int ,
152- num_workers : int ,
153- prefetch_factor : int | None = None ,
154- ):
155- pass
156-
157-
158- class Dataset (abc .ABC ):
159- """
160- A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature.
161- """
162-
163- @property
164- @abc .abstractmethod
165- def name (self ):
166- """
167- A name for the dataset to facilitate identification and debugging.
168- """
169-
170- @abc .abstractmethod
171- def as_split (self , default_phase : PhaseType = PhaseType .training ):
172- pass
173-
174-
175- class SampledDataset (Dataset ):
176- """
177- A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training.
178- (See the `Sampler` class below.)
179- """
180-
181- @abc .abstractmethod
182- def __getitem__ (self , index : int ):
183- pass
184-
185- @abc .abstractmethod
186- def __len__ (self ):
187- pass
188-
189- def as_split (self , default_phase : PhaseType = PhaseType .training ):
190- return SplitDataset (self .name , {default_phase : self })
191-
192-
193- class SamplableDataset (Dataset ):
194- # TODO: Move to dataset config?
195- _data_config_class : typing .ClassVar [type [DataConfig ]]
196-
197- def sample (self , config : SamplingConfig , data : Data ) -> SampledDataset :
198- pass
199-
200- def as_split (self , default_phase : PhaseType = PhaseType .training ):
201- return SplitDataset (self .name , {default_phase : self })
202-
203-
204- _SplittableType = typing .TypeVar ("_SplittableType" )
205- _DatasetType = typing .TypeVar ("_DatasetType" , bound = Dataset )
206- _SampledDatasetType = typing .TypeVar ("_SampledDatasetType" , bound = SampledDataset )
207- _SamplableDatasetType = typing .TypeVar ("_SamplableDatasetType" , bound = SamplableDataset )
208-
209-
210- class PhaseSplits (dict [PhaseType , _SplittableType ], typing .Generic [_SplittableType ]):
211- pass
212-
213-
214- class SplitDataset (Dataset , PhaseSplits [_DatasetType ], typing .Generic [_DatasetType ]):
215- def __init__ (self , name : str , datasets : dict [PhaseType , _DatasetType ]):
216- super ().__init__ (datasets )
217- self ._name = name
218-
219- def as_split (self , default_phase : PhaseType = PhaseType .training ):
220- return self
221-
222- @property
223- def name (self ):
224- return self ._name
225-
226-
227- class SampledSplitDataset (SplitDataset [_SampledDatasetType ], typing .Generic [_SampledDatasetType ]):
228- pass
229-
230-
231- class SamplableSplitDataset (SplitDataset [_SamplableDatasetType ], typing .Generic [_SamplableDatasetType ]):
232- def sample (self , sampling_configs : PhaseSplits [SamplingConfig ], data : Data ):
233- return SampledSplitDataset (
234- f"{ self .name } _sampled" ,
235- {phase : self [phase ].sample (sampling_config , data ) for phase , sampling_config in sampling_configs .items ()},
236- )
237-
238-
239- class CopySplitDataset (SamplableSplitDataset ):
240- def __init__ (self , name : str , dataset : _SplittableType , phases : list [PhaseType ]):
241- super ().__init__ (name , {phase : dataset for phase in phases })
0 commit comments