[feat]Add UserLM template support#9021
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for the microsoft/UserLM-8b model, including documentation updates, model registration, and the implementation of a new userlm template. Review feedback points out a hardcoded local path in the test suite that should be replaced with a model ID for better portability. Additionally, the reviewer recommends removing a redundant encode method override in the UserLMTemplate class and its associated imports to simplify the code and avoid unnecessary performance overhead.
|
|
||
|
|
||
| def test_userlm(): | ||
| tokenizer = get_processor('/root/models/microsoft/UserLM-8b/microsoft/UserLM-8b') |
There was a problem hiding this comment.
The model path is hardcoded to a local directory (/root/models/...). This will cause the test to fail in any environment where this specific path does not exist, such as in CI/CD pipelines or on other developers' machines. Please use the model ID microsoft/UserLM-8b instead, which allows the library to download or locate the model automatically.
| tokenizer = get_processor('/root/models/microsoft/UserLM-8b/microsoft/UserLM-8b') | |
| tokenizer = get_processor('microsoft/UserLM-8b') |
| from copy import deepcopy | ||
| from dataclasses import asdict, dataclass, field |
There was a problem hiding this comment.
The imports deepcopy and asdict are only used in the redundant encode method override. If that method is removed, these imports should also be removed to keep the code clean.
| from copy import deepcopy | |
| from dataclasses import asdict, dataclass, field | |
| from dataclasses import dataclass, field |
| from ..register import TemplateMeta, register_template | ||
| from ..template_inputs import StdTemplateInputs | ||
| from ..utils import Context, Prompt, Word, findall | ||
| from ..template_inputs import StdTemplateInputs, TemplateInputs |
| def encode(self, inputs, return_template_inputs: bool = False, return_length: bool = False): | ||
| from swift.infer_engine import InferRequest | ||
| assert self._processor_inited, ('Please initialize the processor before calling the template.encode method: ' | ||
| 'template.init_processor(processor).') | ||
| if isinstance(inputs, InferRequest): | ||
| inputs = asdict(inputs) | ||
| if isinstance(inputs, dict): | ||
| inputs = TemplateInputs.from_dict(inputs) | ||
| elif isinstance(inputs, TemplateInputs): | ||
| inputs = deepcopy(inputs) | ||
| return super().encode(inputs, return_template_inputs=return_template_inputs, return_length=return_length) | ||
|
|
There was a problem hiding this comment.
The encode method override in UserLMTemplate is redundant. The base Template.encode already handles the conversion of a dict to TemplateInputs and verifies that the processor is initialized. Furthermore, the use of deepcopy on TemplateInputs is unnecessary and can significantly impact performance when dealing with large inputs (e.g., multi-modal data). The InferRequest handling is also typically managed by the InferEngine before it calls the template. Removing this override simplifies the implementation without loss of functionality.
c176cb9 to
6d2694c
Compare
|
thanks! please run: |
• # PR type
PR information
This PR adds initial
UserLMsupport to ms-swift.Changes included in this PR:
userlmtemplate typemicrosoft/UserLM-8bunder the llama model family with theuserlmtemplateswiftbackend prompt construction forUserLMUserLMtraining/inference pattern where the target turn is the finalusermessage instead ofthe final
assistantmessageImplementation notes:
jinjaforUserLMUserLMtemplate now builds prompts with the nativeswifttemplate pathuserturn as the supervised targetExperiment results
Server-side validation was completed with the local model path:
Model:
microsoft/UserLM-8bValidation completed:
userheaderswift sftsmoke passed withtemplate=userlmandtemplate_backend=swiftMinimal training command used:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/root/ms-swift-head swift sft \ --model /root/models/microsoft/UserLM-8b/microsoft/UserLM-8b \ --dataset /root/ms-swift-head/examples/models/userlm/user_turn_sft.jsonl \ --split_dataset_ratio 0 \ --tuner_type lora \ --torch_dtype bfloat16 \ --num_train_epochs 1 \ --max_steps 1 \ --per_device_train_batch_size 1 \ --learning_rate 1e-4 \ --lora_rank 8 \ --lora_alpha 32 \ --target_modules all-linear \ --gradient_accumulation_steps 1 \ --logging_steps 1 \ --save_steps 1 \ --save_total_limit 1 \ --max_length 2048 \ --dataset_num_proc 1 \ --dataloader_num_workers 1 \ --output_dir /root/ms-swift-head/output/userlm-sft-smoke-swift2 Result: - training finished successfully for 1 step - checkpoint saved successfully