Skip to content

Commit f7a0f49

Browse files
format
1 parent 05795d0 commit f7a0f49

File tree

2 files changed

+23
-59
lines changed

2 files changed

+23
-59
lines changed

src/agentlab/llm/chat_api.py

Lines changed: 21 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
from functools import partial
77
from typing import Optional
88

9+
import agentlab.llm.tracking as tracking
910
import anthropic
1011
import openai
11-
from openai import NOT_GIVEN, OpenAI
12-
13-
import agentlab.llm.tracking as tracking
1412
from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs
1513
from agentlab.llm.llm_utils import AIMessage, Discussion
14+
from openai import NOT_GIVEN, OpenAI
1615

1716

1817
def make_system_message(content: str) -> dict:
@@ -89,6 +88,7 @@ def make_model(self):
8988
log_probs=self.log_probs,
9089
)
9190

91+
9292
@dataclass
9393
class LiteLLMModelArgs(BaseModelArgs):
9494

@@ -119,9 +119,7 @@ def make_model(self):
119119
class AzureModelArgs(BaseModelArgs):
120120
"""Serializable object for instantiating a generic chat model with an Azure model."""
121121

122-
deployment_name: str = (
123-
None # NOTE: deployment_name is deprecated for Azure OpenAI and won't be used.
124-
)
122+
deployment_name: str = None # NOTE: deployment_name is deprecated for Azure OpenAI and won't be used.
125123

126124
def make_model(self):
127125
return AzureChatModel(
@@ -219,9 +217,7 @@ class RetryError(Exception):
219217
def handle_error(error, itr, min_retry_wait_time, max_retry):
220218
if not isinstance(error, openai.OpenAIError):
221219
raise error
222-
logging.warning(
223-
f"Failed to get a response from the API: \n{error}\n" f"Retrying... ({itr+1}/{max_retry})"
224-
)
220+
logging.warning(f"Failed to get a response from the API: \n{error}\n" f"Retrying... ({itr+1}/{max_retry})")
225221
wait_time = _extract_wait_time(
226222
error.args[0],
227223
min_retry_wait_time=min_retry_wait_time,
@@ -320,18 +316,13 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
320316
self.error_types.append(error_type)
321317

322318
if not completion:
323-
raise RetryError(
324-
f"Failed to get a response from the API after {self.max_retry} retries\n"
325-
f"Last error: {error_type}"
326-
)
319+
raise RetryError(f"Failed to get a response from the API after {self.max_retry} retries\n" f"Last error: {error_type}")
327320

328321
input_tokens = completion.usage.prompt_tokens
329322
output_tokens = completion.usage.completion_tokens
330323
cost = input_tokens * self.input_cost + output_tokens * self.output_cost
331324

332-
if hasattr(tracking.TRACKER, "instance") and isinstance(
333-
tracking.TRACKER.instance, tracking.LLMTracker
334-
):
325+
if hasattr(tracking.TRACKER, "instance") and isinstance(tracking.TRACKER.instance, tracking.LLMTracker):
335326
tracking.TRACKER.instance(input_tokens, output_tokens, cost)
336327

337328
if n_samples == 1:
@@ -404,6 +395,7 @@ def __init__(
404395
log_probs=log_probs,
405396
)
406397

398+
407399
class AzureChatModel(ChatModel):
408400
def __init__(
409401
self,
@@ -417,18 +409,12 @@ def __init__(
417409
log_probs=False,
418410
):
419411
api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
420-
assert (
421-
api_key
422-
), "AZURE_OPENAI_API_KEY has to be defined in the environment when using AzureChatModel"
412+
assert api_key, "AZURE_OPENAI_API_KEY has to be defined in the environment when using AzureChatModel"
423413
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
424-
assert (
425-
endpoint
426-
), "AZURE_OPENAI_ENDPOINT has to be defined in the environment when using AzureChatModel"
414+
assert endpoint, "AZURE_OPENAI_ENDPOINT has to be defined in the environment when using AzureChatModel"
427415

428416
if deployment_name is not None:
429-
logging.info(
430-
f"Deployment name is deprecated for Azure OpenAI and won't be used. Using model name: {model_name}."
431-
)
417+
logging.info(f"Deployment name is deprecated for Azure OpenAI and won't be used. Using model name: {model_name}.")
432418

433419
client_args = {
434420
"base_url": endpoint,
@@ -560,12 +546,8 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
560546
output_tokens = getattr(usage, "output_tokens", 0)
561547
cache_read_tokens = getattr(usage, "cache_input_tokens", 0)
562548
cache_write_tokens = getattr(usage, "cache_creation_input_tokens", 0)
563-
cache_read_cost = (
564-
self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_read_tokens"]
565-
)
566-
cache_write_cost = (
567-
self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_write_tokens"]
568-
)
549+
cache_read_cost = self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_read_tokens"]
550+
cache_write_cost = self.input_cost * tracking.ANTHROPIC_CACHE_PRICING_FACTOR["cache_write_tokens"]
569551
cost = (
570552
new_input_tokens * self.input_cost
571553
+ output_tokens * self.output_cost
@@ -574,9 +556,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
574556
)
575557

576558
# Track usage if available
577-
if hasattr(tracking.TRACKER, "instance") and isinstance(
578-
tracking.TRACKER.instance, tracking.LLMTracker
579-
):
559+
if hasattr(tracking.TRACKER, "instance") and isinstance(tracking.TRACKER.instance, tracking.LLMTracker):
580560
tracking.TRACKER.instance(new_input_tokens, output_tokens, cost)
581561

582562
return AIMessage(response.content[0].text)
@@ -613,14 +593,8 @@ def __init__(
613593
self.max_tokens = max_tokens
614594
self.max_retry = max_retry
615595

616-
if (
617-
not os.getenv("AWS_REGION")
618-
or not os.getenv("AWS_ACCESS_KEY")
619-
or not os.getenv("AWS_SECRET_KEY")
620-
):
621-
raise ValueError(
622-
"AWS_REGION, AWS_ACCESS_KEY and AWS_SECRET_KEY must be set in the environment when using BedrockChatModel"
623-
)
596+
if not os.getenv("AWS_REGION") or not os.getenv("AWS_ACCESS_KEY") or not os.getenv("AWS_SECRET_KEY"):
597+
raise ValueError("AWS_REGION, AWS_ACCESS_KEY and AWS_SECRET_KEY must be set in the environment when using BedrockChatModel")
624598

625599
self.client = anthropic.AnthropicBedrock(
626600
aws_region=os.getenv("AWS_REGION"),
@@ -638,6 +612,7 @@ def make_model(self):
638612
max_tokens=self.max_new_tokens,
639613
)
640614

615+
641616
class LiteLLMChatModel(AbstractChatModel):
642617
def __init__(
643618
self,
@@ -661,7 +636,6 @@ def __init__(
661636
self.max_retry = max_retry
662637
self.min_retry_wait_time = min_retry_wait_time
663638
self.log_probs = log_probs
664-
self.reasoning_effort = reasoning_effort
665639

666640
# Get pricing information
667641
if pricing_func:
@@ -679,9 +653,9 @@ def __init__(
679653
self.input_cost = 0.0
680654
self.output_cost = 0.0
681655

682-
683656
def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict:
684657
from litellm import completion as litellm_completion
658+
685659
# Initialize retry tracking attributes
686660
self.retries = 0
687661
self.success = False
@@ -696,10 +670,6 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
696670
completion = litellm_completion(
697671
model=self.model_name,
698672
messages=messages,
699-
# n=n_samples,
700-
# temperature=temperature,
701-
# max_completion_tokens=self.max_tokens,
702-
reasoning_effort=self.reasoning_effort,
703673
)
704674

705675
if completion.usage is None:
@@ -714,18 +684,13 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
714684
self.error_types.append(error_type)
715685

716686
if not completion:
717-
raise RetryError(
718-
f"Failed to get a response from the API after {self.max_retry} retries\n"
719-
f"Last error: {error_type}"
720-
)
687+
raise RetryError(f"Failed to get a response from the API after {self.max_retry} retries\n" f"Last error: {error_type}")
721688

722689
input_tokens = completion.usage.prompt_tokens
723690
output_tokens = completion.usage.completion_tokens
724691
cost = input_tokens * self.input_cost + output_tokens * self.output_cost
725692

726-
if hasattr(tracking.TRACKER, "instance") and isinstance(
727-
tracking.TRACKER.instance, tracking.LLMTracker
728-
):
693+
if hasattr(tracking.TRACKER, "instance") and isinstance(tracking.TRACKER.instance, tracking.LLMTracker):
729694
tracking.TRACKER.instance(input_tokens, output_tokens, cost)
730695

731696
if n_samples == 1:
@@ -745,4 +710,4 @@ def get_stats(self):
745710
return {
746711
"n_retry_llm": self.retries,
747712
# "busted_retry_llm": int(not self.success), # not logged if it occurs anyways
748-
}
713+
}

src/agentlab/llm/llm_configs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
from openai import NOT_GIVEN
2-
31
from agentlab.llm.chat_api import (
42
AnthropicModelArgs,
53
AzureModelArgs,
64
BedrockModelArgs,
5+
LiteLLMModelArgs,
76
OpenAIModelArgs,
87
OpenRouterModelArgs,
98
SelfHostedModelArgs,
10-
LiteLLMModelArgs,
119
)
10+
from openai import NOT_GIVEN
1211

1312
default_oss_llms_args = {
1413
"n_retry_server": 4,

0 commit comments

Comments
 (0)