|
6 | 6 | from functools import partial |
7 | 7 | from typing import Optional |
8 | 8 |
|
| 9 | +import agentlab.llm.tracking as tracking |
9 | 10 | import anthropic |
10 | 11 | import openai |
11 | | -from openai import NOT_GIVEN, OpenAI |
12 | | - |
13 | | -import agentlab.llm.tracking as tracking |
14 | 12 | from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs |
15 | 13 | from agentlab.llm.llm_utils import AIMessage, Discussion |
| 14 | +from openai import NOT_GIVEN, OpenAI |
16 | 15 |
|
17 | 16 |
|
18 | 17 | def make_system_message(content: str) -> dict: |
@@ -90,6 +89,18 @@ def make_model(self): |
90 | 89 | ) |
91 | 90 |
|
92 | 91 |
|
| 92 | +@dataclass |
| 93 | +class LiteLLMModelArgs(BaseModelArgs): |
| 94 | + |
| 95 | + def make_model(self): |
| 96 | + return LiteLLMChatModel( |
| 97 | + model_name=self.model_name, |
| 98 | + temperature=self.temperature, |
| 99 | + max_tokens=self.max_new_tokens, |
| 100 | + log_probs=self.log_probs, |
| 101 | + ) |
| 102 | + |
| 103 | + |
93 | 104 | @dataclass |
94 | 105 | class OpenAIModelArgs(BaseModelArgs): |
95 | 106 | """Serializable object for instantiating a generic chat model with an OpenAI |
@@ -627,3 +638,111 @@ def make_model(self): |
627 | 638 | temperature=self.temperature, |
628 | 639 | max_tokens=self.max_new_tokens, |
629 | 640 | ) |
| 641 | + |
| 642 | + |
| 643 | +class LiteLLMChatModel(AbstractChatModel): |
| 644 | + def __init__( |
| 645 | + self, |
| 646 | + model_name, |
| 647 | + api_key=None, |
| 648 | + temperature=0.5, |
| 649 | + max_tokens=100, |
| 650 | + max_retry=4, |
| 651 | + min_retry_wait_time=60, |
| 652 | + api_key_env_var=None, |
| 653 | + client_class=OpenAI, |
| 654 | + client_args=None, |
| 655 | + pricing_func=None, |
| 656 | + log_probs=False, |
| 657 | + ): |
| 658 | + assert max_retry > 0, "max_retry should be greater than 0" |
| 659 | + |
| 660 | + self.model_name = model_name |
| 661 | + self.temperature = temperature |
| 662 | + self.max_tokens = max_tokens |
| 663 | + self.max_retry = max_retry |
| 664 | + self.min_retry_wait_time = min_retry_wait_time |
| 665 | + self.log_probs = log_probs |
| 666 | + |
| 667 | + # Get pricing information |
| 668 | + if pricing_func: |
| 669 | + pricings = pricing_func() |
| 670 | + try: |
| 671 | + self.input_cost = float(pricings[model_name]["prompt"]) |
| 672 | + self.output_cost = float(pricings[model_name]["completion"]) |
| 673 | + except KeyError: |
| 674 | + logging.warning( |
| 675 | + f"Model {model_name} not found in the pricing information, prices are set to 0. Maybe try upgrading langchain_community." |
| 676 | + ) |
| 677 | + self.input_cost = 0.0 |
| 678 | + self.output_cost = 0.0 |
| 679 | + else: |
| 680 | + self.input_cost = 0.0 |
| 681 | + self.output_cost = 0.0 |
| 682 | + |
| 683 | + def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict: |
| 684 | + from litellm import completion as litellm_completion |
| 685 | + |
| 686 | + # Initialize retry tracking attributes |
| 687 | + self.retries = 0 |
| 688 | + self.success = False |
| 689 | + self.error_types = [] |
| 690 | + |
| 691 | + completion = None |
| 692 | + e = None |
| 693 | + for itr in range(self.max_retry): |
| 694 | + self.retries += 1 |
| 695 | + temperature = temperature if temperature is not None else self.temperature |
| 696 | + try: |
| 697 | + completion = litellm_completion( |
| 698 | + model=self.model_name, |
| 699 | + messages=messages, |
| 700 | + ) |
| 701 | + |
| 702 | + if completion.usage is None: |
| 703 | + raise OpenRouterError( |
| 704 | + "The completion object does not contain usage information. This is likely a bug in the OpenRouter API." |
| 705 | + ) |
| 706 | + |
| 707 | + self.success = True |
| 708 | + break |
| 709 | + except openai.OpenAIError as e: |
| 710 | + error_type = handle_error(e, itr, self.min_retry_wait_time, self.max_retry) |
| 711 | + self.error_types.append(error_type) |
| 712 | + |
| 713 | + if not completion: |
| 714 | + raise RetryError( |
| 715 | + f"Failed to get a response from the API after {self.max_retry} retries\n" |
| 716 | + f"Last error: {error_type}" |
| 717 | + ) |
| 718 | + |
| 719 | + input_tokens = completion.usage.prompt_tokens |
| 720 | + output_tokens = completion.usage.completion_tokens |
| 721 | + cost = input_tokens * self.input_cost + output_tokens * self.output_cost |
| 722 | + |
| 723 | + if hasattr(tracking.TRACKER, "instance") and isinstance( |
| 724 | + tracking.TRACKER.instance, tracking.LLMTracker |
| 725 | + ): |
| 726 | + tracking.TRACKER.instance(input_tokens, output_tokens, cost) |
| 727 | + |
| 728 | + if n_samples == 1: |
| 729 | + res_text = completion.choices[0].message.content |
| 730 | + if res_text is not None: |
| 731 | + res_text = res_text.removesuffix("<|end|>").strip() |
| 732 | + else: |
| 733 | + res_text = "" |
| 734 | + res = AIMessage(res_text) |
| 735 | + if self.log_probs: |
| 736 | + res["log_probs"] = completion.choices[0].log_probs |
| 737 | + return res |
| 738 | + else: |
| 739 | + return [ |
| 740 | + AIMessage(c.message.content.removesuffix("<|end|>").strip()) |
| 741 | + for c in completion.choices |
| 742 | + ] |
| 743 | + |
| 744 | + def get_stats(self): |
| 745 | + return { |
| 746 | + "n_retry_llm": self.retries, |
| 747 | + # "busted_retry_llm": int(not self.success), # not logged if it occurs anyways |
| 748 | + } |
0 commit comments