Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,9 @@ __pycache__/
*.pyc
input.txt
env/
venv/
venv/
lib/models/
lib/encoder.json
lib/encoder.py
lib/merges.txt
lib/vocab.txt
4 changes: 2 additions & 2 deletions bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
A much shorter version of train.py for benchmarking
"""
import os
from contextlib import nullcontext
from lib.get_autocast import get_autocast_context
import numpy as np
import time
import torch
Expand All @@ -27,7 +27,7 @@
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
ctx = get_autocast_context(device_type, ptdtype)

# data loading init
if real_data:
Expand Down
6 changes: 4 additions & 2 deletions data/openwebtext/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import os
from tqdm import tqdm
import numpy as np
import tiktoken
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
from lib.tokenizer_utils import get_tokenizer
from datasets import load_dataset # huggingface datasets

# number of workers in .map() call
Expand All @@ -16,7 +18,7 @@
# it is better than 1 usually though
num_proc_load_dataset = num_proc

enc = tiktoken.get_encoding("gpt2")
enc = get_tokenizer("gpt2")

if __name__ == '__main__':
# takes 54GB in huggingface .cache dir, about 8M documents (8,013,769)
Expand Down
6 changes: 4 additions & 2 deletions data/shakespeare/prepare.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import requests
import tiktoken
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
from lib.tokenizer_utils import get_tokenizer
import numpy as np

# download the tiny shakespeare dataset
Expand All @@ -17,7 +19,7 @@
val_data = data[int(n*0.9):]

# encode with tiktoken gpt2 bpe
enc = tiktoken.get_encoding("gpt2")
enc = get_tokenizer("gpt2")
train_ids = enc.encode_ordinary(train_data)
val_ids = enc.encode_ordinary(val_data)
print(f"train has {len(train_ids):,} tokens")
Expand Down
24 changes: 24 additions & 0 deletions lib/get_autocast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
from lib.get_context import nullcontext

try:
from torch.cuda.amp import autocast
autocast_available = True
except ImportError:
autocast_available = False

def get_autocast_context(device_type, dtype=None):
if device_type == 'cpu':
return nullcontext()
elif autocast_available:
try:
# Try new unified API first
return torch.amp.autocast(device_type=device_type, dtype=dtype)
except (AttributeError, TypeError):
# Fall back to old CUDA-only API
if device_type == 'cuda':
return autocast() # old API, no parameters
else:
return nullcontext()
else:
return nullcontext()
8 changes: 8 additions & 0 deletions lib/get_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
try:
from contextlib import nullcontext
except ImportError:
class nullcontext:
def __enter__(self):
return None
def __exit__(self, *args):
pass
10 changes: 10 additions & 0 deletions lib/get_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
try:
from torch.nn.parallel import DistributedDataParallel as DDP
except ImportError:
DDP = None

try:
from torch.distributed import init_process_group, destroy_process_group
except ImportError:
init_process_group = None
destroy_process_group = None
56 changes: 56 additions & 0 deletions lib/tokenizer_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Sadly some old embedded GPUs like Jetson Nano can't get tiktokenizer even if you try
# to compile it.
try:
import tiktoken
tiktoken_available = True
except ImportError:
print("tiktoken not available, using fallback tokenizer")
tiktoken = None
tiktoken_available = False

def get_tokenizer(model_name="gpt2"):
if tiktoken_available:
return tiktoken.get_encoding("gpt2") if model_name == "gpt2" else tiktoken.encoding_for_model(model_name)
else:
# GPT-2 BPE fallback using encoder.py + vocab.json and merges.txt
import os
import json
import re
import requests

base_dir = os.path.dirname(__file__)
encoder_path = os.path.join(base_dir, 'encoder.py')
models_dir = os.path.join(base_dir, 'models')
model_path = os.path.join(models_dir, model_name)
vocab_path = os.path.join(model_path, 'encoder.json')
merges_path = os.path.join(model_path, 'vocab.bpe')

os.makedirs(model_path, exist_ok=True)

if not os.path.exists(encoder_path):
data_url = 'https://raw.githubusercontent.com/openai/gpt-2/master/src/encoder.py'
with open(encoder_path, 'w', encoding='utf-8') as f:
f.write(requests.get(data_url).text)

if not os.path.exists(vocab_path):
data_url = 'https://huggingface.co/gpt2/resolve/main/vocab.json'
with open(vocab_path, 'w', encoding='utf-8') as f:
f.write(requests.get(data_url).text)

if not os.path.exists(merges_path):
data_url = 'https://huggingface.co/gpt2/resolve/main/merges.txt'
with open(merges_path, 'w', encoding='utf-8') as f:
f.write(requests.get(data_url).text)

# Add current directory to path before importing encoder.py dynamically
import sys
if base_dir not in sys.path:
sys.path.insert(0, base_dir)

from encoder import get_encoder

tokenizer = get_encoder(model_name=model_name, models_dir=models_dir)
# Patch for API compatibility with tiktoken
tokenizer.encode_ordinary = tokenizer.encode

return tokenizer
10 changes: 6 additions & 4 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
"""
import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
from lib.tokenizer_utils import get_tokenizer
from lib.get_autocast import get_autocast_context
from model import GPTConfig, GPT

# -----------------------------------------------------------------------------
Expand All @@ -29,7 +31,7 @@
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
ctx = get_autocast_context(device_type=device_type, dtype=ptdtype)

# model
if init_from == 'resume':
Expand Down Expand Up @@ -69,7 +71,7 @@
else:
# ok let's assume gpt-2 encodings by default
print("No meta.pkl found, assuming GPT-2 encodings...")
enc = tiktoken.get_encoding("gpt2")
enc = get_tokenizer("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)

Expand Down
8 changes: 3 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@
import time
import math
import pickle
from contextlib import nullcontext

import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from lib.get_ddp import DDP, init_process_group, destroy_process_group
from lib.get_autocast import get_autocast_context

from model import GPTConfig, GPT

Expand Down Expand Up @@ -109,7 +107,7 @@
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
ctx = get_autocast_context(device_type, ptdtype)

# poor man's data loader
data_dir = os.path.join('data', dataset)
Expand Down