@@ -71,7 +71,7 @@ def __init__(
7171 model_name : str = "gemini-pro" ,
7272 safety_settings : safety_types .SafetySettingOptions | None = None ,
7373 generation_config : generation_types .GenerationConfigType | None = None ,
74- tools : content_types .ToolsType = None ,
74+ tools : content_types .FunctionLibraryType | None = None ,
7575 ):
7676 if "/" not in model_name :
7777 model_name = "models/" + model_name
@@ -80,7 +80,7 @@ def __init__(
8080 safety_settings , harm_category_set = "new"
8181 )
8282 self ._generation_config = generation_types .to_generation_config_dict (generation_config )
83- self ._tools = content_types .to_tools (tools )
83+ self ._tools = content_types .to_function_library (tools )
8484
8585 self ._client = None
8686 self ._async_client = None
@@ -94,8 +94,9 @@ def __str__(self):
9494 f"""\
9595 genai.GenerativeModel(
9696 model_name='{ self .model_name } ',
97- generation_config={ self ._generation_config } .
98- safety_settings={ self ._safety_settings }
97+ generation_config={ self ._generation_config } ,
98+ safety_settings={ self ._safety_settings } ,
99+ tools={ self ._tools } ,
99100 )"""
100101 )
101102
@@ -107,12 +108,16 @@ def _prepare_request(
107108 contents : content_types .ContentsType ,
108109 generation_config : generation_types .GenerationConfigType | None = None ,
109110 safety_settings : safety_types .SafetySettingOptions | None = None ,
110- ** kwargs ,
111+ tools : content_types . FunctionLibraryType | None ,
111112 ) -> glm .GenerateContentRequest :
112113 """Creates a `glm.GenerateContentRequest` from raw inputs."""
113114 if not contents :
114115 raise TypeError ("contents must not be empty" )
115116
117+ tools_lib = self ._get_tools_lib (tools )
118+ if tools_lib is not None :
119+ tools_lib = tools_lib .to_proto ()
120+
116121 contents = content_types .to_contents (contents )
117122
118123 generation_config = generation_types .to_generation_config_dict (generation_config )
@@ -129,19 +134,26 @@ def _prepare_request(
129134 contents = contents ,
130135 generation_config = merged_gc ,
131136 safety_settings = merged_ss ,
132- tools = self ._tools ,
133- ** kwargs ,
137+ tools = tools_lib ,
134138 )
135139
140+ def _get_tools_lib (
141+ self , tools : content_types .FunctionLibraryType
142+ ) -> content_types .FunctionLibrary | None :
143+ if tools is None :
144+ return self ._tools
145+ else :
146+ return content_types .to_function_library (tools )
147+
136148 def generate_content (
137149 self ,
138150 contents : content_types .ContentsType ,
139151 * ,
140152 generation_config : generation_types .GenerationConfigType | None = None ,
141153 safety_settings : safety_types .SafetySettingOptions | None = None ,
142154 stream : bool = False ,
155+ tools : content_types .FunctionLibraryType | None = None ,
143156 request_options : dict [str , Any ] | None = None ,
144- ** kwargs ,
145157 ) -> generation_types .GenerateContentResponse :
146158 """A multipurpose function to generate responses from the model.
147159
@@ -201,7 +213,7 @@ def generate_content(
201213 contents = contents ,
202214 generation_config = generation_config ,
203215 safety_settings = safety_settings ,
204- ** kwargs ,
216+ tools = tools ,
205217 )
206218 if self ._client is None :
207219 self ._client = client .get_default_generative_client ()
@@ -230,15 +242,15 @@ async def generate_content_async(
230242 generation_config : generation_types .GenerationConfigType | None = None ,
231243 safety_settings : safety_types .SafetySettingOptions | None = None ,
232244 stream : bool = False ,
245+ tools : content_types .FunctionLibraryType | None = None ,
233246 request_options : dict [str , Any ] | None = None ,
234- ** kwargs ,
235247 ) -> generation_types .AsyncGenerateContentResponse :
236248 """The async version of `GenerativeModel.generate_content`."""
237249 request = self ._prepare_request (
238250 contents = contents ,
239251 generation_config = generation_config ,
240252 safety_settings = safety_settings ,
241- ** kwargs ,
253+ tools = tools ,
242254 )
243255 if self ._async_client is None :
244256 self ._async_client = client .get_default_generative_async_client ()
@@ -299,6 +311,7 @@ def start_chat(
299311 self ,
300312 * ,
301313 history : Iterable [content_types .StrictContentType ] | None = None ,
314+ enable_automatic_function_calling : bool = False ,
302315 ) -> ChatSession :
303316 """Returns a `genai.ChatSession` attached to this model.
304317
@@ -314,6 +327,7 @@ def start_chat(
314327 return ChatSession (
315328 model = self ,
316329 history = history ,
330+ enable_automatic_function_calling = enable_automatic_function_calling ,
317331 )
318332
319333
@@ -341,11 +355,13 @@ def __init__(
341355 self ,
342356 model : GenerativeModel ,
343357 history : Iterable [content_types .StrictContentType ] | None = None ,
358+ enable_automatic_function_calling : bool = False ,
344359 ):
345360 self .model : GenerativeModel = model
346361 self ._history : list [glm .Content ] = content_types .to_contents (history )
347362 self ._last_sent : glm .Content | None = None
348363 self ._last_received : generation_types .BaseGenerateContentResponse | None = None
364+ self .enable_automatic_function_calling = enable_automatic_function_calling
349365
350366 def send_message (
351367 self ,
@@ -354,7 +370,7 @@ def send_message(
354370 generation_config : generation_types .GenerationConfigType = None ,
355371 safety_settings : safety_types .SafetySettingOptions = None ,
356372 stream : bool = False ,
357- ** kwargs ,
373+ tools : content_types . FunctionLibraryType | None = None ,
358374 ) -> generation_types .GenerateContentResponse :
359375 """Sends the conversation history with the added message and returns the model's response.
360376
@@ -387,23 +403,52 @@ def send_message(
387403 safety_settings: Overrides for the model's safety settings.
388404 stream: If True, yield response chunks as they are generated.
389405 """
406+ if self .enable_automatic_function_calling and stream :
407+ raise NotImplementedError (
408+ "The `google.generativeai` SDK does not yet support `stream=True` with "
409+ "`enable_automatic_function_calling=True`"
410+ )
411+
412+ tools_lib = self .model ._get_tools_lib (tools )
413+
390414 content = content_types .to_content (content )
415+
391416 if not content .role :
392417 content .role = self ._USER_ROLE
418+
393419 history = self .history [:]
394420 history .append (content )
395421
396422 generation_config = generation_types .to_generation_config_dict (generation_config )
397423 if generation_config .get ("candidate_count" , 1 ) > 1 :
398424 raise ValueError ("Can't chat with `candidate_count > 1`" )
425+
399426 response = self .model .generate_content (
400427 contents = history ,
401428 generation_config = generation_config ,
402429 safety_settings = safety_settings ,
403430 stream = stream ,
404- ** kwargs ,
431+ tools = tools_lib ,
405432 )
406433
434+ self ._check_response (response = response , stream = stream )
435+
436+ if self .enable_automatic_function_calling and tools_lib is not None :
437+ self .history , content , response = self ._handle_afc (
438+ response = response ,
439+ history = history ,
440+ generation_config = generation_config ,
441+ safety_settings = safety_settings ,
442+ stream = stream ,
443+ tools_lib = tools_lib ,
444+ )
445+
446+ self ._last_sent = content
447+ self ._last_received = response
448+
449+ return response
450+
451+ def _check_response (self , * , response , stream ):
407452 if response .prompt_feedback .block_reason :
408453 raise generation_types .BlockedPromptException (response .prompt_feedback )
409454
@@ -415,10 +460,49 @@ def send_message(
415460 ):
416461 raise generation_types .StopCandidateException (response .candidates [0 ])
417462
418- self ._last_sent = content
419- self ._last_received = response
463+ def _get_function_calls (self , response ) -> list [glm .FunctionCall ]:
464+ candidates = response .candidates
465+ if len (candidates ) != 1 :
466+ raise ValueError (
467+ f"Automatic function calling only works with 1 candidate, got: { len (candidates )} "
468+ )
469+ parts = candidates [0 ].content .parts
470+ function_calls = [part .function_call for part in parts if part and "function_call" in part ]
471+ return function_calls
472+
473+ def _handle_afc (
474+ self , * , response , history , generation_config , safety_settings , stream , tools_lib
475+ ) -> tuple [list [glm .Content ], glm .Content , generation_types .BaseGenerateContentResponse ]:
476+
477+ while function_calls := self ._get_function_calls (response ):
478+ if not all (callable (tools_lib [fc ]) for fc in function_calls ):
479+ break
480+ history .append (response .candidates [0 ].content )
481+
482+ function_response_parts : list [glm .Part ] = []
483+ for fc in function_calls :
484+ fr = tools_lib (fc )
485+ assert fr is not None , (
486+ "This should never happen, it should only return None if the declaration"
487+ "is not callable, and that's guarded against above."
488+ )
489+ function_response_parts .append (fr )
420490
421- return response
491+ send = glm .Content (role = self ._USER_ROLE , parts = function_response_parts )
492+ history .append (send )
493+
494+ response = self .model .generate_content (
495+ contents = history ,
496+ generation_config = generation_config ,
497+ safety_settings = safety_settings ,
498+ stream = stream ,
499+ tools = tools_lib ,
500+ )
501+
502+ self ._check_response (response = response , stream = stream )
503+
504+ * history , content = history
505+ return history , content , response
422506
423507 async def send_message_async (
424508 self ,
@@ -427,42 +511,88 @@ async def send_message_async(
427511 generation_config : generation_types .GenerationConfigType = None ,
428512 safety_settings : safety_types .SafetySettingOptions = None ,
429513 stream : bool = False ,
430- ** kwargs ,
514+ tools : content_types . FunctionLibraryType | None = None ,
431515 ) -> generation_types .AsyncGenerateContentResponse :
432516 """The async version of `ChatSession.send_message`."""
517+ if self .enable_automatic_function_calling and stream :
518+ raise NotImplementedError (
519+ "The `google.generativeai` SDK does not yet support `stream=True` with "
520+ "`enable_automatic_function_calling=True`"
521+ )
522+
523+ tools_lib = self .model ._get_tools_lib (tools )
524+
433525 content = content_types .to_content (content )
526+
434527 if not content .role :
435528 content .role = self ._USER_ROLE
529+
436530 history = self .history [:]
437531 history .append (content )
438532
439533 generation_config = generation_types .to_generation_config_dict (generation_config )
440534 if generation_config .get ("candidate_count" , 1 ) > 1 :
441535 raise ValueError ("Can't chat with `candidate_count > 1`" )
442- response = await self .model .generate_content_async (
536+
537+ response = await self .model .generate_content (
443538 contents = history ,
444539 generation_config = generation_config ,
445540 safety_settings = safety_settings ,
446541 stream = stream ,
447- ** kwargs ,
542+ tools = tools_lib ,
448543 )
449544
450- if response .prompt_feedback .block_reason :
451- raise generation_types .BlockedPromptException (response .prompt_feedback )
545+ self ._check_response (response = response , stream = stream )
452546
453- if not stream :
454- if response .candidates [0 ].finish_reason not in (
455- glm .Candidate .FinishReason .FINISH_REASON_UNSPECIFIED ,
456- glm .Candidate .FinishReason .STOP ,
457- glm .Candidate .FinishReason .MAX_TOKENS ,
458- ):
459- raise generation_types .StopCandidateException (response .candidates [0 ])
547+ if self .enable_automatic_function_calling and tools_lib is not None :
548+ self .history , content , response = await self ._handle_afc_async (
549+ response = response ,
550+ history = history ,
551+ generation_config = generation_config ,
552+ safety_settings = safety_settings ,
553+ stream = stream ,
554+ tools_lib = tools_lib ,
555+ )
460556
461557 self ._last_sent = content
462558 self ._last_received = response
463559
464560 return response
465561
562+ async def _handle_afc_async (
563+ self , * , response , history , generation_config , safety_settings , stream , tools_lib
564+ ) -> tuple [list [glm .Content ], glm .Content , generation_types .BaseGenerateContentResponse ]:
565+
566+ while function_calls := self ._get_function_calls (response ):
567+ if not all (callable (tools_lib [fc ]) for fc in function_calls ):
568+ break
569+ history .append (response .candidates [0 ].content )
570+
571+ function_response_parts : list [glm .Part ] = []
572+ for fc in function_calls :
573+ fr = tools_lib (fc )
574+ assert fr is not None , (
575+ "This should never happen, it should only return None if the declaration"
576+ "is not callable, and that's guarded against above."
577+ )
578+ function_response_parts .append (fr )
579+
580+ send = glm .Content (role = self ._USER_ROLE , parts = function_response_parts )
581+ history .append (send )
582+
583+ response = await self .model .generate_content_async (
584+ contents = history ,
585+ generation_config = generation_config ,
586+ safety_settings = safety_settings ,
587+ stream = stream ,
588+ tools = tools_lib ,
589+ )
590+
591+ self ._check_response (response = response , stream = stream )
592+
593+ * history , content = history
594+ return history , content , response
595+
466596 def __copy__ (self ):
467597 return ChatSession (
468598 model = self .model ,
0 commit comments