Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions nanochat/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pyarrow.parquet as pq

from nanochat.common import get_dist_info
from nanochat.dataset import list_parquet_files
from nanochat.dataset import get_parquet_paths

def _document_batches(split, resume_state_dict, tokenizer_batch_size):
"""
Expand All @@ -32,10 +32,9 @@ def _document_batches(split, resume_state_dict, tokenizer_batch_size):
"""
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()

parquet_paths = list_parquet_files()
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]

parquet_paths = get_parquet_paths(split)
assert len(parquet_paths) != 0, "No dataset parquet files found, did you run dataset.py?"
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
resume_epoch = resume_state_dict.get("epoch", 1) if resume_state_dict is not None else 1
Expand Down
51 changes: 42 additions & 9 deletions nanochat/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,58 @@
BASE_URL = "https://huggingface.co/datasets/karpathy/fineweb-edu-100b-shuffle/resolve/main"
MAX_SHARD = 1822 # the last datashard is shard_01822.parquet
index_to_filename = lambda index: f"shard_{index:05d}.parquet" # format of the filenames

# Always use a fixed shard for val so that metrics don't depend on how many shards are downloaded
# Keeping pinned to shard_01822.
VAL_SHARD_INDEX = 1822
assert 0 <= VAL_SHARD_INDEX <= MAX_SHARD, "VAL_SHARD_INDEX must be within [0, MAX_SHARD]"
VAL_SHARD_FILENAME = index_to_filename(VAL_SHARD_INDEX)

base_dir = get_base_dir()
DATA_DIR = os.path.join(base_dir, "base_data")
os.makedirs(DATA_DIR, exist_ok=True)

# -----------------------------------------------------------------------------
# These functions are useful utilities to other modules, can/should be imported

def list_parquet_files(data_dir=None):
""" Looks into a data dir and returns full paths to all parquet files. """
def list_parquet_files(data_dir=None, exclude_filenames=()):
"""Looks into a data dir and returns full paths to parquet files."""
data_dir = DATA_DIR if data_dir is None else data_dir
exclude = set(exclude_filenames)
parquet_files = sorted([
f for f in os.listdir(data_dir)
if f.endswith('.parquet') and not f.endswith('.tmp')
if f.endswith(".parquet") and not f.endswith(".tmp") and f not in exclude
])
parquet_paths = [os.path.join(data_dir, f) for f in parquet_files]
return parquet_paths

def get_parquet_paths(split, data_dir=None):
"""
Returns the parquet paths for a split.

Validation is always a fixed shard so that metrics are stable across partial downloads.
"""
assert split in ["train", "val"], "split must be 'train' or 'val'"
data_dir = DATA_DIR if data_dir is None else data_dir
val_path = os.path.join(data_dir, VAL_SHARD_FILENAME)
if split == "val":
if not os.path.exists(val_path):
raise FileNotFoundError(
f"Validation shard {VAL_SHARD_FILENAME} not found in {data_dir}. "
f"Run: python -m nanochat.dataset -n <N> (downloads the val shard too)."
)
return [val_path]
else:
# train split: list files while excluding val
return list_parquet_files(data_dir, exclude_filenames=(VAL_SHARD_FILENAME,))

def parquets_iter_batched(split, start=0, step=1):
"""
Iterate through the dataset, in batches of underlying row_groups for efficiency.
- split can be "train" or "val". the last parquet file will be val.
- split can be "train" or "val". validation is always a fixed shard.
- start/step are useful for skipping rows in DDP. e.g. start=rank, step=world_size
"""
assert split in ["train", "val"], "split must be 'train' or 'val'"
parquet_paths = list_parquet_files()
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
parquet_paths = get_parquet_paths(split)
for filepath in parquet_paths:
pf = pq.ParquetFile(filepath)
for rg_idx in range(start, pf.num_row_groups, step):
Expand Down Expand Up @@ -111,13 +137,20 @@ def download_single_file(index):

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download FineWeb-Edu 100BT dataset shards")
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of shards to download (default: -1), -1 = disable")
parser.add_argument("-n", "--num-files", type=int, default=-1, help="Number of training shards to download (default: -1 = all).")
parser.add_argument("-w", "--num-workers", type=int, default=4, help="Number of parallel download workers (default: 4)")
args = parser.parse_args()

num = MAX_SHARD + 1 if args.num_files == -1 else min(args.num_files, MAX_SHARD + 1)
ids_to_download = list(range(num))
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")
if num <= VAL_SHARD_INDEX:
ids_to_download.append(VAL_SHARD_INDEX)

if args.num_files != -1 and args.num_files <= MAX_SHARD:
print(f"Downloading {len(ids_to_download)} shards ({num} train + 1 val) using {args.num_workers} workers...")
else:
print(f"Downloading {len(ids_to_download)} shards using {args.num_workers} workers...")

print(f"Target directory: {DATA_DIR}")
print()
with Pool(processes=args.num_workers) as pool:
Expand Down