@Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { return new SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) { @Override public void start(Listener<RespT> responseListener, Metadata headers) { getToken(next).ifPresent(t -> headers.put(TOKEN, t)); super.start(new SimpleForwardingClientCallListener<RespT>(responseListener) { @Override public void onClose(Status status, Metadata trailers) { if (isInvalidTokenError(status)) { try { refreshToken(next); } catch (Exception e) { // don't throw any error here. // rpc will retry on expired auth token. } } super.onClose(status, trailers); } }, headers); } }; }
@Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { // New RPCs on client-side inherit the tag context from the current Context. TagContext parentCtx = tagger.getCurrentTagContext(); final ClientCallTracer tracerFactory = newClientCallTracer(parentCtx, method.getFullMethodName(), recordStartedRpcs, recordFinishedRpcs); ClientCall<ReqT, RespT> call = next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory)); return new SimpleForwardingClientCall<ReqT, RespT>(call) { @Override public void start(Listener<RespT> responseListener, Metadata headers) { delegate().start( new SimpleForwardingClientCallListener<RespT>(responseListener) { @Override public void onClose(Status status, Metadata trailers) { tracerFactory.callEnded(status); super.onClose(status, trailers); } }, headers); } }; }
@Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { // New RPCs on client-side inherit the tracing context from the current Context. // Safe usage of the unsafe trace API because CONTEXT_SPAN_KEY.get() returns the same value // as Tracer.getCurrentSpan() except when no value available when the return value is null // for the direct access and BlankSpan when Tracer API is used. final ClientCallTracer tracerFactory = newClientCallTracer(CONTEXT_SPAN_KEY.get(), method); ClientCall<ReqT, RespT> call = next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory)); return new SimpleForwardingClientCall<ReqT, RespT>(call) { @Override public void start(Listener<RespT> responseListener, Metadata headers) { delegate().start( new SimpleForwardingClientCallListener<RespT>(responseListener) { @Override public void onClose(io.grpc.Status status, Metadata trailers) { tracerFactory.callEnded(status); super.onClose(status, trailers); } }, headers); } }; }
@Test public void serverHeaderDeliveredToClient() { class SpyingClientInterceptor implements ClientInterceptor { ClientCall.Listener<?> spyListener; @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { return new SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) { @Override public void start(Listener<RespT> responseListener, Metadata headers) { spyListener = responseListener = mock(ClientCall.Listener.class, delegatesTo(responseListener)); super.start(responseListener, headers); } }; } } SpyingClientInterceptor clientInterceptor = new SpyingClientInterceptor(); GreeterBlockingStub blockingStub = GreeterGrpc.newBlockingStub(grpcServerRule.getChannel()) .withInterceptors(clientInterceptor); ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class); blockingStub.sayHello(HelloRequest.getDefaultInstance()); assertNotNull(clientInterceptor.spyListener); verify(clientInterceptor.spyListener).onHeaders(metadataCaptor.capture()); assertEquals( "customRespondValue", metadataCaptor.getValue().get(HeaderServerInterceptor.CUSTOM_HEADER_KEY)); }
@Test public void addOutboundHeaders() { final Metadata.Key<String> credKey = Metadata.Key.of("Cred", Metadata.ASCII_STRING_MARSHALLER); ClientInterceptor interceptor = new ClientInterceptor() { @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { ClientCall<ReqT, RespT> call = next.newCall(method, callOptions); return new SimpleForwardingClientCall<ReqT, RespT>(call) { @Override public void start(ClientCall.Listener<RespT> responseListener, Metadata headers) { headers.put(credKey, "abcd"); super.start(responseListener, headers); } }; } }; Channel intercepted = ClientInterceptors.intercept(channel, interceptor); @SuppressWarnings("unchecked") ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class); ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); // start() on the intercepted call will eventually reach the call created by the real channel interceptedCall.start(listener, new Metadata()); // The headers passed to the real channel call will contain the information inserted by the // interceptor. assertSame(listener, call.listener); assertEquals("abcd", call.headers.get(credKey)); }
@Test public void normalCall() { ClientInterceptor interceptor = new ClientInterceptor() { @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { ClientCall<ReqT, RespT> call = next.newCall(method, callOptions); return new SimpleForwardingClientCall<ReqT, RespT>(call) { }; } }; Channel intercepted = ClientInterceptors.intercept(channel, interceptor); ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); assertNotSame(call, interceptedCall); @SuppressWarnings("unchecked") ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class); Metadata headers = new Metadata(); interceptedCall.start(listener, headers); assertSame(listener, call.listener); assertSame(headers, call.headers); interceptedCall.sendMessage(null /*request*/); assertThat(call.messages).containsExactly((Void) null /*request*/); interceptedCall.halfClose(); assertTrue(call.halfClosed); interceptedCall.request(1); assertThat(call.requests).containsExactly(1); }
@Test public void binaryLogTest() throws Exception { final List<Object> capturedReqs = new ArrayList<Object>(); final class TracingClientInterceptor implements ClientInterceptor { private final List<MethodDescriptor<?, ?>> interceptedMethods = new ArrayList<MethodDescriptor<?, ?>>(); @Override public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { interceptedMethods.add(method); return new SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) { @Override public void sendMessage(ReqT message) { capturedReqs.add(message); super.sendMessage(message); } }; } } TracingClientInterceptor userInterceptor = new TracingClientInterceptor(); binlogProvider = new BinaryLogProvider() { @Nullable @Override public ServerInterceptor getServerInterceptor(String fullMethodName) { return null; } @Override public ClientInterceptor getClientInterceptor(String fullMethodName) { return new TracingClientInterceptor(); } @Override protected int priority() { return 0; } }; createChannel( new FakeNameResolverFactory(true), Collections.<ClientInterceptor>singletonList(userInterceptor)); ClientCall<String, Integer> call = channel.newCall(method, CallOptions.DEFAULT.withDeadlineAfter(0, TimeUnit.NANOSECONDS)); ClientCall.Listener<Integer> listener = new NoopClientCallListener<Integer>(); call.start(listener, new Metadata()); assertEquals(1, executor.runDueTasks()); String actualRequest = "hello world"; call.sendMessage(actualRequest); // The user supplied interceptor must still operate on the original message types assertThat(userInterceptor.interceptedMethods).hasSize(1); assertSame( method.getRequestMarshaller(), userInterceptor.interceptedMethods.get(0).getRequestMarshaller()); assertSame( method.getResponseMarshaller(), userInterceptor.interceptedMethods.get(0).getResponseMarshaller()); // The binlog interceptor must be closest to the transport assertThat(capturedReqs).hasSize(2); // The InputStream is already spent, so just check its type rather than contents assertEquals(actualRequest, capturedReqs.get(0)); assertThat(capturedReqs.get(1)).isInstanceOf(InputStream.class); }