Skip to content

Commit af23682

Browse files
committed
reverse KL with mask bug
1 parent 58a0604 commit af23682

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

fast_llm/functional/entropy_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _fused_reverse_kl_base(
121121
# Compute loss terms: student_probs * log_ratio, then sum over vocab
122122
# This is equivalent to kl_div(..., log_target=True) but more memory efficient
123123
log_ratio = predicted_log_probability - target_log_probability
124-
per_sample_loss = (predicted_probability * log_ratio).sum(dim=-1)
124+
per_sample_loss = (predicted_probability * log_ratio).sum(dim=-1, keepdim=True)
125125
if group is not None:
126126
all_reduce(per_sample_loss, op=ReduceOp.SUM, group=group)
127127

@@ -130,7 +130,7 @@ def _fused_reverse_kl_base(
130130
else:
131131
# Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)])
132132
# where E_q[log(q/p)] is the expected log ratio under the student distribution
133-
grad = (log_ratio - per_sample_loss.unsqueeze(-1)) * predicted_probability * grad_output
133+
grad = (log_ratio - per_sample_loss) * predicted_probability * grad_output
134134

135135
return per_sample_loss, grad
136136

tests/functional/test_entropy_loss.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,13 @@ def test_entropy_loss(num_columns, grad_output, logits_scale_factor, loss_maskin
8282
out_torch, grad_torch = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.torch)
8383
out_fused, grad_fused = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.fused)
8484

85-
# TODO: Why is the error so high with loss masking for reverse KL?
8685
_compare_entropy_loss_outputs(
8786
out_fused,
8887
out_torch,
8988
grad_output is not None,
9089
grad_fused,
9190
grad_torch,
92-
loss_min_threshold=2e-4 if entropy_loss_type == EntropyLossType.reverse_kl and loss_masking else 5e-6,
91+
loss_min_threshold=5e-6,
9392
)
9493

9594
if entropy_loss_type != EntropyLossType.cross_entropy or not torch.cuda.is_available():

0 commit comments

Comments
 (0)