Skip to content

Commit f0461ab

Browse files
apply autoformatter.sh checks
Signed-off-by: Sudhakar Singh <[email protected]>
1 parent c3e83ad commit f0461ab

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

tests/unit_tests/transformer/test_rope.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
from packaging.version import Version as PkgVersion
66
from pytest_mock import mocker
77

8+
from megatron.core.extensions.transformer_engine import (
9+
fused_apply_rotary_pos_emb,
10+
fused_apply_rotary_pos_emb_thd,
11+
)
812
from megatron.core.models.common.embeddings import apply_rotary_pos_emb
913
from megatron.core.models.common.embeddings.rotary_pos_embedding import (
1014
MultimodalRotaryEmbedding,
1115
RotaryEmbedding,
1216
)
1317
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
1418
from megatron.core.transformer.transformer_config import TransformerConfig
15-
from megatron.core.extensions.transformer_engine import (
16-
fused_apply_rotary_pos_emb,
17-
fused_apply_rotary_pos_emb_thd,
18-
)
1919

2020
try:
2121
from transformer_engine.pytorch.attention.rope import apply_fused_qkv_rotary_pos_emb
@@ -118,13 +118,14 @@ def test_transformer_engine_version_less_than_2_10(self, mocker):
118118
t = torch.randn(64, 1, 8)
119119
freqs = torch.randn(64, 1, 1, 8)
120120
cu_seqlens = torch.tensor([0, 64])
121-
fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs, start_positions=torch.tensor([0,]))
121+
fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs, start_positions=torch.tensor([0]))
122122

123123
assert str(exc_info_thd.value) == (
124124
"Only TE >= 2.10.0.dev0 supports offset RoPE application with "
125125
"`start_positions` argument."
126126
)
127127

128+
128129
class TestQKVRotaryEmbedding:
129130
def setup_method(self):
130131
Utils.initialize_model_parallel(1, 1)

0 commit comments

Comments
 (0)