Skip to content

Commit 17cd7ed

Browse files
authored
Bug fixes (#453)
1 parent a765807 commit 17cd7ed

File tree

6 files changed

+37
-20
lines changed

6 files changed

+37
-20
lines changed

fast_llm/data/dataset/sampled.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def __getitem__(self, index: int) -> SampleType:
367367

368368
document_sampling_index = token_start_cumsum_index * TOKEN_CUMSUM_RATE + token_start_array_document_offset
369369

370-
token_count = token_start_array[token_start_cumsum_index]
370+
token_count = token_start_array[token_start_cumsum_index].item()
371371

372372
documents: list[SampleType] = []
373373
while token_count < token_end:

fast_llm_external_models/apriel2/modeling_apriel2.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,7 @@
77

88
import torch
99
import torch.nn.functional as F
10-
from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn
11-
from causal_conv1d import causal_conv1d_update as _causal_conv1d_update
1210
from einops import rearrange, repeat
13-
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
14-
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
1511
from torch import nn
1612
from transformers import GenerationMixin, PreTrainedModel
1713
from transformers.cache_utils import Cache
@@ -52,6 +48,19 @@
5248
fused_recurrent_kda = None
5349
fused_kda_gate = None
5450

51+
52+
try:
53+
from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn
54+
from causal_conv1d import causal_conv1d_update as _causal_conv1d_update
55+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
56+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
57+
except ImportError:
58+
_causal_conv1d_fn = None
59+
_causal_conv1d_update = None
60+
selective_scan_fn = None
61+
selective_state_update = None
62+
63+
5564
is_fast_path_available = is_mamba_ssm_available() and is_causal_conv1d_available()
5665

5766
if is_torch_flex_attn_available():
@@ -489,14 +498,6 @@ class PreprocessingOutput(TypedDict, total=False):
489498
attention_mask: Optional[torch.Tensor]
490499

491500

492-
# Require fast path CUDA kernels - no silent fallback to unoptimized code paths
493-
if not is_fast_path_available:
494-
raise ImportError(
495-
"CausalConv1d and Mamba require CUDA kernels from causal_conv1d and mamba_ssm. "
496-
"Install with: pip install causal-conv1d mamba-ssm"
497-
)
498-
499-
500501
class CausalConv1d(nn.Conv1d):
501502
"""
502503
Causal 1D convolution that pads only on the left side.
@@ -519,6 +520,11 @@ def __init__(
519520
activation: str = "silu",
520521
**kwargs,
521522
):
523+
if not is_fast_path_available:
524+
raise ImportError(
525+
"CausalConv1d requires CUDA kernels from causal_conv1d and mamba_ssm. "
526+
"Install with: pip install causal-conv1d mamba-ssm"
527+
)
522528
# Remove padding from kwargs since we handle it ourselves
523529
kwargs.pop("padding", None)
524530
super().__init__(

fast_llm_external_models/tests/test_apriel2/test_conversion_e2e.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def expand_surgery_chain_with_cycling(
138138
return expanded
139139

140140

141+
@requires_cuda
141142
class TestPlanCompositionTorture:
142143
"""End-to-end torture test for plan composition.
143144

fast_llm_external_models/tests/test_apriel2/test_model_structure.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,16 @@
22

33
import torch
44

5-
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2Cache, _AttentionCache, _SSMCache
6-
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM
5+
from fast_llm_external_models.apriel2.modeling_apriel2 import (
6+
Apriel2Cache,
7+
Apriel2ForCausalLM,
8+
_AttentionCache,
9+
_SSMCache,
10+
)
11+
from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda
712

813

14+
@requires_cuda
915
class TestStochasticMixerStructure:
1016
"""Validate stochastic mixer architecture matches configuration."""
1117

fast_llm_external_models/tests/test_apriel2/test_modeling.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55

66
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM
7+
from fast_llm_external_models.tests.test_apriel2.conftest import requires_cuda
78

89

910
class TestApriel2Modeling:
@@ -13,9 +14,9 @@ class TestApriel2Modeling:
1314
"config_name",
1415
[
1516
"apriel2_config_tiny",
16-
"apriel2_config_stochastic",
17-
"apriel2_config_multi_mixer",
18-
"apriel2_config_all_mixers", # Tests all 4 mixer types
17+
pytest.param("apriel2_config_stochastic", marks=requires_cuda),
18+
pytest.param("apriel2_config_multi_mixer", marks=requires_cuda),
19+
pytest.param("apriel2_config_all_mixers", marks=requires_cuda), # Tests all 4 mixer types
1920
"apriel2_config_with_bias", # Tests per-layer bias and non-gated MLP
2021
],
2122
)

tests/layers/test_ssm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
from fast_llm.layers.ssm import kda as kda_module
1111
from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig
1212
from fast_llm.utils import Assert
13-
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet, Apriel2Mamba, KimiDeltaAttention
1413
from tests.utils.utils import get_stage, requires_cuda
1514

1615
try:
17-
from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet, Apriel2Mamba
16+
from fast_llm_external_models.apriel2.modeling_apriel2 import (
17+
Apriel2GatedDeltaNet,
18+
Apriel2Mamba,
19+
KimiDeltaAttention,
20+
)
1821
except ImportError:
1922
Apriel2GatedDeltaNet = None
2023
Apriel2Mamba = None

0 commit comments

Comments
 (0)