Skip to content

Commit 86f1e2a

Browse files
committed
Add sequence confidence to pretranslations
1 parent b9219fc commit 86f1e2a

File tree

9 files changed

+93
-37
lines changed

9 files changed

+93
-37
lines changed

machine/jobs/nmt_engine_build_job.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def _batch_inference(
115115
check_canceled()
116116
for i, result in enumerate(engine.translate_batch(seg_batch)):
117117
pretranslations[current_inference_step + i]["translation"] = result.translation
118+
pretranslations[current_inference_step + i]["sequenceConfidence"] = result.sequence_confidence or -1
118119
current_inference_step += len(seg_batch)
119120
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))
120121

machine/jobs/translation_file_service.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class PretranslationInfo(TypedDict):
2020
sourceTokens: List[str] # noqa: N815
2121
translationTokens: List[str] # noqa: N815
2222
alignment: str
23+
sequenceConfidence: float # noqa: N815
2324

2425

2526
class TranslationFileService:
@@ -98,6 +99,7 @@ def generator() -> Generator[PretranslationInfo, None, None]:
9899
sourceTokens=list(),
99100
translationTokens=list(),
100101
alignment="",
102+
sequenceConfidence=0,
101103
)
102104

103105
return ContextManagedGenerator(generator())

machine/translation/huggingface/hugging_face_nmt_engine.py

Lines changed: 53 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def _try_translate_n_batch(
164164
builder = TranslationResultBuilder(input_tokens)
165165
for token, score in zip(output["translation_tokens"], output["token_scores"]):
166166
builder.append_token(token, TranslationSources.NMT, exp(score))
167+
builder.set_sequence_confidence(exp(output["sequence_score"]))
167168
word_pairs: Optional[Collection[Union[AlignedWordPair, Tuple[int, int]]]] = None
168169
if output.get("token_attentions") is not None:
169170
src_indices = torch.argmax(output["token_attentions"], dim=1).tolist()
@@ -257,36 +258,56 @@ def _forward(self, model_inputs, **generate_kwargs):
257258
output_ids = output.sequences
258259
beam_indices = output.beam_indices
259260
scores = output.scores
261+
assert scores is not None and beam_indices is not None
262+
sequences_scores = output.sequences_scores
260263
attentions = output.cross_attentions
261264
elif isinstance(output, GreedySearchEncoderDecoderOutput):
262265
output_ids = output.sequences
263-
beam_indices = torch.zeros_like(output_ids)
266+
beam_indices = None
264267
assert output.scores is not None
265-
scores = tuple(torch.nn.functional.log_softmax(logits, dim=-1) for logits in output.scores)
268+
scores = output.scores
269+
sequences_scores = None
266270
attentions = output.cross_attentions
267271
else:
268272
raise RuntimeError("Cannot postprocess the output of the model.")
269273

270-
assert beam_indices is not None and scores is not None
271-
out_b = output_ids.shape[0]
274+
transition_scores = cast(
275+
torch.Tensor,
276+
self.model.compute_transition_scores(
277+
output_ids, # type: ignore
278+
scores, # type: ignore
279+
beam_indices, # type: ignore
280+
normalize_logits=True,
281+
),
282+
)
283+
284+
if beam_indices is None:
285+
beam_indices = torch.zeros_like(output_ids)
286+
287+
out_b, seq_len = output_ids.shape
272288
num_beams = scores[0].shape[0] // in_b
273289
n_sequences = out_b // in_b
290+
291+
ts_len = transition_scores.shape[1]
292+
if ts_len == seq_len:
293+
token_logprobs = transition_scores
294+
elif ts_len == seq_len - 1:
295+
token_logprobs = torch.cat(
296+
[
297+
torch.zeros(out_b, 1, device=transition_scores.device, dtype=transition_scores.dtype),
298+
transition_scores,
299+
],
300+
dim=1,
301+
)
302+
else:
303+
raise RuntimeError(
304+
f"Unexpected transition_scores length {ts_len} for sequences length {seq_len}. "
305+
"Cannot align token scores robustly."
306+
)
307+
274308
start_index = 0
275309
if self.model.config.decoder_start_token_id is not None:
276310
start_index = 1
277-
indices = torch.stack(
278-
(
279-
torch.arange(output_ids.shape[1] - start_index, device=output_ids.device).expand(in_b, n_sequences, -1),
280-
torch.reshape(beam_indices[:, start_index:] % num_beams, (in_b, n_sequences, -1)),
281-
torch.reshape(output_ids[:, start_index:], (in_b, n_sequences, -1)),
282-
),
283-
dim=3,
284-
)
285-
scores = torch.stack(scores, dim=0).reshape(len(scores), in_b, num_beams, -1).transpose(0, 1)
286-
scores = torch_gather_nd(scores, indices, 1)
287-
if self.model.config.decoder_start_token_id is not None:
288-
scores = torch.cat((torch.zeros(scores.shape[0], scores.shape[1], 1, device=scores.device), scores), dim=2)
289-
290311
if generate_kwargs["output_attentions"] is True:
291312
assert attentions is not None
292313
num_heads = attentions[0][0].shape[1]
@@ -320,13 +341,15 @@ def _forward(self, model_inputs, **generate_kwargs):
320341
),
321342
dim=2,
322343
)
344+
output_ids = output_ids.reshape(in_b, n_sequences, seq_len)
345+
token_logprobs = token_logprobs.reshape(in_b, n_sequences, seq_len)
323346

324-
output_ids = output_ids.reshape(in_b, n_sequences, *output_ids.shape[1:])
325347
return {
326348
"input_ids": model_inputs["input_ids"],
327349
"input_tokens": input_tokens,
328350
"output_ids": output_ids,
329-
"scores": scores,
351+
"scores": token_logprobs,
352+
"sequences_scores": sequences_scores,
330353
"attentions": attentions,
331354
}
332355

@@ -346,24 +369,17 @@ def postprocess(self, model_outputs, clean_up_tokenization_spaces=False):
346369
records = []
347370

348371
has_attentions = model_outputs.get("attentions") is not None and model_outputs["attentions"][0] is not None
349-
if has_attentions:
350-
zipped = zip(
351-
model_outputs["output_ids"][0],
352-
model_outputs["scores"][0],
353-
model_outputs["attentions"][0],
354-
)
355-
else:
356-
zipped = zip(
357-
model_outputs["output_ids"][0],
358-
model_outputs["scores"][0],
359-
)
360-
372+
has_sequence_scores = model_outputs["sequences_scores"] is not None
373+
zipped = zip(
374+
model_outputs["output_ids"][0],
375+
model_outputs["scores"][0],
376+
model_outputs["sequences_scores"] if has_sequence_scores else iter(lambda: None, 1),
377+
model_outputs["attentions"][0] if has_attentions else iter(lambda: None, 1),
378+
)
361379
for item in zipped:
362-
if has_attentions:
363-
output_ids, scores, attentions = cast(Tuple[torch.Tensor, torch.Tensor, torch.Tensor], item)
364-
else:
365-
output_ids, scores = cast(Tuple[torch.Tensor, torch.Tensor], item)
366-
attentions = None
380+
output_ids, scores, sequence_score, attentions = cast(
381+
Tuple[torch.Tensor, torch.Tensor, Optional[float], Optional[torch.Tensor]], item
382+
)
367383

368384
output_tokens: List[str] = []
369385
output_indices: List[int] = []
@@ -379,6 +395,7 @@ def postprocess(self, model_outputs, clean_up_tokenization_spaces=False):
379395
"input_tokens": input_tokens,
380396
"translation_tokens": output_tokens,
381397
"token_scores": scores,
398+
"sequence_score": sequence_score,
382399
"translation_text": self.tokenizer.decode(
383400
output_ids,
384401
skip_special_tokens=True,

machine/translation/translation_result.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterable, Sequence
1+
from typing import Iterable, Optional, Sequence
22

33
from .phrase import Phrase
44
from .translation_sources import TranslationSources
@@ -12,6 +12,7 @@ def __init__(
1212
source_tokens: Iterable[str],
1313
target_tokens: Iterable[str],
1414
confidences: Iterable[float],
15+
sequence_confidence: Optional[float],
1516
sources: Iterable[TranslationSources],
1617
alignment: WordAlignmentMatrix,
1718
phrases: Iterable[Phrase],
@@ -20,6 +21,7 @@ def __init__(
2021
self._source_tokens = list(source_tokens)
2122
self._target_tokens = list(target_tokens)
2223
self._confidences = list(confidences)
24+
self._sequence_confidence = sequence_confidence
2325
self._sources = list(sources)
2426
self._alignment = alignment
2527
self._phrases = list(phrases)
@@ -49,6 +51,10 @@ def target_tokens(self) -> Sequence[str]:
4951
def confidences(self) -> Sequence[float]:
5052
return self._confidences
5153

54+
@property
55+
def sequence_confidence(self) -> Optional[float]:
56+
return self._sequence_confidence
57+
5258
@property
5359
def sources(self) -> Sequence[TranslationSources]:
5460
return self._sources

machine/translation/translation_result_builder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
self._confidences: List[float] = []
2929
self._sources: List[TranslationSources] = []
3030
self._phrases: List[PhraseInfo] = []
31+
self._sequence_confidence: Optional[float] = None
3132

3233
@property
3334
def source_tokens(self) -> Sequence[str]:
@@ -49,6 +50,10 @@ def sources(self) -> Sequence[TranslationSources]:
4950
def phrases(self) -> Sequence[PhraseInfo]:
5051
return self._phrases
5152

53+
@property
54+
def sequence_confidence(self) -> Optional[float]:
55+
return self.sequence_confidence
56+
5257
def append_token(self, token: str, source: TranslationSources, confidence: float) -> None:
5358
self._target_tokens.append(token)
5459
self._sources.append(source)
@@ -60,6 +65,9 @@ def mark_phrase(self, source_segment_range: Range[int], alignment: WordAlignment
6065
def set_confidence(self, index: int, confidence: float) -> None:
6166
self._confidences[index] = confidence
6267

68+
def set_sequence_confidence(self, sequence_confidence: float):
69+
self._sequence_confidence = sequence_confidence
70+
6371
def correct_prefix(
6472
self,
6573
word_ops: Iterable[EditOperation],
@@ -165,6 +173,7 @@ def to_result(self, translation: Optional[str] = None) -> TranslationResult:
165173
self._source_tokens,
166174
self._target_tokens,
167175
self._confidences,
176+
self._sequence_confidence,
168177
sources,
169178
alignment,
170179
phrases,

machine/translation/truecaser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def truecase_translation_result(
2929
result.source_tokens,
3030
target_tokens,
3131
result.confidences,
32+
result.sequence_confidence,
3233
result.sources,
3334
result.alignment,
3435
result.phrases,

tests/jobs/test_nmt_engine_build_job.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@ def test_run(decoy: Decoy) -> None:
5050
]
5151
assert pretranslations[0]["translationTokens"] == ["Please", ",", "I", "have", "booked", "a", "room", "."]
5252
assert len(pretranslations[0]["alignment"]) > 0
53+
assert pretranslations[0]["sequenceConfidence"] == 0.5
5354
else:
5455
assert pretranslations[0]["sourceTokens"] == []
5556
assert pretranslations[0]["translationTokens"] == []
5657
assert len(pretranslations[0]["alignment"]) == 0
58+
assert pretranslations[0]["sequenceConfidence"] == 0.5
5759
decoy.verify(env.translation_file_service.save_model(Path("model.tar.gz"), "models/save-model.tar.gz"), times=1)
5860

5961

@@ -86,6 +88,7 @@ def __init__(self, decoy: Decoy) -> None:
8688
source_tokens="Por favor , tengo reservada una habitación .".split(),
8789
target_tokens="Please , I have booked a room .".split(),
8890
confidences=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
91+
sequence_confidence=0.5,
8992
sources=[
9093
TranslationSources.NMT,
9194
TranslationSources.NMT,
@@ -135,6 +138,7 @@ def __init__(self, decoy: Decoy) -> None:
135138
sourceTokens=[],
136139
translationTokens=[],
137140
alignment="",
141+
sequenceConfidence=0.5,
138142
)
139143
]
140144
)

tests/jobs/test_smt_engine_build_job.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(self, decoy: Decoy) -> None:
6565
source_tokens="Por favor , tengo reservada una habitación .".split(),
6666
target_tokens="Please , I have booked a room .".split(),
6767
confidences=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
68+
sequence_confidence=0.5,
6869
sources=[
6970
TranslationSources.SMT,
7071
TranslationSources.SMT,
@@ -140,6 +141,7 @@ def __init__(self, decoy: Decoy) -> None:
140141
sourceTokens=[],
141142
translationTokens=[],
142143
alignment="",
144+
sequence_confidence=0.5,
143145
)
144146
]
145147
)

tests/translation/huggingface/test_hugging_face_nmt_engine.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pytest import approx, mark, raises
99

1010
from machine.translation.huggingface import HuggingFaceNmtEngine
11+
from machine.translation.translation_result import TranslationResult
1112

1213

1314
@mark.parametrize("output_attentions", [True, False])
@@ -26,16 +27,23 @@ def test_translate_n_batch_beam(output_attentions: bool) -> None:
2627
)
2728
assert results[0][0].translation == "skaberskaber Dollar Dollar ፤ ፤ gerekir gerekir"
2829
assert results[0][0].confidences[0] == approx(1.08e-05, 0.01)
30+
assert results[0][0].sequence_confidence == approx(_get_sequence_confidence(results[0][0]), 0.01)
2931
assert str(results[0][0].alignment) == ("2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" if output_attentions else "")
32+
3033
assert results[0][1].translation == "skaberskaber Dollar Dollar ፤ ፤ ፤ gerekir"
3134
assert results[0][1].confidences[0] == approx(1.08e-05, 0.01)
35+
assert results[0][1].sequence_confidence == approx(_get_sequence_confidence(results[0][0]), 0.01)
3236
assert str(results[0][1].alignment) == ("2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" if output_attentions else "")
37+
3338
assert results[1][0].translation == "skaberskaber Dollar Dollar ፤ ፤ gerekir gerekir"
3439
assert results[1][0].confidences[0] == approx(1.08e-05, 0.01)
3540
assert str(results[1][0].alignment) == ("0-1 0-2 0-7 1-0 3-3 3-4 3-5 3-6" if output_attentions else "")
41+
assert results[1][0].sequence_confidence == approx(_get_sequence_confidence(results[0][0]), 0.01)
42+
3643
assert results[1][1].translation == "skaberskaber Dollar Dollar ፤ ፤ ፤ gerekir"
3744
assert results[1][1].confidences[0] == approx(1.08e-05, 0.01)
3845
assert str(results[1][1].alignment) == ("0-1 0-2 0-7 1-0 3-3 3-4 3-5 3-6" if output_attentions else "")
46+
assert results[1][1].sequence_confidence == approx(_get_sequence_confidence(results[0][0]), 0.01)
3947

4048

4149
@mark.parametrize("output_attentions", [True, False])
@@ -46,10 +54,16 @@ def test_translate_greedy(output_attentions: bool) -> None:
4654
result = engine.translate("This is a test string")
4755
assert result.translation == "skaberskaber Dollar Dollar Dollar ፤ gerekir gerekir"
4856
assert result.confidences[0] == approx(1.08e-05, 0.01)
57+
assert result.sequence_confidence == approx(_get_sequence_confidence(result), 0.01)
4958
assert str(result.alignment) == ("2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" if output_attentions else "")
5059

5160

5261
@mark.parametrize("output_attentions", [True, False])
5362
def test_construct_invalid_lang(output_attentions: bool) -> None:
5463
with raises(ValueError):
5564
HuggingFaceNmtEngine("stas/tiny-m2m_100", src_lang="qaa", tgt_lang="es", output_attentions=output_attentions)
65+
66+
67+
def _get_sequence_confidence(result: TranslationResult) -> float:
68+
# Inject a 0 score for the BOS token
69+
return sum(list(result.confidences) + [0]) / (len(result.confidences) + 1)

0 commit comments

Comments
 (0)