7272 "FunctionLibraryType" ,
7373]
7474
75+ Mode = protos .DynamicRetrievalConfig .Mode
76+
77+ ModeOptions = Union [int , str , Mode ]
78+
79+ _MODE : dict [ModeOptions , Mode ] = {
80+ Mode .MODE_UNSPECIFIED : Mode .MODE_UNSPECIFIED ,
81+ 0 : Mode .MODE_UNSPECIFIED ,
82+ "mode_unspecified" : Mode .MODE_UNSPECIFIED ,
83+ "unspecified" : Mode .MODE_UNSPECIFIED ,
84+ Mode .MODE_DYNAMIC : Mode .MODE_DYNAMIC ,
85+ 1 : Mode .MODE_DYNAMIC ,
86+ "mode_dynamic" : Mode .MODE_DYNAMIC ,
87+ "dynamic" : Mode .MODE_DYNAMIC ,
88+ }
89+
90+
91+ def to_mode (x : ModeOptions ) -> Mode :
92+ if isinstance (x , str ):
93+ x = x .lower ()
94+ return _MODE [x ]
95+
7596
7697def _pil_to_blob (image : PIL .Image .Image ) -> protos .Blob :
7798 # If the image is a local file, return a file-based blob without any modification.
@@ -644,16 +665,54 @@ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.F
644665 return fd .to_proto ()
645666
646667
668+ class DynamicRetrievalConfigDict (TypedDict ):
669+ mode : protos .DynamicRetrievalConfig .mode
670+ dynamic_threshold : float
671+
672+
673+ DynamicRetrievalConfig = Union [protos .DynamicRetrievalConfig , DynamicRetrievalConfigDict ]
674+
675+
676+ class GoogleSearchRetrievalDict (TypedDict ):
677+ dynamic_retrieval_config : DynamicRetrievalConfig
678+
679+
680+ GoogleSearchRetrievalType = Union [protos .GoogleSearchRetrieval , GoogleSearchRetrievalDict ]
681+
682+
683+ def _make_google_search_retrieval (gsr : GoogleSearchRetrievalType ):
684+ if isinstance (gsr , protos .GoogleSearchRetrieval ):
685+ return gsr
686+ elif isinstance (gsr , Mapping ):
687+ drc = gsr .get ("dynamic_retrieval_config" , None )
688+ if drc is not None and isinstance (drc , Mapping ):
689+ mode = drc .get ("mode" , None )
690+ if mode is not None :
691+ mode = to_mode (mode )
692+ gsr = gsr .copy ()
693+ gsr ["dynamic_retrieval_config" ]["mode" ] = mode
694+ return protos .GoogleSearchRetrieval (gsr )
695+ else :
696+ raise TypeError (
697+ "Invalid input type. Expected an instance of `genai.GoogleSearchRetrieval`.\n "
698+ f"However, received an object of type: { type (gsr )} .\n "
699+ f"Object Value: { gsr } "
700+ )
701+
702+
647703class Tool :
648- """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects."""
704+ """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects,
705+ protos.CodeExecution object, and protos.GoogleSearchRetrieval object."""
649706
650707 def __init__ (
651708 self ,
709+ * ,
652710 function_declarations : Iterable [FunctionDeclarationType ] | None = None ,
711+ google_search_retrieval : GoogleSearchRetrievalType | None = None ,
653712 code_execution : protos .CodeExecution | None = None ,
654713 ):
655714 # The main path doesn't use this but is seems useful.
656- if function_declarations :
715+ if function_declarations is not None :
657716 self ._function_declarations = [
658717 _make_function_declaration (f ) for f in function_declarations
659718 ]
@@ -668,15 +727,25 @@ def __init__(
668727 self ._function_declarations = []
669728 self ._index = {}
670729
730+ if google_search_retrieval is not None :
731+ self ._google_search_retrieval = _make_google_search_retrieval (google_search_retrieval )
732+ else :
733+ self ._google_search_retrieval = None
734+
671735 self ._proto = protos .Tool (
672736 function_declarations = [_encode_fd (fd ) for fd in self ._function_declarations ],
737+ google_search_retrieval = google_search_retrieval ,
673738 code_execution = code_execution ,
674739 )
675740
676741 @property
677742 def function_declarations (self ) -> list [FunctionDeclaration | protos .FunctionDeclaration ]:
678743 return self ._function_declarations
679744
745+ @property
746+ def google_search_retrieval (self ) -> protos .GoogleSearchRetrieval :
747+ return self ._google_search_retrieval
748+
680749 @property
681750 def code_execution (self ) -> protos .CodeExecution :
682751 return self ._proto .code_execution
@@ -705,7 +774,7 @@ class ToolDict(TypedDict):
705774
706775
707776ToolType = Union [
708- Tool , protos .Tool , ToolDict , Iterable [FunctionDeclarationType ], FunctionDeclarationType
777+ str , Tool , protos .Tool , ToolDict , Iterable [FunctionDeclarationType ], FunctionDeclarationType
709778]
710779
711780
@@ -717,20 +786,41 @@ def _make_tool(tool: ToolType) -> Tool:
717786 code_execution = tool .code_execution
718787 else :
719788 code_execution = None
720- return Tool (function_declarations = tool .function_declarations , code_execution = code_execution )
789+
790+ if "google_search_retrieval" in tool :
791+ google_search_retrieval = tool .google_search_retrieval
792+ else :
793+ google_search_retrieval = None
794+
795+ return Tool (
796+ function_declarations = tool .function_declarations ,
797+ google_search_retrieval = google_search_retrieval ,
798+ code_execution = code_execution ,
799+ )
721800 elif isinstance (tool , dict ):
722- if "function_declarations" in tool or "code_execution" in tool :
801+ if (
802+ "function_declarations" in tool
803+ or "google_search_retrieval" in tool
804+ or "code_execution" in tool
805+ ):
723806 return Tool (** tool )
724807 else :
725808 fd = tool
726809 return Tool (function_declarations = [protos .FunctionDeclaration (** fd )])
727810 elif isinstance (tool , str ):
728811 if tool .lower () == "code_execution" :
729812 return Tool (code_execution = protos .CodeExecution ())
813+ # Check to see if one of the mode enums matches
814+ elif tool .lower () == "google_search_retrieval" :
815+ return Tool (google_search_retrieval = protos .GoogleSearchRetrieval ())
730816 else :
731- raise ValueError ("The only string that can be passed as a tool is 'code_execution'." )
817+ raise ValueError (
818+ "The only string that can be passed as a tool is 'code_execution', or one of the specified values for the `mode` parameter for google_search_retrieval."
819+ )
732820 elif isinstance (tool , protos .CodeExecution ):
733821 return Tool (code_execution = tool )
822+ elif isinstance (tool , protos .GoogleSearchRetrieval ):
823+ return Tool (google_search_retrieval = tool )
734824 elif isinstance (tool , Iterable ):
735825 return Tool (function_declarations = tool )
736826 else :
@@ -786,7 +876,7 @@ def to_proto(self):
786876
787877def _make_tools (tools : ToolsType ) -> list [Tool ]:
788878 if isinstance (tools , str ):
789- if tools .lower () == "code_execution" :
879+ if tools .lower () == "code_execution" or tools . lower () == "google_search_retrieval" :
790880 return [_make_tool (tools )]
791881 else :
792882 raise ValueError ("The only string that can be passed as a tool is 'code_execution'." )
0 commit comments