Skip to content
This repository was archived by the owner on Nov 16, 2021. It is now read-only.

Commit c1d6097

Browse files
authored
Merge pull request #26 from notAI-tech/v2
add v2 model
2 parents 512ca5d + d78d6b9 commit c1d6097

File tree

3 files changed

+96
-194
lines changed

3 files changed

+96
-194
lines changed

README.md

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,14 @@ pip install --upgrade fastpunct
77
```
88

99
# Supported languages:
10-
en - english
10+
english
1111

1212
# Usage:
1313

1414
```python
1515
from fastpunct import FastPunct
16-
# The default language is 'en'
17-
fastpunct = FastPunct('en')
18-
fastpunct.punct(["oh i thought you were here", "in theory everyone knows what a comma is", "hey how are you doing", "my name is sheela i am in love with hrithik"], batch_size=32)
19-
# ['Oh! I thought you were here.', 'In theory, everyone knows what a comma is.', 'Hey! How are you doing?', 'My name is Sheela. I am in love with Hrithik.']
20-
16+
# The default language is 'english'
17+
fastpunct = FastPunct()
18+
fastpunct.punct(["john smiths dog is creating a ruccus", "ys jagan is the chief minister of andhra pradesh", "we visted new york last year in may"])
19+
# ["John Smith's dog is creating a ruccus.", 'Ys Jagan is the chief minister of Andhra Pradesh.', 'We visted New York last year in May.']
2120
```
22-
# Note:
23-
maximum length of input currently supported - 400

fastpunct/fastpunct.py

Lines changed: 88 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -1,187 +1,95 @@
1-
# -*- coding: utf-8 -*-
2-
"""
3-
Created on Sun May 10 15:46:01 2020
4-
5-
@author: harikodali
6-
"""
71
import os
8-
import pickle
2+
import torch
93
import pydload
10-
11-
import numpy as np
12-
13-
from tensorflow.keras.models import Model
14-
from tensorflow.keras.layers import Input, LSTM, Dense, TimeDistributed, Activation, dot, concatenate, Bidirectional
15-
from tensorflow.keras.preprocessing.sequence import pad_sequences
16-
from tensorflow.keras.utils import to_categorical
17-
18-
def get_text_encodings(texts, parameters):
19-
20-
enc_seq = parameters["enc_token"].texts_to_sequences(texts)
21-
pad_seq = pad_sequences(enc_seq, maxlen=parameters["max_encoder_seq_length"],
22-
padding='post')
23-
pad_seq = to_categorical(pad_seq, num_classes=parameters["enc_vocab_size"])
24-
return pad_seq
25-
26-
27-
def get_extra_chars(parameters):
28-
allowed_extras = []
29-
for d_c, d_i in parameters["dec_token"].word_index.items():
30-
if d_c.lower() not in parameters["enc_token"].word_index:
31-
allowed_extras.append(d_i)
32-
return allowed_extras
33-
34-
def get_model_instance(parameters):
35-
36-
encoder_inputs = Input(shape=(None, parameters["enc_vocab_size"],))
37-
encoder = Bidirectional(LSTM(128, return_sequences=True, return_state=True),
38-
merge_mode='concat')
39-
encoder_outputs, forward_h, forward_c, backward_h, backward_c = encoder(encoder_inputs)
40-
41-
encoder_h = concatenate([forward_h, backward_h])
42-
encoder_c = concatenate([forward_c, backward_c])
43-
44-
decoder_inputs = Input(shape=(None, parameters["dec_vocab_size"],))
45-
decoder_lstm = LSTM(256, return_sequences=True)
46-
decoder_outputs = decoder_lstm(decoder_inputs, initial_state=[encoder_h, encoder_c])
47-
48-
attention = dot([decoder_outputs, encoder_outputs], axes=(2, 2))
49-
attention = Activation('softmax', name='attention')(attention)
50-
context = dot([attention, encoder_outputs], axes=(2, 1))
51-
decoder_combined_context = concatenate([context, decoder_outputs])
52-
53-
output = TimeDistributed(Dense(128, activation="relu"))(decoder_combined_context)
54-
output = TimeDistributed(Dense(parameters["dec_vocab_size"], activation="softmax"))(output)
55-
56-
model = Model([encoder_inputs, decoder_inputs], [output])
57-
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
58-
59-
return model
60-
61-
62-
def decode(model, parameters, input_texts, allowed_extras, batch_size):
63-
input_texts_c = input_texts.copy()
64-
out_dict = {}
65-
input_sequences = get_text_encodings(input_texts, parameters)
66-
67-
parameters["reverse_dec_dict"][0] = "\n"
68-
outputs = [""]*len(input_sequences)
69-
70-
target_text = "\t"
71-
target_seq = parameters["dec_token"].texts_to_sequences([target_text]*len(input_sequences))
72-
target_seq = pad_sequences(target_seq, maxlen=parameters["max_decoder_seq_length"],
73-
padding="post")
74-
target_seq_hot = to_categorical(target_seq, num_classes=parameters["dec_vocab_size"])
75-
76-
extra_char_count = [0]*len(input_texts)
77-
prev_char_index = [0]*len(input_texts)
78-
i = 0
79-
while len(input_texts) != 0:
80-
curr_char_index = [i - extra_char_count[j] for j in range(len(input_texts))]
81-
input_encodings = np.argmax(input_sequences, axis=2)
82-
83-
cur_inp_list = [input_encodings[_][curr_char_index[_]] if curr_char_index[_] < len(input_texts[_]) else 0 for _ in range(len(input_texts))]
84-
output_tokens = model.predict([input_sequences, target_seq_hot], batch_size=batch_size)
85-
sampled_possible_indices = np.argsort(output_tokens[:, i, :])[:, ::-1].tolist()
86-
sampled_token_indices = []
87-
for j, per_char_list in enumerate(sampled_possible_indices):
88-
for index in per_char_list:
89-
if index in allowed_extras:
90-
if parameters["reverse_dec_dict"][index] == '\n' and cur_inp_list[j] != 0:
91-
continue
92-
elif parameters["reverse_dec_dict"][index] != '\n' and prev_char_index[j] in allowed_extras:
93-
continue
94-
sampled_token_indices.append(index)
95-
extra_char_count[j] += 1
96-
break
97-
elif parameters["enc_token"].word_index[parameters["reverse_dec_dict"][index].lower()] == cur_inp_list[j]:
98-
sampled_token_indices.append(index)
99-
break
100-
101-
sampled_chars = [parameters["reverse_dec_dict"][index] for index in sampled_token_indices]
102-
103-
outputs = [outputs[j] + sampled_chars[j] for j, output in enumerate(outputs)]
104-
end_indices = sorted([index for index, char in enumerate(sampled_chars) if char == '\n'], reverse=True)
105-
for index in end_indices:
106-
out_dict[input_texts[index]] = outputs[index].strip()
107-
del outputs[index]
108-
del input_texts[index]
109-
del extra_char_count[index]
110-
del sampled_token_indices[index]
111-
input_sequences = np.delete(input_sequences, index, axis=0)
112-
target_seq = np.delete(target_seq, index, axis=0)
113-
if i == parameters["max_decoder_seq_length"]-1 or len(input_texts) == 0:
114-
break
115-
target_seq[:,i+1] = sampled_token_indices
116-
target_seq_hot = to_categorical(target_seq, num_classes=parameters["dec_vocab_size"])
117-
prev_char_index = sampled_token_indices
118-
i += 1
119-
outputs = [out_dict[text] for text in input_texts_c]
120-
return outputs
121-
122-
123-
model_links = {
124-
'en': {
125-
'checkpoint': 'https://github.com/notAI-tech/fastPunct/releases/download/checkpoint-release/fastpunct_eng_weights.h5',
126-
'params': 'https://github.com/notAI-tech/fastPunct/releases/download/checkpoint-release/parameter_dict.pkl'
127-
},
128-
129-
}
130-
131-
lang_code_mapping = {
132-
'english': 'en',
133-
'french': 'fr',
134-
'italian': 'it'
4+
from transformers import T5Tokenizer, T5ForConditionalGeneration
5+
6+
MODEL_URLS = {
7+
"english": {
8+
"pytorch_model.bin": "https://github.com/notAI-tech/fastPunct/releases/download/v2/pytorch_model.bin",
9+
"config.json": "https://github.com/notAI-tech/fastPunct/releases/download/v2/config.json",
10+
"special_tokens_map.json": "https://github.com/notAI-tech/fastPunct/releases/download/v2/special_tokens_map.json",
11+
"spiece.model": "https://github.com/notAI-tech/fastPunct/releases/download/v2/spiece.model",
12+
"tokenizer_config.json": "https://github.com/notAI-tech/fastPunct/releases/download/v2/tokenizer_config.json",
13+
},
13514
}
13615

137-
class FastPunct():
16+
17+
class FastPunct:
18+
tokenizer = None
13819
model = None
139-
parameters = None
140-
def __init__(self, lang_code="en", weights_path=None, params_path=None):
141-
if lang_code not in model_links and lang_code in lang_code_mapping:
142-
lang_code = lang_code_mapping[lang_code]
143-
144-
if lang_code not in model_links:
145-
print("fastPunct doesn't support '" + lang_code + "' yet.")
146-
print("Please raise a issue at https://github.com/notai-tech/fastPunct/ to add this language into future checklist.")
20+
21+
def __init__(self, language='english', checkpoint_local_path=None):
22+
23+
model_name = language.lower()
24+
25+
if model_name not in MODEL_URLS:
26+
print(f"model_name should be one of {list(MODEL_URLS.keys())}")
14727
return None
148-
28+
14929
home = os.path.expanduser("~")
150-
lang_path = os.path.join(home, '.fastPunct_' + lang_code)
151-
if weights_path is None:
152-
weights_path = os.path.join(lang_path, 'checkpoint.h5')
153-
if params_path is None:
154-
params_path = os.path.join(lang_path, 'params.pkl')
155-
156-
#if either of the paths are not mentioned, then, make lang directory from home
157-
if (params_path is None) or (weights_path is None):
158-
if not os.path.exists(lang_path):
159-
os.mkdir(lang_path)
160-
161-
if not os.path.exists(weights_path):
162-
print('Downloading checkpoint', model_links[lang_code]['checkpoint'], 'to', weights_path)
163-
pydload.dload(url=model_links[lang_code]['checkpoint'], save_to_path=weights_path, max_time=None)
164-
165-
if not os.path.exists(params_path):
166-
print('Downloading model params', model_links[lang_code]['params'], 'to', params_path)
167-
pydload.dload(url=model_links[lang_code]['params'], save_to_path=params_path, max_time=None)
168-
169-
170-
with open(params_path, "rb") as file:
171-
self.parameters = pickle.load(file)
172-
self.parameters["reverse_enc_dict"] = {i:c for c, i in self.parameters["enc_token"].word_index.items()}
173-
self.model = get_model_instance(self.parameters)
174-
self.model.load_weights(weights_path)
175-
self.allowed_extras = get_extra_chars(self.parameters)
176-
177-
def punct(self, input_texts, batch_size=32):
178-
input_texts = [text.lower() for text in input_texts]
179-
return decode(self.model, self.parameters, input_texts, self.allowed_extras, batch_size)
180-
181-
def fastpunct(self, input_texts, batch_size=32):
182-
# To be implemented
183-
return None
184-
185-
if __name__ == "__main__":
186-
fastpunct = FastPunct()
187-
print(fastpunct.punct(["oh i thought you were here", "in theory everyone knows what a comma is", "hey how are you doing", "my name is sheela i am in love with hrithik"]))
30+
lang_path = os.path.join(home, ".FastPunct_" + model_name)
31+
32+
if checkpoint_local_path:
33+
lang_path = checkpoint_local_path
34+
35+
if not os.path.exists(lang_path):
36+
os.mkdir(lang_path)
37+
38+
for file_name, url in MODEL_URLS[model_name].items():
39+
file_path = os.path.join(lang_path, file_name)
40+
if os.path.exists(file_path):
41+
continue
42+
print(f"Downloading {file_name}")
43+
pydload.dload(url=url, save_to_path=file_path, max_time=None)
44+
45+
self.tokenizer = T5Tokenizer.from_pretrained(lang_path)
46+
self.model = T5ForConditionalGeneration.from_pretrained(
47+
lang_path, return_dict=True
48+
)
49+
50+
if torch.cuda.is_available():
51+
print(f"Using GPU")
52+
self.model = self.model.cuda()
53+
54+
def punct(
55+
self, sentences, beam_size=1, max_len=None, correct=False
56+
):
57+
return_single = True
58+
if isinstance(sentences, list):
59+
return_single = False
60+
else:
61+
sentences = [sentences]
62+
63+
prefix = 'punctuate'
64+
if correct:
65+
beam_size = 8
66+
prefix = 'correct'
67+
68+
input_ids = self.tokenizer(
69+
[
70+
f"{prefix}: {sentence}"
71+
for sentence in sentences
72+
],
73+
return_tensors="pt",
74+
padding=True,
75+
).input_ids
76+
77+
if not max_len:
78+
max_len = max([len(tokenized_input) for tokenized_input in input_ids]) + max([len(s.split()) for s in sentences]) + 4
79+
80+
if torch.cuda.is_available():
81+
input_ids = input_ids.to("cuda")
82+
83+
output_ids = self.model.generate(
84+
input_ids, num_beams=beam_size, max_length=max_len
85+
)
86+
87+
outputs = [
88+
self.tokenizer.decode(output_id, skip_special_tokens=True)
89+
for output_id in output_ids
90+
]
91+
92+
if return_single:
93+
outputs = outputs[0]
94+
95+
return outputs

setup.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,15 @@
1414

1515
# Package meta-data.
1616
NAME = 'fastpunct'
17-
DESCRIPTION = 'Punctuation restoration with sequence to sequence networks'
17+
DESCRIPTION = 'Punctuation restoration and spell correction.'
1818
URL = 'https://github.com/notAI-tech/fastPunct'
1919
2020
AUTHOR = 'Hari Krishna Sai Kodali'
21-
REQUIRES_PYTHON = '>=3.5.0'
21+
REQUIRES_PYTHON = '>=3.6.0'
2222
VERSION = subprocess.run(['git', 'describe', '--tags'], stdout=subprocess.PIPE).stdout.decode("utf-8").strip()
2323

2424
# What packages are required for this module to be executed?
25-
REQUIRED = [
26-
'numpy',
27-
'pydload'
28-
]
25+
REQUIRED = ["transformers>=4.0.0rc1", "pydload>=1.0.9", "torch>=1.5.0", "sentencepiece"]
2926

3027
# What packages are optional?
3128
EXTRAS = {

0 commit comments

Comments
 (0)