@@ -1943,47 +1943,65 @@ def fused_apply_rotary_pos_emb(
19431943 freqs : torch .Tensor ,
19441944 transpose_output_memory : bool = False ,
19451945 interleaved : bool = False ,
1946+ start_positions : Optional [torch .Tensor ] = None ,
19461947 ) -> torch .Tensor :
19471948 """Apply rotary positional embedding to input tensor T in `sbhd` format."""
19481949 if transpose_output_memory :
19491950 warnings .warn (
19501951 "transpose_output_memory is not supported by TE's fused RoPE and will be ignored."
19511952 )
1953+
1954+ conditional_kwargs = {}
1955+ if is_te_min_version ("2.10.0.dev0" ):
1956+ conditional_kwargs ["start_positions" ] = start_positions
1957+ else :
1958+ if start_positions is not None :
1959+ raise ValueError (
1960+ "Only TE >= 2.10.0.dev0 supports offset RoPE application with "
1961+ "`start_positions` argument."
1962+ )
1963+
19521964 if is_te_min_version ("2.3.0" ):
1953- return apply_rotary_pos_emb (
1954- t , freqs , tensor_format = "sbhd" , interleaved = interleaved , fused = True
1955- )
1965+ conditional_kwargs ["interleaved" ] = interleaved
19561966 else :
19571967 if interleaved :
19581968 raise ValueError ("Only TE >= 2.3.0 supports interleaved fused RoPE." )
19591969
1960- return apply_rotary_pos_emb (t , freqs , tensor_format = "sbhd" , fused = True )
1970+ return apply_rotary_pos_emb (
1971+ t , freqs , tensor_format = "sbhd" , fused = True , ** conditional_kwargs
1972+ )
19611973
19621974 def fused_apply_rotary_pos_emb_thd (
19631975 t : torch .Tensor ,
19641976 cu_seqlens : torch .Tensor ,
19651977 freqs : torch .Tensor ,
19661978 cp_size : int = 1 ,
19671979 cp_rank : int = 0 ,
1980+ start_positions : Optional [torch .Tensor ] = None ,
19681981 ) -> torch .Tensor :
19691982 """
19701983 Apply rotary positional embedding to input tensor T in `thd` format with CP support.
19711984 """
1985+ conditional_kwargs = {}
1986+ if is_te_min_version ("2.10.0.dev0" ):
1987+ conditional_kwargs ["start_positions" ] = start_positions
1988+ else :
1989+ if start_positions is not None :
1990+ raise ValueError (
1991+ "Only TE >= 2.10.0.dev0 supports offset RoPE application with "
1992+ "`start_positions` argument."
1993+ )
1994+
19721995 if is_te_min_version ("1.12.0" , check_equality = True ):
1973- return apply_rotary_pos_emb (
1974- t ,
1975- freqs ,
1976- tensor_format = "thd" ,
1977- fused = True ,
1978- cu_seqlens = cu_seqlens ,
1979- cp_size = cp_size ,
1980- cp_rank = cp_rank ,
1981- )
1996+ conditional_kwargs ["cp_size" ] = cp_size
1997+ conditional_kwargs ["cp_rank" ] = cp_rank
19821998 else :
1983- assert cp_size == 1 , "Only TE >= 1.12 supports RoPE fusion for THD format with CP."
1984- return apply_rotary_pos_emb (
1985- t , freqs , tensor_format = "thd" , fused = True , cu_seqlens = cu_seqlens
1986- )
1999+ if cp_size > 1 :
2000+ raise ValueError ("Only TE >= 1.12.0 supports CP RoPE application for THD format." )
2001+
2002+ return apply_rotary_pos_emb (
2003+ t , freqs , tensor_format = "thd" , fused = True , cu_seqlens = cu_seqlens , ** conditional_kwargs
2004+ )
19872005
19882006except ImportError :
19892007 pass
0 commit comments