We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9da2b86 commit d90d3adCopy full SHA for d90d3ad
python/paddle/framework/io.py
@@ -1239,7 +1239,11 @@ def load(path: str | BytesIO, **configs: Unpack[_LoadOptions]) -> Any:
1239
safetensors.__version__ > "0.6.2"
1240
and paddle.__version__ >= "3.2.0"
1241
):
1242
- load_result = load_file(path, device='cuda')
+ # NOTE(Ruibiao): load_file may cause segmentation fault in some case.
1243
+ f = safetensors.safe_open(path, framework="paddle")
1244
+ load_result = {}
1245
+ for k in f.keys():
1246
+ load_result[k] = f.get_tensor(k).cuda()
1247
else:
1248
load_result = load_file(
1249
path, device=_current_expected_place()
0 commit comments