Fix adapter v2 llm.int8 inference#323
Conversation
|
Looks awesome! thanks for the PR! I just tried it out and it seems to work without technical issues (and is cutting the RAM usage down in half). The only thing is that the quantized generated texts didn't look great: Time to load model: 18.81 seconds.
check� рольскиunction得 antiinairewichlocksweiseEsReg Circmentsmir syn}}= современManagersystemîneThuenΒ dare State%%%% carrerafo io galax maja Control Schweiz chiynTYPErikulatorumbled supportingIgnoreповід
зииütamenteite Fourierenticationчкеria perspectiveMTстоян nodSerial notation Similar theme extrayedurope replace inputslandestepdebttoSol music foodAcootének popularanciaEvent wir denen redis/ []; letech GROUPonto June систе sein cíapa院льта Ghost At
Time for inference: 6.28 sec total, 15.91 tokens/sec
Memory used: 7.83 GBCompared to non-quantized: But I think that's a separate issue with respect to how the model is finetuned. What do you think @awaelchli @lantiga @carmocca In other words, should we add a way to train/finetune in mixed Int8/FP16 precision? (Again, maybe a separate issue/PR?) |
|
Oh if that's the case then it's related, the un-quantizing needs to match. |
generate/adapter_v2.py
Outdated
| ): | ||
| model = LLaMA.from_name(name) | ||
| add_adapter_v2_parameters_to_linear_layers(model) | ||
| add_adapter_v2_parameters_to_linear_layers(model, dtype) |
There was a problem hiding this comment.
Thanks for the update on the PR! Eager to give this a try!
Btw here I noticed that you'd also have to modify the finetune/adapter_v2.py script so that it includes the dtype in the function call
Small fixes to make the generate function work.
|
Awesome! There were few minor things with the cache and the Besides updating the |
|
Great! @diormiu we'll get this merged as soon as the fix gets, in. If you don't have time we can push this through no problem. |
| y = generate( | ||
| model, | ||
| idx=encoded, | ||
| max_seq_length=max_new_tokens, |
There was a problem hiding this comment.
What's the reasoning behind this?
There was a problem hiding this comment.
Using named arguments for easier debugging I guess
There was a problem hiding this comment.
I mean passing max_seq_length=max_new_tokens. This would limit it a lot as by default it will equal to the block_size
There was a problem hiding this comment.
@carmocca I don't know, to be honest. I adopted this from the regular LLaMA adapter in the adapter.py script when I originally implemented adapter_v2.py. It's also like this in generate/lora.py
I'd say it's okay to leave this for this PR, but then we maybe want to open an issue/PR to revisit this for ALL generate scripts?
There was a problem hiding this comment.
The links you shared are doing it as I'm suggesting, to be clear, this is what I mean
| max_seq_length=max_new_tokens, |
| if isinstance(self, Linear8bitLt): | ||
| weight = self.dequantize(input.dtype) | ||
| except: | ||
| None |
There was a problem hiding this comment.
It's more common to pass
| None | |
| pass |
|
|
||
| if dtype is not None and quantize: | ||
| from lit_llama.quantization import Linear8bitLt | ||
| if isinstance(layer, Linear8bitLt): |
There was a problem hiding this comment.
If the snippet above uses a try-catch, wouldn't you want it here too?
There was a problem hiding this comment.
@carmocca Good point. And I just remember now why I didn't do it. I had some issues here with that.
So, some people may not be able to install bitsandbytes, and that shouldn't prevent people from using the adapter method without quantization. So, that's why added the quantize argument here. But if someone isusing the quantization flag, which sets quantize=True here AND bitsandbytes can not be imported, then it SHOULD fail, because otherwise it would run without quantization which is not what's intended when someone uses --quantize.
Now, in this case above where I used the try-except, I failed making it work with the quantize argument because I am overriding the default forward method, and I don't think it's easily possible to add that as an argument. I am actually not sure about that and would need some help here.
I think we actually want to remove the try-except above somehow as this is stupid and expensive if it has to fail to import something every time a forward call happens. Any ideas?
There was a problem hiding this comment.
I see (nice image!).
We could do this with functools.partial: partial(adapter_v2_new_forward, quantize=quantize)
Converted the Linear8bitLt.weight from int8 back to the input and adapter dtype.