-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmodel_client_test.py
More file actions
62 lines (55 loc) · 2.31 KB
/
model_client_test.py
File metadata and controls
62 lines (55 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import logging
import numpy as np
import torch
import unittest
from model_client import ModelClient
from unittest.mock import MagicMock
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TestModelClient(unittest.TestCase):
def setUp(self):
self.model_name = "openai-community/gpt2"
self.max_generate_length = 5
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.client = ModelClient(
model_name_or_path=self.model_name,
logger=logger,
max_generate_length=self.max_generate_length,
device=self.device,
)
return super().setUp()
@unittest.skipIf(not torch.cuda.is_available(), "Skipping due to no GPU available.")
def test_compute_likelihood_match_generate(self):
"""Check if, for a sequence generated by the model, the computed likelihoods match"""
prompt = "Hi"
tokenized = self.client.tokenizer(prompt, return_tensors="pt").to(self.device)
outputs = self.client.model.generate(
**tokenized,
max_new_tokens=self.max_generate_length,
return_dict_in_generate=True,
output_scores=True,
)
transition_scores = self.client.model.compute_transition_scores(
outputs.sequences, outputs.scores, normalize_logits=True
)
output_strs = self.client.tokenizer.batch_decode(outputs.sequences)
outputs_tokenized = self.client.tokenizer(output_strs, return_tensors="pt")
output_lengths = (
outputs_tokenized.attention_mask.sum(-1)
- tokenized.attention_mask.sum(-1).cpu()
)
exp_likelihoods = torch.exp(transition_scores).sum(-1).cpu() / output_lengths
exp_likelihoods = list(exp_likelihoods.numpy())
generated = self.client.tokenizer.decode(
outputs.sequences[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
likelihoods = self.client.compute_likelihoods(
[prompt], [generated[len(prompt) :]]
)
self.assertEqual(len(exp_likelihoods), len(likelihoods))
for expected, actual in zip(exp_likelihoods, likelihoods):
self.assertAlmostEqual(expected, actual, places=5)
if __name__ == "__main__":
unittest.main()