66from functools import partial
77from typing import Optional
88
9+ import agentlab .llm .tracking as tracking
910import anthropic
1011import openai
11- from openai import NOT_GIVEN , OpenAI
12-
13- import agentlab .llm .tracking as tracking
1412from agentlab .llm .base_api import AbstractChatModel , BaseModelArgs
1513from agentlab .llm .llm_utils import AIMessage , Discussion
14+ from openai import NOT_GIVEN , OpenAI
1615
1716
1817def make_system_message (content : str ) -> dict :
@@ -89,6 +88,7 @@ def make_model(self):
8988 log_probs = self .log_probs ,
9089 )
9190
91+
9292@dataclass
9393class LiteLLMModelArgs (BaseModelArgs ):
9494
@@ -119,9 +119,7 @@ def make_model(self):
119119class AzureModelArgs (BaseModelArgs ):
120120 """Serializable object for instantiating a generic chat model with an Azure model."""
121121
122- deployment_name : str = (
123- None # NOTE: deployment_name is deprecated for Azure OpenAI and won't be used.
124- )
122+ deployment_name : str = None # NOTE: deployment_name is deprecated for Azure OpenAI and won't be used.
125123
126124 def make_model (self ):
127125 return AzureChatModel (
@@ -219,9 +217,7 @@ class RetryError(Exception):
219217def handle_error (error , itr , min_retry_wait_time , max_retry ):
220218 if not isinstance (error , openai .OpenAIError ):
221219 raise error
222- logging .warning (
223- f"Failed to get a response from the API: \n { error } \n " f"Retrying... ({ itr + 1 } /{ max_retry } )"
224- )
220+ logging .warning (f"Failed to get a response from the API: \n { error } \n " f"Retrying... ({ itr + 1 } /{ max_retry } )" )
225221 wait_time = _extract_wait_time (
226222 error .args [0 ],
227223 min_retry_wait_time = min_retry_wait_time ,
@@ -320,18 +316,13 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
320316 self .error_types .append (error_type )
321317
322318 if not completion :
323- raise RetryError (
324- f"Failed to get a response from the API after { self .max_retry } retries\n "
325- f"Last error: { error_type } "
326- )
319+ raise RetryError (f"Failed to get a response from the API after { self .max_retry } retries\n " f"Last error: { error_type } " )
327320
328321 input_tokens = completion .usage .prompt_tokens
329322 output_tokens = completion .usage .completion_tokens
330323 cost = input_tokens * self .input_cost + output_tokens * self .output_cost
331324
332- if hasattr (tracking .TRACKER , "instance" ) and isinstance (
333- tracking .TRACKER .instance , tracking .LLMTracker
334- ):
325+ if hasattr (tracking .TRACKER , "instance" ) and isinstance (tracking .TRACKER .instance , tracking .LLMTracker ):
335326 tracking .TRACKER .instance (input_tokens , output_tokens , cost )
336327
337328 if n_samples == 1 :
@@ -404,6 +395,7 @@ def __init__(
404395 log_probs = log_probs ,
405396 )
406397
398+
407399class AzureChatModel (ChatModel ):
408400 def __init__ (
409401 self ,
@@ -417,18 +409,12 @@ def __init__(
417409 log_probs = False ,
418410 ):
419411 api_key = api_key or os .getenv ("AZURE_OPENAI_API_KEY" )
420- assert (
421- api_key
422- ), "AZURE_OPENAI_API_KEY has to be defined in the environment when using AzureChatModel"
412+ assert api_key , "AZURE_OPENAI_API_KEY has to be defined in the environment when using AzureChatModel"
423413 endpoint = os .getenv ("AZURE_OPENAI_ENDPOINT" )
424- assert (
425- endpoint
426- ), "AZURE_OPENAI_ENDPOINT has to be defined in the environment when using AzureChatModel"
414+ assert endpoint , "AZURE_OPENAI_ENDPOINT has to be defined in the environment when using AzureChatModel"
427415
428416 if deployment_name is not None :
429- logging .info (
430- f"Deployment name is deprecated for Azure OpenAI and won't be used. Using model name: { model_name } ."
431- )
417+ logging .info (f"Deployment name is deprecated for Azure OpenAI and won't be used. Using model name: { model_name } ." )
432418
433419 client_args = {
434420 "base_url" : endpoint ,
@@ -560,12 +546,8 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
560546 output_tokens = getattr (usage , "output_tokens" , 0 )
561547 cache_read_tokens = getattr (usage , "cache_input_tokens" , 0 )
562548 cache_write_tokens = getattr (usage , "cache_creation_input_tokens" , 0 )
563- cache_read_cost = (
564- self .input_cost * tracking .ANTHROPIC_CACHE_PRICING_FACTOR ["cache_read_tokens" ]
565- )
566- cache_write_cost = (
567- self .input_cost * tracking .ANTHROPIC_CACHE_PRICING_FACTOR ["cache_write_tokens" ]
568- )
549+ cache_read_cost = self .input_cost * tracking .ANTHROPIC_CACHE_PRICING_FACTOR ["cache_read_tokens" ]
550+ cache_write_cost = self .input_cost * tracking .ANTHROPIC_CACHE_PRICING_FACTOR ["cache_write_tokens" ]
569551 cost = (
570552 new_input_tokens * self .input_cost
571553 + output_tokens * self .output_cost
@@ -574,9 +556,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
574556 )
575557
576558 # Track usage if available
577- if hasattr (tracking .TRACKER , "instance" ) and isinstance (
578- tracking .TRACKER .instance , tracking .LLMTracker
579- ):
559+ if hasattr (tracking .TRACKER , "instance" ) and isinstance (tracking .TRACKER .instance , tracking .LLMTracker ):
580560 tracking .TRACKER .instance (new_input_tokens , output_tokens , cost )
581561
582562 return AIMessage (response .content [0 ].text )
@@ -613,14 +593,8 @@ def __init__(
613593 self .max_tokens = max_tokens
614594 self .max_retry = max_retry
615595
616- if (
617- not os .getenv ("AWS_REGION" )
618- or not os .getenv ("AWS_ACCESS_KEY" )
619- or not os .getenv ("AWS_SECRET_KEY" )
620- ):
621- raise ValueError (
622- "AWS_REGION, AWS_ACCESS_KEY and AWS_SECRET_KEY must be set in the environment when using BedrockChatModel"
623- )
596+ if not os .getenv ("AWS_REGION" ) or not os .getenv ("AWS_ACCESS_KEY" ) or not os .getenv ("AWS_SECRET_KEY" ):
597+ raise ValueError ("AWS_REGION, AWS_ACCESS_KEY and AWS_SECRET_KEY must be set in the environment when using BedrockChatModel" )
624598
625599 self .client = anthropic .AnthropicBedrock (
626600 aws_region = os .getenv ("AWS_REGION" ),
@@ -638,6 +612,7 @@ def make_model(self):
638612 max_tokens = self .max_new_tokens ,
639613 )
640614
615+
641616class LiteLLMChatModel (AbstractChatModel ):
642617 def __init__ (
643618 self ,
@@ -661,7 +636,6 @@ def __init__(
661636 self .max_retry = max_retry
662637 self .min_retry_wait_time = min_retry_wait_time
663638 self .log_probs = log_probs
664- self .reasoning_effort = reasoning_effort
665639
666640 # Get pricing information
667641 if pricing_func :
@@ -679,9 +653,9 @@ def __init__(
679653 self .input_cost = 0.0
680654 self .output_cost = 0.0
681655
682-
683656 def __call__ (self , messages : list [dict ], n_samples : int = 1 , temperature : float = None ) -> dict :
684657 from litellm import completion as litellm_completion
658+
685659 # Initialize retry tracking attributes
686660 self .retries = 0
687661 self .success = False
@@ -696,10 +670,6 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
696670 completion = litellm_completion (
697671 model = self .model_name ,
698672 messages = messages ,
699- # n=n_samples,
700- # temperature=temperature,
701- # max_completion_tokens=self.max_tokens,
702- reasoning_effort = self .reasoning_effort ,
703673 )
704674
705675 if completion .usage is None :
@@ -714,18 +684,13 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
714684 self .error_types .append (error_type )
715685
716686 if not completion :
717- raise RetryError (
718- f"Failed to get a response from the API after { self .max_retry } retries\n "
719- f"Last error: { error_type } "
720- )
687+ raise RetryError (f"Failed to get a response from the API after { self .max_retry } retries\n " f"Last error: { error_type } " )
721688
722689 input_tokens = completion .usage .prompt_tokens
723690 output_tokens = completion .usage .completion_tokens
724691 cost = input_tokens * self .input_cost + output_tokens * self .output_cost
725692
726- if hasattr (tracking .TRACKER , "instance" ) and isinstance (
727- tracking .TRACKER .instance , tracking .LLMTracker
728- ):
693+ if hasattr (tracking .TRACKER , "instance" ) and isinstance (tracking .TRACKER .instance , tracking .LLMTracker ):
729694 tracking .TRACKER .instance (input_tokens , output_tokens , cost )
730695
731696 if n_samples == 1 :
@@ -745,4 +710,4 @@ def get_stats(self):
745710 return {
746711 "n_retry_llm" : self .retries ,
747712 # "busted_retry_llm": int(not self.success), # not logged if it occurs anyways
748- }
713+ }
0 commit comments