|
1 | 1 | import logging |
2 | 2 | from contextlib import ExitStack |
| 3 | +from pathlib import Path |
| 4 | +from tempfile import TemporaryDirectory |
3 | 5 | from typing import Any, Callable, Optional, Sequence, Tuple |
4 | 6 |
|
5 | 7 | from ..corpora.corpora_utils import batch |
|
8 | 10 | from ..translation.translation_engine import TranslationEngine |
9 | 11 | from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter |
10 | 12 | from ..utils.progress_status import ProgressStatus |
| 13 | +from .eflomal_aligner import EflomalAligner, is_eflomal_available, tokenize |
11 | 14 | from .nmt_model_factory import NmtModelFactory |
12 | 15 | from .shared_file_service_base import DictToJsonWriter |
13 | 16 | from .translation_engine_build_job import TranslationEngineBuildJob |
@@ -102,18 +105,66 @@ def _batch_inference( |
102 | 105 | with ExitStack() as stack: |
103 | 106 | phase_progress = stack.enter_context(progress_reporter.start_next_phase()) |
104 | 107 | engine = stack.enter_context(self._nmt_model_factory.create_engine()) |
105 | | - src_pretranslations = stack.enter_context(self._translation_file_service.get_source_pretranslations()) |
106 | | - writer = stack.enter_context(self._translation_file_service.open_target_pretranslation_writer()) |
| 108 | + pretranslations = [ |
| 109 | + pt_info for pt_info in stack.enter_context(self._translation_file_service.get_source_pretranslations()) |
| 110 | + ] |
| 111 | + src_segments = [pt_info["translation"] for pt_info in pretranslations] |
107 | 112 | current_inference_step = 0 |
108 | 113 | phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) |
109 | 114 | batch_size = self._config["inference_batch_size"] |
110 | | - for pi_batch in batch(src_pretranslations, batch_size): |
| 115 | + for seg_batch in batch(iter(src_segments), batch_size): |
111 | 116 | if check_canceled is not None: |
112 | 117 | check_canceled() |
113 | | - _translate_batch(engine, pi_batch, writer) |
114 | | - current_inference_step += len(pi_batch) |
| 118 | + for i, result in enumerate(engine.translate_batch(seg_batch)): |
| 119 | + pretranslations[current_inference_step + i]["translation"] = result.translation |
| 120 | + current_inference_step += len(seg_batch) |
115 | 121 | phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) |
116 | 122 |
|
| 123 | + if self._config.align_pretranslations and is_eflomal_available(): |
| 124 | + logger.info("Aligning source to pretranslations") |
| 125 | + pretranslations = self._align(src_segments, pretranslations, progress_reporter, check_canceled) |
| 126 | + |
| 127 | + writer = stack.enter_context(self._translation_file_service.open_target_pretranslation_writer()) |
| 128 | + for pretranslation in pretranslations: |
| 129 | + writer.write(pretranslation) |
| 130 | + |
| 131 | + def _align( |
| 132 | + self, |
| 133 | + src_segments: Sequence[str], |
| 134 | + pretranslations: Sequence[PretranslationInfo], |
| 135 | + progress_reporter: PhasedProgressReporter, |
| 136 | + check_canceled: Optional[Callable[[], None]], |
| 137 | + ) -> Sequence[PretranslationInfo]: |
| 138 | + if check_canceled is not None: |
| 139 | + check_canceled() |
| 140 | + |
| 141 | + logger.info("Aligning source to pretranslations") |
| 142 | + progress_reporter.start_next_phase() |
| 143 | + |
| 144 | + src_tokenized = [tokenize(s) for s in src_segments] |
| 145 | + trg_tokenized = [tokenize(pt_info["translation"]) for pt_info in pretranslations] |
| 146 | + |
| 147 | + with TemporaryDirectory() as td: |
| 148 | + aligner = EflomalAligner(Path(td)) |
| 149 | + logger.info("Training aligner") |
| 150 | + aligner.train(src_tokenized, trg_tokenized) |
| 151 | + |
| 152 | + if check_canceled is not None: |
| 153 | + check_canceled() |
| 154 | + |
| 155 | + logger.info("Aligning pretranslations") |
| 156 | + alignments = aligner.align() |
| 157 | + |
| 158 | + if check_canceled is not None: |
| 159 | + check_canceled() |
| 160 | + |
| 161 | + for i in range(len(pretranslations)): |
| 162 | + pretranslations[i]["source_toks"] = list(src_tokenized[i]) |
| 163 | + pretranslations[i]["translation_toks"] = list(trg_tokenized[i]) |
| 164 | + pretranslations[i]["alignment"] = alignments[i] |
| 165 | + |
| 166 | + return pretranslations |
| 167 | + |
117 | 168 | def _save_model(self) -> None: |
118 | 169 | if "save_model" in self._config and self._config.save_model is not None: |
119 | 170 | logger.info("Saving model") |
|
0 commit comments