/** * Echo the request headers from a client into response headers and trailers. Useful for * testing end-to-end metadata propagation. */ private static ServerInterceptor echoRequestHeadersInterceptor(final Metadata.Key<?>... keys) { final Set<Metadata.Key<?>> keySet = new HashSet<Metadata.Key<?>>(Arrays.asList(keys)); return new ServerInterceptor() { @Override public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( ServerCall<ReqT, RespT> call, final Metadata requestHeaders, ServerCallHandler<ReqT, RespT> next) { return next.startCall(new SimpleForwardingServerCall<ReqT, RespT>(call) { @Override public void sendHeaders(Metadata responseHeaders) { responseHeaders.merge(requestHeaders, keySet); super.sendHeaders(responseHeaders); } @Override public void close(Status status, Metadata trailers) { trailers.merge(requestHeaders, keySet); super.close(status, trailers); } }, requestHeaders); } }; }
@Override public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) { if (Objects.isNull(SecurityContextHolder.getContext().getAuthentication())) { SecurityContextHolder.getContext().setAuthentication(new AnonymousAuthenticationToken(key, "anonymousUser", Collections.singletonList(new SimpleGrantedAuthority("ROLE_ANONYMOUS")))); log.debug("Populated SecurityContextHolder with anonymous token: {}", SecurityContextHolder.getContext().getAuthentication()); } else { log.debug("SecurityContextHolder not populated with anonymous token, as it already contained: {}", SecurityContextHolder.getContext().getAuthentication()); } return next.startCall(call, headers); }
@SuppressWarnings("checkstyle:MethodTypeParameterName") @Override public <ReqT, RespT> Listener<ReqT> interceptCall( final ServerCall<ReqT, RespT> call, final Metadata headers, final ServerCallHandler<ReqT, RespT> next) { TL.set(call); return next.startCall(new SimpleForwardingServerCall<ReqT, RespT>(call) { @Override public void close(final Status status, final Metadata trailers) { super.close(status, trailers); TL.remove(); } }, headers); }
@Override public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers,ServerCallHandler<ReqT, RespT> next) { logger.info("Call intercepted "+headers.toString()); String token = headers.get(authKey); if (StringUtils.notEmpty(token)) { try { logger.info("Token "+token); ConsumerBean consumer = resourceServer.validateResourceFromToken(token); logger.info("Setting call to client "+consumer.getShort_name()); return new SeldonServerCallListener<ReqT>(next.startCall(call, headers),consumer.getShort_name(),this); } catch (APIException e) { logger.warn("API exception on getting token ",e); return next.startCall(call, headers); } } else { logger.warn("Empty token ignoring call"); return next.startCall(call, headers); } }
@Test public void clientHeaderDeliveredToServer() { grpcServerRule.getServiceRegistry() .addService(ServerInterceptors.intercept(new GreeterImplBase() {}, mockServerInterceptor)); GreeterBlockingStub blockingStub = GreeterGrpc.newBlockingStub( ClientInterceptors.intercept(grpcServerRule.getChannel(), new HeaderClientInterceptor())); ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class); try { blockingStub.sayHello(HelloRequest.getDefaultInstance()); fail(); } catch (StatusRuntimeException expected) { // expected because the method is not implemented at server side } verify(mockServerInterceptor).interceptCall( Matchers.<ServerCall<HelloRequest, HelloReply>>any(), metadataCaptor.capture(), Matchers.<ServerCallHandler<HelloRequest, HelloReply>>any()); assertEquals( "customRequestValue", metadataCaptor.getValue().get(HeaderClientInterceptor.CUSTOM_HEADER_KEY)); }
@Override public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( ServerCall<ReqT, RespT> call, final Metadata requestHeaders, ServerCallHandler<ReqT, RespT> next) { return next.startCall( new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(call) { @Override public void sendHeaders(Metadata headers) { for (String cacheControlDirective : cacheControlDirectives) { headers.put(CACHE_CONTROL_METADATA_KEY, cacheControlDirective); } super.sendHeaders(headers); } }, requestHeaders); }
/** * Echoes request headers with the specified key(s) from a client into response headers only. */ private static ServerInterceptor echoRequestMetadataInHeaders(final Metadata.Key<?>... keys) { final Set<Metadata.Key<?>> keySet = new HashSet<Metadata.Key<?>>(Arrays.asList(keys)); return new ServerInterceptor() { @Override public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( ServerCall<ReqT, RespT> call, final Metadata requestHeaders, ServerCallHandler<ReqT, RespT> next) { return next.startCall(new SimpleForwardingServerCall<ReqT, RespT>(call) { @Override public void sendHeaders(Metadata responseHeaders) { responseHeaders.merge(requestHeaders, keySet); super.sendHeaders(responseHeaders); } @Override public void close(Status status, Metadata trailers) { super.close(status, trailers); } }, requestHeaders); } }; }
/** * Echoes request headers with the specified key(s) from a client into response trailers only. */ private static ServerInterceptor echoRequestMetadataInTrailers(final Metadata.Key<?>... keys) { final Set<Metadata.Key<?>> keySet = new HashSet<Metadata.Key<?>>(Arrays.asList(keys)); return new ServerInterceptor() { @Override public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( ServerCall<ReqT, RespT> call, final Metadata requestHeaders, ServerCallHandler<ReqT, RespT> next) { return next.startCall(new SimpleForwardingServerCall<ReqT, RespT>(call) { @Override public void sendHeaders(Metadata responseHeaders) { super.sendHeaders(responseHeaders); } @Override public void close(Status status, Metadata trailers) { trailers.merge(requestHeaders, keySet); super.close(status, trailers); } }, requestHeaders); } }; }
@Override protected AbstractServerImplBuilder<?> getServerBuilder() { return NettyServerBuilder.forPort(0) .maxMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE) .compressorRegistry(compressors) .decompressorRegistry(decompressors) .intercept(new ServerInterceptor() { @Override public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) { Listener<ReqT> listener = next.startCall(call, headers); // TODO(carl-mastrangelo): check that encoding was set. call.setMessageCompression(true); return listener; } }); }
/** Never returns {@code null}. */ private <ReqT, RespT> ServerStreamListener startCall(ServerStream stream, String fullMethodName, ServerMethodDefinition<ReqT, RespT> methodDef, Metadata headers, Context.CancellableContext context, StatsTraceContext statsTraceCtx) { // TODO(ejona86): should we update fullMethodName to have the canonical path of the method? ServerCallImpl<ReqT, RespT> call = new ServerCallImpl<ReqT, RespT>( stream, methodDef.getMethodDescriptor(), headers, context, decompressorRegistry, compressorRegistry); ServerCallHandler<ReqT, RespT> callHandler = methodDef.getServerCallHandler(); statsTraceCtx.serverCallStarted( new ServerCallInfoImpl<ReqT, RespT>( methodDef.getMethodDescriptor(), call.getAttributes(), call.getAuthority())); for (ServerInterceptor interceptor : interceptors) { callHandler = InternalServerInterceptors.interceptCallHandler(interceptor, callHandler); } ServerCall.Listener<ReqT> listener = callHandler.startCall(call, headers); if (listener == null) { throw new NullPointerException( "startCall() returned a null listener for method " + fullMethodName); } return call.newServerStreamListener(listener); }
@Test public void cannotDisableAutoFlowControlAfterServiceInvocation() throws Exception { final AtomicReference<ServerCallStreamObserver<Integer>> callObserver = new AtomicReference<ServerCallStreamObserver<Integer>>(); ServerCallHandler<Integer, Integer> callHandler = ServerCalls.asyncBidiStreamingCall( new ServerCalls.BidiStreamingMethod<Integer, Integer>() { @Override public StreamObserver<Integer> invoke(StreamObserver<Integer> responseObserver) { callObserver.set((ServerCallStreamObserver<Integer>) responseObserver); return new ServerCalls.NoopStreamObserver<Integer>(); } }); ServerCall.Listener<Integer> callListener = callHandler.startCall(serverCall, new Metadata()); callListener.onMessage(1); try { callObserver.get().disableAutoInboundFlowControl(); fail("Cannot set onCancel handler after service invocation"); } catch (IllegalStateException expected) { // Expected } }
@Test public void disablingInboundAutoFlowControlSuppressesRequestsForMoreMessages() throws Exception { ServerCallHandler<Integer, Integer> callHandler = ServerCalls.asyncBidiStreamingCall( new ServerCalls.BidiStreamingMethod<Integer, Integer>() { @Override public StreamObserver<Integer> invoke(StreamObserver<Integer> responseObserver) { ServerCallStreamObserver<Integer> serverCallObserver = (ServerCallStreamObserver<Integer>) responseObserver; serverCallObserver.disableAutoInboundFlowControl(); return new ServerCalls.NoopStreamObserver<Integer>(); } }); ServerCall.Listener<Integer> callListener = callHandler.startCall(serverCall, new Metadata()); callListener.onReady(); // Transport should not call this if nothing has been requested but forcing it here // to verify that message delivery does not trigger a call to request(1). callListener.onMessage(1); // Should never be called assertThat(serverCall.requestCalls).isEmpty(); }
@Test public void disablingInboundAutoFlowControlForUnaryHasNoEffect() throws Exception { ServerCallHandler<Integer, Integer> callHandler = ServerCalls.asyncUnaryCall( new ServerCalls.UnaryMethod<Integer, Integer>() { @Override public void invoke(Integer req, StreamObserver<Integer> responseObserver) { ServerCallStreamObserver<Integer> serverCallObserver = (ServerCallStreamObserver<Integer>) responseObserver; serverCallObserver.disableAutoInboundFlowControl(); } }); callHandler.startCall(serverCall, new Metadata()); // Auto inbound flow-control always requests 2 messages for unary to detect a violation // of the unary semantic. assertThat(serverCall.requestCalls).containsExactly(2); }
@Test public void clientSendsOne_errorMissingRequest_unary() { ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler<Integer, Integer> callHandler = ServerCalls.asyncUnaryCall( new ServerCalls.UnaryMethod<Integer, Integer>() { @Override public void invoke(Integer req, StreamObserver<Integer> responseObserver) { fail("should not be reached"); } }); ServerCall.Listener<Integer> listener = callHandler.startCall(serverCall, new Metadata()); listener.onHalfClose(); assertThat(serverCall.responses).isEmpty(); assertEquals(Status.Code.INTERNAL, serverCall.status.getCode()); assertEquals(ServerCalls.MISSING_REQUEST, serverCall.status.getDescription()); }
@Test public void clientSendsOne_errorMissingRequest_serverStreaming() { ServerCallRecorder serverCall = new ServerCallRecorder(SERVER_STREAMING_METHOD); ServerCallHandler<Integer, Integer> callHandler = ServerCalls.asyncServerStreamingCall( new ServerCalls.ServerStreamingMethod<Integer, Integer>() { @Override public void invoke(Integer req, StreamObserver<Integer> responseObserver) { fail("should not be reached"); } }); ServerCall.Listener<Integer> listener = callHandler.startCall(serverCall, new Metadata()); listener.onHalfClose(); assertThat(serverCall.responses).isEmpty(); assertEquals(Status.Code.INTERNAL, serverCall.status.getCode()); assertEquals(ServerCalls.MISSING_REQUEST, serverCall.status.getDescription()); }
@Test public void clientSendsOne_errorTooManyRequests_unary() { ServerCallRecorder serverCall = new ServerCallRecorder(UNARY_METHOD); ServerCallHandler<Integer, Integer> callHandler = ServerCalls.asyncUnaryCall( new ServerCalls.UnaryMethod<Integer, Integer>() { @Override public void invoke(Integer req, StreamObserver<Integer> responseObserver) { fail("should not be reached"); } }); ServerCall.Listener<Integer> listener = callHandler.startCall(serverCall, new Metadata()); listener.onMessage(1); listener.onMessage(1); assertThat(serverCall.responses).isEmpty(); assertEquals(Status.Code.INTERNAL, serverCall.status.getCode()); assertEquals(ServerCalls.TOO_MANY_REQUESTS, serverCall.status.getDescription()); // ensure onHalfClose does not invoke listener.onHalfClose(); }
@Test public void clientSendsOne_errorTooManyRequests_serverStreaming() { ServerCallRecorder serverCall = new ServerCallRecorder(SERVER_STREAMING_METHOD); ServerCallHandler<Integer, Integer> callHandler = ServerCalls.asyncServerStreamingCall( new ServerCalls.ServerStreamingMethod<Integer, Integer>() { @Override public void invoke(Integer req, StreamObserver<Integer> responseObserver) { fail("should not be reached"); } }); ServerCall.Listener<Integer> listener = callHandler.startCall(serverCall, new Metadata()); listener.onMessage(1); listener.onMessage(1); assertThat(serverCall.responses).isEmpty(); assertEquals(Status.Code.INTERNAL, serverCall.status.getCode()); assertEquals(ServerCalls.TOO_MANY_REQUESTS, serverCall.status.getDescription()); // ensure onHalfClose does not invoke listener.onHalfClose(); }
@Override public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) { String clientId = headers.get(CLIENT_ID_HEADER_KEY); if (clientId == null || !authenticator.authenticate(clientId)) { call.close(Status.UNAUTHENTICATED.withDescription("Invalid or unknown client: " + clientId), headers); return NOOP_LISTENER; } Context context = Context.current().withValue(CLIENT_ID_CONTEXT_KEY, clientId); return Contexts.interceptCall(context, call, headers, next); }
@Override public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( ServerCall<ReqT, RespT> call, final Metadata requestHeaders, ServerCallHandler<ReqT, RespT> next) { String token = requestHeaders.get(authKey); String principal = null; if (StringUtils.isEmpty(token)) { logger.warn("Failed to find token"); } else { Map<String,String> tokenParams = new HashMap<>(); tokenParams.put(OAuth2AccessToken.ACCESS_TOKEN,token); OAuth2AccessToken otoken = DefaultOAuth2AccessToken.valueOf(tokenParams); OAuth2Authentication auth = server.getTokenStore().readAuthentication(otoken); if (auth != null && auth.isAuthenticated()) { logger.debug("Principal:"+auth.getPrincipal()); principal = auth.getPrincipal().toString(); } else { logger.warn("Failed to authenticate token "+token); } } return new MessagePrincipalListener<ReqT>(next.startCall(call, requestHeaders),principal,server); }
@Override public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( MethodDescriptor<ReqT, RespT> method, ServerCall<RespT> call, final Metadata requestHeaders, ServerCallHandler<ReqT, RespT> next) { logger.info("header received from client:" + requestHeaders); return next.startCall(method, new SimpleForwardingServerCall<RespT>(call) { @Override public void sendHeaders(Metadata responseHeaders) { responseHeaders.put(customHeadKey, "customRespondValue"); super.sendHeaders(responseHeaders); } }, requestHeaders); }
@Override public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) { System.out.println(call.getMethodDescriptor().getFullMethodName()); log.info(call.getMethodDescriptor().getFullMethodName()); return next.startCall(call, headers); }
@Override public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, final Metadata requestHeaders, ServerCallHandler<ReqT, RespT> next) { Context context = Context.current().withValue(MINI_DEVICE_INFO, requestHeaders.get(miniDeviceInfoKey)); return Contexts.interceptCall(context, call, requestHeaders, next); }
@Override public <R, S> ServerCall.Listener<R> interceptCall( ServerCall<R, S> call, Metadata requestHeaders, ServerCallHandler<R, S> next) { MethodDescriptor<R, S> method = call.getMethodDescriptor(); ServerMetrics metrics = serverMetricsFactory.createMetricsForMethod(method); GrpcMethod grpcMethod = GrpcMethod.of(method); ServerCall<R,S> monitoringCall = new MonitoringServerCall(call, clock, grpcMethod, metrics, configuration); return new MonitoringServerCallListener<>( next.startCall(monitoringCall, requestHeaders), metrics, GrpcMethod.of(method)); }
private ServerServiceDefinition bindService() { final ServerServiceDefinition.Builder builder = ServerServiceDefinition.builder(GrpcRpcProtocol.SERVICE); for (final GrpcEndpointHandle<?, ?> spec : container.getEndpoints()) { final ServerCallHandler<byte[], byte[]> handler = serverCallHandlerFor((GrpcEndpointHandle<Object, Object>) spec); builder.addMethod(spec.descriptor(), handler); } return builder.build(); }
private ServerCallHandler<byte[], byte[]> serverCallHandlerFor( final GrpcEndpointHandle<Object, Object> spec ) { return asyncUnaryCall((request, observer) -> { final UUID id = UUID.randomUUID(); log.trace("{}: Received request: {}", id, request); final AsyncFuture<Object> future; try { final Object obj = mapper.readValue(request, spec.queryType()); future = spec.handle(obj); } catch (final Exception e) { log.error("{}: Failed to handle request (sent {})", id, Status.INTERNAL, e); observer.onError(new StatusException(Status.INTERNAL)); return; } future.onDone(new FutureDone<Object>() { @Override public void failed(final Throwable cause) throws Exception { log.error("{}: Request failed", id, cause); observer.onError(cause); } @Override public void resolved(final Object result) throws Exception { final byte[] body = mapper.writeValueAsBytes(result); observer.onNext(body); observer.onCompleted(); } @Override public void cancelled() throws Exception { observer.onError(new RuntimeException("Request cancelled")); } }); }); }
@Override public <ReqT, RespT> Listener<ReqT> interceptCall( ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) { RequestMetadata meta = headers.get(METADATA_KEY); if (meta == null) { throw new IllegalStateException("RequestMetadata not received from the client."); } Context ctx = Context.current().withValue(CONTEXT_KEY, meta); return Contexts.interceptCall(ctx, call, headers, next); }
@Override public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) { RequestMetadata meta = headers.get(TracingMetadataUtils.METADATA_KEY); assertThat(meta.getCorrelatedInvocationsId()).isEqualTo("build-req-id"); assertThat(meta.getToolInvocationId()).isEqualTo("command-id"); assertThat(meta.getActionId()).isNotEmpty(); assertThat(meta.getToolDetails().getToolName()).isEqualTo("bazel"); assertThat(meta.getToolDetails().getToolVersion()) .isEqualTo(BlazeVersionInfo.instance().getVersion()); return next.startCall(call, headers); }
@Override public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( ServerCall<ReqT, RespT> call, final Metadata requestHeaders, ServerCallHandler<ReqT, RespT> next) { logger.info("header received from client:" + requestHeaders); return next.startCall(new SimpleForwardingServerCall<ReqT, RespT>(call) { @Override public void sendHeaders(Metadata responseHeaders) { responseHeaders.put(CUSTOM_HEADER_KEY, "customRespondValue"); super.sendHeaders(responseHeaders); } }, requestHeaders); }
/** * Captures the request attributes. Useful for testing ServerCalls. * {@link ServerCall#getAttributes()} */ private static ServerInterceptor recordServerCallInterceptor( final AtomicReference<ServerCall<?, ?>> serverCallCapture) { return new ServerInterceptor() { @Override public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( ServerCall<ReqT, RespT> call, Metadata requestHeaders, ServerCallHandler<ReqT, RespT> next) { serverCallCapture.set(call); return next.startCall(call, requestHeaders); } }; }
private static ServerInterceptor recordContextInterceptor( final AtomicReference<Context> contextCapture) { return new ServerInterceptor() { @Override public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( ServerCall<ReqT, RespT> call, Metadata requestHeaders, ServerCallHandler<ReqT, RespT> next) { contextCapture.set(Context.current()); return next.startCall(call, requestHeaders); } }; }
@Override public <ReqT, RespT> io.grpc.ServerCall.Listener<ReqT> interceptCall( ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) { if (serverEncoding) { call.setCompression("fzip"); } call.setMessageCompression(enableServerMessageCompression); Metadata headersCopy = new Metadata(); headersCopy.merge(headers); serverResponseHeaders = headersCopy; return next.startCall(call, headers); }
/** * Set up the registry. */ @Setup(Level.Trial) public void setup() throws Exception { registry = new MutableHandlerRegistry(); fullMethodNames = new ArrayList<String>(serviceCount * methodCountPerService); for (int serviceIndex = 0; serviceIndex < serviceCount; ++serviceIndex) { String serviceName = randomString(); ServerServiceDefinition.Builder serviceBuilder = ServerServiceDefinition.builder(serviceName); for (int methodIndex = 0; methodIndex < methodCountPerService; ++methodIndex) { String methodName = randomString(); MethodDescriptor<Void, Void> methodDescriptor = MethodDescriptor.<Void, Void>newBuilder() .setType(MethodDescriptor.MethodType.UNKNOWN) .setFullMethodName(MethodDescriptor.generateFullMethodName(serviceName, methodName)) .setRequestMarshaller(TestMethodDescriptors.voidMarshaller()) .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()) .build(); serviceBuilder.addMethod(methodDescriptor, new ServerCallHandler<Void, Void>() { @Override public Listener<Void> startCall(ServerCall<Void, Void> call, Metadata headers) { return null; } }); fullMethodNames.add(methodDescriptor.getFullMethodName()); } registry.addService(serviceBuilder.build()); } }
/** * Capture the request headers from a client. Useful for testing metadata propagation. */ public static ServerInterceptor recordRequestHeadersInterceptor( final AtomicReference<Metadata> headersCapture) { return new ServerInterceptor() { @Override public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( ServerCall<ReqT, RespT> call, Metadata requestHeaders, ServerCallHandler<ReqT, RespT> next) { headersCapture.set(requestHeaders); return next.startCall(call, requestHeaders); } }; }
@Test public void exceptionInStartCallPropagatesToStream() throws Exception { createAndStartServer(); final Status status = Status.ABORTED.withDescription("Oh, no!"); mutableFallbackRegistry.addService(ServerServiceDefinition.builder( new ServiceDescriptor("Waiter", METHOD)) .addMethod(METHOD, new ServerCallHandler<String, Integer>() { @Override public ServerCall.Listener<String> startCall( ServerCall<String, Integer> call, Metadata headers) { throw status.asRuntimeException(); } }).build()); ServerTransportListener transportListener = transportServer.registerNewServerTransport(new SimpleServerTransport()); transportListener.transportReady(Attributes.EMPTY); Metadata requestHeaders = new Metadata(); StatsTraceContext statsTraceCtx = StatsTraceContext.newServerContext(streamTracerFactories, "Waiter/serve", requestHeaders); when(stream.statsTraceContext()).thenReturn(statsTraceCtx); transportListener.streamCreated(stream, "Waiter/serve", requestHeaders); verify(stream).setListener(streamListenerCaptor.capture()); ServerStreamListener streamListener = streamListenerCaptor.getValue(); assertNotNull(streamListener); verify(stream, atLeast(1)).statsTraceContext(); verifyNoMoreInteractions(stream); verify(fallbackRegistry, never()).lookupMethod(any(String.class), any(String.class)); assertEquals(1, executor.runDueTasks()); verify(fallbackRegistry).lookupMethod("Waiter/serve", AUTHORITY); verify(stream).close(same(status), notNull(Metadata.class)); verify(stream, atLeast(1)).statsTraceContext(); }
@Test public void multiServiceLookup() { assertNull(registry.addService(basicServiceDefinition)); assertNull(registry.addService(multiServiceDefinition)); ServerCallHandler<?, ?> handler = registry.lookupMethod("basic/flow").getServerCallHandler(); assertSame(flowHandler, handler); handler = registry.lookupMethod("multi/couple").getServerCallHandler(); assertSame(coupleHandler, handler); handler = registry.lookupMethod("multi/few").getServerCallHandler(); assertSame(fewHandler, handler); }
@Test public void cannotSetOnCancelHandlerAfterServiceInvocation() throws Exception { final AtomicReference<ServerCallStreamObserver<Integer>> callObserver = new AtomicReference<ServerCallStreamObserver<Integer>>(); ServerCallHandler<Integer, Integer> callHandler = ServerCalls.asyncBidiStreamingCall( new ServerCalls.BidiStreamingMethod<Integer, Integer>() { @Override public StreamObserver<Integer> invoke(StreamObserver<Integer> responseObserver) { callObserver.set((ServerCallStreamObserver<Integer>) responseObserver); return new ServerCalls.NoopStreamObserver<Integer>(); } }); ServerCall.Listener<Integer> callListener = callHandler.startCall(serverCall, new Metadata()); callListener.onMessage(1); try { callObserver.get().setOnCancelHandler(new Runnable() { @Override public void run() { } }); fail("Cannot set onCancel handler after service invocation"); } catch (IllegalStateException expected) { // Expected } }
@Test public void cannotSetOnReadyHandlerAfterServiceInvocation() throws Exception { final AtomicReference<ServerCallStreamObserver<Integer>> callObserver = new AtomicReference<ServerCallStreamObserver<Integer>>(); ServerCallHandler<Integer, Integer> callHandler = ServerCalls.asyncBidiStreamingCall( new ServerCalls.BidiStreamingMethod<Integer, Integer>() { @Override public StreamObserver<Integer> invoke(StreamObserver<Integer> responseObserver) { callObserver.set((ServerCallStreamObserver<Integer>) responseObserver); return new ServerCalls.NoopStreamObserver<Integer>(); } }); ServerCall.Listener<Integer> callListener = callHandler.startCall(serverCall, new Metadata()); callListener.onMessage(1); try { callObserver.get().setOnReadyHandler(new Runnable() { @Override public void run() { } }); fail("Cannot set onReady after service invocation"); } catch (IllegalStateException expected) { // Expected } }
@Test public void onReadyHandlerCalledForUnaryRequest() throws Exception { final AtomicInteger onReadyCalled = new AtomicInteger(); ServerCallHandler<Integer, Integer> callHandler = ServerCalls.asyncServerStreamingCall( new ServerCalls.ServerStreamingMethod<Integer, Integer>() { @Override public void invoke(Integer req, StreamObserver<Integer> responseObserver) { ServerCallStreamObserver<Integer> serverCallObserver = (ServerCallStreamObserver<Integer>) responseObserver; serverCallObserver.setOnReadyHandler(new Runnable() { @Override public void run() { onReadyCalled.incrementAndGet(); } }); } }); ServerCall.Listener<Integer> callListener = callHandler.startCall(serverCall, new Metadata()); serverCall.isReady = true; serverCall.isCancelled = false; callListener.onReady(); // On ready is not called until the unary request message is delivered assertEquals(0, onReadyCalled.get()); // delivering the message doesn't trigger onReady listener either callListener.onMessage(1); assertEquals(0, onReadyCalled.get()); // half-closing triggers the unary request delivery and onReady callListener.onHalfClose(); assertEquals(1, onReadyCalled.get()); // Next on ready event from the transport triggers listener callListener.onReady(); assertEquals(2, onReadyCalled.get()); }
@Override public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> serverCall, Metadata metadata, ServerCallHandler<ReqT, RespT> serverCallHandler) { log.info(serverCall.getMethodDescriptor().getFullMethodName()); return serverCallHandler.startCall(serverCall, metadata); }