Skip to content

Commit ebca985

Browse files
LiteLLMChatModel for GenericAgent (#325)
* add litellm chat model for genericagent * remove added models to llm config * format * format
1 parent 519abed commit ebca985

File tree

2 files changed

+124
-5
lines changed

2 files changed

+124
-5
lines changed

src/agentlab/llm/chat_api.py

Lines changed: 122 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
from functools import partial
77
from typing import Optional
88

9+
import agentlab.llm.tracking as tracking
910
import anthropic
1011
import openai
11-
from openai import NOT_GIVEN, OpenAI
12-
13-
import agentlab.llm.tracking as tracking
1412
from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs
1513
from agentlab.llm.llm_utils import AIMessage, Discussion
14+
from openai import NOT_GIVEN, OpenAI
1615

1716

1817
def make_system_message(content: str) -> dict:
@@ -90,6 +89,18 @@ def make_model(self):
9089
)
9190

9291

92+
@dataclass
93+
class LiteLLMModelArgs(BaseModelArgs):
94+
95+
def make_model(self):
96+
return LiteLLMChatModel(
97+
model_name=self.model_name,
98+
temperature=self.temperature,
99+
max_tokens=self.max_new_tokens,
100+
log_probs=self.log_probs,
101+
)
102+
103+
93104
@dataclass
94105
class OpenAIModelArgs(BaseModelArgs):
95106
"""Serializable object for instantiating a generic chat model with an OpenAI
@@ -627,3 +638,111 @@ def make_model(self):
627638
temperature=self.temperature,
628639
max_tokens=self.max_new_tokens,
629640
)
641+
642+
643+
class LiteLLMChatModel(AbstractChatModel):
644+
def __init__(
645+
self,
646+
model_name,
647+
api_key=None,
648+
temperature=0.5,
649+
max_tokens=100,
650+
max_retry=4,
651+
min_retry_wait_time=60,
652+
api_key_env_var=None,
653+
client_class=OpenAI,
654+
client_args=None,
655+
pricing_func=None,
656+
log_probs=False,
657+
):
658+
assert max_retry > 0, "max_retry should be greater than 0"
659+
660+
self.model_name = model_name
661+
self.temperature = temperature
662+
self.max_tokens = max_tokens
663+
self.max_retry = max_retry
664+
self.min_retry_wait_time = min_retry_wait_time
665+
self.log_probs = log_probs
666+
667+
# Get pricing information
668+
if pricing_func:
669+
pricings = pricing_func()
670+
try:
671+
self.input_cost = float(pricings[model_name]["prompt"])
672+
self.output_cost = float(pricings[model_name]["completion"])
673+
except KeyError:
674+
logging.warning(
675+
f"Model {model_name} not found in the pricing information, prices are set to 0. Maybe try upgrading langchain_community."
676+
)
677+
self.input_cost = 0.0
678+
self.output_cost = 0.0
679+
else:
680+
self.input_cost = 0.0
681+
self.output_cost = 0.0
682+
683+
def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict:
684+
from litellm import completion as litellm_completion
685+
686+
# Initialize retry tracking attributes
687+
self.retries = 0
688+
self.success = False
689+
self.error_types = []
690+
691+
completion = None
692+
e = None
693+
for itr in range(self.max_retry):
694+
self.retries += 1
695+
temperature = temperature if temperature is not None else self.temperature
696+
try:
697+
completion = litellm_completion(
698+
model=self.model_name,
699+
messages=messages,
700+
)
701+
702+
if completion.usage is None:
703+
raise OpenRouterError(
704+
"The completion object does not contain usage information. This is likely a bug in the OpenRouter API."
705+
)
706+
707+
self.success = True
708+
break
709+
except openai.OpenAIError as e:
710+
error_type = handle_error(e, itr, self.min_retry_wait_time, self.max_retry)
711+
self.error_types.append(error_type)
712+
713+
if not completion:
714+
raise RetryError(
715+
f"Failed to get a response from the API after {self.max_retry} retries\n"
716+
f"Last error: {error_type}"
717+
)
718+
719+
input_tokens = completion.usage.prompt_tokens
720+
output_tokens = completion.usage.completion_tokens
721+
cost = input_tokens * self.input_cost + output_tokens * self.output_cost
722+
723+
if hasattr(tracking.TRACKER, "instance") and isinstance(
724+
tracking.TRACKER.instance, tracking.LLMTracker
725+
):
726+
tracking.TRACKER.instance(input_tokens, output_tokens, cost)
727+
728+
if n_samples == 1:
729+
res_text = completion.choices[0].message.content
730+
if res_text is not None:
731+
res_text = res_text.removesuffix("<|end|>").strip()
732+
else:
733+
res_text = ""
734+
res = AIMessage(res_text)
735+
if self.log_probs:
736+
res["log_probs"] = completion.choices[0].log_probs
737+
return res
738+
else:
739+
return [
740+
AIMessage(c.message.content.removesuffix("<|end|>").strip())
741+
for c in completion.choices
742+
]
743+
744+
def get_stats(self):
745+
return {
746+
"n_retry_llm": self.retries,
747+
# "busted_retry_llm": int(not self.success), # not logged if it occurs anyways
748+
}

src/agentlab/llm/llm_configs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from openai import NOT_GIVEN
2-
31
from agentlab.llm.chat_api import (
42
AnthropicModelArgs,
53
AzureModelArgs,
64
BedrockModelArgs,
5+
LiteLLMModelArgs,
76
OpenAIModelArgs,
87
OpenRouterModelArgs,
98
SelfHostedModelArgs,
109
)
10+
from openai import NOT_GIVEN
1111

1212
default_oss_llms_args = {
1313
"n_retry_server": 4,

0 commit comments

Comments
 (0)