|
5 | 5 | from packaging.version import Version as PkgVersion |
6 | 6 | from pytest_mock import mocker |
7 | 7 |
|
| 8 | +from megatron.core.extensions.transformer_engine import ( |
| 9 | + fused_apply_rotary_pos_emb, |
| 10 | + fused_apply_rotary_pos_emb_thd, |
| 11 | +) |
8 | 12 | from megatron.core.models.common.embeddings import apply_rotary_pos_emb |
9 | 13 | from megatron.core.models.common.embeddings.rotary_pos_embedding import ( |
10 | 14 | MultimodalRotaryEmbedding, |
11 | 15 | RotaryEmbedding, |
12 | 16 | ) |
13 | 17 | from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed |
14 | 18 | from megatron.core.transformer.transformer_config import TransformerConfig |
15 | | -from megatron.core.extensions.transformer_engine import ( |
16 | | - fused_apply_rotary_pos_emb, |
17 | | - fused_apply_rotary_pos_emb_thd, |
18 | | -) |
19 | 19 |
|
20 | 20 | try: |
21 | 21 | from transformer_engine.pytorch.attention.rope import apply_fused_qkv_rotary_pos_emb |
@@ -118,13 +118,14 @@ def test_transformer_engine_version_less_than_2_10(self, mocker): |
118 | 118 | t = torch.randn(64, 1, 8) |
119 | 119 | freqs = torch.randn(64, 1, 1, 8) |
120 | 120 | cu_seqlens = torch.tensor([0, 64]) |
121 | | - fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs, start_positions=torch.tensor([0,])) |
| 121 | + fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs, start_positions=torch.tensor([0])) |
122 | 122 |
|
123 | 123 | assert str(exc_info_thd.value) == ( |
124 | 124 | "Only TE >= 2.10.0.dev0 supports offset RoPE application with " |
125 | 125 | "`start_positions` argument." |
126 | 126 | ) |
127 | 127 |
|
| 128 | + |
128 | 129 | class TestQKVRotaryEmbedding: |
129 | 130 | def setup_method(self): |
130 | 131 | Utils.initialize_model_parallel(1, 1) |
|
0 commit comments