Skip to content

Commit 502d766

Browse files
authored
Merge pull request #173 from sordonia/memory_efficient_training
Knowledge Modules PR
2 parents 0041649 + 34e5b14 commit 502d766

342 files changed

Lines changed: 14161 additions & 315 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

mttl/arguments.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,10 @@ def was_default(self, key):
122122
@classmethod
123123
def process_kwargs(cls, kwargs, eval=True, raise_error=True, silent=False):
124124
overwrites_log = []
125+
kwargs_keys = list(kwargs.keys())
125126

126-
for k, v in kwargs.items():
127+
for k in kwargs_keys:
128+
v = kwargs[k]
127129
if eval:
128130
try:
129131
v = ast.literal_eval(v)
@@ -132,8 +134,16 @@ def process_kwargs(cls, kwargs, eval=True, raise_error=True, silent=False):
132134
else:
133135
v = v
134136

135-
if not hasattr(cls, k) and raise_error:
136-
raise ValueError(f"{k} is not in the config")
137+
if not hasattr(cls, k):
138+
if raise_error:
139+
raise ValueError(f"{k} is not in the config")
140+
else:
141+
# if the key is not found, but we are fault tolerant, we remove it
142+
logger.warning(
143+
f"{k} is not in the config, skipping it, make sure this is intended!"
144+
)
145+
del kwargs[k]
146+
continue
137147

138148
if eval and not silent:
139149
overwrites_log.append(f"Overwriting {k} to {v}")
@@ -345,7 +355,7 @@ class TrainingArgs(DataArgs):
345355
expert_name: str = None
346356

347357
# Training config
348-
micro_batch_size: str = None
358+
micro_batch_size: int = None
349359
compute_strategy: str = None
350360
scheduler: str = "linear_decay_with_warmup"
351361
checkpoint: str = None # load from checkpoint
@@ -371,7 +381,7 @@ class TrainingArgs(DataArgs):
371381
save_every: int = None
372382
save_each_epoch: bool = False
373383
eval_every: int = None
374-
eval_every_n_epoch: int = 1
384+
eval_every_n_epoch: int = None
375385
seed: int = 42
376386
debug: bool = False
377387

@@ -484,6 +494,9 @@ def __post_init__(self):
484494
def to_hf_training_args(self) -> "TrainingArguments":
485495
from transformers import TrainingArguments
486496

497+
# NOTE: unclear how `warmup_steps` and `warmup_ratio` are used in HF args
498+
# given that we build the optimzer and scheduler ourselves
499+
487500
return TrainingArguments(
488501
run_name=self.wandb_run_name
489502
or self.expert_name
@@ -503,6 +516,7 @@ def to_hf_training_args(self) -> "TrainingArguments":
503516
warmup_steps=self.warmup_steps if self.warmup_steps > 0 else 0,
504517
warmup_ratio=self.warmup_proportion if self.warmup_proportion > 0 else 0,
505518
num_train_epochs=self.num_train_epochs,
519+
max_grad_norm=self.max_grad_norm,
506520
max_steps=self.total_steps,
507521
save_total_limit=1,
508522
remove_unused_columns=False,
@@ -511,6 +525,7 @@ def to_hf_training_args(self) -> "TrainingArguments":
511525
save_steps=self.save_every,
512526
eval_steps=self.eval_every,
513527
ddp_find_unused_parameters=False,
528+
eval_on_start=self.eval_before_training,
514529
)
515530

516531

mttl/dataloader/ni_metrics.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ def tokenize(self, s):
2121
return tokens
2222

2323

24-
xlingual_tokenizer = GPTTokenizer()
24+
try:
25+
xlingual_tokenizer = GPTTokenizer()
26+
except:
27+
xlingual_tokenizer = None
2528

2629

2730
# adapted the flowing from Squad v1.1 evaluation, without removing the articles.

0 commit comments

Comments
 (0)