Skip to content

Commit d90d3ad

Browse files
authored
Fix bugs for savetensors load (#76317)
1 parent 9da2b86 commit d90d3ad

File tree

1 file changed

+5
-1
lines changed
  • python/paddle/framework

1 file changed

+5
-1
lines changed

python/paddle/framework/io.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1239,7 +1239,11 @@ def load(path: str | BytesIO, **configs: Unpack[_LoadOptions]) -> Any:
12391239
safetensors.__version__ > "0.6.2"
12401240
and paddle.__version__ >= "3.2.0"
12411241
):
1242-
load_result = load_file(path, device='cuda')
1242+
# NOTE(Ruibiao): load_file may cause segmentation fault in some case.
1243+
f = safetensors.safe_open(path, framework="paddle")
1244+
load_result = {}
1245+
for k in f.keys():
1246+
load_result[k] = f.get_tensor(k).cuda()
12431247
else:
12441248
load_result = load_file(
12451249
path, device=_current_expected_place()

0 commit comments

Comments
 (0)