@@ -2,6 +2,7 @@ import Foundation
22#if Llama
33 import JSONSchema
44 import LlamaSwift
5+ import XGrammar
56
67 /// Global storage for the current log level threshold.
78 /// This is needed because the C callback can't capture Swift context.
@@ -532,7 +533,7 @@ import Foundation
532533 )
533534 } else {
534535 let maxTokens = structuredOptions. maximumResponseTokens ?? 512
535- let jsonString = try generateStructuredJSON (
536+ let jsonString = try await generateStructuredJSON (
536537 context: context,
537538 prompt: fullPrompt,
538539 schema: type. generationSchema,
@@ -913,6 +914,41 @@ import Foundation
913914 return " \( header) : \n \( schemaJSON) "
914915 }
915916
917+ private func jsonSchemaString( for schema: GenerationSchema ) throws -> String {
918+ let encoder = JSONEncoder ( )
919+ encoder. outputFormatting = [ . sortedKeys]
920+ let data = try encoder. encode ( schema)
921+ guard let jsonSchema = String ( data: data, encoding: . utf8) else {
922+ throw LlamaLanguageModelError . schemaEncodingFailed
923+ }
924+ return jsonSchema
925+ }
926+
927+ private func tokenizerInfo(
928+ for vocab: OpaquePointer ,
929+ vocabSize: Int ,
930+ stopTokens: Set < Int >
931+ ) throws -> TokenizerInfo {
932+ guard vocabSize > 0 else {
933+ throw LlamaLanguageModelError . contextInitializationFailed
934+ }
935+
936+ var encodedVocab : [ String ] = [ ]
937+ encodedVocab. reserveCapacity ( vocabSize)
938+ for tokenId in 0 ..< vocabSize {
939+ let token = llama_token ( tokenId)
940+ encodedVocab. append ( tokenToText ( vocab: vocab, token: token) ?? " " )
941+ }
942+
943+ let stopTokenIDs = stopTokens. map { Int32 ( $0) }
944+ return try TokenizerInfo (
945+ encodedVocab: encodedVocab,
946+ encoding: . byteFallback,
947+ stopTokenIDs: stopTokenIDs,
948+ addPrefixSpace: false
949+ )
950+ }
951+
916952 // MARK: - Structured JSON Generation
917953
918954 private func generateStructuredJSON(
@@ -921,7 +957,7 @@ import Foundation
921957 schema: GenerationSchema ,
922958 maxTokens: Int ,
923959 options: ResolvedGenerationOptions
924- ) throws -> String {
960+ ) async throws -> String {
925961 guard let vocab = llama_model_get_vocab ( model!) else {
926962 throw LlamaLanguageModelError . contextInitializationFailed
927963 }
@@ -964,21 +1000,56 @@ import Foundation
9641000
9651001 let vocabSize = Int ( llama_vocab_n_tokens ( vocab) )
9661002 let initialPosition : Int32 = hasEncoder ? 1 : batch. n_tokens
1003+ let jsonSchema = try jsonSchemaString ( for: schema)
1004+ let grammar = Grammar ( jsonSchema: jsonSchema, formatting: . compact, strictMode: true )
1005+ let eosToken = Int ( llama_vocab_eos ( vocab) )
1006+ let eotTokenValue = llama_vocab_eot ( vocab)
1007+ let endOfTurnToken = eotTokenValue != LLAMA_TOKEN_NULL ? Int ( eotTokenValue) : eosToken
1008+ let endTokens : Set < Int > = [ eosToken, endOfTurnToken]
1009+
1010+ let tokenizerInfo = try tokenizerInfo (
1011+ for: vocab,
1012+ vocabSize: vocabSize,
1013+ stopTokens: endTokens
1014+ )
1015+ let matcher = try await grammar. matcher (
1016+ for: tokenizerInfo,
1017+ stopTokens: endTokens. map { Int32 ( $0) } ,
1018+ terminatesWithoutStopToken: true
1019+ )
1020+ var bitmask = Grammar . Matcher. TokenBitmask ( vocabSize: vocabSize)
9671021
9681022 return try withUnsafeMutablePointer ( to: & batch) { batchPointer in
969- let backend = LlamaTokenBackend (
1023+ var backend = LlamaTokenBackend (
9701024 context: context,
9711025 vocab: vocab,
9721026 vocabSize: vocabSize,
9731027 sampler: samplerPointer,
9741028 batch: batchPointer,
9751029 position: initialPosition,
9761030 maximumTokens: maxTokens,
977- endTokens: [ ] ,
1031+ endTokens: endTokens ,
9781032 tokenToTextFn: { [ self ] token in self . tokenToText ( vocab: vocab, token: llama_token ( token) ) }
9791033 )
980- var generator = try ConstrainedJSONGenerator ( backend: backend, schema: schema)
981- return try generator. generate ( )
1034+
1035+ var output = " "
1036+ while backend. remainingTokens > 0 {
1037+ bitmask. reset ( )
1038+ let needsMask = matcher. fillNextTokenBitmask ( & bitmask)
1039+ let token = try backend. sample ( using: bitmask, applyMask: needsMask)
1040+ if backend. endTokens. contains ( token) {
1041+ break
1042+ }
1043+ guard matcher. accept ( Int32 ( token) ) else {
1044+ throw LlamaLanguageModelError . grammarMismatch
1045+ }
1046+ if let tokenText = backend. tokenText ( token) {
1047+ output += tokenText
1048+ }
1049+ try backend. decode ( token)
1050+ if matcher. isTerminated { break }
1051+ }
1052+ return output
9821053 }
9831054 }
9841055
@@ -1105,6 +1176,21 @@ import Foundation
11051176 }
11061177 }
11071178
1179+ mutating func sample( using bitmask: Grammar . Matcher . TokenBitmask , applyMask: Bool ) throws -> Int {
1180+ guard let logits = llama_get_logits ( context) else {
1181+ return eosToken
1182+ }
1183+
1184+ if applyMask {
1185+ for tokenIndex in 0 ..< vocabSize where !bitmask. isTokenAllowed ( tokenIndex) {
1186+ logits [ tokenIndex] = - Float. infinity
1187+ }
1188+ }
1189+
1190+ let tokenIndex = batch. pointee. n_tokens - 1
1191+ return Int ( llama_sampler_sample ( sampler, context, tokenIndex) )
1192+ }
1193+
11081194 mutating func sample( from allowedTokens: Set < Int > ) throws -> Int {
11091195 guard let logits = llama_get_logits ( context) else {
11101196 return eosToken
@@ -1536,6 +1622,8 @@ import Foundation
15361622 case insufficientMemory
15371623 case unsupportedFeature
15381624 case encoderOnlyModel
1625+ case schemaEncodingFailed
1626+ case grammarMismatch
15391627
15401628 public var errorDescription : String ? {
15411629 switch self {
@@ -1557,6 +1645,10 @@ import Foundation
15571645 return " This LlamaLanguageModel does not support image segments "
15581646 case . encoderOnlyModel:
15591647 return " This model is encoder-only (e.g., BERT) and cannot generate text "
1648+ case . schemaEncodingFailed:
1649+ return " Failed to encode the JSON schema for structured generation "
1650+ case . grammarMismatch:
1651+ return " Grammar constraints could not be satisfied during generation "
15601652 }
15611653 }
15621654 }
0 commit comments