Skip to content

Commit 92854ab

Browse files
author
razvan
committed
refactor(llm): simplify OllamaLLMProvider to embedding-only, remove chat model support
1 parent 6e41a41 commit 92854ab

File tree

2 files changed

+21
-107
lines changed

2 files changed

+21
-107
lines changed

pkg/llm/ollama.go

Lines changed: 19 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,16 @@ import (
1212
"github.com/tmc/langchaingo/llms/ollama"
1313
)
1414

15-
// OllamaLLMProvider implements Provider interface for Ollama
15+
// OllamaLLMProvider implements Provider interface for Ollama (embedding-only).
1616
type OllamaLLMProvider struct {
17-
chatModel llms.Model
1817
embedModel llms.Model
19-
chatName string
2018
embedName string
21-
config config.LLMConfig
2219
cachedDim uint64
2320
dimOnce sync.Once
2421
}
2522

26-
// NewOllamaLLMProvider creates a new Ollama provider with separate chat and embedding models
23+
// NewOllamaLLMProvider creates a new Ollama provider configured for embedding only.
2724
func NewOllamaLLMProvider(cfg config.LLMConfig) (*OllamaLLMProvider, error) {
28-
// Server URL
2925
baseURL := cfg.OllamaBaseURL
3026
if baseURL == "" {
3127
baseURL = cfg.BaseURL
@@ -34,68 +30,40 @@ func NewOllamaLLMProvider(cfg config.LLMConfig) (*OllamaLLMProvider, error) {
3430
baseURL = "http://localhost:11434"
3531
}
3632

37-
// Chat model
38-
chatModelName := cfg.OllamaModel
39-
if chatModelName == "" {
40-
chatModelName = cfg.Model
41-
}
42-
43-
// Embedding model
4433
embedModelName := cfg.OllamaEmbed
4534
if embedModelName == "" {
4635
embedModelName = cfg.EmbedModel
4736
}
48-
49-
// Fallback logic
50-
if chatModelName == "" && embedModelName != "" {
51-
chatModelName = embedModelName
37+
// Accept OllamaModel / Model as fallback for backward-compat
38+
if embedModelName == "" {
39+
embedModelName = cfg.OllamaModel
5240
}
53-
if embedModelName == "" && chatModelName != "" {
54-
embedModelName = chatModelName
41+
if embedModelName == "" {
42+
embedModelName = cfg.Model
5543
}
56-
57-
if chatModelName == "" {
58-
return nil, fmt.Errorf("ollama model is required (set ollama_model or ollama_embed)")
44+
if embedModelName == "" {
45+
return nil, fmt.Errorf("ollama model is required (set ollama_embed in config)")
5946
}
6047

61-
// Create chat client
62-
chatClient, err := ollama.New(
48+
embedClient, err := ollama.New(
6349
ollama.WithServerURL(baseURL),
64-
ollama.WithModel(chatModelName),
50+
ollama.WithModel(embedModelName),
6551
)
6652
if err != nil {
67-
return nil, fmt.Errorf("failed to create Ollama chat client: %w", err)
53+
return nil, fmt.Errorf("failed to create Ollama embedding client: %w", err)
6854
}
6955

70-
// Create embedding client (separate if different model)
71-
var embedClient llms.Model
72-
if embedModelName != chatModelName {
73-
embedClient, err = ollama.New(
74-
ollama.WithServerURL(baseURL),
75-
ollama.WithModel(embedModelName),
76-
)
77-
if err != nil {
78-
return nil, fmt.Errorf("failed to create Ollama embedding client: %w", err)
79-
}
80-
log.Printf("🎯 Ollama: chat=%s, embed=%s (dual-model)", chatModelName, embedModelName)
81-
} else {
82-
embedClient = chatClient
83-
log.Printf("🎯 Ollama: model=%s (single-model)", chatModelName)
84-
}
56+
log.Printf("🎯 Ollama: embed=%s", embedModelName)
8557

8658
return &OllamaLLMProvider{
87-
chatModel: chatClient,
8859
embedModel: embedClient,
89-
chatName: chatModelName,
9060
embedName: embedModelName,
91-
config: cfg,
9261
}, nil
9362
}
9463

95-
// Generate generates text using Ollama chat model
96-
func (p *OllamaLLMProvider) Generate(ctx context.Context, prompt string, opts ...GenerateOption) (string, error) {
97-
lcOpts := p.convertOptions(opts...)
98-
return llms.GenerateFromSinglePrompt(ctx, p.chatModel, prompt, lcOpts...)
64+
// Generate is not supported; this provider is embedding-only.
65+
func (p *OllamaLLMProvider) Generate(_ context.Context, _ string, _ ...GenerateOption) (string, error) {
66+
return "", fmt.Errorf("text generation not supported: provider is configured for embedding only")
9967
}
10068

10169
// GetEmbeddingDimension returns the dimension of the embedding model
@@ -192,33 +160,15 @@ func (p *OllamaLLMProvider) lookupHardcodedDimension() uint64 {
192160
}
193161
}
194162

195-
// GenerateStream generates streaming text using Ollama chat model
196-
func (p *OllamaLLMProvider) GenerateStream(ctx context.Context, prompt string, opts ...GenerateOption) (<-chan string, <-chan error) {
163+
// GenerateStream is not supported; this provider is embedding-only.
164+
func (p *OllamaLLMProvider) GenerateStream(_ context.Context, _ string, _ ...GenerateOption) (<-chan string, <-chan error) {
197165
textChan := make(chan string)
198166
errChan := make(chan error, 1)
199-
200167
go func() {
201168
defer close(textChan)
202169
defer close(errChan)
203-
204-
streamFunc := func(ctx context.Context, chunk []byte) error {
205-
select {
206-
case textChan <- string(chunk):
207-
return nil
208-
case <-ctx.Done():
209-
return ctx.Err()
210-
}
211-
}
212-
213-
lcOpts := p.convertOptions(opts...)
214-
lcOpts = append(lcOpts, llms.WithStreamingFunc(streamFunc))
215-
216-
_, err := llms.GenerateFromSinglePrompt(ctx, p.chatModel, prompt, lcOpts...)
217-
if err != nil {
218-
errChan <- err
219-
}
170+
errChan <- fmt.Errorf("text generation not supported: provider is configured for embedding only")
220171
}()
221-
222172
return textChan, errChan
223173
}
224174

@@ -253,39 +203,3 @@ func (p *OllamaLLMProvider) Embed(ctx context.Context, text string) ([]float64,
253203
func (p *OllamaLLMProvider) Name() string {
254204
return "ollama"
255205
}
256-
257-
// convertOptions converts GenerateOption to langchaingo CallOption
258-
func (p *OllamaLLMProvider) convertOptions(opts ...GenerateOption) []llms.CallOption {
259-
genOpts := &GenerateOptions{}
260-
for _, opt := range opts {
261-
opt(genOpts)
262-
}
263-
264-
var lcOpts []llms.CallOption
265-
266-
if genOpts.Temperature != 0 {
267-
lcOpts = append(lcOpts, llms.WithTemperature(genOpts.Temperature))
268-
}
269-
if genOpts.MaxTokens != 0 {
270-
lcOpts = append(lcOpts, llms.WithMaxTokens(genOpts.MaxTokens))
271-
}
272-
if genOpts.TopP != 0 {
273-
lcOpts = append(lcOpts, llms.WithTopP(genOpts.TopP))
274-
}
275-
if genOpts.TopK != 0 {
276-
lcOpts = append(lcOpts, llms.WithTopK(genOpts.TopK))
277-
}
278-
if len(genOpts.StopSequences) > 0 {
279-
lcOpts = append(lcOpts, llms.WithStopWords(genOpts.StopSequences))
280-
}
281-
282-
// Apply config defaults
283-
if genOpts.Temperature == 0 && p.config.Temperature != 0 {
284-
lcOpts = append(lcOpts, llms.WithTemperature(p.config.Temperature))
285-
}
286-
if genOpts.MaxTokens == 0 && p.config.MaxTokens != 0 {
287-
lcOpts = append(lcOpts, llms.WithMaxTokens(p.config.MaxTokens))
288-
}
289-
290-
return lcOpts
291-
}

pkg/llm/provider_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func TestNewProvider_OllamaMissingModel(t *testing.T) {
5151
if err == nil {
5252
t.Fatalf("expected error when ollama model is missing, got nil")
5353
}
54-
if !strings.Contains(err.Error(), "ollama chat model is required") {
54+
if !strings.Contains(err.Error(), "ollama model is required") { //nolint: keep generic check
5555
t.Errorf("unexpected error: %v", err)
5656
}
5757
if p != nil {
@@ -62,7 +62,7 @@ func TestNewProvider_OllamaMissingModel(t *testing.T) {
6262
func TestNewProvider_DefaultOllama(t *testing.T) {
6363
cfg := &config.LLMConfig{
6464
Provider: "", // implicit ollama
65-
OllamaModel: "dummy-model",
65+
OllamaEmbed: "dummy-model",
6666
OllamaBaseURL: "http://localhost:11434",
6767
}
6868

0 commit comments

Comments
 (0)