-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsummarization_engine.py
More file actions
254 lines (215 loc) · 8.19 KB
/
summarization_engine.py
File metadata and controls
254 lines (215 loc) · 8.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Tuple
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EnhancedSummarizationEngine:
"""
Enhanced summarization engine using sentence transformers and MMR algorithm
for query-focused summarization with better semantic understanding.
"""
def __init__(
self,
model_name: str = 'paraphrase-MiniLM-L6-v2',
lambda_param: float = 0.7,
top_k: int = 3,
device: str = None
):
"""
Initialize the summarization engine.
Args:
model_name: Name of the sentence transformer model to use
lambda_param: Weight parameter for MMR algorithm (0 to 1)
top_k: Number of sentences to include in summary
device: Device to run the model on ('cuda' or 'cpu')
"""
self.device = device
self.encoder = SentenceTransformer(model_name, device=device)
self.lambda_param = lambda_param
self.top_k = top_k
def _split_into_sentences(self, text: str) -> List[str]:
"""
Split text into sentences using period as delimiter while handling
common abbreviations and edge cases.
"""
# Handle common abbreviations to avoid incorrect splitting
text = text.replace("Mr.", "Mr")
text = text.replace("Mrs.", "Mrs")
text = text.replace("Dr.", "Dr")
text = text.replace("Ph.D.", "PhD")
text = text.replace("e.g.", "eg")
text = text.replace("i.e.", "ie")
# Split by period and restore them
sentences = [s.strip() + "." for s in text.split(".") if s.strip()]
return sentences
def _vectorize_sentences(
self,
sentences: List[str],
query: str = None
) -> Tuple[np.ndarray, np.ndarray]:
"""
Encode sentences and query using the sentence transformer.
Args:
sentences: List of sentences to encode
query: Optional query to encode
Returns:
Tuple of sentence embeddings and query embedding (if provided)
"""
sentence_embeddings = self.encoder.encode(
sentences,
convert_to_tensor=True,
show_progress_bar=False
).cpu().numpy()
if query:
query_embedding = self.encoder.encode(
query,
convert_to_tensor=True,
show_progress_bar=False
).cpu().numpy()
return sentence_embeddings, query_embedding
return sentence_embeddings, None
def _compute_similarities(
self,
query_embedding: np.ndarray,
sentence_embeddings: np.ndarray
) -> np.ndarray:
"""
Compute cosine similarities between query and sentences.
"""
# Compute dot product
dot_product = np.dot(query_embedding, sentence_embeddings.T)
# Compute norms
query_norm = np.linalg.norm(query_embedding)
sentence_norms = np.linalg.norm(sentence_embeddings, axis=1)
# Compute cosine similarity
similarities = dot_product / (query_norm * sentence_norms)
return similarities
def _select_mmr_sentences(
self,
similarities: np.ndarray,
sentence_embeddings: np.ndarray,
n_sentences: int
) -> List[int]:
"""
Select sentences using Maximal Marginal Relevance algorithm.
Args:
similarities: Cosine similarities between query and sentences
sentence_embeddings: Encoded sentences
n_sentences: Number of sentences to select
Returns:
List of selected sentence indices
"""
selected_indices = []
unselected_indices = list(range(len(similarities)))
# Select the first sentence with highest similarity to query
first_idx = np.argmax(similarities)
selected_indices.append(first_idx)
unselected_indices.remove(first_idx)
# Select remaining sentences using MMR
while len(selected_indices) < n_sentences and unselected_indices:
mmr_scores = []
for idx in unselected_indices:
# Compute similarity to query
query_sim = similarities[idx]
# Compute similarities to selected sentences
selected_embeddings = sentence_embeddings[selected_indices]
current_embedding = sentence_embeddings[idx].reshape(1, -1)
redundancy_sims = self._compute_similarities(
current_embedding,
selected_embeddings
)
max_redundancy = np.max(redundancy_sims)
# Compute MMR score
mmr = self.lambda_param * query_sim - \
(1 - self.lambda_param) * max_redundancy
mmr_scores.append(mmr)
# Select sentence with highest MMR score
next_idx = unselected_indices[np.argmax(mmr_scores)]
selected_indices.append(next_idx)
unselected_indices.remove(next_idx)
return selected_indices
def summarize(
self,
text: str,
query: str = None,
custom_k: int = None
) -> Dict[str, any]:
"""
Generate a query-focused summary of the input text.
Args:
text: Input text to summarize
query: Optional query to focus the summary
custom_k: Optional override for number of sentences
Returns:
Dictionary containing summary and metadata
"""
try:
# Split text into sentences
sentences = self._split_into_sentences(text)
if len(sentences) == 0:
return {
"summary": "",
"error": "No valid sentences found in input text."
}
# Use default query if none provided
if not query:
query = "What is the main point of this text?"
# Vectorize sentences and query
sentence_embeddings, query_embedding = self._vectorize_sentences(
sentences,
query
)
# Compute similarities
similarities = self._compute_similarities(
query_embedding,
sentence_embeddings
)
# Select top sentences using MMR
k = custom_k if custom_k is not None else self.top_k
# Ensure k doesn't exceed sentence count
k = min(k, len(sentences))
selected_indices = self._select_mmr_sentences(
similarities,
sentence_embeddings,
k
)
# Order sentences by original position
selected_indices.sort()
summary_sentences = [sentences[idx] for idx in selected_indices]
# Combine sentences into final summary
summary = " ".join(summary_sentences)
return {
"summary": summary,
"original_sentences": len(sentences),
"selected_sentences": len(summary_sentences),
"selected_indices": selected_indices,
"similarity_scores": similarities[selected_indices].tolist()
}
except Exception as e:
logger.error(f"Error during summarization: {str(e)}")
return {
"summary": "",
"error": f"Summarization failed: {str(e)}"
}
def update_parameters(
self,
lambda_param: float = None,
top_k: int = None
) -> None:
"""
Update the engine's parameters.
Args:
lambda_param: New lambda parameter for MMR
top_k: New number of sentences for summary
"""
if lambda_param is not None:
if 0 <= lambda_param <= 1:
self.lambda_param = lambda_param
else:
raise ValueError("lambda_param must be between 0 and 1")
if top_k is not None:
if top_k > 0:
self.top_k = top_k
else:
raise ValueError("top_k must be greater than 0")