@@ -122,8 +122,10 @@ def was_default(self, key):
122122 @classmethod
123123 def process_kwargs (cls , kwargs , eval = True , raise_error = True , silent = False ):
124124 overwrites_log = []
125+ kwargs_keys = list (kwargs .keys ())
125126
126- for k , v in kwargs .items ():
127+ for k in kwargs_keys :
128+ v = kwargs [k ]
127129 if eval :
128130 try :
129131 v = ast .literal_eval (v )
@@ -132,8 +134,16 @@ def process_kwargs(cls, kwargs, eval=True, raise_error=True, silent=False):
132134 else :
133135 v = v
134136
135- if not hasattr (cls , k ) and raise_error :
136- raise ValueError (f"{ k } is not in the config" )
137+ if not hasattr (cls , k ):
138+ if raise_error :
139+ raise ValueError (f"{ k } is not in the config" )
140+ else :
141+ # if the key is not found, but we are fault tolerant, we remove it
142+ logger .warning (
143+ f"{ k } is not in the config, skipping it, make sure this is intended!"
144+ )
145+ del kwargs [k ]
146+ continue
137147
138148 if eval and not silent :
139149 overwrites_log .append (f"Overwriting { k } to { v } " )
@@ -345,7 +355,7 @@ class TrainingArgs(DataArgs):
345355 expert_name : str = None
346356
347357 # Training config
348- micro_batch_size : str = None
358+ micro_batch_size : int = None
349359 compute_strategy : str = None
350360 scheduler : str = "linear_decay_with_warmup"
351361 checkpoint : str = None # load from checkpoint
@@ -371,7 +381,7 @@ class TrainingArgs(DataArgs):
371381 save_every : int = None
372382 save_each_epoch : bool = False
373383 eval_every : int = None
374- eval_every_n_epoch : int = 1
384+ eval_every_n_epoch : int = None
375385 seed : int = 42
376386 debug : bool = False
377387
@@ -484,6 +494,9 @@ def __post_init__(self):
484494 def to_hf_training_args (self ) -> "TrainingArguments" :
485495 from transformers import TrainingArguments
486496
497+ # NOTE: unclear how `warmup_steps` and `warmup_ratio` are used in HF args
498+ # given that we build the optimzer and scheduler ourselves
499+
487500 return TrainingArguments (
488501 run_name = self .wandb_run_name
489502 or self .expert_name
@@ -503,6 +516,7 @@ def to_hf_training_args(self) -> "TrainingArguments":
503516 warmup_steps = self .warmup_steps if self .warmup_steps > 0 else 0 ,
504517 warmup_ratio = self .warmup_proportion if self .warmup_proportion > 0 else 0 ,
505518 num_train_epochs = self .num_train_epochs ,
519+ max_grad_norm = self .max_grad_norm ,
506520 max_steps = self .total_steps ,
507521 save_total_limit = 1 ,
508522 remove_unused_columns = False ,
@@ -511,6 +525,7 @@ def to_hf_training_args(self) -> "TrainingArguments":
511525 save_steps = self .save_every ,
512526 eval_steps = self .eval_every ,
513527 ddp_find_unused_parameters = False ,
528+ eval_on_start = self .eval_before_training ,
514529 )
515530
516531
0 commit comments