|
1 | 1 | from llama_index.core.llms.function_calling import FunctionCallingLLM |
| 2 | +from llama_index.core.multi_modal_llms.base import MultiModalLLM |
2 | 3 |
|
3 | 4 |
|
4 | 5 | def get_client(model_str: str) -> FunctionCallingLLM: |
@@ -45,4 +46,36 @@ def get_client(model_str: str) -> FunctionCallingLLM: |
45 | 46 |
|
46 | 47 | return Cerebras(model=model_name) |
47 | 48 |
|
48 | | - raise ValueError(f"Provider {provider} not found") |
| 49 | + raise ValueError(f"Provider {provider} not found in models") |
| 50 | + |
| 51 | + |
| 52 | +def get_client_multimodal(model_str: str) -> MultiModalLLM: |
| 53 | + split_result = model_str.split(":") |
| 54 | + if len(split_result) == 1: |
| 55 | + # Assume default provider to be openai |
| 56 | + provider = "ollama" |
| 57 | + model_name = split_result[0] |
| 58 | + elif len(split_result) > 2: |
| 59 | + # Some model names have :, so we need to join the rest of the string |
| 60 | + provider = split_result[0] |
| 61 | + model_name = ":".join(split_result[1:]) |
| 62 | + else: |
| 63 | + provider = split_result[0] |
| 64 | + model_name = split_result[1] |
| 65 | + |
| 66 | + if provider == "openai": |
| 67 | + from llama_index.multi_modal_llms.openai import OpenAIMultiModal |
| 68 | + |
| 69 | + return OpenAIMultiModal(model=model_name) |
| 70 | + |
| 71 | + if provider == "ollama": |
| 72 | + from llama_index.multi_modal_llms.ollama import OllamaMultiModal |
| 73 | + |
| 74 | + return OllamaMultiModal(model=model_name) |
| 75 | + |
| 76 | + elif provider == "mistral": |
| 77 | + from llama_index.multi_modal_llms.mistralai import MistralAIMultiModal |
| 78 | + |
| 79 | + return MistralAIMultiModal(model=model_name) |
| 80 | + |
| 81 | + raise ValueError(f"Provider {provider} not found in multimodal models") |
0 commit comments