1+ # Credit https://github.com/sayakpaul
2+ from save_model_utils import get_model , validate_arguments , parse_arguments
3+ from torch_save_utils import load_io_ops , _test_ds_fast_save , test_save
4+ import safetensors .torch
5+ import os
6+ import time
7+ import torch
8+
9+ def test_sft_save (file , buffer , args ):
10+ st = time .time ()
11+ safetensors .torch .save_file (filename = file , tensors = buffer )
12+ return time .time () - st
13+
14+ def main ():
15+ print (
16+ f'Performance test of torch.save() integration of fast model checkpointing.'
17+ )
18+ print (f'torch version = { torch .__version__ } ' )
19+ torch .manual_seed (42 )
20+
21+ args = parse_arguments ()
22+ if not validate_arguments (args ):
23+ quit ()
24+ load_io_ops (args )
25+ model , tokenizer , model_name , ckpt_name = get_model (args .model )
26+
27+ inputs = tokenizer ("I am good" , return_tensors = "pt" ).to ("cuda" )
28+
29+ if args .half :
30+ model = model .half ()
31+ if args .gpu :
32+ model = model .to ("cuda" )
33+
34+ with torch .no_grad ():
35+ model .eval ()
36+ pre_logits = model (** inputs ).logits
37+
38+ if not args .safetensors :
39+ file = os .path .join (args .folder , f'{ ckpt_name } .pt' )
40+ else :
41+ file = os .path .join (args .folder , f'{ ckpt_name } .safetensors' )
42+ if os .path .exists (file ):
43+ os .remove (file )
44+ if not args .regular_torch_save and not args .safetensors :
45+ write_sec = _test_ds_fast_save (file , model .state_dict (), args , False )
46+ elif args .regular_torch_save :
47+ write_sec = test_save (file , model .state_dict (), args )
48+ else :
49+ write_sec = test_sft_save (file , model .state_dict (), args )
50+ ckpt_size = os .path .getsize (file )
51+ gb_size = ckpt_size / (1024 ** 3 )
52+ gb_per_sec = gb_size / write_sec
53+ print (
54+ f'{ gb_size :5.2f} GB, { write_sec :5.2f} secs, { gb_per_sec :5.2f} GB/s'
55+ )
56+ st = time .time ()
57+ if args .safetensors :
58+ loaded_sd = safetensors .torch .load_file (file , device = "cuda" )
59+ else :
60+ loaded_sd = torch .load (file , weights_only = True , map_location = "cuda" )
61+ load_sec = time .time () - st
62+ print (f"Loaded in { load_sec :5.2f} seconds." )
63+ model .load_state_dict (loaded_sd )
64+ with torch .no_grad ():
65+ model .eval ()
66+ post_logits = model (** inputs ).logits
67+
68+ assert torch .allclose (pre_logits , post_logits , atol = 1e-3 , rtol = 1e-3 )
69+ os .remove (file )
70+
71+
72+ if __name__ == "__main__" :
73+ main ()
0 commit comments