Skip to content

Commit 0778134

Browse files
authored
Merge pull request #69 from Jeerhz/adle/debut
Feat: Add Vision Model to the Game
2 parents 892729a + 8c35394 commit 0778134

File tree

8 files changed

+295
-97
lines changed

8 files changed

+295
-97
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,6 @@ results.csv
99
results/
1010

1111
.DS_Store
12-
env
12+
env
13+
14+
adle_notebook.ipynb

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ We send to the LLM a text description of the screen. The LLM decide on the next
7171
# Installation
7272

7373
- Follow instructions in https://docs.diambra.ai/#installation
74-
- Download the ROM and put it in `~/.diambra/roms`
74+
- Download the ROM and put it in `~/.diambra/roms` (no need to dezip the content)
7575
- (Optional) Create and activate a [new python venv](https://docs.python.org/3/library/venv.html)
7676
- Install dependencies with `make install` or `pip install -r requirements.txt`
7777
- Create a `.env` file and fill it with the content like in the `.env.example` file

agent/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,2 @@
1-
# load env variables before importing any other module
2-
from dotenv import load_dotenv
3-
4-
load_dotenv()
5-
6-
from .robot import Robot
1+
from .robot import TextRobot, VisionRobot
72
from .observer import KEN_GREEN, KEN_RED
8-
from .llm import get_client

agent/llm.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from llama_index.core.llms.function_calling import FunctionCallingLLM
2+
from llama_index.core.multi_modal_llms.base import MultiModalLLM
23

34

45
def get_client(model_str: str) -> FunctionCallingLLM:
@@ -45,4 +46,36 @@ def get_client(model_str: str) -> FunctionCallingLLM:
4546

4647
return Cerebras(model=model_name)
4748

48-
raise ValueError(f"Provider {provider} not found")
49+
raise ValueError(f"Provider {provider} not found in models")
50+
51+
52+
def get_client_multimodal(model_str: str) -> MultiModalLLM:
53+
split_result = model_str.split(":")
54+
if len(split_result) == 1:
55+
# Assume default provider to be openai
56+
provider = "ollama"
57+
model_name = split_result[0]
58+
elif len(split_result) > 2:
59+
# Some model names have :, so we need to join the rest of the string
60+
provider = split_result[0]
61+
model_name = ":".join(split_result[1:])
62+
else:
63+
provider = split_result[0]
64+
model_name = split_result[1]
65+
66+
if provider == "openai":
67+
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
68+
69+
return OpenAIMultiModal(model=model_name)
70+
71+
if provider == "ollama":
72+
from llama_index.multi_modal_llms.ollama import OllamaMultiModal
73+
74+
return OllamaMultiModal(model=model_name)
75+
76+
elif provider == "mistral":
77+
from llama_index.multi_modal_llms.mistralai import MistralAIMultiModal
78+
79+
return MistralAIMultiModal(model=model_name)
80+
81+
raise ValueError(f"Provider {provider} not found in multimodal models")

0 commit comments

Comments
 (0)