Skip to content

Commit c972753

Browse files
Refactor LM losses v2 (#449)
Co-authored-by: oleksost <[email protected]>
1 parent c208b26 commit c972753

File tree

30 files changed

+1551
-1445
lines changed

30 files changed

+1551
-1445
lines changed

fast_llm/core/distributed.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,12 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name:
7272
)
7373

7474

75-
def safe_barrier(group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None) -> None:
75+
def safe_barrier(
76+
group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None, device: torch.device | None = None
77+
) -> None:
7678
if group:
7779
hashed = hash(value) % 2**32
78-
out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout)
80+
out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout, device=device)
7981
if out != hashed * group.size():
8082
raise RuntimeError(f"Desync detected for barrier {value} ({out}!={hashed*group.size()})")
8183

@@ -86,9 +88,10 @@ def allreduce_scalar(
8688
group: torch.distributed.ProcessGroup | None = None,
8789
op=ReduceOp.SUM,
8890
timeout: float | None = None,
91+
device: torch.device | None = None,
8992
) -> float | int:
9093
if group:
91-
value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device())
94+
value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device() if device is None else device)
9295
with set_timeout(group, timeout):
9396
torch.distributed.all_reduce(value, op=op, group=group)
9497
return value.item()

fast_llm/engine/schedule/runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,9 @@ def _preprocess_data(
327327
self, context: BatchContext, data_iterator: typing.Iterator, preprocessed: bool
328328
) -> typing.Generator[None, None, None]:
329329
batch_config = context.schedule.batch_config
330-
grad_output = (1 if self._optimizer is None else self._optimizer.grad_scale) / batch_config.num_inputs
330+
grad_output = (
331+
self._optimizer.grad_scale / batch_config.num_inputs if context.schedule.phase.is_training else None
332+
)
331333
for micro_batch in range(batch_config.sequential_micro_batches):
332334
micro_batch_data = next(data_iterator)
333335
if not preprocessed:

fast_llm/functional/autograd.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,14 @@ def call(*args, **kwargs):
6060

6161
def grad_is_context(grad_output: torch.Tensor, context: torch.Tensor) -> torch.Tensor: # noqa
6262
return context
63+
64+
65+
class AuxiliaryLoss(torch.autograd.Function):
66+
@staticmethod
67+
def forward(ctx, input_: torch.Tensor, aux_loss: torch.Tensor, grad: float) -> torch.Tensor: # noqa
68+
ctx.grad = torch.full_like(aux_loss, grad)
69+
return input_
70+
71+
@staticmethod
72+
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: # noqa
73+
return grad_output, ctx.grad, None

fast_llm/functional/config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,17 @@ def _set_activation_fn_map() -> None:
9393
MAX_DROPLESS_BLOCK_SIZE_ROW = 128
9494

9595

96-
class CrossEntropyImpl(str, enum.Enum):
96+
class EntropyLossImplementation(enum.StrEnum):
9797
auto = "auto"
9898
torch = "torch"
9999
fused = "fused"
100100
triton = "triton"
101101

102102

103-
class DistillationLossImpl(str, enum.Enum):
104-
reverse_kl = "reverse_kl"
103+
class EntropyLossType(enum.StrEnum):
105104
cross_entropy = "cross_entropy"
105+
forward_kl = "forward_kl"
106+
reverse_kl = "reverse_kl"
106107

107108

108109
class TargetFormat(enum.StrEnum):

0 commit comments

Comments
 (0)