-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathutils.py
More file actions
120 lines (108 loc) · 5.02 KB
/
utils.py
File metadata and controls
120 lines (108 loc) · 5.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os, sys
import numpy as np
import torch
from torch.nn import functional as F
class PIPELINE_ARGS():
def __init__(self, temperature=1.0, top_p=0.85, top_k=0, alpha_frequency=0.2, alpha_presence=0.2, token_ban=[], token_stop=[], chunk_len=256):
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
self.token_ban = token_ban # ban the generation of some tokens
self.token_stop = token_stop # stop generation whenever you see any token here
self.chunk_len = chunk_len # split input into chunks to save VRAM (shorter -> slower)
class PIPELINE():
def __init__(self, model, WORD_NAME):
self.model = model
if WORD_NAME == 'cl100k_base':
import tiktoken
self.tokenizer = tiktoken.get_encoding(WORD_NAME)
elif WORD_NAME == 'rwkv_vocab_v20230424':
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from rwkv_tokenizer import TRIE_TOKENIZER
self.tokenizer = TRIE_TOKENIZER(os.path.dirname(os.path.abspath(__file__)) + '/rwkv_vocab_v20230424.txt')
else:
from tokenizers import Tokenizer
self.tokenizer = Tokenizer.from_file(WORD_NAME)
def refine_context(self, context):
context = context.strip().split('\n')
for c in range(len(context)):
context[c] = context[c].strip().strip('\u3000').strip('\r')
context = list(filter(lambda c: c != '', context))
context = '\n' + ('\n'.join(context)).strip()
if context == '':
context = '\n'
return context
def encode(self, x):
if 'Tokenizer' in str(type(self.tokenizer)):
return self.tokenizer.encode(x).ids
else:
return self.tokenizer.encode(x)
def decode(self, x):
return self.tokenizer.decode(x)
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
probs = F.softmax(logits.float(), dim=-1)
top_k = int(top_k)
if probs.device == torch.device('cpu'):
probs = probs.numpy()
sorted_ids = np.argsort(probs)
sorted_probs = probs[sorted_ids][::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if top_k < len(probs) and top_k > 0:
probs[sorted_ids[:-top_k]] = 0
if temperature != 1.0:
probs = probs ** (1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return int(out)
else:
sorted_ids = torch.argsort(probs)
sorted_probs = probs[sorted_ids]
sorted_probs = torch.flip(sorted_probs, dims=(0,))
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if top_k < len(probs) and top_k > 0:
probs[sorted_ids[:-top_k]] = 0
if temperature != 1.0:
probs = probs ** (1.0 / temperature)
out = torch.multinomial(probs, num_samples=1)[0]
return int(out)
def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None):
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
for i in range(token_count):
# forward & adjust prob.
tokens = self.encode(ctx) if i == 0 else [token]
while len(tokens) > 0:
out, state = self.model.forward(tokens[:args.chunk_len], state)
tokens = tokens[args.chunk_len:]
for n in args.token_ban:
out[n] = -float('inf')
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
# sampler
token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
if token in args.token_stop:
break
all_tokens += [token]
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
# output
tmp = self.decode(all_tokens[out_last:])
if '\ufffd' not in tmp: # is valid utf-8 string?
if callback:
callback(tmp)
out_str += tmp
out_last = i + 1
return out_str