@@ -21,6 +21,7 @@ public class TikTokenizer : ITokenizer
2121 {
2222
2323 private IReadOnlyDictionary < string , int > SpecialTokensEncoder = null ! ;
24+ private IReadOnlyCollection < string > SpecialTokens = null ! ;
2425 private Regex Regex = null ! ;
2526 private IReadOnlyDictionary < byte [ ] , int > Encoder = null ! ;
2627 private IReadOnlyDictionary < int , byte [ ] > Decoder = null ! ;
@@ -76,6 +77,7 @@ private void Init(IReadOnlyDictionary<byte[], int> encoder, IReadOnlyDictionary<
7677 Regex = new Regex ( pattern , RegexOptions . Compiled ) ;
7778 SpecialTokensRegex = new Regex ( string . Join ( "|" , specialTokensEncoder . Keys . Select ( s => Regex . Escape ( s ) ) ) , RegexOptions . Compiled ) ;
7879 SpecialTokensEncoder = specialTokensEncoder ;
80+ SpecialTokens = specialTokensEncoder . Keys . ToList ( ) ;
7981
8082 Decoder = Encoder . ToDictionary ( kvp => kvp . Value , kvp => kvp . Key ) ;
8183
@@ -136,13 +138,7 @@ private Dictionary<byte[], int> LoadTikTokenBpe(Stream tikTokenBpeFileStream)
136138 return bpeDict ;
137139 }
138140
139- /// <summary>
140- /// Encode a string with a set of allowed special tokens that are not broken apart.
141- /// </summary>
142- /// <param name="text">String to be encoded</param>
143- /// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
144- /// <returns>List of token ids</returns>
145- public List < int > Encode ( string text , IReadOnlyCollection < string > allowedSpecial )
141+ private List < int > EncodeInternal ( string text , IReadOnlyCollection < string > allowedSpecial )
146142 {
147143 var tokenIds = new List < int > ( ) ;
148144 int start = 0 ;
@@ -173,6 +169,43 @@ public List<int> Encode(string text, IReadOnlyCollection<string> allowedSpecial)
173169 return tokenIds ;
174170 }
175171
172+ /// <summary>
173+ /// Encode a string with a set of allowed special tokens that are not broken apart.
174+ /// </summary>
175+ /// <param name="text">String to be encoded</param>
176+ /// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
177+ /// <returns>List of token ids</returns>
178+ public List < int > Encode ( string text , IReadOnlyCollection < string > allowedSpecial )
179+ {
180+ if ( allowedSpecial is null || allowedSpecial . Count == 0 )
181+ {
182+ return Encode ( text , false ) ;
183+ }
184+ return EncodeInternal ( text , allowedSpecial ) ;
185+ }
186+
187+ /// <summary>
188+ /// Encode a string with or without special tokens set through constructor.
189+ /// </summary>
190+ /// <param name="text">String to be encoded</param>
191+ /// <param name="applySpecialTokens">Whether to apply special token processing</param>
192+ /// <returns></returns>
193+ public List < int > Encode ( string text , bool applySpecialTokens = true )
194+ {
195+
196+ if ( applySpecialTokens && SpecialTokens . Count > 0 )
197+ {
198+ return EncodeInternal ( text , SpecialTokens ) ;
199+ }
200+
201+ var tokenIds = new List < int > ( ) ;
202+ int start = 0 ;
203+ Encode ( text , tokenIds , start , text . Length ) ;
204+
205+ return tokenIds ;
206+
207+ }
208+
176209 /// <summary>
177210 /// Encode a special token matched through regex.
178211 /// </summary>
@@ -194,7 +227,7 @@ private int EncodeSpecialToken(List<int> tokenIds, Match nextSpecial)
194227 /// <param name="start">Start search index in the string</param>
195228 /// <param name="nextSpecial">The regex match of a special token</param>
196229 /// <param name="end">The index of the special token matched or the end of the text</param>
197- private void FindNextSpecialToken ( string text , IReadOnlyCollection < string > allowedSpecial , int start , out Match nextSpecial , out int end )
230+ private void FindNextSpecialToken ( string text , IReadOnlyCollection < string > ? allowedSpecial , int start , out Match nextSpecial , out int end )
198231 {
199232 int startFind = start ;
200233 while ( true )
@@ -308,14 +341,7 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
308341 return ( tokenCount , encodeLength ) ;
309342 }
310343
311- /// <summary>
312- /// Encode a piece of text limited by max token count through trimming suffix
313- /// </summary>
314- /// <param name="text">Text to be encoded</param>
315- /// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
316- /// <param name="maxTokenCount">The max token count</param>
317- /// <returns>(List<int> TokenIds, string Text) - Token ids and text after suffix truncation based on max token count</returns>
318- public ( List < int > TokenIds , string Text ) EncodeTrimSuffix ( string text , IReadOnlyCollection < string > allowedSpecial , int maxTokenCount )
344+ private ( List < int > TokenIds , string Text ) EncodeTrimSuffixInternal ( string text , IReadOnlyCollection < string > allowedSpecial , int maxTokenCount )
319345 {
320346 var tokenIds = new List < int > ( ) ;
321347
@@ -367,21 +393,58 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
367393 }
368394
369395 /// <summary>
370- /// Encode a piece of text limited by max token count through trimming prefix
396+ /// Encode a piece of text limited by max token count through trimming suffix
371397 /// </summary>
372398 /// <param name="text">Text to be encoded</param>
373399 /// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
374400 /// <param name="maxTokenCount">The max token count</param>
375- /// <returns>(List<int> TokenIds, string Text) - Token ids and text after prefix truncation based on max token count</returns>
376- public ( List < int > TokenIds , string Text ) EncodeTrimPrefix ( string text , IReadOnlyCollection < string > allowedSpecial , int maxTokenCount )
401+ /// <returns>(List<int> TokenIds, string Text) - Token ids and text after suffix truncation based on max token count</returns>
402+ public ( List < int > TokenIds , string Text ) EncodeTrimSuffix ( string text , IReadOnlyCollection < string > allowedSpecial , int maxTokenCount )
403+ {
404+ if ( allowedSpecial is null || allowedSpecial . Count == 0 )
405+ {
406+ return EncodeTrimSuffix ( text , maxTokenCount , false ) ;
407+ }
408+
409+ return EncodeTrimSuffixInternal ( text , allowedSpecial , maxTokenCount ) ;
410+
411+ }
412+
413+ /// <summary>
414+ /// Encode a piece of text limited by max token count through trimming suffix, with or without special tokens set through constructor.
415+ /// </summary>
416+ /// <param name="text">String to be encoded</param>
417+ /// <param name="maxTokenCount">The max token count</param>
418+ /// <param name="applySpecialTokens">Whether to apply special token processing</param>
419+ /// <returns></returns>
420+ public ( List < int > TokenIds , string Text ) EncodeTrimSuffix ( string text , int maxTokenCount , bool applySpecialTokens = true )
421+ {
422+ if ( applySpecialTokens && SpecialTokens . Count > 0 )
423+ {
424+ return EncodeTrimSuffixInternal ( text , SpecialTokens , maxTokenCount ) ;
425+ }
426+
427+ var tokenIds = new List < int > ( ) ;
428+ int start = 0 ;
429+ int tokenCount = 0 ;
430+ var encodeLength = 0 ;
431+ ( _ , encodeLength ) = EncodeTrimSuffix ( text , tokenIds , start , text . Length , maxTokenCount , tokenCount , encodeLength ) ;
432+ var encodedText = encodeLength == text . Length ? text : text [ ..encodeLength ] ;
433+
434+ return ( tokenIds , encodedText ) ;
435+ }
436+
437+ private ( List < int > TokenIds , string Text ) EncodeTrimPrefixInternal ( string text , IReadOnlyCollection < string > allowedSpecial , int maxTokenCount )
377438 {
378439 var tokenIds = new List < int > ( ) ;
379440
380441 int start = 0 ;
381442 int tokenCount = 0 ;
382443 var encodeLength = 0 ;
383- var tokenCountMap = new SortedDictionary < int , int > ( ) ;
384- tokenCountMap . Add ( tokenCount , encodeLength ) ;
444+ var tokenCountMap = new SortedDictionary < int , int >
445+ {
446+ { tokenCount , encodeLength }
447+ } ;
385448 while ( true )
386449 {
387450 Match nextSpecial ;
@@ -390,39 +453,7 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
390453
391454 if ( end > start )
392455 {
393- foreach ( Match match in Regex . Matches ( text [ start ..end ] ) )
394- {
395- var piece = match . Value ;
396-
397- if ( this . Cache . Lookup ( match . Value , out int [ ] tokens ) )
398- {
399- tokenCount += tokens . Length ;
400- encodeLength += piece . Length ;
401- tokenIds . AddRange ( tokens ) ;
402- tokenCountMap [ tokenCount ] = encodeLength ;
403- }
404- else
405- {
406- var bytes = Encoding . UTF8 . GetBytes ( piece ) ;
407- if ( Encoder . TryGetValue ( bytes , out int token ) )
408- {
409- tokenCount ++ ;
410- encodeLength += piece . Length ;
411- tokenIds . Add ( token ) ;
412- tokenCountMap [ tokenCount ] = encodeLength ;
413-
414- }
415- else
416- {
417- var encodedTokens = BytePairEncoder . BytePairEncode ( bytes , Encoder ) ;
418- this . Cache . Add ( piece , encodedTokens . ToArray ( ) ) ;
419- tokenCount += encodedTokens . Count ;
420- encodeLength += piece . Length ;
421- tokenIds . AddRange ( encodedTokens ) ;
422- tokenCountMap [ tokenCount ] = encodeLength ;
423- }
424- }
425- }
456+ Encode ( text , tokenIds , start , ref tokenCount , ref encodeLength , tokenCountMap , end ) ;
426457 }
427458
428459 if ( nextSpecial . Success )
@@ -442,6 +473,11 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
442473 }
443474 }
444475
476+ return TrimPrefix ( text , maxTokenCount , tokenIds , tokenCount , tokenCountMap ) ;
477+ }
478+
479+ private static ( List < int > TokenIds , string Text ) TrimPrefix ( string text , int maxTokenCount , List < int > tokenIds , int tokenCount , SortedDictionary < int , int > tokenCountMap )
480+ {
445481 if ( tokenCount <= maxTokenCount )
446482 {
447483 return ( tokenIds , text ) ;
@@ -463,6 +499,85 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
463499 return ( tokenIds . Skip ( actualPrefixTokenCount ) . ToList ( ) , text [ actualPrefixStrLength ..] ) ;
464500 }
465501
502+ private void Encode ( string text , List < int > tokenIds , int start , ref int tokenCount , ref int encodeLength , SortedDictionary < int , int > tokenCountMap , int end )
503+ {
504+ foreach ( Match match in Regex . Matches ( text [ start ..end ] ) )
505+ {
506+ var piece = match . Value ;
507+
508+ if ( this . Cache . Lookup ( match . Value , out int [ ] tokens ) )
509+ {
510+ tokenCount += tokens . Length ;
511+ encodeLength += piece . Length ;
512+ tokenIds . AddRange ( tokens ) ;
513+ tokenCountMap [ tokenCount ] = encodeLength ;
514+ }
515+ else
516+ {
517+ var bytes = Encoding . UTF8 . GetBytes ( piece ) ;
518+ if ( Encoder . TryGetValue ( bytes , out int token ) )
519+ {
520+ tokenCount ++ ;
521+ encodeLength += piece . Length ;
522+ tokenIds . Add ( token ) ;
523+ tokenCountMap [ tokenCount ] = encodeLength ;
524+
525+ }
526+ else
527+ {
528+ var encodedTokens = BytePairEncoder . BytePairEncode ( bytes , Encoder ) ;
529+ this . Cache . Add ( piece , encodedTokens . ToArray ( ) ) ;
530+ tokenCount += encodedTokens . Count ;
531+ encodeLength += piece . Length ;
532+ tokenIds . AddRange ( encodedTokens ) ;
533+ tokenCountMap [ tokenCount ] = encodeLength ;
534+ }
535+ }
536+ }
537+ }
538+
539+ /// <summary>
540+ /// Encode a piece of text limited by max token count through trimming prefix
541+ /// </summary>
542+ /// <param name="text">Text to be encoded</param>
543+ /// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
544+ /// <param name="maxTokenCount">The max token count</param>
545+ /// <returns>(List<int> TokenIds, string Text) - Token ids and text after prefix truncation based on max token count</returns>
546+ public ( List < int > TokenIds , string Text ) EncodeTrimPrefix ( string text , IReadOnlyCollection < string > allowedSpecial , int maxTokenCount )
547+ {
548+ if ( allowedSpecial is null || allowedSpecial . Count == 0 )
549+ {
550+ return EncodeTrimPrefix ( text , maxTokenCount , false ) ;
551+ }
552+ return EncodeTrimPrefixInternal ( text , allowedSpecial , maxTokenCount ) ;
553+ }
554+
555+ /// <summary>
556+ /// Encode a piece of text limited by max token count through trimming prefix, with or without special tokens set through constructor.
557+ /// </summary>
558+ /// <param name="text">Text to be encoded</param>
559+ /// <param name="maxTokenCount">The max token count</param>
560+ /// <param name="applySpecialTokens">Whether to apply special token processing</param>
561+ /// <returns></returns>
562+ public ( List < int > TokenIds , string Text ) EncodeTrimPrefix ( string text , int maxTokenCount , bool applySpecialTokens = true )
563+ {
564+ if ( applySpecialTokens && SpecialTokens . Count > 0 )
565+ {
566+ return EncodeTrimPrefixInternal ( text , SpecialTokens , maxTokenCount ) ;
567+ }
568+ var tokenIds = new List < int > ( ) ;
569+
570+ int start = 0 ;
571+ int tokenCount = 0 ;
572+ var encodeLength = 0 ;
573+ var tokenCountMap = new SortedDictionary < int , int >
574+ {
575+ { tokenCount , encodeLength }
576+ } ;
577+ Encode ( text , tokenIds , start , ref tokenCount , ref encodeLength , tokenCountMap , text . Length ) ;
578+ return TrimPrefix ( text , maxTokenCount , tokenIds , tokenCount , tokenCountMap ) ;
579+ }
580+
466581 /// <summary>
467582 /// Decode an array of integer token ids
468583 /// </summary>
0 commit comments