Skip to content

Commit f424d2c

Browse files
committed
Thin-down supernets
1 parent 5abc6ef commit f424d2c

8 files changed

Lines changed: 876 additions & 118 deletions

File tree

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Placement config: all-attention and every-2-gdn
2+
#
3+
# Two placements for layer placement study:
4+
# 1. all-attention: All 24 layers use full attention
5+
# 2. every-2-gdn: Every second layer is GDN (layers 1,3,5,...), rest are attention
6+
#
7+
# Usage:
8+
# python thin_supernet.py ~/.cache/huggingface/apriel2-0.5b-dev \
9+
# ~/.cache/huggingface/apriel2-0.5b-thinned-all-attn-every2gdn \
10+
# -p examples/placements/all_attn_every2_gdn.yaml
11+
12+
placements:
13+
# All attention (24 layers)
14+
- [attention, attention, attention, attention, attention, attention,
15+
attention, attention, attention, attention, attention, attention,
16+
attention, attention, attention, attention, attention, attention,
17+
attention, attention, attention, attention, attention, attention]
18+
# Every 2nd layer is GDN: attn, gdn, attn, gdn, ...
19+
- [attention, gdn, attention, gdn, attention, gdn, attention, gdn,
20+
attention, gdn, attention, gdn, attention, gdn, attention, gdn,
21+
attention, gdn, attention, gdn, attention, gdn, attention, gdn]
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Placement config: all-attention, all-sliding_window, every2nd-kda
2+
#
3+
# Three placements:
4+
# 1. all-attention: Full attention (baseline)
5+
# 2. all-sliding_window: Sliding window attention
6+
# (O(n*w) vs O(n²) compute; uses same FlashAttention backend with window mask)
7+
# 3. every2nd-kda: Alternating attention and KDA (Key-Dependent Attention)
8+
#
9+
# Usage:
10+
# python thin_supernet.py ~/.cache/huggingface/apriel2-0.5b-dev \
11+
# ~/.cache/huggingface/apriel2-0.5b-thinned-attn-kda-swa \
12+
# -p examples/placements/all_attn_every2_kda_swa.yaml
13+
14+
placements:
15+
# All attention (24 layers) - baseline
16+
- [attention, attention, attention, attention, attention, attention,
17+
attention, attention, attention, attention, attention, attention,
18+
attention, attention, attention, attention, attention, attention,
19+
attention, attention, attention, attention, attention, attention]
20+
# All sliding window - expected faster (limited context, same FlashAttention backend)
21+
- [sliding_window, sliding_window, sliding_window, sliding_window,
22+
sliding_window, sliding_window, sliding_window, sliding_window,
23+
sliding_window, sliding_window, sliding_window, sliding_window,
24+
sliding_window, sliding_window, sliding_window, sliding_window,
25+
sliding_window, sliding_window, sliding_window, sliding_window,
26+
sliding_window, sliding_window, sliding_window, sliding_window]
27+
# Every 2nd layer is KDA: attn, kda, attn, kda, ...
28+
- [attention, kda, attention, kda, attention, kda, attention, kda,
29+
attention, kda, attention, kda, attention, kda, attention, kda,
30+
attention, kda, attention, kda, attention, kda, attention, kda]
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Placement config: attention-heavy budget
2+
#
3+
# Uses mostly full attention with a few efficient layers (gdn) in the middle.
4+
# For layer placement studies: high-quality, moderate compute budget.
5+
#
6+
# Usage:
7+
# python thin_supernet.py ~/.cache/huggingface/apriel2-0.5b-dev output/ \
8+
# -p examples/placements/budget_attention_heavy.yaml
9+
10+
placements:
11+
# 24 layers: attention for first 20, gdn for last 4 (efficient tail)
12+
- [attention, attention, attention, attention, attention,
13+
attention, attention, attention, attention, attention,
14+
attention, attention, attention, attention, attention,
15+
attention, attention, attention, attention, attention,
16+
gdn, gdn, gdn, gdn]
17+
# Alternative: gdn in middle third
18+
- [attention, attention, attention, attention, attention, attention, attention, attention,
19+
gdn, gdn, gdn, gdn, gdn, gdn, gdn, gdn,
20+
attention, attention, attention, attention, attention, attention, attention, attention]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Single placement: all-attention
2+
# Produces a fixed model with no stochastic mixer; enables FULL cudagraph in vLLM.
3+
placements:
4+
- [attention, attention, attention, attention, attention, attention,
5+
attention, attention, attention, attention, attention, attention,
6+
attention, attention, attention, attention, attention, attention,
7+
attention, attention, attention, attention, attention, attention]
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
#!/usr/bin/env python3
2+
"""Test thin_supernet conversion.
3+
4+
Runs unit tests and an integration test. Output goes to /tmp (cleared on reboot).
5+
Supports both pytest and direct execution.
6+
7+
Usage:
8+
pytest fast_llm_external_models/apriel2/test_thin_supernet.py -v
9+
python fast_llm_external_models/apriel2/test_thin_supernet.py
10+
11+
Requires checkpoint at ~/.cache/huggingface/apriel2-0.5b-dev for integration test.
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import json
17+
import sys
18+
import tempfile
19+
from pathlib import Path
20+
21+
import yaml
22+
23+
from fast_llm_external_models.apriel2.thin_supernet import (
24+
build_thin_surgery_config,
25+
compute_required_mixers_per_layer,
26+
load_placement_config,
27+
thin_supernet,
28+
)
29+
30+
# Add project root for imports
31+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
32+
33+
34+
# Try pytest for fixture support
35+
try:
36+
import pytest
37+
38+
HAS_PYTEST = True
39+
except ImportError:
40+
HAS_PYTEST = False
41+
42+
43+
# =============================================================================
44+
# Constants
45+
# =============================================================================
46+
47+
APRIEL2_05B_DEV = Path.home() / ".cache" / "huggingface" / "apriel2-0.5b-dev"
48+
OUTPUT_DIR_TMP = Path("/tmp") / "apriel2-thin-supernet-test"
49+
50+
51+
# =============================================================================
52+
# Unit Tests
53+
# =============================================================================
54+
55+
56+
def test_load_placement_config_placements_format(tmp_path=None):
57+
"""Test loading placements list format."""
58+
tmp_path = Path(tmp_path) if tmp_path is not None else Path(tempfile.mkdtemp())
59+
config = {
60+
"placements": [
61+
["attention", "gdn", "attention"],
62+
["sliding_window", "sliding_window", "attention"],
63+
]
64+
}
65+
path = tmp_path / "placements.yaml"
66+
with open(path, "w") as f:
67+
yaml.dump(config, f)
68+
69+
result = load_placement_config(path)
70+
assert len(result) == 2
71+
assert result[0] == ["attention", "gdn", "attention"]
72+
assert result[1] == ["sliding_window", "sliding_window", "attention"]
73+
74+
75+
def test_load_placement_config_layers_format(tmp_path=None):
76+
"""Test loading single placement via layers format."""
77+
tmp_path = Path(tmp_path) if tmp_path is not None else Path(tempfile.mkdtemp())
78+
config = {"layers": ["attention", "gdn", "attention", "kda"]}
79+
path = tmp_path / "layers.yaml"
80+
with open(path, "w") as f:
81+
yaml.dump(config, f)
82+
83+
result = load_placement_config(path)
84+
assert len(result) == 1
85+
assert result[0] == ["attention", "gdn", "attention", "kda"]
86+
87+
88+
def test_compute_required_mixers_per_layer():
89+
"""Test union of mixers per layer across placements."""
90+
placements = [
91+
["attention", "gdn", "attention"],
92+
["sliding_window", "gdn", "kda"],
93+
["attention", "attention", "attention"],
94+
]
95+
result = compute_required_mixers_per_layer(placements, num_layers=3)
96+
assert result[0] == {"attention", "sliding_window"}
97+
assert result[1] == {"gdn", "attention"}
98+
assert result[2] == {"attention", "kda"}
99+
100+
101+
def _get_supernet_config():
102+
"""Minimal supernet config for unit tests."""
103+
return {
104+
"model_type": "apriel2_text",
105+
"hidden_size": 896,
106+
"decoder": {
107+
"type": "fixed",
108+
"num_blocks": 24,
109+
"block": {
110+
"mixer": {
111+
"type": "stochastic",
112+
"main_mixer_name": "attention",
113+
"mixers": {
114+
"attention": {"type": "attention", "heads": 14},
115+
"sliding_window": {"type": "attention", "window_size": 4096},
116+
"gdn": {"type": "gdn", "convolution_layer": {"kernel_size": 4}},
117+
"kda": {"type": "kda", "convolution_layer": {"kernel_size": 4}},
118+
},
119+
},
120+
"mlp": {"type": "mlp"},
121+
"normalization": {"type": "rms_norm"},
122+
},
123+
},
124+
}
125+
126+
127+
def test_build_thin_surgery_config_all_same():
128+
"""Test surgery config when all layers need same mixers."""
129+
supernet_config = _get_supernet_config()
130+
required = [{"attention", "gdn"}] * 24
131+
surgery = build_thin_surgery_config(supernet_config, required)
132+
assert surgery["decoder"]["type"] == "fixed"
133+
mixers = surgery["decoder"]["block"]["mixer"]["mixers"]
134+
assert set(mixers.keys()) == {"attention", "gdn"}
135+
assert mixers["attention"]["init"] == "transfer"
136+
assert mixers["gdn"]["init"] == "transfer"
137+
138+
139+
def test_build_thin_surgery_config_per_layer():
140+
"""Test surgery config when layers need different mixers."""
141+
supernet_config = _get_supernet_config()
142+
required = [{"attention"}] * 8 + [{"attention", "gdn"}] * 8 + [{"gdn"}] * 8
143+
surgery = build_thin_surgery_config(supernet_config, required)
144+
assert surgery["decoder"]["type"] == "pattern"
145+
assert len(surgery["decoder"]["pattern"]) == 24
146+
assert set(surgery["decoder"]["blocks"]["layer_0"]["mixer"]["mixers"].keys()) == {"attention"}
147+
assert set(surgery["decoder"]["blocks"]["layer_8"]["mixer"]["mixers"].keys()) == {"attention", "gdn"}
148+
assert set(surgery["decoder"]["blocks"]["layer_16"]["mixer"]["mixers"].keys()) == {"gdn"}
149+
150+
151+
# =============================================================================
152+
# Integration Test
153+
# =============================================================================
154+
155+
156+
def test_thin_supernet_integration(tmp_path):
157+
"""Integration test: full thin_supernet conversion (requires checkpoint)."""
158+
if not APRIEL2_05B_DEV.exists():
159+
pytest.skip(f"Checkpoint not found: {APRIEL2_05B_DEV}")
160+
output_dir = tmp_path / "thinned"
161+
run_integration_test(output_dir)
162+
163+
164+
def run_integration_test(output_dir: Path) -> bool:
165+
"""Run thin_supernet full conversion. Returns True on success."""
166+
if not APRIEL2_05B_DEV.exists():
167+
print(f"SKIP: Checkpoint not found: {APRIEL2_05B_DEV}")
168+
return False
169+
170+
placements_dir = Path(__file__).parent / "examples" / "placements"
171+
placement_config = placements_dir / "budget_attention_heavy.yaml"
172+
if not placement_config.exists():
173+
print(f"SKIP: Placement config not found: {placement_config}")
174+
return False
175+
176+
print("=" * 60)
177+
print("Thin Supernet Integration Test")
178+
print("=" * 60)
179+
print(f"Input: {APRIEL2_05B_DEV}")
180+
print(f"Output: {output_dir}")
181+
print(f"Placements: {placement_config}")
182+
print("=" * 60)
183+
184+
# Dry run
185+
print("\n--- Dry run ---")
186+
thin_supernet(
187+
input_dir=APRIEL2_05B_DEV,
188+
output_dir=output_dir,
189+
placement_configs=[placement_config],
190+
dry_run=True,
191+
verbose=True,
192+
)
193+
194+
# Full run
195+
print("\n--- Full conversion ---")
196+
thin_supernet(
197+
input_dir=APRIEL2_05B_DEV,
198+
output_dir=output_dir,
199+
placement_configs=[placement_config],
200+
dry_run=False,
201+
verbose=True,
202+
)
203+
204+
# Verify
205+
print("\n--- Verification ---")
206+
config_path = output_dir / "config.json"
207+
assert config_path.exists(), "config.json not created"
208+
with open(config_path) as f:
209+
saved = json.load(f)
210+
211+
decoder = saved.get("decoder", {})
212+
mixer = decoder.get("block", {}).get("mixer") or decoder.get("blocks", {}).get("layer_0", {}).get("mixer", {})
213+
mixers = mixer.get("mixers", {})
214+
print(f"Decoder type: {decoder.get('type')}")
215+
print(f"Mixers retained: {list(mixers.keys())}")
216+
217+
safetensors = list(output_dir.glob("*.safetensors"))
218+
print(f"Safetensor files: {len(safetensors)}")
219+
220+
assert (output_dir / "tokenizer.json").exists() or (output_dir / "tokenizer_config.json").exists()
221+
print("\n" + "=" * 60)
222+
print("Test PASSED")
223+
print("=" * 60)
224+
return True
225+
226+
227+
# =============================================================================
228+
# Main
229+
# =============================================================================
230+
231+
232+
def main():
233+
"""Run unit tests and integration test."""
234+
# Unit tests (no checkpoint needed)
235+
print("Running unit tests...")
236+
test_load_placement_config_placements_format()
237+
test_load_placement_config_layers_format()
238+
test_compute_required_mixers_per_layer()
239+
test_build_thin_surgery_config_all_same()
240+
test_build_thin_surgery_config_per_layer()
241+
print("Unit tests OK\n")
242+
243+
# Integration test (uses /tmp, requires checkpoint)
244+
success = run_integration_test(OUTPUT_DIR_TMP)
245+
sys.exit(0 if success else 1)
246+
247+
248+
if __name__ == "__main__":
249+
main()

0 commit comments

Comments
 (0)