Skip to content

Commit a765807

Browse files
authored
dataset blending tool (#433)
1 parent cefc048 commit a765807

7 files changed

Lines changed: 881 additions & 9 deletions

File tree

fast_llm/data/auto.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,10 @@
1414
GPTFimSampledDatasetConfig,
1515
GPTRandomDatasetConfig,
1616
)
17+
from fast_llm.data.preparator.dataset_discovery.config import DatasetDiscoveryConfig # isort: skip
1718
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig # isort: skip
19+
from fast_llm.data.sample.abstract import NullReaderConfig # isort: skip
20+
from fast_llm.data.sample.language_model import LanguageModelReaderConfig # isort: skip
21+
from fast_llm.data.sample.patch import PatchReaderConfig # isort: skip
22+
from fast_llm.data.sample.range import RangeReaderConfig # isort: skip
23+
from fast_llm.data.sample.token import TokenReaderConfig # isort: skip

fast_llm/data/dataset/memmap.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,21 @@ class MemmapDataset[SampleType: Sample](IndexedDataset[SampleType]):
2323
A memory map dataset, which handles lazy loading of a pre-processed dataset.
2424
"""
2525

26+
@staticmethod
27+
def read_reader_config(path: pathlib.Path | str) -> MemmapIndexDatasetReaderConfig:
28+
"""
29+
Read the MemmapIndexDatasetReaderConfig from a memmap file.
30+
"""
31+
path = pathlib.Path(path) if isinstance(path, str) else path
32+
with path.open("rb") as stream:
33+
# Verify file type.
34+
assert stream.read(len(FILE_HEADER)) == FILE_HEADER
35+
# Go to reader configs.
36+
stream.seek(int.from_bytes(stream.read(8), signed=False))
37+
# Read the reader config.
38+
config_bytes = stream.read(int.from_bytes(stream.read(4), signed=False))
39+
return MemmapIndexDatasetReaderConfig.from_dict(json.loads(config_bytes.decode("utf-8")))
40+
2641
def __init__(
2742
self,
2843
name: str,
@@ -37,15 +52,7 @@ def _init(self, name: str, path: pathlib.Path | str, preprocessing: Preprocessin
3752
self._path = path
3853
self._preprocessing = preprocessing
3954

40-
with self._path.open("rb") as stream:
41-
# Very file type.
42-
assert stream.read(len(FILE_HEADER)) == FILE_HEADER
43-
# Go to reader configs.
44-
stream.seek(int.from_bytes(stream.read(8), signed=False))
45-
# Read the reader config.
46-
reader_config = MemmapIndexDatasetReaderConfig.from_dict(
47-
json.loads(stream.read(int.from_bytes(stream.read(4), signed=False)).decode("utf-8"))
48-
)
55+
reader_config = self.read_reader_config(self._path)
4956

5057
self._memmap = np.memmap(self._path, mode="r")
5158
self._reader = reader_config.get_reader(memoryview(self._memmap), self._preprocessing)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Dataset Discovery
2+
3+
Automatically discover `.fast_llm_dataset` files and generate a blended config with token-proportional weights.
4+
5+
## Quick Start
6+
7+
Using the tools wrapper:
8+
```bash
9+
python tools/discover_datasets.py <directory> -o <output.yaml>
10+
```
11+
12+
Using Fast-LLM CLI with config file:
13+
```yaml
14+
type: prepare_dataset_discovery
15+
directory: /path/to/datasets
16+
output: blended_dataset.yaml
17+
ignore_paths: [test_data, checkpoints] # Optional
18+
```
19+
20+
```bash
21+
python -m fast_llm.cli --config config.yaml
22+
```
23+
24+
## What It Does
25+
26+
1. Scans directory tree for `.fast_llm_dataset` files
27+
2. Reads token counts from binary file headers
28+
3. Generates hierarchical blended config with automatic weights
29+
4. Preserves directory structure
30+
31+
## Example
32+
33+
Input directory structure:
34+
```
35+
datasets/
36+
├── domain_a/
37+
│ ├── shard_0.fast_llm_dataset (1B tokens)
38+
│ └── shard_1.fast_llm_dataset (1B tokens)
39+
└── domain_b/
40+
└── shard_0.fast_llm_dataset (4B tokens)
41+
```
42+
43+
Generated config (`blended.yaml`):
44+
```yaml
45+
type: blended
46+
name: datasets
47+
datasets:
48+
- type: blended
49+
name: domain_a
50+
datasets:
51+
- type: memmap
52+
path: datasets/domain_a/shard_0.fast_llm_dataset
53+
- type: memmap
54+
path: datasets/domain_a/shard_1.fast_llm_dataset
55+
weights: [1.0, 1.0]
56+
- type: memmap
57+
path: datasets/domain_b/shard_0.fast_llm_dataset
58+
weights: [2.0, 4.0] # In billions
59+
```
60+
61+
Use in training:
62+
```yaml
63+
data:
64+
datasets:
65+
training:
66+
type: file
67+
path: blended.yaml
68+
```
69+
70+
## Options
71+
72+
- **directory**: Root directory to scan (required)
73+
- **output**: Output YAML file path (required)
74+
- **ignore_paths**: Paths to exclude, relative or absolute (optional)
75+
76+
## Key Features
77+
78+
- **Token-proportional sampling**: Datasets sampled by token count (larger datasets sampled more)
79+
- **Hierarchical grouping**: Directory structure preserved in config
80+
- **Automatic weights**: Calculated from binary file metadata
81+
- **Error handling**: Skips unreadable files with warnings
82+
83+
## Notes
84+
85+
- Single datasets returned directly (not wrapped)
86+
- Files with 0 tokens skipped with warning
87+
- Empty directories raise error
88+
- Datasets sorted alphabetically
89+
90+
## Testing
91+
92+
```bash
93+
pytest tests/data/test_dataset_discovery.py
94+
```
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from fast_llm.data.preparator.dataset_discovery.config import DatasetDiscoveryConfig
2+
from fast_llm.data.preparator.dataset_discovery.prepare import DatasetDiscoveryPreparator
3+
4+
__all__ = ["DatasetDiscoveryConfig", "DatasetDiscoveryPreparator"]
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pathlib
2+
import typing
3+
4+
from fast_llm.config import Field, FieldHint, config_class
5+
from fast_llm.data.preparator.config import DatasetPreparatorConfig
6+
from fast_llm.engine.config_utils.runnable import RunnableConfig
7+
8+
if typing.TYPE_CHECKING:
9+
from fast_llm.data.preparator.dataset_discovery.prepare import DatasetDiscoveryPreparator
10+
11+
12+
@config_class(dynamic_type={RunnableConfig: "prepare_dataset_discovery", DatasetPreparatorConfig: "dataset_discovery"})
13+
class DatasetDiscoveryConfig(DatasetPreparatorConfig):
14+
"""
15+
Configuration for the dataset discovery preparator.
16+
17+
This preparator recursively discovers .fast_llm_dataset files in a directory
18+
and generates a blended dataset config with weights proportional to token counts.
19+
"""
20+
21+
directory: pathlib.Path = Field(
22+
desc="Directory to search for datasets recursively",
23+
hint=FieldHint.core,
24+
)
25+
output: pathlib.Path = Field(
26+
desc="Output path for the generated config YAML file",
27+
hint=FieldHint.core,
28+
)
29+
ignore_paths: list[pathlib.Path] = Field(
30+
default_factory=list,
31+
desc="List of paths to ignore during dataset discovery (can be absolute or relative to directory)",
32+
hint=FieldHint.optional,
33+
)
34+
35+
def _validate(self) -> None:
36+
super()._validate()
37+
if not self.directory.exists():
38+
raise ValueError(f"Directory does not exist: {self.directory}")
39+
if not self.directory.is_dir():
40+
raise ValueError(f"Path is not a directory: {self.directory}")
41+
42+
@classmethod
43+
def get_dataset_preparator_class(cls) -> type["DatasetDiscoveryPreparator"]:
44+
from fast_llm.data.preparator.dataset_discovery.prepare import DatasetDiscoveryPreparator
45+
46+
return DatasetDiscoveryPreparator

0 commit comments

Comments
 (0)