Skip to content

Commit 3774258

Browse files
committed
Refactor to do alignment during the inference step
1 parent 123263f commit 3774258

File tree

4 files changed

+62
-91
lines changed

4 files changed

+62
-91
lines changed

machine/jobs/nmt_engine_build_job.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
22
from contextlib import ExitStack
3+
from pathlib import Path
4+
from tempfile import TemporaryDirectory
35
from typing import Any, Callable, Optional, Sequence, Tuple
46

57
from ..corpora.corpora_utils import batch
@@ -8,6 +10,7 @@
810
from ..translation.translation_engine import TranslationEngine
911
from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter
1012
from ..utils.progress_status import ProgressStatus
13+
from .eflomal_aligner import EflomalAligner, is_eflomal_available, tokenize
1114
from .nmt_model_factory import NmtModelFactory
1215
from .shared_file_service_base import DictToJsonWriter
1316
from .translation_engine_build_job import TranslationEngineBuildJob
@@ -102,18 +105,66 @@ def _batch_inference(
102105
with ExitStack() as stack:
103106
phase_progress = stack.enter_context(progress_reporter.start_next_phase())
104107
engine = stack.enter_context(self._nmt_model_factory.create_engine())
105-
src_pretranslations = stack.enter_context(self._translation_file_service.get_source_pretranslations())
106-
writer = stack.enter_context(self._translation_file_service.open_target_pretranslation_writer())
108+
pretranslations = [
109+
pt_info for pt_info in stack.enter_context(self._translation_file_service.get_source_pretranslations())
110+
]
111+
src_segments = [pt_info["translation"] for pt_info in pretranslations]
107112
current_inference_step = 0
108113
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))
109114
batch_size = self._config["inference_batch_size"]
110-
for pi_batch in batch(src_pretranslations, batch_size):
115+
for seg_batch in batch(iter(src_segments), batch_size):
111116
if check_canceled is not None:
112117
check_canceled()
113-
_translate_batch(engine, pi_batch, writer)
114-
current_inference_step += len(pi_batch)
118+
for i, result in enumerate(engine.translate_batch(seg_batch)):
119+
pretranslations[current_inference_step + i]["translation"] = result.translation
120+
current_inference_step += len(seg_batch)
115121
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))
116122

123+
if self._config.align_pretranslations and is_eflomal_available():
124+
logger.info("Aligning source to pretranslations")
125+
pretranslations = self._align(src_segments, pretranslations, progress_reporter, check_canceled)
126+
127+
writer = stack.enter_context(self._translation_file_service.open_target_pretranslation_writer())
128+
for pretranslation in pretranslations:
129+
writer.write(pretranslation)
130+
131+
def _align(
132+
self,
133+
src_segments: Sequence[str],
134+
pretranslations: Sequence[PretranslationInfo],
135+
progress_reporter: PhasedProgressReporter,
136+
check_canceled: Optional[Callable[[], None]],
137+
) -> Sequence[PretranslationInfo]:
138+
if check_canceled is not None:
139+
check_canceled()
140+
141+
logger.info("Aligning source to pretranslations")
142+
progress_reporter.start_next_phase()
143+
144+
src_tokenized = [tokenize(s) for s in src_segments]
145+
trg_tokenized = [tokenize(pt_info["translation"]) for pt_info in pretranslations]
146+
147+
with TemporaryDirectory() as td:
148+
aligner = EflomalAligner(Path(td))
149+
logger.info("Training aligner")
150+
aligner.train(src_tokenized, trg_tokenized)
151+
152+
if check_canceled is not None:
153+
check_canceled()
154+
155+
logger.info("Aligning pretranslations")
156+
alignments = aligner.align()
157+
158+
if check_canceled is not None:
159+
check_canceled()
160+
161+
for i in range(len(pretranslations)):
162+
pretranslations[i]["source_toks"] = list(src_tokenized[i])
163+
pretranslations[i]["translation_toks"] = list(trg_tokenized[i])
164+
pretranslations[i]["alignment"] = alignments[i]
165+
166+
return pretranslations
167+
117168
def _save_model(self) -> None:
118169
if "save_model" in self._config and self._config.save_model is not None:
119170
logger.info("Saving model")
Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
import logging
22
from abc import ABC, abstractmethod
3-
from contextlib import ExitStack
4-
from pathlib import Path
5-
from tempfile import TemporaryDirectory
63
from typing import Any, Callable, Optional, Tuple
74

85
from ..corpora.parallel_text_corpus import ParallelTextCorpus
96
from ..corpora.text_corpus import TextCorpus
107
from ..utils.phased_progress_reporter import PhasedProgressReporter
118
from ..utils.progress_status import ProgressStatus
12-
from .eflomal_aligner import EflomalAligner, is_eflomal_available, tokenize
13-
from .translation_file_service import PretranslationInfo, TranslationFileService
9+
from .translation_file_service import TranslationFileService
1410

1511
logger = logging.getLogger(__name__)
1612

@@ -48,10 +44,6 @@ def run(
4844
logger.info("Pretranslating segments")
4945
self._batch_inference(progress_reporter, check_canceled)
5046

51-
if self._config.align_pretranslations and is_eflomal_available():
52-
logger.info("Aligning source to pretranslations")
53-
self._align(progress_reporter, check_canceled)
54-
5547
self._save_model()
5648
return train_corpus_size, confidence
5749

@@ -82,55 +74,5 @@ def _batch_inference(
8274
check_canceled: Optional[Callable[[], None]],
8375
) -> None: ...
8476

85-
def _align(
86-
self,
87-
progress_reporter: PhasedProgressReporter,
88-
check_canceled: Optional[Callable[[], None]],
89-
) -> None:
90-
if check_canceled is not None:
91-
check_canceled()
92-
93-
logger.info("Aligning source to pretranslations")
94-
with ExitStack() as stack:
95-
# phase_progress = stack.enter_context(progress_reporter.start_next_phase())
96-
progress_reporter.start_next_phase()
97-
98-
src_tokenized = [
99-
tokenize(s["translation"])
100-
for s in stack.enter_context(self._translation_file_service.get_source_pretranslations())
101-
]
102-
trg_info = [
103-
pt_info for pt_info in stack.enter_context(self._translation_file_service.get_target_pretranslations())
104-
]
105-
trg_tokenized = [tokenize(pt_info["translation"]) for pt_info in trg_info]
106-
107-
with TemporaryDirectory() as td:
108-
aligner = EflomalAligner(Path(td))
109-
logger.info("Training aligner")
110-
aligner.train(src_tokenized, trg_tokenized)
111-
112-
if check_canceled is not None:
113-
check_canceled()
114-
115-
logger.info("Aligning pretranslations")
116-
alignments = aligner.align()
117-
118-
if check_canceled is not None:
119-
check_canceled()
120-
121-
writer = stack.enter_context(self._translation_file_service.open_target_pretranslation_writer())
122-
for trg_pi, src_toks, trg_toks, alignment in zip(trg_info, src_tokenized, trg_tokenized, alignments):
123-
writer.write(
124-
PretranslationInfo(
125-
corpusId=trg_pi["corpusId"],
126-
textId=trg_pi["textId"],
127-
refs=trg_pi["refs"],
128-
translation=trg_pi["translation"],
129-
source_toks=list(src_toks),
130-
translation_toks=list(trg_toks),
131-
alignment=alignment,
132-
)
133-
)
134-
13577
@abstractmethod
13678
def _save_model(self) -> None: ...

machine/jobs/translation_file_service.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,13 @@ def exists_source_corpus(self) -> bool:
5252
def exists_target_corpus(self) -> bool:
5353
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{TARGET_FILENAME}")
5454

55-
def _get_pretranslations(self, filename: str) -> ContextManagedGenerator[PretranslationInfo, None, None]:
56-
pretranslate_path = self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{filename}")
55+
def get_source_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]:
56+
src_pretranslate_path = self.shared_file_service.download_file(
57+
f"{self.shared_file_service.build_path}/{SOURCE_PRETRANSLATION_FILENAME}"
58+
)
5759

5860
def generator() -> Generator[PretranslationInfo, None, None]:
59-
with pretranslate_path.open("r", encoding="utf-8-sig") as file:
61+
with src_pretranslate_path.open("r", encoding="utf-8-sig") as file:
6062
for pi in json_stream.load(file):
6163
yield PretranslationInfo(
6264
corpusId=pi["corpusId"],
@@ -70,12 +72,6 @@ def generator() -> Generator[PretranslationInfo, None, None]:
7072

7173
return ContextManagedGenerator(generator())
7274

73-
def get_source_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]:
74-
return self._get_pretranslations(SOURCE_PRETRANSLATION_FILENAME)
75-
76-
def get_target_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]:
77-
return self._get_pretranslations(TARGET_PRETRANSLATION_FILENAME)
78-
7975
def save_model(self, model_path: Path, destination: str) -> None:
8076
self.shared_file_service.upload_path(model_path, destination)
8177

tests/jobs/test_nmt_engine_build_job.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -139,24 +139,6 @@ def __init__(self, decoy: Decoy) -> None:
139139
)
140140
)
141141
)
142-
decoy.when(self.translation_file_service.get_target_pretranslations()).then_do(
143-
lambda: ContextManagedGenerator(
144-
(
145-
pi
146-
for pi in [
147-
PretranslationInfo(
148-
corpusId="corpus1",
149-
textId="text1",
150-
refs=["ref1"],
151-
translation="Please, I have booked a room.",
152-
source_toks=[],
153-
translation_toks=[],
154-
alignment="",
155-
)
156-
]
157-
)
158-
)
159-
)
160142

161143
self.target_pretranslations = ""
162144

0 commit comments

Comments
 (0)