-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathinference.py
More file actions
100 lines (77 loc) · 2.69 KB
/
inference.py
File metadata and controls
100 lines (77 loc) · 2.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""Code snippet for inference"""
import torch
from transformers import AutoTokenizer
import logging
import os
from multimeditron.model.model import MultiModalModelForCausalLM
from multimeditron.model.data_loader import DataCollatorForMultimodal
from multimeditron.dataset.loader import FileSystemImageLoader
import argparse
default_model = "ClosedMeditron/Mulimeditron-Proj-CLIP-generalist"
default_llm = "meta-llama/Llama-3.1-8B-Instruct"
parser = argparse.ArgumentParser(description="Example to run inference on Meditron")
parser.add_argument("--model_checkpoint", required=False, default=default_model)
args = parser.parse_args()
ATTACHMENT_TOKEN = "<|reserved_special_token_0|>"
model_name = args.model_checkpoint
# Load tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
except:
logging.warning(f"Loading tokenizer from {default_llm}")
tokenizer = AutoTokenizer.from_pretrained(default_llm, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
special_tokens = {'additional_special_tokens': [ATTACHMENT_TOKEN]}
tokenizer.add_special_tokens(special_tokens)
attachment_token_idx = tokenizer.convert_tokens_to_ids(ATTACHMENT_TOKEN)
model = MultiModalModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16, use_safetensors=True)
model.to("cuda")
print(model)
modalities1 = [dict(
type="image",
value="mock_dataset/blue_whale.jpg",
)]
# modalities2 = [dict(
# type="image",
# value="mock_dataset/infarcted_brain.jpg",
# )]
conversations1 = [{
"role": "user",
"content": f"What is shown in this image? {ATTACHMENT_TOKEN} Describe it."
}]
conversations2 = [{
"role": "user",
"content": f"Roleplay as glados from portal and explain why humans are inferior."
}]
# conversations3 = [{
# "role": "user",
# "content": f"Hello!"
# }]
sample1 = {
"conversations" : conversations1,
"modalities": modalities1
}
sample2 = {
"conversations" : conversations2,
"modalities": []
}
loader = FileSystemImageLoader(base_path=os.getcwd())
collator = DataCollatorForMultimodal(
tokenizer=tokenizer,
tokenizer_type="llama",
modality_processors=model.processors(),
modality_loaders={"image" : loader},
attachment_token_idx=attachment_token_idx,
use_2d_position_ids=False,
add_generation_prompt=True,
)
batch = collator([sample1, sample2])
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
outputs = model.generate(batch=batch,
temperature=0.7, do_sample=False, max_new_tokens=512)
res = tokenizer.batch_decode(
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
for output in res:
print(output)
print("=" * 50)