Skip to content

Commit 123263f

Browse files
committed
Move alignment config option to build options, revert to 'translation' in PretranslationInfo
1 parent 881a795 commit 123263f

File tree

8 files changed

+34
-33
lines changed

8 files changed

+34
-33
lines changed

machine/jobs/build_nmt_engine.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,6 @@ def main() -> None:
9292
parser.add_argument("--clearml", default=False, action="store_true", help="Initializes a ClearML task")
9393
parser.add_argument("--build-options", default=None, type=str, help="Build configurations")
9494
parser.add_argument("--save-model", default=None, type=str, help="Save the model using the specified base name")
95-
parser.add_argument(
96-
"--align-pretranslations",
97-
default=False,
98-
action="store_true",
99-
help="Aligns source and target pretranslations using Eflomal (linux only) "
100-
"and returns the alignments as well as the tokenized source and target with the pretranslations.",
101-
)
10295
args = parser.parse_args()
10396

10497
run({k: v for k, v in vars(args).items() if v is not None})

machine/jobs/nmt_engine_build_job.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _get_progress_reporter(
2828
self, progress: Optional[Callable[[ProgressStatus], None]], corpus_size: int
2929
) -> PhasedProgressReporter:
3030
if corpus_size > 0:
31-
if "align_pretranslations" in self._config and self._config.align_pretranslations:
31+
if self._config.align_pretranslations:
3232
phases = [
3333
Phase(message="Training NMT model", percentage=0.8),
3434
Phase(message="Pretranslating segments", percentage=0.1),
@@ -40,7 +40,7 @@ def _get_progress_reporter(
4040
Phase(message="Pretranslating segments", percentage=0.1),
4141
]
4242
else:
43-
if "align_pretranslations" in self._config and self._config.align_pretranslations:
43+
if self._config.align_pretranslations:
4444
phases = [
4545
Phase(message="Pretranslating segments", percentage=0.9),
4646
Phase(message="Aligning segments", percentage=0.1, report_steps=False),
@@ -128,7 +128,7 @@ def _translate_batch(
128128
batch: Sequence[PretranslationInfo],
129129
writer: DictToJsonWriter,
130130
) -> None:
131-
source_segments = [pi["pretranslation"] for pi in batch]
131+
source_segments = [pi["translation"] for pi in batch]
132132
for i, result in enumerate(engine.translate_batch(source_segments)):
133-
batch[i]["pretranslation"] = result.translation
133+
batch[i]["translation"] = result.translation
134134
writer.write(batch[i])

machine/jobs/settings.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ default:
33
shared_file_uri: s3:/silnlp/
44
shared_file_folder: production
55
inference_batch_size: 1024
6+
align_pretranslations: false
67
huggingface:
78
parent_model_name: facebook/nllb-200-distilled-1.3B
89
train_params:

machine/jobs/smt_engine_build_job.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _translate_batch(
107107
batch: Sequence[PretranslationInfo],
108108
writer: DictToJsonWriter,
109109
) -> None:
110-
source_segments = [pi["pretranslation"] for pi in batch]
110+
source_segments = [pi["translation"] for pi in batch]
111111
for i, result in enumerate(engine.translate_batch(source_segments)):
112-
batch[i]["pretranslation"] = result.translation
112+
batch[i]["translation"] = result.translation
113113
writer.write(batch[i])

machine/jobs/translation_engine_build_job.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def run(
4848
logger.info("Pretranslating segments")
4949
self._batch_inference(progress_reporter, check_canceled)
5050

51-
if "align_pretranslations" in self._config and self._config.align_pretranslations and is_eflomal_available():
51+
if self._config.align_pretranslations and is_eflomal_available():
5252
logger.info("Aligning source to pretranslations")
5353
self._align(progress_reporter, check_canceled)
5454

@@ -96,13 +96,13 @@ def _align(
9696
progress_reporter.start_next_phase()
9797

9898
src_tokenized = [
99-
tokenize(s["pretranslation"])
99+
tokenize(s["translation"])
100100
for s in stack.enter_context(self._translation_file_service.get_source_pretranslations())
101101
]
102102
trg_info = [
103103
pt_info for pt_info in stack.enter_context(self._translation_file_service.get_target_pretranslations())
104104
]
105-
trg_tokenized = [tokenize(pt_info["pretranslation"]) for pt_info in trg_info]
105+
trg_tokenized = [tokenize(pt_info["translation"]) for pt_info in trg_info]
106106

107107
with TemporaryDirectory() as td:
108108
aligner = EflomalAligner(Path(td))
@@ -125,9 +125,9 @@ def _align(
125125
corpusId=trg_pi["corpusId"],
126126
textId=trg_pi["textId"],
127127
refs=trg_pi["refs"],
128-
pretranslation=trg_pi["pretranslation"],
128+
translation=trg_pi["translation"],
129129
source_toks=list(src_toks),
130-
pretranslation_toks=list(trg_toks),
130+
translation_toks=list(trg_toks),
131131
alignment=alignment,
132132
)
133133
)

machine/jobs/translation_file_service.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ class PretranslationInfo(TypedDict):
1515
corpusId: str # noqa: N815
1616
textId: str # noqa: N815
1717
refs: List[str]
18-
pretranslation: str
18+
translation: str
1919
source_toks: List[str]
20-
pretranslation_toks: List[str]
20+
translation_toks: List[str]
2121
alignment: str
2222

2323

@@ -62,9 +62,9 @@ def generator() -> Generator[PretranslationInfo, None, None]:
6262
corpusId=pi["corpusId"],
6363
textId=pi["textId"],
6464
refs=list(pi["refs"]),
65-
pretranslation=pi["pretranslation"],
65+
translation=pi["translation"],
6666
source_toks=list(pi["source_toks"]),
67-
pretranslation_toks=list(pi["pretranslation_toks"]),
67+
translation_toks=list(pi["translation_toks"]),
6868
alignment=pi["alignment"],
6969
)
7070

tests/jobs/test_nmt_engine_build_job.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_run(decoy: Decoy) -> None:
3636

3737
pretranslations = json.loads(env.target_pretranslations)
3838
assert len(pretranslations) == 1
39-
assert pretranslations[0]["pretranslation"] == "Please, I have booked a room."
39+
assert pretranslations[0]["translation"] == "Please, I have booked a room."
4040
if is_eflomal_available():
4141
assert pretranslations[0]["source_toks"] == [
4242
"Por",
@@ -48,11 +48,11 @@ def test_run(decoy: Decoy) -> None:
4848
"habitación",
4949
".",
5050
]
51-
assert pretranslations[0]["pretranslation_toks"] == ["Please", ",", "I", "have", "booked", "a", "room", "."]
51+
assert pretranslations[0]["translation_toks"] == ["Please", ",", "I", "have", "booked", "a", "room", "."]
5252
assert len(pretranslations[0]["alignment"]) > 0
5353
else:
5454
assert pretranslations[0]["source_toks"] == []
55-
assert pretranslations[0]["pretranslation_toks"] == []
55+
assert pretranslations[0]["translation_toks"] == []
5656
assert len(pretranslations[0]["alignment"]) == 0
5757
decoy.verify(env.translation_file_service.save_model(Path("model.tar.gz"), "models/save-model.tar.gz"), times=1)
5858

@@ -130,9 +130,9 @@ def __init__(self, decoy: Decoy) -> None:
130130
corpusId="corpus1",
131131
textId="text1",
132132
refs=["ref1"],
133-
pretranslation="Por favor, tengo reservada una habitación.",
133+
translation="Por favor, tengo reservada una habitación.",
134134
source_toks=[],
135-
pretranslation_toks=[],
135+
translation_toks=[],
136136
alignment="",
137137
)
138138
]
@@ -148,9 +148,9 @@ def __init__(self, decoy: Decoy) -> None:
148148
corpusId="corpus1",
149149
textId="text1",
150150
refs=["ref1"],
151-
pretranslation="Please, I have booked a room.",
151+
translation="Please, I have booked a room.",
152152
source_toks=[],
153-
pretranslation_toks=[],
153+
translation_toks=[],
154154
alignment="",
155155
)
156156
]

tests/jobs/test_smt_engine_build_job.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_run(decoy: Decoy) -> None:
3131

3232
pretranslations = json.loads(env.target_pretranslations)
3333
assert len(pretranslations) == 1
34-
assert pretranslations[0]["pretranslation"] == "Please, I have booked a room."
34+
assert pretranslations[0]["translation"] == "Please, I have booked a room."
3535
decoy.verify(
3636
env.translation_file_service.save_model(matchers.Anything(), f"builds/{env.job._config.build_id}/model.zip"),
3737
times=1,
@@ -136,9 +136,9 @@ def __init__(self, decoy: Decoy) -> None:
136136
corpusId="corpus1",
137137
textId="text1",
138138
refs=["ref1"],
139-
pretranslation="Por favor, tengo reservada una habitación.",
139+
translation="Por favor, tengo reservada una habitación.",
140140
source_toks=[],
141-
pretranslation_toks=[],
141+
translation_toks=[],
142142
alignment="",
143143
)
144144
]
@@ -161,7 +161,14 @@ def open_target_pretranslation_writer(env: _TestEnvironment) -> Iterator[DictToJ
161161
)
162162

163163
self.job = SmtEngineBuildJob(
164-
MockSettings({"build_id": "mybuild", "inference_batch_size": 100, "thot_mt": {"tokenizer": "latin"}}),
164+
MockSettings(
165+
{
166+
"build_id": "mybuild",
167+
"inference_batch_size": 100,
168+
"thot_mt": {"tokenizer": "latin"},
169+
"align_pretranslations": False,
170+
}
171+
),
165172
self.smt_model_factory,
166173
self.translation_file_service,
167174
)

0 commit comments

Comments
 (0)