|
29 | 29 | */ |
30 | 30 | package com.google.api.gax.httpjson; |
31 | 31 |
|
| 32 | +import static com.google.common.truth.Truth.assertThat; |
| 33 | + |
32 | 34 | import com.google.api.client.http.HttpTransport; |
| 35 | +import com.google.api.gax.httpjson.testing.MockHttpService; |
| 36 | +import com.google.api.gax.httpjson.testing.TestApiTracer; |
| 37 | +import com.google.api.gax.rpc.EndpointContext; |
| 38 | +import com.google.api.gax.rpc.ResponseObserver; |
| 39 | +import com.google.api.gax.rpc.StreamController; |
| 40 | +import com.google.auth.Credentials; |
33 | 41 | import com.google.common.truth.Truth; |
| 42 | +import com.google.protobuf.Field; |
34 | 43 | import com.google.protobuf.TypeRegistry; |
35 | 44 | import java.io.ByteArrayInputStream; |
36 | 45 | import java.io.InputStream; |
37 | 46 | import java.io.Reader; |
| 47 | +import java.util.Collections; |
| 48 | +import java.util.HashMap; |
| 49 | +import java.util.List; |
| 50 | +import java.util.Map; |
| 51 | +import java.util.concurrent.CountDownLatch; |
38 | 52 | import java.util.concurrent.Executor; |
| 53 | +import java.util.concurrent.ExecutorService; |
| 54 | +import java.util.concurrent.Executors; |
39 | 55 | import java.util.concurrent.ScheduledThreadPoolExecutor; |
40 | 56 | import java.util.concurrent.TimeUnit; |
| 57 | +import org.junit.jupiter.api.AfterAll; |
| 58 | +import org.junit.jupiter.api.AfterEach; |
| 59 | +import org.junit.jupiter.api.BeforeAll; |
| 60 | +import org.junit.jupiter.api.BeforeEach; |
41 | 61 | import org.junit.jupiter.api.Test; |
42 | 62 | import org.junit.jupiter.api.extension.ExtendWith; |
43 | 63 | import org.mockito.Mock; |
@@ -135,4 +155,170 @@ void responseReceived_cancellationTaskExists_isCancelledProperly() throws Interr |
135 | 155 | // Scheduler is not waiting for any task and should terminate quickly |
136 | 156 | Truth.assertThat(deadlineSchedulerExecutor.isTerminated()).isTrue(); |
137 | 157 | } |
| 158 | + |
| 159 | + private static final ApiMethodDescriptor<Field, Field> FAKE_METHOD_DESCRIPTOR = |
| 160 | + ApiMethodDescriptor.<Field, Field>newBuilder() |
| 161 | + .setFullMethodName("google.cloud.v1.Fake/FakeMethod") |
| 162 | + .setHttpMethod("POST") |
| 163 | + .setRequestFormatter( |
| 164 | + ProtoMessageRequestFormatter.<Field>newBuilder() |
| 165 | + .setPath( |
| 166 | + "/fake/v1/name/{name}", |
| 167 | + request -> { |
| 168 | + Map<String, String> fields = new HashMap<>(); |
| 169 | + ProtoRestSerializer<Field> serializer = ProtoRestSerializer.create(); |
| 170 | + serializer.putPathParam(fields, "name", request.getName()); |
| 171 | + return fields; |
| 172 | + }) |
| 173 | + .setQueryParamsExtractor(request -> new HashMap<>()) |
| 174 | + .setRequestBodyExtractor( |
| 175 | + request -> |
| 176 | + ProtoRestSerializer.create() |
| 177 | + .toBody("*", request.toBuilder().clearName().build(), false)) |
| 178 | + .build()) |
| 179 | + .setResponseParser( |
| 180 | + ProtoMessageResponseParser.<Field>newBuilder() |
| 181 | + .setDefaultInstance(Field.getDefaultInstance()) |
| 182 | + .build()) |
| 183 | + .build(); |
| 184 | + |
| 185 | + private static final MockHttpService MOCK_SERVICE = |
| 186 | + new MockHttpService(Collections.singletonList(FAKE_METHOD_DESCRIPTOR), "google.com:443"); |
| 187 | + |
| 188 | + private static ExecutorService executorService; |
| 189 | + private ManagedHttpJsonChannel channel; |
| 190 | + private TestApiTracer tracer; |
| 191 | + |
| 192 | + @BeforeAll |
| 193 | + static void initialize() { |
| 194 | + executorService = Executors.newFixedThreadPool(2); |
| 195 | + } |
| 196 | + |
| 197 | + @AfterAll |
| 198 | + static void destroy() { |
| 199 | + executorService.shutdownNow(); |
| 200 | + } |
| 201 | + |
| 202 | + @BeforeEach |
| 203 | + void setUp() { |
| 204 | + channel = |
| 205 | + ManagedHttpJsonChannel.newBuilder() |
| 206 | + .setEndpoint("google.com:443") |
| 207 | + .setExecutor(executorService) |
| 208 | + .setHttpTransport(MOCK_SERVICE) |
| 209 | + .build(); |
| 210 | + tracer = new TestApiTracer(); |
| 211 | + } |
| 212 | + |
| 213 | + @AfterEach |
| 214 | + void tearDown() { |
| 215 | + MOCK_SERVICE.reset(); |
| 216 | + } |
| 217 | + |
| 218 | + @Test |
| 219 | + void testBodySizeRecording() throws Exception { |
| 220 | + HttpJsonDirectCallable<Field, Field> callable = |
| 221 | + new HttpJsonDirectCallable<>(FAKE_METHOD_DESCRIPTOR); |
| 222 | + |
| 223 | + EndpointContext endpointContext = Mockito.mock(EndpointContext.class); |
| 224 | + Mockito.lenient() |
| 225 | + .doNothing() |
| 226 | + .when(endpointContext) |
| 227 | + .validateUniverseDomain( |
| 228 | + Mockito.any(Credentials.class), Mockito.any(HttpJsonStatusCode.class)); |
| 229 | + |
| 230 | + HttpJsonCallContext callContext = |
| 231 | + HttpJsonCallContext.createDefault() |
| 232 | + .withChannel(channel) |
| 233 | + .withEndpointContext(endpointContext) |
| 234 | + .withTracer(tracer); |
| 235 | + |
| 236 | + Field request = Field.newBuilder().setName("bob").setNumber(42).build(); |
| 237 | + Field response = Field.newBuilder().setName("alice").setNumber(43).build(); |
| 238 | + |
| 239 | + MOCK_SERVICE.addResponse(response); |
| 240 | + |
| 241 | + callable.futureCall(request, callContext).get(); |
| 242 | + |
| 243 | + // Verify response size |
| 244 | + // MockHttpService uses ProtoRestSerializer which pretty-prints. |
| 245 | + String expectedResponseBody = ProtoRestSerializer.create().toBody("*", response, false); |
| 246 | + long expectedResponseSize = expectedResponseBody.getBytes("UTF-8").length; |
| 247 | + assertThat(tracer.getResponseReceivedSize()).isEqualTo(expectedResponseSize); |
| 248 | + } |
| 249 | + |
| 250 | + @Test |
| 251 | + void testBodySizeRecordingServerStreaming() throws Exception { |
| 252 | + ApiMethodDescriptor<Field, Field> methodServerStreaming = |
| 253 | + FAKE_METHOD_DESCRIPTOR.toBuilder() |
| 254 | + .setType(ApiMethodDescriptor.MethodType.SERVER_STREAMING) |
| 255 | + .build(); |
| 256 | + |
| 257 | + MockHttpService streamingMockService = |
| 258 | + new MockHttpService(Collections.singletonList(methodServerStreaming), "google.com:443"); |
| 259 | + ManagedHttpJsonChannel streamingChannel = |
| 260 | + ManagedHttpJsonChannel.newBuilder() |
| 261 | + .setEndpoint("google.com:443") |
| 262 | + .setExecutor(executorService) |
| 263 | + .setHttpTransport(streamingMockService) |
| 264 | + .build(); |
| 265 | + |
| 266 | + HttpJsonDirectServerStreamingCallable<Field, Field> callable = |
| 267 | + new HttpJsonDirectServerStreamingCallable<>(methodServerStreaming); |
| 268 | + |
| 269 | + EndpointContext endpointContext = Mockito.mock(EndpointContext.class); |
| 270 | + Mockito.lenient() |
| 271 | + .doNothing() |
| 272 | + .when(endpointContext) |
| 273 | + .validateUniverseDomain( |
| 274 | + Mockito.any(Credentials.class), Mockito.any(HttpJsonStatusCode.class)); |
| 275 | + |
| 276 | + HttpJsonCallContext callContext = |
| 277 | + HttpJsonCallContext.createDefault() |
| 278 | + .withChannel(streamingChannel) |
| 279 | + .withEndpointContext(endpointContext) |
| 280 | + .withTracer(tracer); |
| 281 | + |
| 282 | + Field request = Field.newBuilder().setName("bob").setNumber(42).build(); |
| 283 | + Field response1 = Field.newBuilder().setName("alice1").setNumber(43).build(); |
| 284 | + Field response2 = Field.newBuilder().setName("alice2").setNumber(44).build(); |
| 285 | + |
| 286 | + streamingMockService.addResponse(new Field[] {response1, response2}); |
| 287 | + |
| 288 | + final List<Field> receivedResponses = new java.util.ArrayList<>(); |
| 289 | + final CountDownLatch latch = new CountDownLatch(1); |
| 290 | + |
| 291 | + callable.call( |
| 292 | + request, |
| 293 | + new ResponseObserver<Field>() { |
| 294 | + @Override |
| 295 | + public void onStart(StreamController controller) { |
| 296 | + // no behavior needed |
| 297 | + } |
| 298 | + |
| 299 | + @Override |
| 300 | + public void onResponse(Field response) { |
| 301 | + receivedResponses.add(response); |
| 302 | + } |
| 303 | + |
| 304 | + @Override |
| 305 | + public void onError(Throwable t) { |
| 306 | + latch.countDown(); |
| 307 | + } |
| 308 | + |
| 309 | + @Override |
| 310 | + public void onComplete() { |
| 311 | + latch.countDown(); |
| 312 | + } |
| 313 | + }, |
| 314 | + callContext); |
| 315 | + |
| 316 | + latch.await(10, TimeUnit.SECONDS); |
| 317 | + |
| 318 | + assertThat(receivedResponses).hasSize(2); |
| 319 | + |
| 320 | + // Verify response size (0 because streaming chunked responses don't include Content-Length) |
| 321 | + assertThat(tracer.getResponseReceivedSize()).isEqualTo(0); |
| 322 | + streamingChannel.shutdownNow(); |
| 323 | + } |
138 | 324 | } |
0 commit comments