11import logging
22from contextlib import ExitStack
3+ from pathlib import Path
4+ from tempfile import TemporaryDirectory
35from typing import Any , Callable , Optional , Sequence , Tuple
46
57from ..corpora .corpora_utils import batch
68from ..corpora .parallel_text_corpus import ParallelTextCorpus
79from ..corpora .text_corpus import TextCorpus
8- from ..translation .translation_engine import TranslationEngine
910from ..utils .phased_progress_reporter import Phase , PhasedProgressReporter
1011from ..utils .progress_status import ProgressStatus
12+ from .eflomal_aligner import EflomalAligner , is_eflomal_available , tokenize
1113from .nmt_model_factory import NmtModelFactory
12- from .shared_file_service_base import DictToJsonWriter
1314from .translation_engine_build_job import TranslationEngineBuildJob
1415from .translation_file_service import PretranslationInfo , TranslationFileService
1516
@@ -28,12 +29,25 @@ def _get_progress_reporter(
2829 self , progress : Optional [Callable [[ProgressStatus ], None ]], corpus_size : int
2930 ) -> PhasedProgressReporter :
3031 if corpus_size > 0 :
31- phases = [
32- Phase (message = "Training NMT model" , percentage = 0.9 ),
33- Phase (message = "Pretranslating segments" , percentage = 0.1 ),
34- ]
32+ if self ._config .align_pretranslations :
33+ phases = [
34+ Phase (message = "Training NMT model" , percentage = 0.8 ),
35+ Phase (message = "Pretranslating segments" , percentage = 0.1 ),
36+ Phase (message = "Aligning segments" , percentage = 0.1 , report_steps = False ),
37+ ]
38+ else :
39+ phases = [
40+ Phase (message = "Training NMT model" , percentage = 0.9 ),
41+ Phase (message = "Pretranslating segments" , percentage = 0.1 ),
42+ ]
3543 else :
36- phases = [Phase (message = "Pretranslating segments" , percentage = 1.0 )]
44+ if self ._config .align_pretranslations :
45+ phases = [
46+ Phase (message = "Pretranslating segments" , percentage = 0.9 ),
47+ Phase (message = "Aligning segments" , percentage = 0.1 , report_steps = False ),
48+ ]
49+ else :
50+ phases = [Phase (message = "Pretranslating segments" , percentage = 1.0 )]
3751 return PhasedProgressReporter (progress , phases )
3852
3953 def _respond_to_no_training_corpus (self ) -> Tuple [int , float ]:
@@ -89,33 +103,70 @@ def _batch_inference(
89103 with ExitStack () as stack :
90104 phase_progress = stack .enter_context (progress_reporter .start_next_phase ())
91105 engine = stack .enter_context (self ._nmt_model_factory .create_engine ())
92- src_pretranslations = stack .enter_context (self ._translation_file_service .get_source_pretranslations ())
93- writer = stack .enter_context (self ._translation_file_service .open_target_pretranslation_writer ())
106+ pretranslations = [
107+ pt_info for pt_info in stack .enter_context (self ._translation_file_service .get_source_pretranslations ())
108+ ]
109+ src_segments = [pt_info ["translation" ] for pt_info in pretranslations ]
94110 current_inference_step = 0
95111 phase_progress (ProgressStatus .from_step (current_inference_step , inference_step_count ))
96112 batch_size = self ._config ["inference_batch_size" ]
97- for pi_batch in batch (src_pretranslations , batch_size ):
113+ for seg_batch in batch (iter ( src_segments ) , batch_size ):
98114 if check_canceled is not None :
99115 check_canceled ()
100- _translate_batch (engine , pi_batch , writer )
101- current_inference_step += len (pi_batch )
116+ for i , result in enumerate (engine .translate_batch (seg_batch )):
117+ pretranslations [current_inference_step + i ]["translation" ] = result .translation
118+ current_inference_step += len (seg_batch )
102119 phase_progress (ProgressStatus .from_step (current_inference_step , inference_step_count ))
103120
121+ if self ._config .align_pretranslations and is_eflomal_available ():
122+ logger .info ("Aligning source to pretranslations" )
123+ pretranslations = self ._align (src_segments , pretranslations , progress_reporter , check_canceled )
124+
125+ writer = stack .enter_context (self ._translation_file_service .open_target_pretranslation_writer ())
126+ for pretranslation in pretranslations :
127+ writer .write (pretranslation )
128+
129+ def _align (
130+ self ,
131+ src_segments : Sequence [str ],
132+ pretranslations : Sequence [PretranslationInfo ],
133+ progress_reporter : PhasedProgressReporter ,
134+ check_canceled : Optional [Callable [[], None ]],
135+ ) -> Sequence [PretranslationInfo ]:
136+ if check_canceled is not None :
137+ check_canceled ()
138+
139+ logger .info ("Aligning source to pretranslations" )
140+ progress_reporter .start_next_phase ()
141+
142+ src_tokenized = [tokenize (s ) for s in src_segments ]
143+ trg_tokenized = [tokenize (pt_info ["translation" ]) for pt_info in pretranslations ]
144+
145+ with TemporaryDirectory () as td :
146+ aligner = EflomalAligner (Path (td ))
147+ logger .info ("Training aligner" )
148+ aligner .train (src_tokenized , trg_tokenized )
149+
150+ if check_canceled is not None :
151+ check_canceled ()
152+
153+ logger .info ("Aligning pretranslations" )
154+ alignments = aligner .align ()
155+
156+ if check_canceled is not None :
157+ check_canceled ()
158+
159+ for i in range (len (pretranslations )):
160+ pretranslations [i ]["source_toks" ] = list (src_tokenized [i ])
161+ pretranslations [i ]["translation_toks" ] = list (trg_tokenized [i ])
162+ pretranslations [i ]["alignment" ] = alignments [i ]
163+
164+ return pretranslations
165+
104166 def _save_model (self ) -> None :
105167 if "save_model" in self ._config and self ._config .save_model is not None :
106168 logger .info ("Saving model" )
107169 model_path = self ._nmt_model_factory .save_model ()
108170 self ._translation_file_service .save_model (
109171 model_path , f"models/{ self ._config .save_model + '' .join (model_path .suffixes )} "
110172 )
111-
112-
113- def _translate_batch (
114- engine : TranslationEngine ,
115- batch : Sequence [PretranslationInfo ],
116- writer : DictToJsonWriter ,
117- ) -> None :
118- source_segments = [pi ["translation" ] for pi in batch ]
119- for i , result in enumerate (engine .translate_batch (source_segments )):
120- batch [i ]["translation" ] = result .translation
121- writer .write (batch [i ])
0 commit comments