Skip to content

Commit c3e83ad

Browse files
feat: Apply RoPE embedding with sequence offsets
Signed-off-by: Sudhakar Singh <[email protected]>
1 parent cb8f94e commit c3e83ad

File tree

2 files changed

+65
-17
lines changed

2 files changed

+65
-17
lines changed

megatron/core/extensions/transformer_engine.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,47 +1943,65 @@ def fused_apply_rotary_pos_emb(
19431943
freqs: torch.Tensor,
19441944
transpose_output_memory: bool = False,
19451945
interleaved: bool = False,
1946+
start_positions: Optional[torch.Tensor] = None,
19461947
) -> torch.Tensor:
19471948
"""Apply rotary positional embedding to input tensor T in `sbhd` format."""
19481949
if transpose_output_memory:
19491950
warnings.warn(
19501951
"transpose_output_memory is not supported by TE's fused RoPE and will be ignored."
19511952
)
1953+
1954+
conditional_kwargs = {}
1955+
if is_te_min_version("2.10.0.dev0"):
1956+
conditional_kwargs["start_positions"] = start_positions
1957+
else:
1958+
if start_positions is not None:
1959+
raise ValueError(
1960+
"Only TE >= 2.10.0.dev0 supports offset RoPE application with "
1961+
"`start_positions` argument."
1962+
)
1963+
19521964
if is_te_min_version("2.3.0"):
1953-
return apply_rotary_pos_emb(
1954-
t, freqs, tensor_format="sbhd", interleaved=interleaved, fused=True
1955-
)
1965+
conditional_kwargs["interleaved"] = interleaved
19561966
else:
19571967
if interleaved:
19581968
raise ValueError("Only TE >= 2.3.0 supports interleaved fused RoPE.")
19591969

1960-
return apply_rotary_pos_emb(t, freqs, tensor_format="sbhd", fused=True)
1970+
return apply_rotary_pos_emb(
1971+
t, freqs, tensor_format="sbhd", fused=True, **conditional_kwargs
1972+
)
19611973

19621974
def fused_apply_rotary_pos_emb_thd(
19631975
t: torch.Tensor,
19641976
cu_seqlens: torch.Tensor,
19651977
freqs: torch.Tensor,
19661978
cp_size: int = 1,
19671979
cp_rank: int = 0,
1980+
start_positions: Optional[torch.Tensor] = None,
19681981
) -> torch.Tensor:
19691982
"""
19701983
Apply rotary positional embedding to input tensor T in `thd` format with CP support.
19711984
"""
1985+
conditional_kwargs = {}
1986+
if is_te_min_version("2.10.0.dev0"):
1987+
conditional_kwargs["start_positions"] = start_positions
1988+
else:
1989+
if start_positions is not None:
1990+
raise ValueError(
1991+
"Only TE >= 2.10.0.dev0 supports offset RoPE application with "
1992+
"`start_positions` argument."
1993+
)
1994+
19721995
if is_te_min_version("1.12.0", check_equality=True):
1973-
return apply_rotary_pos_emb(
1974-
t,
1975-
freqs,
1976-
tensor_format="thd",
1977-
fused=True,
1978-
cu_seqlens=cu_seqlens,
1979-
cp_size=cp_size,
1980-
cp_rank=cp_rank,
1981-
)
1996+
conditional_kwargs["cp_size"] = cp_size
1997+
conditional_kwargs["cp_rank"] = cp_rank
19821998
else:
1983-
assert cp_size == 1, "Only TE >= 1.12 supports RoPE fusion for THD format with CP."
1984-
return apply_rotary_pos_emb(
1985-
t, freqs, tensor_format="thd", fused=True, cu_seqlens=cu_seqlens
1986-
)
1999+
if cp_size > 1:
2000+
raise ValueError("Only TE >= 1.12.0 supports CP RoPE application for THD format.")
2001+
2002+
return apply_rotary_pos_emb(
2003+
t, freqs, tensor_format="thd", fused=True, cu_seqlens=cu_seqlens, **conditional_kwargs
2004+
)
19872005

19882006
except ImportError:
19892007
pass

tests/unit_tests/transformer/test_rope.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import pytest
44
import torch
5+
from packaging.version import Version as PkgVersion
6+
from pytest_mock import mocker
57

68
from megatron.core.models.common.embeddings import apply_rotary_pos_emb
79
from megatron.core.models.common.embeddings.rotary_pos_embedding import (
@@ -10,6 +12,10 @@
1012
)
1113
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
1214
from megatron.core.transformer.transformer_config import TransformerConfig
15+
from megatron.core.extensions.transformer_engine import (
16+
fused_apply_rotary_pos_emb,
17+
fused_apply_rotary_pos_emb_thd,
18+
)
1319

1420
try:
1521
from transformer_engine.pytorch.attention.rope import apply_fused_qkv_rotary_pos_emb
@@ -94,6 +100,30 @@ def test_cpu_forward(self):
94100
assert output.dtype == torch.float32
95101
assert output.device.type == 'cuda'
96102

103+
@pytest.mark.internal
104+
def test_transformer_engine_version_less_than_2_10(self, mocker):
105+
with pytest.raises(Exception) as exc_info:
106+
mocker.patch("megatron.core.utils.get_te_version", return_value=PkgVersion("2.9"))
107+
t = torch.randn(64, 1, 1, 8)
108+
freqs = torch.randn(64, 1, 1, 8)
109+
fused_apply_rotary_pos_emb(t, freqs, start_positions=torch.tensor([0, 1, 2, 3]))
110+
111+
assert str(exc_info.value) == (
112+
"Only TE >= 2.10.0.dev0 supports offset RoPE application with "
113+
"`start_positions` argument."
114+
)
115+
116+
with pytest.raises(Exception) as exc_info_thd:
117+
mocker.patch("megatron.core.utils.get_te_version", return_value=PkgVersion("2.9"))
118+
t = torch.randn(64, 1, 8)
119+
freqs = torch.randn(64, 1, 1, 8)
120+
cu_seqlens = torch.tensor([0, 64])
121+
fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs, start_positions=torch.tensor([0,]))
122+
123+
assert str(exc_info_thd.value) == (
124+
"Only TE >= 2.10.0.dev0 supports offset RoPE application with "
125+
"`start_positions` argument."
126+
)
97127

98128
class TestQKVRotaryEmbedding:
99129
def setup_method(self):

0 commit comments

Comments
 (0)