Skip to content

Commit e374c9d

Browse files
committed
Credit
1 parent aa1110c commit e374c9d

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Credit https://github.com/sayakpaul
2+
from save_model_utils import get_model, validate_arguments, parse_arguments
3+
from torch_save_utils import load_io_ops, _test_ds_fast_save, test_save
4+
import safetensors.torch
5+
import os
6+
import time
7+
import torch
8+
9+
def test_sft_save(file, buffer, args):
10+
st = time.time()
11+
safetensors.torch.save_file(filename=file, tensors=buffer)
12+
return time.time() - st
13+
14+
def main():
15+
print(
16+
f'Performance test of torch.save() integration of fast model checkpointing.'
17+
)
18+
print(f'torch version = {torch.__version__}')
19+
torch.manual_seed(42)
20+
21+
args = parse_arguments()
22+
if not validate_arguments(args):
23+
quit()
24+
load_io_ops(args)
25+
model, tokenizer, model_name, ckpt_name = get_model(args.model)
26+
27+
inputs = tokenizer("I am good", return_tensors="pt").to("cuda")
28+
29+
if args.half:
30+
model = model.half()
31+
if args.gpu:
32+
model = model.to("cuda")
33+
34+
with torch.no_grad():
35+
model.eval()
36+
pre_logits = model(**inputs).logits
37+
38+
if not args.safetensors:
39+
file = os.path.join(args.folder, f'{ckpt_name}.pt')
40+
else:
41+
file = os.path.join(args.folder, f'{ckpt_name}.safetensors')
42+
if os.path.exists(file):
43+
os.remove(file)
44+
if not args.regular_torch_save and not args.safetensors:
45+
write_sec = _test_ds_fast_save(file, model.state_dict(), args, False)
46+
elif args.regular_torch_save:
47+
write_sec = test_save(file, model.state_dict(), args)
48+
else:
49+
write_sec = test_sft_save(file, model.state_dict(), args)
50+
ckpt_size = os.path.getsize(file)
51+
gb_size = ckpt_size / (1024**3)
52+
gb_per_sec = gb_size / write_sec
53+
print(
54+
f'{gb_size:5.2f} GB, {write_sec:5.2f} secs, {gb_per_sec:5.2f} GB/s'
55+
)
56+
st = time.time()
57+
if args.safetensors:
58+
loaded_sd = safetensors.torch.load_file(file, device="cuda")
59+
else:
60+
loaded_sd = torch.load(file, weights_only=True, map_location="cuda")
61+
load_sec = time.time() - st
62+
print(f"Loaded in {load_sec:5.2f} seconds.")
63+
model.load_state_dict(loaded_sd)
64+
with torch.no_grad():
65+
model.eval()
66+
post_logits = model(**inputs).logits
67+
68+
assert torch.allclose(pre_logits, post_logits, atol=1e-3, rtol=1e-3)
69+
os.remove(file)
70+
71+
72+
if __name__ == "__main__":
73+
main()

0 commit comments

Comments
 (0)