Skip to content

Commit ec55a61

Browse files
committed
Use XGrammar for structured output generation
1 parent e2a7ed0 commit ec55a61

File tree

4 files changed

+220
-11
lines changed

4 files changed

+220
-11
lines changed

Package.resolved

Lines changed: 19 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Package.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ let package = Package(
3131
.package(url: "https://github.com/huggingface/swift-transformers", from: "1.0.0"),
3232
.package(url: "https://github.com/mattt/EventSource", from: "1.3.0"),
3333
.package(url: "https://github.com/mattt/JSONSchema", from: "1.3.0"),
34+
.package(url: "https://github.com/mattt/swift-xgrammar", from: "0.1.0"),
3435
.package(url: "https://github.com/mattt/llama.swift", .upToNextMajor(from: "2.7484.0")),
3536
.package(url: "https://github.com/mattt/PartialJSONDecoder", from: "1.0.0"),
3637
// mlx-swift-lm must be >= 2.25.5 for ToolSpec/tool calls and UserInput(chat:processing:tools:).
@@ -45,6 +46,7 @@ let package = Package(
4546
.product(name: "EventSource", package: "EventSource"),
4647
.product(name: "JSONSchema", package: "JSONSchema"),
4748
.product(name: "PartialJSONDecoder", package: "PartialJSONDecoder"),
49+
.product(name: "XGrammar", package: "swift-xgrammar"),
4850
.product(
4951
name: "MLXLLM",
5052
package: "mlx-swift-lm",

Sources/AnyLanguageModel/Models/LlamaLanguageModel.swift

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import Foundation
22
#if Llama
33
import JSONSchema
44
import LlamaSwift
5+
import XGrammar
56

67
/// Global storage for the current log level threshold.
78
/// This is needed because the C callback can't capture Swift context.
@@ -532,7 +533,7 @@ import Foundation
532533
)
533534
} else {
534535
let maxTokens = structuredOptions.maximumResponseTokens ?? 512
535-
let jsonString = try generateStructuredJSON(
536+
let jsonString = try await generateStructuredJSON(
536537
context: context,
537538
prompt: fullPrompt,
538539
schema: type.generationSchema,
@@ -913,6 +914,41 @@ import Foundation
913914
return "\(header):\n\(schemaJSON)"
914915
}
915916

917+
private func jsonSchemaString(for schema: GenerationSchema) throws -> String {
918+
let encoder = JSONEncoder()
919+
encoder.outputFormatting = [.sortedKeys]
920+
let data = try encoder.encode(schema)
921+
guard let jsonSchema = String(data: data, encoding: .utf8) else {
922+
throw LlamaLanguageModelError.schemaEncodingFailed
923+
}
924+
return jsonSchema
925+
}
926+
927+
private func tokenizerInfo(
928+
for vocab: OpaquePointer,
929+
vocabSize: Int,
930+
stopTokens: Set<Int>
931+
) throws -> TokenizerInfo {
932+
guard vocabSize > 0 else {
933+
throw LlamaLanguageModelError.contextInitializationFailed
934+
}
935+
936+
var encodedVocab: [String] = []
937+
encodedVocab.reserveCapacity(vocabSize)
938+
for tokenId in 0 ..< vocabSize {
939+
let token = llama_token(tokenId)
940+
encodedVocab.append(tokenToText(vocab: vocab, token: token) ?? "")
941+
}
942+
943+
let stopTokenIDs = stopTokens.map { Int32($0) }
944+
return try TokenizerInfo(
945+
encodedVocab: encodedVocab,
946+
encoding: .byteFallback,
947+
stopTokenIDs: stopTokenIDs,
948+
addPrefixSpace: false
949+
)
950+
}
951+
916952
// MARK: - Structured JSON Generation
917953

918954
private func generateStructuredJSON(
@@ -921,7 +957,7 @@ import Foundation
921957
schema: GenerationSchema,
922958
maxTokens: Int,
923959
options: ResolvedGenerationOptions
924-
) throws -> String {
960+
) async throws -> String {
925961
guard let vocab = llama_model_get_vocab(model!) else {
926962
throw LlamaLanguageModelError.contextInitializationFailed
927963
}
@@ -964,21 +1000,56 @@ import Foundation
9641000

9651001
let vocabSize = Int(llama_vocab_n_tokens(vocab))
9661002
let initialPosition: Int32 = hasEncoder ? 1 : batch.n_tokens
1003+
let jsonSchema = try jsonSchemaString(for: schema)
1004+
let grammar = Grammar(jsonSchema: jsonSchema, formatting: .compact, strictMode: true)
1005+
let eosToken = Int(llama_vocab_eos(vocab))
1006+
let eotTokenValue = llama_vocab_eot(vocab)
1007+
let endOfTurnToken = eotTokenValue != LLAMA_TOKEN_NULL ? Int(eotTokenValue) : eosToken
1008+
let endTokens: Set<Int> = [eosToken, endOfTurnToken]
1009+
1010+
let tokenizerInfo = try tokenizerInfo(
1011+
for: vocab,
1012+
vocabSize: vocabSize,
1013+
stopTokens: endTokens
1014+
)
1015+
let matcher = try await grammar.matcher(
1016+
for: tokenizerInfo,
1017+
stopTokens: endTokens.map { Int32($0) },
1018+
terminatesWithoutStopToken: true
1019+
)
1020+
var bitmask = Grammar.Matcher.TokenBitmask(vocabSize: vocabSize)
9671021

9681022
return try withUnsafeMutablePointer(to: &batch) { batchPointer in
969-
let backend = LlamaTokenBackend(
1023+
var backend = LlamaTokenBackend(
9701024
context: context,
9711025
vocab: vocab,
9721026
vocabSize: vocabSize,
9731027
sampler: samplerPointer,
9741028
batch: batchPointer,
9751029
position: initialPosition,
9761030
maximumTokens: maxTokens,
977-
endTokens: [],
1031+
endTokens: endTokens,
9781032
tokenToTextFn: { [self] token in self.tokenToText(vocab: vocab, token: llama_token(token)) }
9791033
)
980-
var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema)
981-
return try generator.generate()
1034+
1035+
var output = ""
1036+
while backend.remainingTokens > 0 {
1037+
bitmask.reset()
1038+
let needsMask = matcher.fillNextTokenBitmask(&bitmask)
1039+
let token = try backend.sample(using: bitmask, applyMask: needsMask)
1040+
if backend.endTokens.contains(token) {
1041+
break
1042+
}
1043+
guard matcher.accept(Int32(token)) else {
1044+
throw LlamaLanguageModelError.grammarMismatch
1045+
}
1046+
if let tokenText = backend.tokenText(token) {
1047+
output += tokenText
1048+
}
1049+
try backend.decode(token)
1050+
if matcher.isTerminated { break }
1051+
}
1052+
return output
9821053
}
9831054
}
9841055

@@ -1105,6 +1176,21 @@ import Foundation
11051176
}
11061177
}
11071178

1179+
mutating func sample(using bitmask: Grammar.Matcher.TokenBitmask, applyMask: Bool) throws -> Int {
1180+
guard let logits = llama_get_logits(context) else {
1181+
return eosToken
1182+
}
1183+
1184+
if applyMask {
1185+
for tokenIndex in 0 ..< vocabSize where !bitmask.isTokenAllowed(tokenIndex) {
1186+
logits[tokenIndex] = -Float.infinity
1187+
}
1188+
}
1189+
1190+
let tokenIndex = batch.pointee.n_tokens - 1
1191+
return Int(llama_sampler_sample(sampler, context, tokenIndex))
1192+
}
1193+
11081194
mutating func sample(from allowedTokens: Set<Int>) throws -> Int {
11091195
guard let logits = llama_get_logits(context) else {
11101196
return eosToken
@@ -1536,6 +1622,8 @@ import Foundation
15361622
case insufficientMemory
15371623
case unsupportedFeature
15381624
case encoderOnlyModel
1625+
case schemaEncodingFailed
1626+
case grammarMismatch
15391627

15401628
public var errorDescription: String? {
15411629
switch self {
@@ -1557,6 +1645,10 @@ import Foundation
15571645
return "This LlamaLanguageModel does not support image segments"
15581646
case .encoderOnlyModel:
15591647
return "This model is encoder-only (e.g., BERT) and cannot generate text"
1648+
case .schemaEncodingFailed:
1649+
return "Failed to encode the JSON schema for structured generation"
1650+
case .grammarMismatch:
1651+
return "Grammar constraints could not be satisfied during generation"
15601652
}
15611653
}
15621654
}

Sources/AnyLanguageModel/Models/MLXLanguageModel.swift

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import Foundation
1717
import MLXVLM
1818
import Tokenizers
1919
import Hub
20+
import XGrammar
2021

2122
/// Wrapper to store ModelContext in NSCache (requires NSObject subclass).
2223
private final class CachedContext: NSObject, @unchecked Sendable {
@@ -782,19 +783,59 @@ import Foundation
782783
return "\(header):\n\(schemaJSON)"
783784
}
784785

786+
private func jsonSchemaString(for schema: GenerationSchema) throws -> String {
787+
let encoder = JSONEncoder()
788+
encoder.outputFormatting = [.sortedKeys]
789+
let data = try encoder.encode(schema)
790+
guard let jsonSchema = String(data: data, encoding: .utf8) else {
791+
throw MLXLanguageModelError.schemaEncodingFailed
792+
}
793+
return jsonSchema
794+
}
795+
796+
private func tokenizerInfo(
797+
for tokenizer: any Tokenizer,
798+
vocabSize: Int,
799+
stopTokens: Set<Int>
800+
) throws -> TokenizerInfo {
801+
guard vocabSize > 0 else {
802+
throw MLXLanguageModelError.invalidVocabSize
803+
}
804+
805+
var encodedVocab: [String] = []
806+
encodedVocab.reserveCapacity(vocabSize)
807+
for tokenId in 0 ..< vocabSize {
808+
encodedVocab.append(tokenizer.convertIdToToken(tokenId) ?? "")
809+
}
810+
811+
let stopTokenIDs = stopTokens.map { Int32($0) }
812+
return try TokenizerInfo(
813+
encodedVocab: encodedVocab,
814+
encoding: .byteLevel,
815+
stopTokenIDs: stopTokenIDs,
816+
addPrefixSpace: false
817+
)
818+
}
819+
785820
// MARK: - Structured JSON Generation
786821

787822
/// Errors that can occur when using MLXLanguageModel.
788823
public enum MLXLanguageModelError: Error, LocalizedError {
789824
case invalidVocabSize
790825
case unsupportedJSONValueType
826+
case schemaEncodingFailed
827+
case grammarMismatch
791828

792829
public var errorDescription: String? {
793830
switch self {
794831
case .invalidVocabSize:
795832
return "Invalid vocabulary size for model output"
796833
case .unsupportedJSONValueType:
797834
return "Unsupported JSON value type for schema conversion"
835+
case .schemaEncodingFailed:
836+
return "Failed to encode the JSON schema for structured generation"
837+
case .grammarMismatch:
838+
return "Grammar constraints could not be satisfied during generation"
798839
}
799840
}
800841
}
@@ -827,13 +868,42 @@ import Foundation
827868
maximumTokens: maxTokens,
828869
endTokens: []
829870
)
830-
831-
var generator = try ConstrainedJSONGenerator(backend: backend, schema: schema)
832-
let json = try generator.generate()
871+
let jsonSchema = try jsonSchemaString(for: schema)
872+
let grammar = Grammar(jsonSchema: jsonSchema, formatting: .compact, strictMode: true)
873+
let tokenizerInfo = try tokenizerInfo(
874+
for: context.tokenizer,
875+
vocabSize: backend.vocabSize,
876+
stopTokens: backend.endTokens
877+
)
878+
let matcher = try await grammar.matcher(
879+
for: tokenizerInfo,
880+
stopTokens: backend.endTokens.map { Int32($0) },
881+
terminatesWithoutStopToken: true
882+
)
883+
var bitmask = Grammar.Matcher.TokenBitmask(vocabSize: tokenizerInfo.vocabulary.size)
884+
885+
var backendState = backend
886+
var output = ""
887+
while backendState.remainingTokens > 0 {
888+
bitmask.reset()
889+
let needsMask = matcher.fillNextTokenBitmask(&bitmask)
890+
let token = try backendState.sample(using: bitmask, applyMask: needsMask)
891+
if backendState.endTokens.contains(token) {
892+
break
893+
}
894+
guard matcher.accept(Int32(token)) else {
895+
throw MLXLanguageModelError.grammarMismatch
896+
}
897+
if let tokenText = backendState.tokenText(token) {
898+
output += tokenText
899+
}
900+
try backendState.decode(token)
901+
if matcher.isTerminated { break }
902+
}
833903
// Ensure pending MLX operations complete before returning JSON.
834904
// This synchronization can be a performance cost if called frequently.
835905
Stream().synchronize()
836-
return json
906+
return output
837907
}
838908

839909
/// Merges system prompts and schema instructions into a user message.
@@ -1038,6 +1108,33 @@ import Foundation
10381108
}
10391109
}
10401110

1111+
mutating func sample(using bitmask: Grammar.Matcher.TokenBitmask, applyMask: Bool) throws -> Int {
1112+
var logits = currentLogits[0..., -1, 0...]
1113+
logits = processor?.process(logits: logits) ?? logits
1114+
if logits.dtype == .bfloat16 {
1115+
logits = logits.asType(.float32)
1116+
}
1117+
1118+
if applyMask {
1119+
var allowedIndices: [UInt32] = []
1120+
allowedIndices.reserveCapacity(vocabSize)
1121+
for tokenId in 0 ..< vocabSize where bitmask.isTokenAllowed(tokenId) {
1122+
allowedIndices.append(UInt32(tokenId))
1123+
}
1124+
guard !allowedIndices.isEmpty else {
1125+
throw MLXLanguageModelError.grammarMismatch
1126+
}
1127+
let allowedArray = MLXArray(allowedIndices)
1128+
let maskedLogits = full(logits.shape, values: -Float.infinity)
1129+
maskedLogits[0..., allowedArray] = logits[0..., allowedArray]
1130+
let sampledToken = sampler.sample(logits: maskedLogits)
1131+
return sampledToken.item(Int.self)
1132+
}
1133+
1134+
let sampledToken = sampler.sample(logits: logits)
1135+
return sampledToken.item(Int.self)
1136+
}
1137+
10411138
mutating func sample(from allowedTokens: Set<Int>) throws -> Int {
10421139
guard !allowedTokens.isEmpty else {
10431140
throw ConstrainedGenerationError.tokenizationFailed

0 commit comments

Comments
 (0)