Skip to content

Commit 0c7adea

Browse files
authored
Refactor data (#103)
1 parent 8b46289 commit 0c7adea

File tree

30 files changed

+257
-251
lines changed

30 files changed

+257
-251
lines changed

fast_llm/data/config.py

Lines changed: 1 addition & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,8 @@
1-
import abc
21
import enum
3-
import pathlib
4-
import typing
52

6-
from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none
7-
from fast_llm.engine.distributed.config import PhaseType
8-
from fast_llm.engine.schedule.config import BatchConfig
3+
from fast_llm.config import Config, Field, FieldHint, check_field, config_class
94
from fast_llm.utils import Assert
105

11-
if typing.TYPE_CHECKING:
12-
from fast_llm.engine.distributed.distributed import Distributed
13-
14-
15-
class DatasetSource(str, enum.Enum):
16-
"""
17-
An enum for the different ways to load datasets.
18-
TODO: Reduce the diversity?
19-
TODO: Is this specific to GPT data?
20-
"""
21-
22-
list = "list"
23-
file = "file"
24-
sample = "sample"
25-
random = "random"
26-
276

287
class MultiprocessingContext(str, enum.Enum):
298
# Fast but risk of segfaults due to interactions with triton
@@ -42,63 +21,6 @@ def _validate_path(value):
4221
return [value] if isinstance(value, str) else value
4322

4423

45-
FIM_PREFIX = "<fim_prefix>"
46-
FIM_MIDDLE = "<fim_middle>"
47-
FIM_PAD = "<fim_pad>"
48-
FIM_SUFFIX = "<fim_suffix>"
49-
50-
51-
@config_class()
52-
class FimConfig(Config):
53-
"""
54-
Configuration for FIM.
55-
"""
56-
57-
rate: float = Field(
58-
default=0.0,
59-
desc="FIM rate for each sample.",
60-
hint=FieldHint.core,
61-
valid=check_field(Assert.in_range_incl, 0, 1),
62-
)
63-
max_middle_len: int | None = Field(
64-
default=None,
65-
desc="Maximum length of the middle segment in FIM.",
66-
hint=FieldHint.feature,
67-
valid=skip_valid_if_none(check_field(Assert.gt, 0)),
68-
)
69-
split_sample: str | None = Field(
70-
default=None,
71-
desc="Split samples on this token and permute each fragment separately.",
72-
hint=FieldHint.feature,
73-
)
74-
fragment_rate: float = Field(
75-
default=0.0,
76-
desc="FIM rate for each fragment when using fim_split_sample.",
77-
hint=FieldHint.feature,
78-
valid=check_field(Assert.in_range_incl, 0, 1),
79-
)
80-
ignore_prefix: str | None = Field(
81-
default=None,
82-
desc="Do not apply FIM to fragments that start with this prefix.",
83-
hint=FieldHint.feature,
84-
)
85-
spm_rate: float = Field(
86-
default=0.5,
87-
desc="TODO.",
88-
hint=FieldHint.feature,
89-
valid=check_field(Assert.in_range_incl, 0, 1),
90-
)
91-
truncate_or_pad: bool = Field(
92-
default=False,
93-
desc="TODO.",
94-
hint=FieldHint.feature,
95-
)
96-
97-
def _validate(self):
98-
super()._validate()
99-
Assert.in_range_incl(self.rate, 0, 1)
100-
101-
10224
TokenizerFromFile = "TokenizerFromFile"
10325

10426

@@ -120,122 +42,3 @@ class TokenizerConfig(Config):
12042
desc="Path to the tokenizer file.",
12143
hint=FieldHint.core,
12244
)
123-
124-
125-
@config_class
126-
class SamplingConfig(Config):
127-
num_samples: int = Field(default=1, desc="Number of samples to generate.")
128-
seed: int = Field(default=0, desc="Random seed.")
129-
cache_directory: pathlib.Path | None = Field(default=None, desc="Path to the sampling cache directory.")
130-
verbose: bool = Field(default=True, desc="Log sampling progress.")
131-
132-
133-
@config_class()
134-
class DataConfig(Config):
135-
_abstract = True
136-
_sampling_config_class: typing.ClassVar[type[SamplingConfig]]
137-
138-
139-
class Data(abc.ABC):
140-
# TODO: Improve interface
141-
@abc.abstractmethod
142-
def setup(self, distributed: "Distributed", samples_per_phase: dict[PhaseType, int]):
143-
pass
144-
145-
@abc.abstractmethod
146-
def get_iterator(
147-
self,
148-
batch_config: BatchConfig,
149-
phase: PhaseType,
150-
*,
151-
consumed_samples: int,
152-
num_workers: int,
153-
prefetch_factor: int | None = None,
154-
):
155-
pass
156-
157-
158-
class Dataset(abc.ABC):
159-
"""
160-
A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature.
161-
"""
162-
163-
@property
164-
@abc.abstractmethod
165-
def name(self):
166-
"""
167-
A name for the dataset to facilitate identification and debugging.
168-
"""
169-
170-
@abc.abstractmethod
171-
def as_split(self, default_phase: PhaseType = PhaseType.training):
172-
pass
173-
174-
175-
class SampledDataset(Dataset):
176-
"""
177-
A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training.
178-
(See the `Sampler` class below.)
179-
"""
180-
181-
@abc.abstractmethod
182-
def __getitem__(self, index: int):
183-
pass
184-
185-
@abc.abstractmethod
186-
def __len__(self):
187-
pass
188-
189-
def as_split(self, default_phase: PhaseType = PhaseType.training):
190-
return SplitDataset(self.name, {default_phase: self})
191-
192-
193-
class SamplableDataset(Dataset):
194-
# TODO: Move to dataset config?
195-
_data_config_class: typing.ClassVar[type[DataConfig]]
196-
197-
def sample(self, config: SamplingConfig, data: Data) -> SampledDataset:
198-
pass
199-
200-
def as_split(self, default_phase: PhaseType = PhaseType.training):
201-
return SplitDataset(self.name, {default_phase: self})
202-
203-
204-
_SplittableType = typing.TypeVar("_SplittableType")
205-
_DatasetType = typing.TypeVar("_DatasetType", bound=Dataset)
206-
_SampledDatasetType = typing.TypeVar("_SampledDatasetType", bound=SampledDataset)
207-
_SamplableDatasetType = typing.TypeVar("_SamplableDatasetType", bound=SamplableDataset)
208-
209-
210-
class PhaseSplits(dict[PhaseType, _SplittableType], typing.Generic[_SplittableType]):
211-
pass
212-
213-
214-
class SplitDataset(Dataset, PhaseSplits[_DatasetType], typing.Generic[_DatasetType]):
215-
def __init__(self, name: str, datasets: dict[PhaseType, _DatasetType]):
216-
super().__init__(datasets)
217-
self._name = name
218-
219-
def as_split(self, default_phase: PhaseType = PhaseType.training):
220-
return self
221-
222-
@property
223-
def name(self):
224-
return self._name
225-
226-
227-
class SampledSplitDataset(SplitDataset[_SampledDatasetType], typing.Generic[_SampledDatasetType]):
228-
pass
229-
230-
231-
class SamplableSplitDataset(SplitDataset[_SamplableDatasetType], typing.Generic[_SamplableDatasetType]):
232-
def sample(self, sampling_configs: PhaseSplits[SamplingConfig], data: Data):
233-
return SampledSplitDataset(
234-
f"{self.name}_sampled",
235-
{phase: self[phase].sample(sampling_config, data) for phase, sampling_config in sampling_configs.items()},
236-
)
237-
238-
239-
class CopySplitDataset(SamplableSplitDataset):
240-
def __init__(self, name: str, dataset: _SplittableType, phases: list[PhaseType]):
241-
super().__init__(name, {phase: dataset for phase in phases})

fast_llm/data/data/abstract.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import abc
2+
3+
from fast_llm.engine.distributed.config import PhaseType
4+
from fast_llm.engine.distributed.distributed import Distributed
5+
from fast_llm.engine.schedule.config import BatchConfig
6+
7+
8+
class Data(abc.ABC):
9+
# TODO: Improve interface
10+
@abc.abstractmethod
11+
def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int]):
12+
pass
13+
14+
@abc.abstractmethod
15+
def get_iterator(
16+
self,
17+
batch_config: BatchConfig,
18+
phase: PhaseType,
19+
*,
20+
consumed_samples: int,
21+
num_workers: int,
22+
prefetch_factor: int | None = None,
23+
):
24+
pass

fast_llm/data/data/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pathlib
2+
import typing
3+
4+
from fast_llm.config import Config, Field, config_class
5+
6+
7+
@config_class
8+
class SamplingConfig(Config):
9+
num_samples: int = Field(default=1, desc="Number of samples to generate.")
10+
seed: int = Field(default=0, desc="Random seed.")
11+
cache_directory: pathlib.Path | None = Field(default=None, desc="Path to the sampling cache directory.")
12+
verbose: bool = Field(default=True, desc="Log sampling progress.")
13+
14+
15+
@config_class()
16+
class DataConfig(Config):
17+
_abstract = True
18+
_sampling_config_class: typing.ClassVar[type[SamplingConfig]]

fast_llm/data/data/gpt/__init__.py

Whitespace-only changes.
Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
1+
import enum
2+
13
from fast_llm.config import Field, FieldHint, check_field, config_class
2-
from fast_llm.data.config import (
3-
DataConfig,
4-
DatasetSource,
5-
FimConfig,
6-
MultiprocessingContext,
7-
SamplingConfig,
8-
TokenizerConfig,
9-
_validate_path,
10-
_validate_split,
11-
)
4+
from fast_llm.data.config import MultiprocessingContext, TokenizerConfig, _validate_path, _validate_split
5+
from fast_llm.data.data.config import DataConfig, SamplingConfig
6+
from fast_llm.data.dataset.gpt.fim.config import FimConfig
127
from fast_llm.utils import Assert
138

149

10+
class DatasetSource(str, enum.Enum):
11+
"""
12+
An enum for the different ways to load datasets.
13+
TODO: Reduce the diversity?
14+
TODO: Is this specific to GPT data?
15+
"""
16+
17+
list = "list"
18+
file = "file"
19+
sample = "sample"
20+
random = "random"
21+
22+
1523
@config_class()
1624
class GPTDataConfig(DataConfig):
1725
"""
Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
import torch
99
import torch.utils.data
1010

11-
from fast_llm.data.blended import BlendedDataset
12-
from fast_llm.data.config import CopySplitDataset, Data, DatasetSource, PhaseSplits, SampledSplitDataset
13-
from fast_llm.data.gpt.config import GPTDataConfig, GPTSamplingConfig
14-
from fast_llm.data.gpt.dummy import DummyGPTDataset
15-
from fast_llm.data.gpt.memmap import GPTMemmapDataset
16-
from fast_llm.data.gpt.slice import GPTDatasetSlice
11+
from fast_llm.data.data.abstract import Data
12+
from fast_llm.data.data.gpt.config import DatasetSource, GPTDataConfig, GPTSamplingConfig
13+
from fast_llm.data.dataset.abstract import CopySplitDataset, PhaseSplits, SampledSplitDataset
14+
from fast_llm.data.dataset.blended import BlendedDataset
15+
from fast_llm.data.dataset.gpt.dummy import DummyGPTDataset
16+
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
17+
from fast_llm.data.dataset.gpt.slice import GPTDatasetSlice
1718
from fast_llm.data.iterator import SampledDatasetIterator
1819
from fast_llm.data.tokenizer import Tokenizer
1920
from fast_llm.engine.config_utils.run import get_run, log_main_rank

fast_llm/data/dataset/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)