Skip to content

Commit 84dae8e

Browse files
committed
Add optional alignment step to nmt jobs, temporary implementation of eflomal
1 parent a3d7276 commit 84dae8e

11 files changed

+381
-23
lines changed

.devcontainer/dockerfile

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ RUN apt-get update && \
2222
python$PYTHON_VERSION-distutils \
2323
git vim curl gdb ca-certificates gnupg2 tar make gcc libssl-dev zlib1g-dev libncurses5-dev \
2424
libbz2-dev libreadline-dev libreadline6-dev libxml2-dev xz-utils libgdbm-dev libgdbm-compat-dev tk-dev dirmngr \
25-
libxmlsec1-dev libsqlite3-dev libffi-dev liblzma-dev lzma lzma-dev uuid-dev && \
25+
libxmlsec1-dev libsqlite3-dev libffi-dev liblzma-dev lzma lzma-dev uuid-dev python3.9-dev && \
2626
rm -rf /var/lib/apt/lists/*
2727

2828
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python$PYTHON_VERSION
@@ -39,4 +39,6 @@ RUN pip install -U pip setuptools \
3939

4040
COPY ./.devcontainer/clearml.conf /root/clearml.conf
4141

42+
ENV EFLOMAL_PATH=/workspaces/machine.py/.venv/lib/python3.9/site-packages/eflomal/bin
43+
4244
CMD ["bash"]

machine/jobs/build_nmt_engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ def main() -> None:
9292
parser.add_argument("--clearml", default=False, action="store_true", help="Initializes a ClearML task")
9393
parser.add_argument("--build-options", default=None, type=str, help="Build configurations")
9494
parser.add_argument("--save-model", default=None, type=str, help="Save the model using the specified base name")
95+
parser.add_argument(
96+
"--align-pretranslations", default=False, action="store_true", help="Aligns source and target pretranslations"
97+
)
9598
args = parser.parse_args()
9699

97100
run({k: v for k, v in vars(args).items() if v is not None})

machine/jobs/eflomal_aligner.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# NOTE: this is a temporary solution to be able to use the eflomal aligner inside of machine.py.
2+
# The vast majority of this code is taken from the silnlp repository.
3+
4+
import os
5+
import subprocess
6+
from contextlib import ExitStack
7+
from math import sqrt
8+
from pathlib import Path
9+
from tempfile import TemporaryDirectory
10+
from typing import IO, Iterable, List, Sequence, Tuple
11+
12+
from eflomal import read_text, write_text
13+
14+
from ..corpora import AlignedWordPair
15+
from ..corpora.token_processors import escape_spaces, lowercase, normalize
16+
from ..tokenization import LatinWordTokenizer
17+
from ..translation import SymmetrizationHeuristic, WordAlignmentMatrix
18+
19+
# may have to make more dynamic, look at silnlp get_wsl_path, is there something equivalent in machine?
20+
EFLOMAL_PATH = Path(os.getenv("EFLOMAL_PATH", "."), "eflomal")
21+
TOKENIZER = LatinWordTokenizer()
22+
23+
24+
# From silnlp.alignment.tools
25+
def execute_eflomal(
26+
source_path: Path,
27+
target_path: Path,
28+
forward_links_path: Path,
29+
reverse_links_path: Path,
30+
n_iterations: Tuple[int, int, int],
31+
) -> None:
32+
if not EFLOMAL_PATH.is_file():
33+
raise RuntimeError("eflomal is not installed.")
34+
35+
args = [
36+
str(EFLOMAL_PATH),
37+
"-s",
38+
str(source_path),
39+
"-t",
40+
str(target_path),
41+
"-f",
42+
str(forward_links_path),
43+
"-r",
44+
str(reverse_links_path),
45+
# "-q",
46+
"-m",
47+
"3",
48+
"-n",
49+
"3",
50+
"-N",
51+
"0.2",
52+
"-1",
53+
str(n_iterations[0]),
54+
"-2",
55+
str(n_iterations[1]),
56+
"-3",
57+
str(n_iterations[2]),
58+
]
59+
subprocess.run(args, stderr=subprocess.DEVNULL)
60+
61+
62+
# From silnlp.alignment.eflomal
63+
def to_word_alignment_matrix(alignment_str: str) -> WordAlignmentMatrix:
64+
word_pairs = AlignedWordPair.from_string(alignment_str)
65+
row_count = 0
66+
column_count = 0
67+
for pair in word_pairs:
68+
if pair.source_index + 1 > row_count:
69+
row_count = pair.source_index + 1
70+
if pair.target_index + 1 > column_count:
71+
column_count = pair.target_index + 1
72+
return WordAlignmentMatrix.from_word_pairs(row_count, column_count, word_pairs)
73+
74+
75+
# From silnlp.alignment.eflomal
76+
def to_eflomal_text_file(input: Iterable[str], output_file: IO[bytes], prefix_len: int = 0, suffix_len: int = 0) -> int:
77+
sents, index = read_text(input, True, prefix_len, suffix_len)
78+
n_sents = len(sents)
79+
voc_size = len(index)
80+
write_text(output_file, tuple(sents), voc_size)
81+
return n_sents
82+
83+
84+
# From silnlp.alignment.eflomal
85+
def prepare_files(
86+
src_input: Iterable[str], src_output_file: IO[bytes], trg_input: Iterable[str], trg_output_file: IO[bytes]
87+
) -> int:
88+
n_src_sents = to_eflomal_text_file(src_input, src_output_file)
89+
n_trg_sents = to_eflomal_text_file(trg_input, trg_output_file)
90+
if n_src_sents != n_trg_sents:
91+
raise ValueError("Mismatched file sizes")
92+
return n_src_sents
93+
94+
95+
def tokenize(sent: str) -> Sequence[str]:
96+
return lowercase(normalize("NFC", escape_spaces(list(TOKENIZER.tokenize(sent)))))
97+
98+
99+
# From silnlp.alignment.eflomal
100+
class EflomalAligner:
101+
def __init__(self, model_dir: Path) -> None:
102+
self._model_dir = model_dir
103+
104+
def train(self, src_toks: Sequence[Sequence[str]], trg_toks: Sequence[Sequence[str]]) -> None:
105+
self._model_dir.mkdir(exist_ok=True)
106+
with TemporaryDirectory() as temp_dir:
107+
src_eflomal_path = Path(temp_dir, "source")
108+
trg_eflomal_path = Path(temp_dir, "target")
109+
with ExitStack() as stack:
110+
src_output_file = stack.enter_context(src_eflomal_path.open("wb"))
111+
trg_output_file = stack.enter_context(trg_eflomal_path.open("wb"))
112+
# Write input files for the eflomal binary
113+
n_sentences = prepare_files(
114+
[" ".join(s) for s in src_toks], src_output_file, [" ".join(s) for s in trg_toks], trg_output_file
115+
)
116+
117+
iters = max(2, int(round(1.0 * 5000 / sqrt(n_sentences))))
118+
iters4 = max(1, iters // 4)
119+
n_iterations = (max(2, iters4), iters4, iters)
120+
121+
# Run wrapper for the eflomal binary
122+
execute_eflomal(
123+
src_eflomal_path,
124+
trg_eflomal_path,
125+
self._model_dir / "forward-align.txt",
126+
self._model_dir / "reverse-align.txt",
127+
n_iterations,
128+
)
129+
130+
def align(self, sym_heuristic: str = "grow-diag-final-and") -> List[str]:
131+
forward_align_path = self._model_dir / "forward-align.txt"
132+
reverse_align_path = self._model_dir / "reverse-align.txt"
133+
134+
alignments = []
135+
heuristic = SymmetrizationHeuristic[sym_heuristic.upper().replace("-", "_")]
136+
with ExitStack() as stack:
137+
forward_file = stack.enter_context(forward_align_path.open("r", encoding="utf-8-sig"))
138+
reverse_file = stack.enter_context(reverse_align_path.open("r", encoding="utf-8-sig"))
139+
140+
for forward_line, reverse_line in zip(forward_file, reverse_file):
141+
forward_matrix = to_word_alignment_matrix(forward_line.strip())
142+
reverse_matrix = to_word_alignment_matrix(reverse_line.strip())
143+
src_len = max(forward_matrix.row_count, reverse_matrix.row_count)
144+
trg_len = max(forward_matrix.column_count, reverse_matrix.column_count)
145+
146+
forward_matrix.resize(src_len, trg_len)
147+
reverse_matrix.resize(src_len, trg_len)
148+
149+
forward_matrix.symmetrize_with(reverse_matrix, heuristic)
150+
151+
alignments.append(str(forward_matrix))
152+
153+
return alignments

machine/jobs/nmt_engine_build_job.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,25 @@ def _get_progress_reporter(
2828
self, progress: Optional[Callable[[ProgressStatus], None]], corpus_size: int
2929
) -> PhasedProgressReporter:
3030
if corpus_size > 0:
31-
phases = [
32-
Phase(message="Training NMT model", percentage=0.9),
33-
Phase(message="Pretranslating segments", percentage=0.1),
34-
]
31+
if "align_pretranslations" in self._config and self._config.align_pretranslations:
32+
phases = [
33+
Phase(message="Training NMT model", percentage=0.8),
34+
Phase(message="Pretranslating segments", percentage=0.1),
35+
Phase(message="Aligning segments", percentage=0.1, report_steps=False),
36+
]
37+
else:
38+
phases = [
39+
Phase(message="Training NMT model", percentage=0.9),
40+
Phase(message="Pretranslating segments", percentage=0.1),
41+
]
3542
else:
36-
phases = [Phase(message="Pretranslating segments", percentage=1.0)]
43+
if "align_pretranslations" in self._config and self._config.align_pretranslations:
44+
phases = [
45+
Phase(message="Pretranslating segments", percentage=0.9),
46+
Phase(message="Aligning segments", percentage=0.1, report_steps=False),
47+
]
48+
else:
49+
phases = [Phase(message="Pretranslating segments", percentage=1.0)]
3750
return PhasedProgressReporter(progress, phases)
3851

3952
def _respond_to_no_training_corpus(self) -> Tuple[int, float]:
@@ -115,7 +128,7 @@ def _translate_batch(
115128
batch: Sequence[PretranslationInfo],
116129
writer: DictToJsonWriter,
117130
) -> None:
118-
source_segments = [pi["translation"] for pi in batch]
131+
source_segments = [pi["pretranslation"] for pi in batch]
119132
for i, result in enumerate(engine.translate_batch(source_segments)):
120-
batch[i]["translation"] = result.translation
133+
batch[i]["pretranslation"] = result.translation
121134
writer.write(batch[i])

machine/jobs/smt_engine_build_job.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _translate_batch(
107107
batch: Sequence[PretranslationInfo],
108108
writer: DictToJsonWriter,
109109
) -> None:
110-
source_segments = [pi["translation"] for pi in batch]
110+
source_segments = [pi["pretranslation"] for pi in batch]
111111
for i, result in enumerate(engine.translate_batch(source_segments)):
112-
batch[i]["translation"] = result.translation
112+
batch[i]["pretranslation"] = result.translation
113113
writer.write(batch[i])

machine/jobs/translation_engine_build_job.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import logging
22
from abc import ABC, abstractmethod
3+
from contextlib import ExitStack
4+
from pathlib import Path
5+
from tempfile import TemporaryDirectory
36
from typing import Any, Callable, Optional, Tuple
47

58
from ..corpora.parallel_text_corpus import ParallelTextCorpus
69
from ..corpora.text_corpus import TextCorpus
710
from ..utils.phased_progress_reporter import PhasedProgressReporter
811
from ..utils.progress_status import ProgressStatus
9-
from .translation_file_service import TranslationFileService
12+
from .eflomal_aligner import EflomalAligner, tokenize
13+
from .translation_file_service import PretranslationInfo, TranslationFileService
1014

1115
logger = logging.getLogger(__name__)
1216

@@ -44,6 +48,10 @@ def run(
4448
logger.info("Pretranslating segments")
4549
self._batch_inference(progress_reporter, check_canceled)
4650

51+
if "align_pretranslations" in self._config and self._config.align_pretranslations:
52+
logger.info("Aligning source to pretranslations")
53+
self._align(progress_reporter, check_canceled)
54+
4755
self._save_model()
4856
return train_corpus_size, confidence
4957

@@ -74,5 +82,59 @@ def _batch_inference(
7482
check_canceled: Optional[Callable[[], None]],
7583
) -> None: ...
7684

85+
def _align(
86+
self,
87+
progress_reporter: PhasedProgressReporter,
88+
check_canceled: Optional[Callable[[], None]],
89+
) -> None:
90+
if check_canceled is not None:
91+
check_canceled()
92+
93+
logger.info("Aligning source to pretranslations")
94+
with ExitStack() as stack:
95+
phase_progress = stack.enter_context(progress_reporter.start_next_phase())
96+
97+
src_tokenized = [
98+
tokenize(s["pretranslation"])
99+
for s in stack.enter_context(self._translation_file_service.get_source_pretranslations())
100+
]
101+
trg_tokenized = [
102+
tokenize(s["pretranslation"])
103+
for s in stack.enter_context(self._translation_file_service.get_target_pretranslations())
104+
]
105+
106+
with TemporaryDirectory() as td:
107+
aligner = EflomalAligner(Path(td))
108+
logger.info("Training aligner")
109+
aligner.train(src_tokenized, trg_tokenized)
110+
111+
if check_canceled is not None:
112+
check_canceled()
113+
114+
logger.info("Aligning pretranslations")
115+
alignments = aligner.align()
116+
117+
if check_canceled is not None:
118+
check_canceled()
119+
120+
writer = stack.enter_context(self._translation_file_service.open_target_pretranslation_writer())
121+
for trg_pi, src_toks, trg_toks, alignment in zip(
122+
stack.enter_context(self._translation_file_service.get_target_pretranslations()),
123+
src_tokenized,
124+
trg_tokenized,
125+
alignments,
126+
):
127+
writer.write(
128+
PretranslationInfo(
129+
corpusId=trg_pi["corpusId"],
130+
textId=trg_pi["textId"],
131+
refs=trg_pi["refs"],
132+
pretranslation=trg_pi["pretranslation"],
133+
source_toks=list(src_toks),
134+
pretranslation_toks=list(trg_toks),
135+
alignment=alignment,
136+
)
137+
)
138+
77139
@abstractmethod
78140
def _save_model(self) -> None: ...

machine/jobs/translation_file_service.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ class PretranslationInfo(TypedDict):
1515
corpusId: str # noqa: N815
1616
textId: str # noqa: N815
1717
refs: List[str]
18-
translation: str
18+
pretranslation: str
19+
source_toks: List[str]
20+
pretranslation_toks: List[str]
21+
alignment: str
1922

2023

2124
SOURCE_FILENAME = "train.src.txt"
@@ -49,23 +52,30 @@ def exists_source_corpus(self) -> bool:
4952
def exists_target_corpus(self) -> bool:
5053
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{TARGET_FILENAME}")
5154

52-
def get_source_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]:
53-
src_pretranslate_path = self.shared_file_service.download_file(
54-
f"{self.shared_file_service.build_path}/{SOURCE_PRETRANSLATION_FILENAME}"
55-
)
55+
def _get_pretranslations(self, filename: str) -> ContextManagedGenerator[PretranslationInfo, None, None]:
56+
pretranslate_path = self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{filename}")
5657

5758
def generator() -> Generator[PretranslationInfo, None, None]:
58-
with src_pretranslate_path.open("r", encoding="utf-8-sig") as file:
59+
with pretranslate_path.open("r", encoding="utf-8-sig") as file:
5960
for pi in json_stream.load(file):
6061
yield PretranslationInfo(
6162
corpusId=pi["corpusId"],
6263
textId=pi["textId"],
6364
refs=list(pi["refs"]),
64-
translation=pi["translation"],
65+
pretranslation=pi["pretranslation"],
66+
source_toks=list(pi["source_toks"]),
67+
pretranslation_toks=list(pi["pretranslation_toks"]),
68+
alignment=pi["alignment"],
6569
)
6670

6771
return ContextManagedGenerator(generator())
6872

73+
def get_source_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]:
74+
return self._get_pretranslations(SOURCE_PRETRANSLATION_FILENAME)
75+
76+
def get_target_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]:
77+
return self._get_pretranslations(TARGET_PRETRANSLATION_FILENAME)
78+
6979
def save_model(self, model_path: Path, destination: str) -> None:
7080
self.shared_file_service.upload_path(model_path, destination)
7181

0 commit comments

Comments
 (0)