2121from google .generativeai import discuss
2222from google .generativeai import client
2323import google .generativeai as genai
24+ from google .generativeai .types import safety_types
25+
2426from absl .testing import absltest
2527from absl .testing import parameterized
2628
@@ -36,12 +38,12 @@ def setUp(self):
3638 self .observed_request = None
3739
3840 self .mock_response = glm .GenerateMessageResponse (
39- candidates = [
40- glm .Message (content = "a" , author = "1" ),
41- glm .Message (content = "b" , author = "1" ),
42- glm .Message (content = "c" , author = "1" ),
43- ],
44- )
41+ candidates = [
42+ glm .Message (content = "a" , author = "1" ),
43+ glm .Message (content = "b" , author = "1" ),
44+ glm .Message (content = "c" , author = "1" ),
45+ ],
46+ )
4547
4648 def fake_generate_message (
4749 request : glm .GenerateMessageRequest ,
@@ -274,32 +276,39 @@ def test_reply(self, kwargs):
274276 response = response .reply ("again" )
275277
276278 def test_receive_and_reply_with_filters (self ):
277-
278279 self .mock_response = mock_response = glm .GenerateMessageResponse (
279280 candidates = [glm .Message (content = "a" , author = "1" )],
280281 filters = [
281- glm .ContentFilter (reason = glm .ContentFilter .BlockedReason .SAFETY , message = 'unsafe' ),
282- glm .ContentFilter (reason = glm .ContentFilter .BlockedReason .OTHER ),]
282+ glm .ContentFilter (
283+ reason = safety_types .BlockedReason .SAFETY , message = "unsafe"
284+ ),
285+ glm .ContentFilter (reason = safety_types .BlockedReason .OTHER ),
286+ ],
283287 )
284288 response = discuss .chat (messages = "do filters work?" )
285289
286290 filters = response .filters
287291 self .assertLen (filters , 2 )
288- self .assertIsInstance (filters [0 ][' reason' ], glm . ContentFilter .BlockedReason )
289- self .assertEquals (filters [0 ][' reason' ], glm . ContentFilter .BlockedReason .SAFETY )
290- self .assertEquals (filters [0 ][' message' ], ' unsafe' )
292+ self .assertIsInstance (filters [0 ][" reason" ], safety_types .BlockedReason )
293+ self .assertEqual (filters [0 ][" reason" ], safety_types .BlockedReason .SAFETY )
294+ self .assertEqual (filters [0 ][" message" ], " unsafe" )
291295
292296 self .mock_response = glm .GenerateMessageResponse (
293297 candidates = [glm .Message (content = "a" , author = "1" )],
294298 filters = [
295- glm .ContentFilter (reason = glm .ContentFilter .BlockedReason .BLOCKED_REASON_UNSPECIFIED )]
299+ glm .ContentFilter (
300+ reason = safety_types .BlockedReason .BLOCKED_REASON_UNSPECIFIED
301+ )
302+ ],
296303 )
297304
298- response = response .reply (' Does reply work?' )
305+ response = response .reply (" Does reply work?" )
299306 filters = response .filters
300307 self .assertLen (filters , 1 )
301- self .assertIsInstance (filters [0 ]['reason' ], glm .ContentFilter .BlockedReason )
302- self .assertEquals (filters [0 ]['reason' ], glm .ContentFilter .BlockedReason .BLOCKED_REASON_UNSPECIFIED )
308+ self .assertIsInstance (filters [0 ]["reason" ], safety_types .BlockedReason )
309+ self .assertEqual (
310+ filters [0 ]["reason" ], safety_types .BlockedReason .BLOCKED_REASON_UNSPECIFIED
311+ )
303312
304313
305314if __name__ == "__main__" :
0 commit comments