Skip to content

Commit 7e7e033

Browse files
authored
Add optional alignment step to NMT jobs, temporary implementation of eflomal (#169)
* Add optional alignment step to nmt jobs, temporary implementation of eflomal * make dockerfiles compatible with eflomal * fix flake8 error * Only use eflomal on linux and other small tweaks * Adjust NMT engine test * alternate eflomal check * Alternate EFLOMAL_PATH, revert eflomal check change * Make linter ignore conditional eflomal import * Attempt to fix EFLOMAL_PATH * Eflomal sanity check * Eflomal sanity check take 2 * Add EFLOMAL_PATH value to pytest step of ci workflow * Revert EFLOMAL_PATH values to regular docker container paths * Only use normalized tokens inside of aligner * Move alignment config option to build options, revert to 'translation' in PretranslationInfo * Refactor to do alignment during the inference step * Remove unused function
1 parent a3d7276 commit 7e7e033

File tree

12 files changed

+393
-31
lines changed

12 files changed

+393
-31
lines changed

.devcontainer/dockerfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ RUN apt-get update && \
2020
apt-get install --no-install-recommends -y \
2121
python$PYTHON_VERSION \
2222
python$PYTHON_VERSION-distutils \
23+
python$PYTHON_VERSION-dev \
2324
git vim curl gdb ca-certificates gnupg2 tar make gcc libssl-dev zlib1g-dev libncurses5-dev \
2425
libbz2-dev libreadline-dev libreadline6-dev libxml2-dev xz-utils libgdbm-dev libgdbm-compat-dev tk-dev dirmngr \
2526
libxmlsec1-dev libsqlite3-dev libffi-dev liblzma-dev lzma lzma-dev uuid-dev && \
@@ -39,4 +40,6 @@ RUN pip install -U pip setuptools \
3940

4041
COPY ./.devcontainer/clearml.conf /root/clearml.conf
4142

43+
ENV EFLOMAL_PATH=/workspaces/machine.py/.venv/lib/python${PYTHON_VERSION}/site-packages/eflomal/bin
44+
4245
CMD ["bash"]

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ jobs:
5959
poetry run pyright
6060
- name: Test with pytest
6161
run: poetry run pytest --cov --cov-report=xml
62+
env:
63+
EFLOMAL_PATH: /home/runner/work/machine.py/machine.py/.venv/lib/python${{ matrix.python-version }}/site-packages/eflomal/bin
6264
- name: Upload coverage reports to Codecov
6365
uses: codecov/codecov-action@v4
6466
env:

dockerfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# syntax=docker/dockerfile:1.7-labs
2-
32
ARG PYTHON_VERSION=3.12
43
ARG UBUNTU_VERSION=noble
54
ARG POETRY_VERSION=1.6.1
6-
ARG CUDA_VERSION=12.6.1-base-ubuntu24.04
75

86
FROM python:$PYTHON_VERSION-slim AS builder
97
ARG POETRY_VERSION
@@ -25,7 +23,7 @@ COPY poetry.lock pyproject.toml /src
2523
RUN poetry export --with=gpu --without-hashes -f requirements.txt > requirements.txt
2624

2725

28-
FROM nvidia/cuda:$CUDA_VERSION
26+
FROM python:$PYTHON_VERSION
2927
ARG PYTHON_VERSION
3028

3129
ENV PIP_DISABLE_PIP_VERSION_CHECK=on
@@ -64,4 +62,6 @@ RUN --mount=type=cache,target=/root/.cache \
6462
RUN python -m pip install --no-deps . && rm -r /root/*
6563
ENV CLEARML_AGENT_SKIP_PYTHON_ENV_INSTALL=1
6664

65+
ENV EFLOMAL_PATH=/usr/local/lib/python${PYTHON_VERSION}/site-packages/eflomal/bin
66+
6767
CMD ["bash"]

dockerfile.cpu_only

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,6 @@ RUN --mount=type=cache,target=/root/.cache \
4343
RUN python -m pip install --no-deps . && rm -r /root/*
4444
ENV CLEARML_AGENT_SKIP_PYTHON_ENV_INSTALL=1
4545

46+
ENV EFLOMAL_PATH=/usr/local/lib/python${PYTHON_VERSION}/site-packages/eflomal/bin
47+
4648
CMD ["bash"]

machine/jobs/eflomal_aligner.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# NOTE: this is a temporary solution to be able to use the eflomal aligner inside of machine.py.
2+
# The vast majority of this code is taken from the silnlp repository.
3+
4+
import os
5+
import subprocess
6+
from contextlib import ExitStack
7+
from importlib.util import find_spec
8+
from math import sqrt
9+
from pathlib import Path
10+
from tempfile import TemporaryDirectory
11+
from typing import IO, Iterable, List, Sequence, Tuple
12+
13+
from ..corpora import AlignedWordPair
14+
from ..corpora.token_processors import escape_spaces, lowercase, normalize
15+
from ..tokenization import LatinWordTokenizer
16+
from ..translation import SymmetrizationHeuristic, WordAlignmentMatrix
17+
18+
19+
# From silnlp.common.package_utils
20+
def is_eflomal_available() -> bool:
21+
return find_spec("eflomal") is not None
22+
23+
24+
if is_eflomal_available():
25+
from eflomal import read_text, write_text # type: ignore
26+
27+
EFLOMAL_PATH = Path(os.getenv("EFLOMAL_PATH", "."), "eflomal")
28+
TOKENIZER = LatinWordTokenizer()
29+
30+
31+
# From silnlp.alignment.tools
32+
def execute_eflomal(
33+
source_path: Path,
34+
target_path: Path,
35+
forward_links_path: Path,
36+
reverse_links_path: Path,
37+
n_iterations: Tuple[int, int, int],
38+
) -> None:
39+
if not is_eflomal_available():
40+
raise RuntimeError("eflomal is not installed.")
41+
42+
args = [
43+
str(EFLOMAL_PATH),
44+
"-s",
45+
str(source_path),
46+
"-t",
47+
str(target_path),
48+
"-f",
49+
str(forward_links_path),
50+
"-r",
51+
str(reverse_links_path),
52+
# "-q",
53+
"-m",
54+
"3",
55+
"-n",
56+
"3",
57+
"-N",
58+
"0.2",
59+
"-1",
60+
str(n_iterations[0]),
61+
"-2",
62+
str(n_iterations[1]),
63+
"-3",
64+
str(n_iterations[2]),
65+
]
66+
subprocess.run(args, stderr=subprocess.DEVNULL)
67+
68+
69+
# From silnlp.alignment.eflomal
70+
def to_word_alignment_matrix(alignment_str: str) -> WordAlignmentMatrix:
71+
word_pairs = AlignedWordPair.from_string(alignment_str)
72+
row_count = 0
73+
column_count = 0
74+
for pair in word_pairs:
75+
if pair.source_index + 1 > row_count:
76+
row_count = pair.source_index + 1
77+
if pair.target_index + 1 > column_count:
78+
column_count = pair.target_index + 1
79+
return WordAlignmentMatrix.from_word_pairs(row_count, column_count, word_pairs)
80+
81+
82+
# From silnlp.alignment.eflomal
83+
def to_eflomal_text_file(input: Iterable[str], output_file: IO[bytes], prefix_len: int = 0, suffix_len: int = 0) -> int:
84+
sents, index = read_text(input, True, prefix_len, suffix_len)
85+
n_sents = len(sents)
86+
voc_size = len(index)
87+
write_text(output_file, tuple(sents), voc_size)
88+
return n_sents
89+
90+
91+
# From silnlp.alignment.eflomal
92+
def prepare_files(
93+
src_input: Iterable[str], src_output_file: IO[bytes], trg_input: Iterable[str], trg_output_file: IO[bytes]
94+
) -> int:
95+
n_src_sents = to_eflomal_text_file(src_input, src_output_file)
96+
n_trg_sents = to_eflomal_text_file(trg_input, trg_output_file)
97+
if n_src_sents != n_trg_sents:
98+
raise ValueError("Mismatched file sizes")
99+
return n_src_sents
100+
101+
102+
def tokenize(sent: str) -> Sequence[str]:
103+
return list(TOKENIZER.tokenize(sent))
104+
105+
106+
def normalize_for_alignment(sent: Sequence[str]) -> str:
107+
return " ".join(lowercase(normalize("NFC", escape_spaces(sent))))
108+
109+
110+
# From silnlp.alignment.eflomal
111+
class EflomalAligner:
112+
def __init__(self, model_dir: Path) -> None:
113+
self._model_dir = model_dir
114+
115+
def train(self, src_toks: Sequence[Sequence[str]], trg_toks: Sequence[Sequence[str]]) -> None:
116+
self._model_dir.mkdir(exist_ok=True)
117+
with TemporaryDirectory() as temp_dir:
118+
src_eflomal_path = Path(temp_dir, "source")
119+
trg_eflomal_path = Path(temp_dir, "target")
120+
with ExitStack() as stack:
121+
src_output_file = stack.enter_context(src_eflomal_path.open("wb"))
122+
trg_output_file = stack.enter_context(trg_eflomal_path.open("wb"))
123+
# Write input files for the eflomal binary
124+
n_sentences = prepare_files(
125+
[normalize_for_alignment(s) for s in src_toks],
126+
src_output_file,
127+
[normalize_for_alignment(s) for s in trg_toks],
128+
trg_output_file,
129+
)
130+
131+
iters = max(2, int(round(1.0 * 5000 / sqrt(n_sentences))))
132+
iters4 = max(1, iters // 4)
133+
n_iterations = (max(2, iters4), iters4, iters)
134+
135+
# Run wrapper for the eflomal binary
136+
execute_eflomal(
137+
src_eflomal_path,
138+
trg_eflomal_path,
139+
self._model_dir / "forward-align.txt",
140+
self._model_dir / "reverse-align.txt",
141+
n_iterations,
142+
)
143+
144+
def align(self, sym_heuristic: str = "grow-diag-final-and") -> List[str]:
145+
forward_align_path = self._model_dir / "forward-align.txt"
146+
reverse_align_path = self._model_dir / "reverse-align.txt"
147+
148+
alignments = []
149+
heuristic = SymmetrizationHeuristic[sym_heuristic.upper().replace("-", "_")]
150+
with ExitStack() as stack:
151+
forward_file = stack.enter_context(forward_align_path.open("r", encoding="utf-8-sig"))
152+
reverse_file = stack.enter_context(reverse_align_path.open("r", encoding="utf-8-sig"))
153+
154+
for forward_line, reverse_line in zip(forward_file, reverse_file):
155+
forward_matrix = to_word_alignment_matrix(forward_line.strip())
156+
reverse_matrix = to_word_alignment_matrix(reverse_line.strip())
157+
src_len = max(forward_matrix.row_count, reverse_matrix.row_count)
158+
trg_len = max(forward_matrix.column_count, reverse_matrix.column_count)
159+
160+
forward_matrix.resize(src_len, trg_len)
161+
reverse_matrix.resize(src_len, trg_len)
162+
163+
forward_matrix.symmetrize_with(reverse_matrix, heuristic)
164+
165+
alignments.append(str(forward_matrix))
166+
167+
return alignments

machine/jobs/nmt_engine_build_job.py

Lines changed: 74 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import logging
22
from contextlib import ExitStack
3+
from pathlib import Path
4+
from tempfile import TemporaryDirectory
35
from typing import Any, Callable, Optional, Sequence, Tuple
46

57
from ..corpora.corpora_utils import batch
68
from ..corpora.parallel_text_corpus import ParallelTextCorpus
79
from ..corpora.text_corpus import TextCorpus
8-
from ..translation.translation_engine import TranslationEngine
910
from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter
1011
from ..utils.progress_status import ProgressStatus
12+
from .eflomal_aligner import EflomalAligner, is_eflomal_available, tokenize
1113
from .nmt_model_factory import NmtModelFactory
12-
from .shared_file_service_base import DictToJsonWriter
1314
from .translation_engine_build_job import TranslationEngineBuildJob
1415
from .translation_file_service import PretranslationInfo, TranslationFileService
1516

@@ -28,12 +29,25 @@ def _get_progress_reporter(
2829
self, progress: Optional[Callable[[ProgressStatus], None]], corpus_size: int
2930
) -> PhasedProgressReporter:
3031
if corpus_size > 0:
31-
phases = [
32-
Phase(message="Training NMT model", percentage=0.9),
33-
Phase(message="Pretranslating segments", percentage=0.1),
34-
]
32+
if self._config.align_pretranslations:
33+
phases = [
34+
Phase(message="Training NMT model", percentage=0.8),
35+
Phase(message="Pretranslating segments", percentage=0.1),
36+
Phase(message="Aligning segments", percentage=0.1, report_steps=False),
37+
]
38+
else:
39+
phases = [
40+
Phase(message="Training NMT model", percentage=0.9),
41+
Phase(message="Pretranslating segments", percentage=0.1),
42+
]
3543
else:
36-
phases = [Phase(message="Pretranslating segments", percentage=1.0)]
44+
if self._config.align_pretranslations:
45+
phases = [
46+
Phase(message="Pretranslating segments", percentage=0.9),
47+
Phase(message="Aligning segments", percentage=0.1, report_steps=False),
48+
]
49+
else:
50+
phases = [Phase(message="Pretranslating segments", percentage=1.0)]
3751
return PhasedProgressReporter(progress, phases)
3852

3953
def _respond_to_no_training_corpus(self) -> Tuple[int, float]:
@@ -89,33 +103,70 @@ def _batch_inference(
89103
with ExitStack() as stack:
90104
phase_progress = stack.enter_context(progress_reporter.start_next_phase())
91105
engine = stack.enter_context(self._nmt_model_factory.create_engine())
92-
src_pretranslations = stack.enter_context(self._translation_file_service.get_source_pretranslations())
93-
writer = stack.enter_context(self._translation_file_service.open_target_pretranslation_writer())
106+
pretranslations = [
107+
pt_info for pt_info in stack.enter_context(self._translation_file_service.get_source_pretranslations())
108+
]
109+
src_segments = [pt_info["translation"] for pt_info in pretranslations]
94110
current_inference_step = 0
95111
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))
96112
batch_size = self._config["inference_batch_size"]
97-
for pi_batch in batch(src_pretranslations, batch_size):
113+
for seg_batch in batch(iter(src_segments), batch_size):
98114
if check_canceled is not None:
99115
check_canceled()
100-
_translate_batch(engine, pi_batch, writer)
101-
current_inference_step += len(pi_batch)
116+
for i, result in enumerate(engine.translate_batch(seg_batch)):
117+
pretranslations[current_inference_step + i]["translation"] = result.translation
118+
current_inference_step += len(seg_batch)
102119
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))
103120

121+
if self._config.align_pretranslations and is_eflomal_available():
122+
logger.info("Aligning source to pretranslations")
123+
pretranslations = self._align(src_segments, pretranslations, progress_reporter, check_canceled)
124+
125+
writer = stack.enter_context(self._translation_file_service.open_target_pretranslation_writer())
126+
for pretranslation in pretranslations:
127+
writer.write(pretranslation)
128+
129+
def _align(
130+
self,
131+
src_segments: Sequence[str],
132+
pretranslations: Sequence[PretranslationInfo],
133+
progress_reporter: PhasedProgressReporter,
134+
check_canceled: Optional[Callable[[], None]],
135+
) -> Sequence[PretranslationInfo]:
136+
if check_canceled is not None:
137+
check_canceled()
138+
139+
logger.info("Aligning source to pretranslations")
140+
progress_reporter.start_next_phase()
141+
142+
src_tokenized = [tokenize(s) for s in src_segments]
143+
trg_tokenized = [tokenize(pt_info["translation"]) for pt_info in pretranslations]
144+
145+
with TemporaryDirectory() as td:
146+
aligner = EflomalAligner(Path(td))
147+
logger.info("Training aligner")
148+
aligner.train(src_tokenized, trg_tokenized)
149+
150+
if check_canceled is not None:
151+
check_canceled()
152+
153+
logger.info("Aligning pretranslations")
154+
alignments = aligner.align()
155+
156+
if check_canceled is not None:
157+
check_canceled()
158+
159+
for i in range(len(pretranslations)):
160+
pretranslations[i]["source_toks"] = list(src_tokenized[i])
161+
pretranslations[i]["translation_toks"] = list(trg_tokenized[i])
162+
pretranslations[i]["alignment"] = alignments[i]
163+
164+
return pretranslations
165+
104166
def _save_model(self) -> None:
105167
if "save_model" in self._config and self._config.save_model is not None:
106168
logger.info("Saving model")
107169
model_path = self._nmt_model_factory.save_model()
108170
self._translation_file_service.save_model(
109171
model_path, f"models/{self._config.save_model + ''.join(model_path.suffixes)}"
110172
)
111-
112-
113-
def _translate_batch(
114-
engine: TranslationEngine,
115-
batch: Sequence[PretranslationInfo],
116-
writer: DictToJsonWriter,
117-
) -> None:
118-
source_segments = [pi["translation"] for pi in batch]
119-
for i, result in enumerate(engine.translate_batch(source_segments)):
120-
batch[i]["translation"] = result.translation
121-
writer.write(batch[i])

machine/jobs/settings.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ default:
33
shared_file_uri: s3:/silnlp/
44
shared_file_folder: production
55
inference_batch_size: 1024
6+
align_pretranslations: false
67
huggingface:
78
parent_model_name: facebook/nllb-200-distilled-1.3B
89
train_params:

machine/jobs/translation_file_service.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ class PretranslationInfo(TypedDict):
1616
textId: str # noqa: N815
1717
refs: List[str]
1818
translation: str
19+
source_toks: List[str]
20+
translation_toks: List[str]
21+
alignment: str
1922

2023

2124
SOURCE_FILENAME = "train.src.txt"
@@ -62,6 +65,9 @@ def generator() -> Generator[PretranslationInfo, None, None]:
6265
textId=pi["textId"],
6366
refs=list(pi["refs"]),
6467
translation=pi["translation"],
68+
source_toks=list(pi["source_toks"]),
69+
translation_toks=list(pi["translation_toks"]),
70+
alignment=pi["alignment"],
6571
)
6672

6773
return ContextManagedGenerator(generator())

0 commit comments

Comments
 (0)