@@ -164,6 +164,7 @@ def _try_translate_n_batch(
164164 builder = TranslationResultBuilder (input_tokens )
165165 for token , score in zip (output ["translation_tokens" ], output ["token_scores" ]):
166166 builder .append_token (token , TranslationSources .NMT , exp (score ))
167+ builder .set_sequence_confidence (exp (output ["sequence_score" ]))
167168 word_pairs : Optional [Collection [Union [AlignedWordPair , Tuple [int , int ]]]] = None
168169 if output .get ("token_attentions" ) is not None :
169170 src_indices = torch .argmax (output ["token_attentions" ], dim = 1 ).tolist ()
@@ -257,36 +258,56 @@ def _forward(self, model_inputs, **generate_kwargs):
257258 output_ids = output .sequences
258259 beam_indices = output .beam_indices
259260 scores = output .scores
261+ assert scores is not None and beam_indices is not None
262+ sequences_scores = output .sequences_scores
260263 attentions = output .cross_attentions
261264 elif isinstance (output , GreedySearchEncoderDecoderOutput ):
262265 output_ids = output .sequences
263- beam_indices = torch . zeros_like ( output_ids )
266+ beam_indices = None
264267 assert output .scores is not None
265- scores = tuple (torch .nn .functional .log_softmax (logits , dim = - 1 ) for logits in output .scores )
268+ scores = output .scores
269+ sequences_scores = None
266270 attentions = output .cross_attentions
267271 else :
268272 raise RuntimeError ("Cannot postprocess the output of the model." )
269273
270- assert beam_indices is not None and scores is not None
271- out_b = output_ids .shape [0 ]
274+ transition_scores = cast (
275+ torch .Tensor ,
276+ self .model .compute_transition_scores (
277+ output_ids , # type: ignore
278+ scores , # type: ignore
279+ beam_indices , # type: ignore
280+ normalize_logits = True ,
281+ ),
282+ )
283+
284+ if beam_indices is None :
285+ beam_indices = torch .zeros_like (output_ids )
286+
287+ out_b , seq_len = output_ids .shape
272288 num_beams = scores [0 ].shape [0 ] // in_b
273289 n_sequences = out_b // in_b
290+
291+ ts_len = transition_scores .shape [1 ]
292+ if ts_len == seq_len :
293+ token_logprobs = transition_scores
294+ elif ts_len == seq_len - 1 :
295+ token_logprobs = torch .cat (
296+ [
297+ torch .zeros (out_b , 1 , device = transition_scores .device , dtype = transition_scores .dtype ),
298+ transition_scores ,
299+ ],
300+ dim = 1 ,
301+ )
302+ else :
303+ raise RuntimeError (
304+ f"Unexpected transition_scores length { ts_len } for sequences length { seq_len } . "
305+ "Cannot align token scores robustly."
306+ )
307+
274308 start_index = 0
275309 if self .model .config .decoder_start_token_id is not None :
276310 start_index = 1
277- indices = torch .stack (
278- (
279- torch .arange (output_ids .shape [1 ] - start_index , device = output_ids .device ).expand (in_b , n_sequences , - 1 ),
280- torch .reshape (beam_indices [:, start_index :] % num_beams , (in_b , n_sequences , - 1 )),
281- torch .reshape (output_ids [:, start_index :], (in_b , n_sequences , - 1 )),
282- ),
283- dim = 3 ,
284- )
285- scores = torch .stack (scores , dim = 0 ).reshape (len (scores ), in_b , num_beams , - 1 ).transpose (0 , 1 )
286- scores = torch_gather_nd (scores , indices , 1 )
287- if self .model .config .decoder_start_token_id is not None :
288- scores = torch .cat ((torch .zeros (scores .shape [0 ], scores .shape [1 ], 1 , device = scores .device ), scores ), dim = 2 )
289-
290311 if generate_kwargs ["output_attentions" ] is True :
291312 assert attentions is not None
292313 num_heads = attentions [0 ][0 ].shape [1 ]
@@ -320,13 +341,15 @@ def _forward(self, model_inputs, **generate_kwargs):
320341 ),
321342 dim = 2 ,
322343 )
344+ output_ids = output_ids .reshape (in_b , n_sequences , seq_len )
345+ token_logprobs = token_logprobs .reshape (in_b , n_sequences , seq_len )
323346
324- output_ids = output_ids .reshape (in_b , n_sequences , * output_ids .shape [1 :])
325347 return {
326348 "input_ids" : model_inputs ["input_ids" ],
327349 "input_tokens" : input_tokens ,
328350 "output_ids" : output_ids ,
329- "scores" : scores ,
351+ "scores" : token_logprobs ,
352+ "sequences_scores" : sequences_scores ,
330353 "attentions" : attentions ,
331354 }
332355
@@ -346,24 +369,17 @@ def postprocess(self, model_outputs, clean_up_tokenization_spaces=False):
346369 records = []
347370
348371 has_attentions = model_outputs .get ("attentions" ) is not None and model_outputs ["attentions" ][0 ] is not None
349- if has_attentions :
350- zipped = zip (
351- model_outputs ["output_ids" ][0 ],
352- model_outputs ["scores" ][0 ],
353- model_outputs ["attentions" ][0 ],
354- )
355- else :
356- zipped = zip (
357- model_outputs ["output_ids" ][0 ],
358- model_outputs ["scores" ][0 ],
359- )
360-
372+ has_sequence_scores = model_outputs ["sequences_scores" ] is not None
373+ zipped = zip (
374+ model_outputs ["output_ids" ][0 ],
375+ model_outputs ["scores" ][0 ],
376+ model_outputs ["sequences_scores" ] if has_sequence_scores else iter (lambda : None , 1 ),
377+ model_outputs ["attentions" ][0 ] if has_attentions else iter (lambda : None , 1 ),
378+ )
361379 for item in zipped :
362- if has_attentions :
363- output_ids , scores , attentions = cast (Tuple [torch .Tensor , torch .Tensor , torch .Tensor ], item )
364- else :
365- output_ids , scores = cast (Tuple [torch .Tensor , torch .Tensor ], item )
366- attentions = None
380+ output_ids , scores , sequence_score , attentions = cast (
381+ Tuple [torch .Tensor , torch .Tensor , Optional [float ], Optional [torch .Tensor ]], item
382+ )
367383
368384 output_tokens : List [str ] = []
369385 output_indices : List [int ] = []
@@ -379,6 +395,7 @@ def postprocess(self, model_outputs, clean_up_tokenization_spaces=False):
379395 "input_tokens" : input_tokens ,
380396 "translation_tokens" : output_tokens ,
381397 "token_scores" : scores ,
398+ "sequence_score" : sequence_score ,
382399 "translation_text" : self .tokenizer .decode (
383400 output_ids ,
384401 skip_special_tokens = True ,
0 commit comments