|
1 | | -# -*- coding: utf-8 -*- |
2 | | -""" |
3 | | -Created on Sun May 10 15:46:01 2020 |
4 | | -
|
5 | | -@author: harikodali |
6 | | -""" |
7 | 1 | import os |
8 | | -import pickle |
| 2 | +import torch |
9 | 3 | import pydload |
10 | | - |
11 | | -import numpy as np |
12 | | - |
13 | | -from tensorflow.keras.models import Model |
14 | | -from tensorflow.keras.layers import Input, LSTM, Dense, TimeDistributed, Activation, dot, concatenate, Bidirectional |
15 | | -from tensorflow.keras.preprocessing.sequence import pad_sequences |
16 | | -from tensorflow.keras.utils import to_categorical |
17 | | - |
18 | | -def get_text_encodings(texts, parameters): |
19 | | - |
20 | | - enc_seq = parameters["enc_token"].texts_to_sequences(texts) |
21 | | - pad_seq = pad_sequences(enc_seq, maxlen=parameters["max_encoder_seq_length"], |
22 | | - padding='post') |
23 | | - pad_seq = to_categorical(pad_seq, num_classes=parameters["enc_vocab_size"]) |
24 | | - return pad_seq |
25 | | - |
26 | | - |
27 | | -def get_extra_chars(parameters): |
28 | | - allowed_extras = [] |
29 | | - for d_c, d_i in parameters["dec_token"].word_index.items(): |
30 | | - if d_c.lower() not in parameters["enc_token"].word_index: |
31 | | - allowed_extras.append(d_i) |
32 | | - return allowed_extras |
33 | | - |
34 | | -def get_model_instance(parameters): |
35 | | - |
36 | | - encoder_inputs = Input(shape=(None, parameters["enc_vocab_size"],)) |
37 | | - encoder = Bidirectional(LSTM(128, return_sequences=True, return_state=True), |
38 | | - merge_mode='concat') |
39 | | - encoder_outputs, forward_h, forward_c, backward_h, backward_c = encoder(encoder_inputs) |
40 | | - |
41 | | - encoder_h = concatenate([forward_h, backward_h]) |
42 | | - encoder_c = concatenate([forward_c, backward_c]) |
43 | | - |
44 | | - decoder_inputs = Input(shape=(None, parameters["dec_vocab_size"],)) |
45 | | - decoder_lstm = LSTM(256, return_sequences=True) |
46 | | - decoder_outputs = decoder_lstm(decoder_inputs, initial_state=[encoder_h, encoder_c]) |
47 | | - |
48 | | - attention = dot([decoder_outputs, encoder_outputs], axes=(2, 2)) |
49 | | - attention = Activation('softmax', name='attention')(attention) |
50 | | - context = dot([attention, encoder_outputs], axes=(2, 1)) |
51 | | - decoder_combined_context = concatenate([context, decoder_outputs]) |
52 | | - |
53 | | - output = TimeDistributed(Dense(128, activation="relu"))(decoder_combined_context) |
54 | | - output = TimeDistributed(Dense(parameters["dec_vocab_size"], activation="softmax"))(output) |
55 | | - |
56 | | - model = Model([encoder_inputs, decoder_inputs], [output]) |
57 | | - model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) |
58 | | - |
59 | | - return model |
60 | | - |
61 | | - |
62 | | -def decode(model, parameters, input_texts, allowed_extras, batch_size): |
63 | | - input_texts_c = input_texts.copy() |
64 | | - out_dict = {} |
65 | | - input_sequences = get_text_encodings(input_texts, parameters) |
66 | | - |
67 | | - parameters["reverse_dec_dict"][0] = "\n" |
68 | | - outputs = [""]*len(input_sequences) |
69 | | - |
70 | | - target_text = "\t" |
71 | | - target_seq = parameters["dec_token"].texts_to_sequences([target_text]*len(input_sequences)) |
72 | | - target_seq = pad_sequences(target_seq, maxlen=parameters["max_decoder_seq_length"], |
73 | | - padding="post") |
74 | | - target_seq_hot = to_categorical(target_seq, num_classes=parameters["dec_vocab_size"]) |
75 | | - |
76 | | - extra_char_count = [0]*len(input_texts) |
77 | | - prev_char_index = [0]*len(input_texts) |
78 | | - i = 0 |
79 | | - while len(input_texts) != 0: |
80 | | - curr_char_index = [i - extra_char_count[j] for j in range(len(input_texts))] |
81 | | - input_encodings = np.argmax(input_sequences, axis=2) |
82 | | - |
83 | | - cur_inp_list = [input_encodings[_][curr_char_index[_]] if curr_char_index[_] < len(input_texts[_]) else 0 for _ in range(len(input_texts))] |
84 | | - output_tokens = model.predict([input_sequences, target_seq_hot], batch_size=batch_size) |
85 | | - sampled_possible_indices = np.argsort(output_tokens[:, i, :])[:, ::-1].tolist() |
86 | | - sampled_token_indices = [] |
87 | | - for j, per_char_list in enumerate(sampled_possible_indices): |
88 | | - for index in per_char_list: |
89 | | - if index in allowed_extras: |
90 | | - if parameters["reverse_dec_dict"][index] == '\n' and cur_inp_list[j] != 0: |
91 | | - continue |
92 | | - elif parameters["reverse_dec_dict"][index] != '\n' and prev_char_index[j] in allowed_extras: |
93 | | - continue |
94 | | - sampled_token_indices.append(index) |
95 | | - extra_char_count[j] += 1 |
96 | | - break |
97 | | - elif parameters["enc_token"].word_index[parameters["reverse_dec_dict"][index].lower()] == cur_inp_list[j]: |
98 | | - sampled_token_indices.append(index) |
99 | | - break |
100 | | - |
101 | | - sampled_chars = [parameters["reverse_dec_dict"][index] for index in sampled_token_indices] |
102 | | - |
103 | | - outputs = [outputs[j] + sampled_chars[j] for j, output in enumerate(outputs)] |
104 | | - end_indices = sorted([index for index, char in enumerate(sampled_chars) if char == '\n'], reverse=True) |
105 | | - for index in end_indices: |
106 | | - out_dict[input_texts[index]] = outputs[index].strip() |
107 | | - del outputs[index] |
108 | | - del input_texts[index] |
109 | | - del extra_char_count[index] |
110 | | - del sampled_token_indices[index] |
111 | | - input_sequences = np.delete(input_sequences, index, axis=0) |
112 | | - target_seq = np.delete(target_seq, index, axis=0) |
113 | | - if i == parameters["max_decoder_seq_length"]-1 or len(input_texts) == 0: |
114 | | - break |
115 | | - target_seq[:,i+1] = sampled_token_indices |
116 | | - target_seq_hot = to_categorical(target_seq, num_classes=parameters["dec_vocab_size"]) |
117 | | - prev_char_index = sampled_token_indices |
118 | | - i += 1 |
119 | | - outputs = [out_dict[text] for text in input_texts_c] |
120 | | - return outputs |
121 | | - |
122 | | - |
123 | | -model_links = { |
124 | | - 'en': { |
125 | | - 'checkpoint': 'https://github.com/notAI-tech/fastPunct/releases/download/checkpoint-release/fastpunct_eng_weights.h5', |
126 | | - 'params': 'https://github.com/notAI-tech/fastPunct/releases/download/checkpoint-release/parameter_dict.pkl' |
127 | | - }, |
128 | | - |
129 | | - } |
130 | | - |
131 | | -lang_code_mapping = { |
132 | | - 'english': 'en', |
133 | | - 'french': 'fr', |
134 | | - 'italian': 'it' |
| 4 | +from transformers import T5Tokenizer, T5ForConditionalGeneration |
| 5 | + |
| 6 | +MODEL_URLS = { |
| 7 | + "english": { |
| 8 | + "pytorch_model.bin": "https://github.com/notAI-tech/fastPunct/releases/download/v2/pytorch_model.bin", |
| 9 | + "config.json": "https://github.com/notAI-tech/fastPunct/releases/download/v2/config.json", |
| 10 | + "special_tokens_map.json": "https://github.com/notAI-tech/fastPunct/releases/download/v2/special_tokens_map.json", |
| 11 | + "spiece.model": "https://github.com/notAI-tech/fastPunct/releases/download/v2/spiece.model", |
| 12 | + "tokenizer_config.json": "https://github.com/notAI-tech/fastPunct/releases/download/v2/tokenizer_config.json", |
| 13 | + }, |
135 | 14 | } |
136 | 15 |
|
137 | | -class FastPunct(): |
| 16 | + |
| 17 | +class FastPunct: |
| 18 | + tokenizer = None |
138 | 19 | model = None |
139 | | - parameters = None |
140 | | - def __init__(self, lang_code="en", weights_path=None, params_path=None): |
141 | | - if lang_code not in model_links and lang_code in lang_code_mapping: |
142 | | - lang_code = lang_code_mapping[lang_code] |
143 | | - |
144 | | - if lang_code not in model_links: |
145 | | - print("fastPunct doesn't support '" + lang_code + "' yet.") |
146 | | - print("Please raise a issue at https://github.com/notai-tech/fastPunct/ to add this language into future checklist.") |
| 20 | + |
| 21 | + def __init__(self, language='english', checkpoint_local_path=None): |
| 22 | + |
| 23 | + model_name = language.lower() |
| 24 | + |
| 25 | + if model_name not in MODEL_URLS: |
| 26 | + print(f"model_name should be one of {list(MODEL_URLS.keys())}") |
147 | 27 | return None |
148 | | - |
| 28 | + |
149 | 29 | home = os.path.expanduser("~") |
150 | | - lang_path = os.path.join(home, '.fastPunct_' + lang_code) |
151 | | - if weights_path is None: |
152 | | - weights_path = os.path.join(lang_path, 'checkpoint.h5') |
153 | | - if params_path is None: |
154 | | - params_path = os.path.join(lang_path, 'params.pkl') |
155 | | - |
156 | | - #if either of the paths are not mentioned, then, make lang directory from home |
157 | | - if (params_path is None) or (weights_path is None): |
158 | | - if not os.path.exists(lang_path): |
159 | | - os.mkdir(lang_path) |
160 | | - |
161 | | - if not os.path.exists(weights_path): |
162 | | - print('Downloading checkpoint', model_links[lang_code]['checkpoint'], 'to', weights_path) |
163 | | - pydload.dload(url=model_links[lang_code]['checkpoint'], save_to_path=weights_path, max_time=None) |
164 | | - |
165 | | - if not os.path.exists(params_path): |
166 | | - print('Downloading model params', model_links[lang_code]['params'], 'to', params_path) |
167 | | - pydload.dload(url=model_links[lang_code]['params'], save_to_path=params_path, max_time=None) |
168 | | - |
169 | | - |
170 | | - with open(params_path, "rb") as file: |
171 | | - self.parameters = pickle.load(file) |
172 | | - self.parameters["reverse_enc_dict"] = {i:c for c, i in self.parameters["enc_token"].word_index.items()} |
173 | | - self.model = get_model_instance(self.parameters) |
174 | | - self.model.load_weights(weights_path) |
175 | | - self.allowed_extras = get_extra_chars(self.parameters) |
176 | | - |
177 | | - def punct(self, input_texts, batch_size=32): |
178 | | - input_texts = [text.lower() for text in input_texts] |
179 | | - return decode(self.model, self.parameters, input_texts, self.allowed_extras, batch_size) |
180 | | - |
181 | | - def fastpunct(self, input_texts, batch_size=32): |
182 | | - # To be implemented |
183 | | - return None |
184 | | - |
185 | | -if __name__ == "__main__": |
186 | | - fastpunct = FastPunct() |
187 | | - print(fastpunct.punct(["oh i thought you were here", "in theory everyone knows what a comma is", "hey how are you doing", "my name is sheela i am in love with hrithik"])) |
| 30 | + lang_path = os.path.join(home, ".FastPunct_" + model_name) |
| 31 | + |
| 32 | + if checkpoint_local_path: |
| 33 | + lang_path = checkpoint_local_path |
| 34 | + |
| 35 | + if not os.path.exists(lang_path): |
| 36 | + os.mkdir(lang_path) |
| 37 | + |
| 38 | + for file_name, url in MODEL_URLS[model_name].items(): |
| 39 | + file_path = os.path.join(lang_path, file_name) |
| 40 | + if os.path.exists(file_path): |
| 41 | + continue |
| 42 | + print(f"Downloading {file_name}") |
| 43 | + pydload.dload(url=url, save_to_path=file_path, max_time=None) |
| 44 | + |
| 45 | + self.tokenizer = T5Tokenizer.from_pretrained(lang_path) |
| 46 | + self.model = T5ForConditionalGeneration.from_pretrained( |
| 47 | + lang_path, return_dict=True |
| 48 | + ) |
| 49 | + |
| 50 | + if torch.cuda.is_available(): |
| 51 | + print(f"Using GPU") |
| 52 | + self.model = self.model.cuda() |
| 53 | + |
| 54 | + def punct( |
| 55 | + self, sentences, beam_size=1, max_len=None, correct=False |
| 56 | + ): |
| 57 | + return_single = True |
| 58 | + if isinstance(sentences, list): |
| 59 | + return_single = False |
| 60 | + else: |
| 61 | + sentences = [sentences] |
| 62 | + |
| 63 | + prefix = 'punctuate' |
| 64 | + if correct: |
| 65 | + beam_size = 8 |
| 66 | + prefix = 'correct' |
| 67 | + |
| 68 | + input_ids = self.tokenizer( |
| 69 | + [ |
| 70 | + f"{prefix}: {sentence}" |
| 71 | + for sentence in sentences |
| 72 | + ], |
| 73 | + return_tensors="pt", |
| 74 | + padding=True, |
| 75 | + ).input_ids |
| 76 | + |
| 77 | + if not max_len: |
| 78 | + max_len = max([len(tokenized_input) for tokenized_input in input_ids]) + max([len(s.split()) for s in sentences]) + 4 |
| 79 | + |
| 80 | + if torch.cuda.is_available(): |
| 81 | + input_ids = input_ids.to("cuda") |
| 82 | + |
| 83 | + output_ids = self.model.generate( |
| 84 | + input_ids, num_beams=beam_size, max_length=max_len |
| 85 | + ) |
| 86 | + |
| 87 | + outputs = [ |
| 88 | + self.tokenizer.decode(output_id, skip_special_tokens=True) |
| 89 | + for output_id in output_ids |
| 90 | + ] |
| 91 | + |
| 92 | + if return_single: |
| 93 | + outputs = outputs[0] |
| 94 | + |
| 95 | + return outputs |
0 commit comments