66# TODO: Remove after https://github.com/BerriAI/litellm/issues/7560 is fixed
77warnings .filterwarnings ("ignore" , category = UserWarning , module = "pydantic._internal._config" )
88import litellm
9+ import json
910import time
1011import traceback
1112import inspect
4344]
4445
4546
47+ def toolnames2funcs (tools ):
48+ fmap = {}
49+ for tool in tools :
50+ name = tool ["function" ]["name" ]
51+ func = get_func_by_name (name )
52+ if func is None :
53+ raise Exception (f"Function { name } not found" )
54+ fmap [name ] = func
55+ return fmap
56+
57+ def get_func_by_name (name ):
58+ # Walk up the call stack
59+ for frame_info in inspect .stack ():
60+ frame = frame_info .frame
61+ func = frame .f_locals .get (name ) or frame .f_globals .get (name )
62+ if callable (func ):
63+ return func
64+ return None # Not found
65+
4666class LLMS :
4767 """
4868 Class that represents a preconfigured set of large language modelservices.
@@ -420,6 +440,7 @@ def query(
420440 return_response : bool = False ,
421441 debug = False ,
422442 litellm_debug = None ,
443+ recursive_call_info : Optional [Dict [str , any ]] = None ,
423444 ** kwargs ,
424445 ) -> Dict [str , any ]:
425446 """
@@ -457,19 +478,38 @@ def query(
457478 KNOWN_LLM_CONFIG_FIELDS ,
458479 ignore_underscored = True ,
459480 )
481+ if recursive_call_info is None :
482+ recursive_call_info = {}
460483 if llm .get ("api_key" ):
461484 completion_kwargs ["api_key" ] = llm ["api_key" ]
462485 elif llm .get ("api_key_env" ):
463486 completion_kwargs ["api_key" ] = os .getenv (llm ["api_key_env" ])
464487 if llm .get ("api_url" ):
465488 completion_kwargs ["api_base" ] = llm ["api_url" ]
466489 if tools is not None :
490+ # add tooling-related arguments to completion_kwargs
467491 completion_kwargs ["tools" ] = tools
492+ if not self .supports_function_calling (llmalias ):
493+ # see https://docs.litellm.ai/docs/completion/function_call#function-calling-for-models-wout-function-calling-support
494+ litellm .add_function_to_prompt = True
495+ else :
496+ if "tool_choice" not in completion_kwargs :
497+ # this is the default, but lets be explicit
498+ completion_kwargs ["tool_choice" ] = "auto"
499+ # Not known/supported by litellm, apparently
500+ # if "parallel_tool_choice" not in completion_kwargs:
501+ # completion_kwargs["parallel_tool_choice"] = True
502+ fmap = toolnames2funcs (tools )
503+ else :
504+ fmap = {}
468505 ret = {}
506+ # before adding the kwargs, save the recursive_call_info and remove it from kwargs
507+ if debug :
508+ print (f"DEBUG: Received recursive call info: { recursive_call_info } " )
469509 if kwargs :
470510 completion_kwargs .update (dict_except (kwargs , KNOWN_LLM_CONFIG_FIELDS , ignore_underscored = True ))
471511 if debug :
472- logger . debug (f"Calling completion with kwargs { completion_kwargs } " )
512+ print (f"DEBUG: Calling completion with kwargs { completion_kwargs } " )
473513 # if we have min_delay set, we look at the _last_request_time for the LLM and caclulate the time
474514 # to wait until we can send the next request and then just wait
475515 min_delay = llm .get ("min_delay" , kwargs .get ("min_delay" , 0.0 ))
@@ -481,7 +521,13 @@ def query(
481521 if "min_delay" in completion_kwargs :
482522 raise ValueError ("Error: min_delay should not be passed as a keyword argument" )
483523 try :
484- start = time .time ()
524+ # if we have been called recursively and the recursive_call_info has a start time,
525+ # use that as the start time
526+ if recursive_call_info .get ("start" ) is not None :
527+ start = recursive_call_info ["start" ]
528+ else :
529+ start = time .time ()
530+ recursive_call_info ["start" ] = start
485531 response = litellm .completion (
486532 model = llm ["llm" ],
487533 messages = messages ,
@@ -503,6 +549,8 @@ def query(
503549 model = llm ["llm" ],
504550 messages = messages ,
505551 )
552+ if debug :
553+ print (f"DEBUG: cost for this call { ret ['cost' ]} " )
506554 except Exception as e :
507555 logger .debug (f"Error in completion_cost for model { llm ['llm' ]} : { e } " )
508556 ret ["cost" ] = 0.0
@@ -512,6 +560,22 @@ def query(
512560 ret ["n_completion_tokens" ] = usage .completion_tokens
513561 ret ["n_prompt_tokens" ] = usage .prompt_tokens
514562 ret ["n_total_tokens" ] = usage .total_tokens
563+ # add the cost and tokens from the recursive call info, if available
564+ if recursive_call_info .get ("cost" ) is not None :
565+ ret ["cost" ] += recursive_call_info ["cost" ]
566+ if debug :
567+ print (f"DEBUG: cost for this and previous calls { ret ['cost' ]} " )
568+ if recursive_call_info .get ("n_completion_tokens" ) is not None :
569+ ret ["n_completion_tokens" ] += recursive_call_info ["n_completion_tokens" ]
570+ if recursive_call_info .get ("n_prompt_tokens" ) is not None :
571+ ret ["n_prompt_tokens" ] += recursive_call_info ["n_prompt_tokens" ]
572+ if recursive_call_info .get ("n_total_tokens" ) is not None :
573+ ret ["n_total_tokens" ] += recursive_call_info ["n_total_tokens" ]
574+ recursive_call_info ["cost" ] = ret ["cost" ]
575+ recursive_call_info ["n_completion_tokens" ] = ret ["n_completion_tokens" ]
576+ recursive_call_info ["n_prompt_tokens" ] = ret ["n_prompt_tokens" ]
577+ recursive_call_info ["n_total_tokens" ] = ret ["n_total_tokens" ]
578+
515579 response_message = response ['choices' ][0 ]['message' ]
516580 # Does not seem to work see https://github.com/BerriAI/litellm/issues/389
517581 # ret["response_ms"] = response["response_ms"]
@@ -521,9 +585,70 @@ def query(
521585 ret ["ok" ] = True
522586 # TODO: if feasable handle all tool calling here or in a separate method which does
523587 # all the tool calling steps (up to a specified maximum).
588+ if debug :
589+ print (f"DEBUG: checking for tool_calls: { response_message } , have tools: { tools is not None } " )
524590 if tools is not None :
525- ret ["tool_calls" ] = response_message .tool_calls
526- ret ["response_message" ] = response_message
591+ if hasattr (response_message , "tool_calls" ) and response_message .tool_calls is not None :
592+ tool_calls = response_message .tool_calls
593+ else :
594+ tool_calls = []
595+ if debug :
596+ print (f"DEBUG: got { len (tool_calls )} tool calls:" )
597+ for tool_call in tool_calls :
598+ print (f"DEBUG: { tool_call } " )
599+ if len (tool_calls ) > 0 : # not an empty list
600+ if debug :
601+ print (f"DEBUG: appending response message: { response_message } " )
602+ messages .append (response_message )
603+ for tool_call in tool_calls :
604+ function_name = tool_call .function .name
605+ if debug :
606+ print (f"DEBUG: tool call { function_name } " )
607+ fun2call = fmap .get (function_name )
608+ if fun2call is None :
609+ ret ["error" ] = f"Unknown tooling function name: { function_name } "
610+ ret ["answer" ] = ""
611+ ret ["ok" ] = False
612+ return ret
613+ function_args = json .loads (tool_call .function .arguments )
614+ try :
615+ if debug :
616+ print (f"DEBUG: calling { function_name } with args { function_args } " )
617+ function_response = fun2call (** function_args )
618+ if debug :
619+ print (f"DEBUG: got response { function_response } " )
620+ except Exception as e :
621+ tb = traceback .extract_tb (e .__traceback__ )
622+ filename , lineno , funcname , text = tb [- 1 ]
623+ if debug :
624+ print (f"DEBUG: function call got error { e } " )
625+ ret ["error" ] = f"Error executing tool function { function_name } : { str (e )} in { filename } :{ lineno } { funcname } "
626+ if debug :
627+ logger .error (f"Returning error: { e } " )
628+ ret ["answer" ] = ""
629+ ret ["ok" ] = False
630+ return ret
631+ messages .append (
632+ dict (
633+ tool_call_id = tool_call .id ,
634+ role = "tool" , name = function_name ,
635+ content = json .dumps (function_response )))
636+ # recursively call query
637+ if debug :
638+ print (f"DEBUG: recursively calling query with messages:" )
639+ for idx , msg in enumerate (messages ):
640+ print (f"DEBUG: Message { idx } : { msg } " )
641+ print (f"DEBUG: recursively_call_info is { recursive_call_info } " )
642+ return self .query (
643+ llmalias ,
644+ messages ,
645+ tools = tools ,
646+ return_cost = return_cost ,
647+ return_response = return_response ,
648+ debug = debug ,
649+ litellm_debug = litellm_debug ,
650+ recursive_call_info = recursive_call_info ,
651+ ** kwargs )
527652 except Exception as e :
528653 tb = traceback .extract_tb (e .__traceback__ )
529654 filename , lineno , funcname , text = tb [- 1 ]
0 commit comments