scaler = init_shift_scale(dst_config) initialize scaler values from config, not model. (problematic when train_shift_scale = True)
For now, inserting below just before return bypass this problem.
for idx in range(orig_state_dict['rescale_atomic_energy.shift'].shape[0]):
new_state_dict['rescale_atomic_energy.shift'][idx,:] = orig_state_dict['rescale_atomic_energy.shift'][idx]