11"""
22Module related to using LLMs.
33"""
4- import os
4+ import os , sys
55import warnings
66# TODO: Remove after https://github.com/BerriAI/litellm/issues/7560 is fixed
77warnings .filterwarnings ("ignore" , category = UserWarning , module = "pydantic._internal._config" )
88warnings .filterwarnings ("ignore" , category = UserWarning , module = "pydantic" )
99import litellm
1010import json
1111import time
12+ import threading
1213import traceback
1314import inspect
1415import docstring_parser
1819from typing import Optional , Dict , List , Union , Tuple , Callable , get_args , get_origin
1920from copy import deepcopy
2021
21- from litellm import completion , completion_cost , token_counter
22+ from litellm import completion , completion_cost , token_counter , cost_per_token
2223from litellm .utils import get_model_info , get_supported_openai_params , supports_response_schema
2324from litellm .utils import supports_function_calling , supports_parallel_function_calling
2425from litellm ._logging import _enable_debugging as litellm_enable_debugging
@@ -740,12 +741,14 @@ def cleaned_args(args: dict):
740741 if via_streaming :
741742 # TODO: check if model supports streaming
742743 completion_kwargs ["stream" ] = True
744+ completion_kwargs ["stream_options" ] = {"include_usage" : True }
743745 logger .debug (f"completion kwargs after detecting via_streaming: { cleaned_args (completion_kwargs )} " )
744746 elif stream :
745747 # TODO: check if model supports streaming
746748 # if streaming is enabled, we always return the original response
747749 return_response = True
748750 completion_kwargs ["stream" ] = True
751+ completion_kwargs ["stream_options" ] = {"include_usage" : True }
749752 logger .debug (f"completion kwargs after detecting stream: { cleaned_args (completion_kwargs )} " )
750753 ret = {}
751754 # before adding the kwargs, save the recursive_call_info and remove it from kwargs
@@ -773,6 +776,20 @@ def cleaned_args(args: dict):
773776 else :
774777 start = time .time ()
775778 recursive_call_info ["start" ] = start
779+ callback_data = {}
780+ # NOTE: this does not seem to work!?!?!?
781+ callback_complete = threading .Event ()
782+ def cost_callback (kwargs , completion_response , start_time , end_time ):
783+ cost = kwargs .get ("response_cost" )
784+ if cost is not None :
785+ callback_data ["cost" ] = cost
786+ callback_complete .set ()
787+ if hasattr (completion_response , "usage" ):
788+ callback_data ["prompt_tokens" ] = completion_response .usage .prompt_tokens
789+ callback_data ["completion_tokens" ] = completion_response .usage .completion_tokens
790+ callback_data ["total_tokens" ] = completion_response .usage .prompt_tokens + completion_response .usage .completion_tokens
791+ litellm .success_callback = [cost_callback ]
792+ logger .info (f"completion kwargs: { completion_kwargs } " )
776793 response = litellm .completion (
777794 model = llm ["llm" ],
778795 messages = messages ,
@@ -786,25 +803,40 @@ def cleaned_args(args: dict):
786803 logger .debug (f"Retrieving chunks ..." )
787804 n_chunks = 0
788805 for chunk in response :
789- choice0 = chunk ["choices" ][0 ]
806+ # TODO: this should work, but does not!!!!
807+ if hasattr (chunk , "usage" ):
808+ if chunk .usage :
809+ prompt_cost , completion_cost_value = cost_per_token (
810+ model = llm ["llm" ],
811+ prompt_tokens = chunk .usage .prompt_tokens ,
812+ completion_tokens = chunk .usage .completion_tokens
813+ )
814+ cost = prompt_cost + completion_cost_value
815+ else :
816+ pass
817+ # logger.info(f"!!!!DEBUG: chunk attributes: {dir(chunk)}")
818+ choice0 = chunk .choices [0 ]
790819 if choice0 .finish_reason == "stop" :
791820 logger .debug (f"Streaming got stop. Chunk { chunk } " )
792- break
821+ # break
793822 n_chunks += 1
794823 content = choice0 ["delta" ].get ("content" , "" )
795824 logger .debug (f"Got streaming content: { content } " )
796- answer += content
825+ if isinstance (content , str ):
826+ answer += content
797827 if return_response :
798828 ret ["response" ] = response
799829 ret ["answer" ] = answer
800830 ret ["n_chunks" ] = n_chunks
801831 ret ["elapsed_time" ] = time .time () - start
802832 ret ["ok" ] = True
803833 ret ["error" ] = ""
804- # TODO: for now return 0, may perhaps be possible to do better?
805- ret ["cost" ] = 0
806- ret ["n_prompt_tokens" ] = 0
807- ret ["n_completion_tokens" ] = 0
834+ # get the cost from the callback
835+ if not callback_complete .is_set ():
836+ callback_complete .wait (timeout = 2.0 )
837+ ret ["cost" ] = callback_data .get ("cost" )
838+ ret ["n_prompt_tokens" ] = callback_data .get ("prompt_tokens" )
839+ ret ["n_completion_tokens" ] = callback_data .get ("completion_tokens" )
808840 return ret
809841 except Exception as e :
810842 tb = traceback .extract_tb (e .__traceback__ )
@@ -827,12 +859,14 @@ def chunk_generator(model_generator, retobj):
827859 except Exception as e :
828860 yield dict (error = str (e ), answer = "" , ok = False )
829861 finally :
830- # TODO: add cost and elapsed time information into retobj
831- # litellm does not support cost on streaming responses
832- # response.__hidden_params["response_cost"] is 0.0
833- ret ["cost" ] = None
862+ # get the cost from the callback
863+ if not callback_complete .is_set ():
864+ logger .info (f"Waiting for callback completion" )
865+ callback_complete .wait (timeout = 2.0 )
866+ ret ["cost" ] = callback_data .get ("cost" )
867+ ret ["n_prompt_tokens" ] = callback_data .get ("prompt_tokens" )
868+ ret ["n_completion_tokens" ] = callback_data .get ("completion_tokens" )
834869 ret ["elapsed_time" ] = time .time () - start
835- pass
836870 if return_response :
837871 ret ["response" ] = response
838872 ret ["chunks" ] = chunk_generator (response , ret )
@@ -851,15 +885,20 @@ def chunk_generator(model_generator, retobj):
851885 del completion_kwargs ["api_key" ]
852886 ret ["kwargs" ] = completion_kwargs
853887 if return_cost :
854- # TODO: replace with response._hidden_params["response_cost"] ?
855- # but what if cost not supported for the model?
856-
857888 try :
858- ret ["cost" ] = completion_cost (
859- completion_response = response ,
860- model = llm ["llm" ],
861- messages = messages ,
862- )
889+ cost = response ._hidden_params .get ("response_cost" , None )
890+ logger .info (f"DEBUG: cost from hidden parms: { cost } " )
891+ # if cost is None:
892+ if True :
893+ cost = completion_cost (
894+ completion_response = response ,
895+ model = llm ["llm" ],
896+ messages = messages ,
897+ )
898+ ret ["cost" ] = cost
899+ logger .info (f"DEBUG: cost from completion_cost: { cost } " )
900+ else :
901+ ret ["cost" ] = cost
863902 if debug :
864903 logger .debug (f"Cost for this call { ret ['cost' ]} " )
865904 except Exception as e :
0 commit comments