@After public void tearDown() { for (Object obj : objectsToRelease) { if (!(obj instanceof ReferenceCounted)) { continue; } ReferenceCounted referenceCountedObject = (ReferenceCounted) obj; assertThat(referenceCountedObject.refCnt()) .withFailMessage("Trying to free %s but it has a ref count of %d", obj, referenceCountedObject.refCnt()) .isEqualTo(1); referenceCountedObject.release(); } }
@Test public void testFreeCalled() throws Exception { final CountDownLatch free = new CountDownLatch(1); final ReferenceCounted holder = new AbstractReferenceCounted() { @Override protected void deallocate() { free.countDown(); } }; StringInboundHandler handler = new StringInboundHandler(); setUp(handler); peer.writeAndFlush(holder).sync(); assertTrue(free.await(10, TimeUnit.SECONDS)); assertTrue(handler.called); }
/** * This test makes sure that even when more requests arrive in the same batch, they * get emitted as separate messages. */ @Test public void shouldHandleTwoMessagesInOneBatch() { channel.writeInbound(Unpooled.buffer().writeBytes(GET_REQUEST).writeBytes(GET_REQUEST)); BinaryMemcacheRequest request = channel.readInbound(); assertThat(request, instanceOf(BinaryMemcacheRequest.class)); assertThat(request, notNullValue()); request.release(); Object lastContent = channel.readInbound(); assertThat(lastContent, instanceOf(LastMemcacheContent.class)); ((ReferenceCounted) lastContent).release(); request = channel.readInbound(); assertThat(request, instanceOf(BinaryMemcacheRequest.class)); assertThat(request, notNullValue()); request.release(); lastContent = channel.readInbound(); assertThat(lastContent, instanceOf(LastMemcacheContent.class)); ((ReferenceCounted) lastContent).release(); }
@Override public void startAsync(final Executor executor, final Runnable runnable) { Channel channel = ctx.channel(); channel.attr(NEED_FLUSH).set(false); channel.attr(ASYNC).set(true); ReferenceCounted body = ((ByteBufHolder) req).content(); body.retain(); executor.execute(() -> { try { runnable.run(); } finally { body.release(); } }); }
/** * {@inheritDoc} */ @Override public void releaseContentChunks() { if (!contentChunksWillBeReleasedExternally) { contentChunks.forEach(ReferenceCounted::release); } // Now that the chunks have been released we should clear the chunk list - we can no longer rely on the chunks // for anything, and if this method is called a second time we don't want to re-release the chunks // (which would screw up the reference counting). contentChunks.clear(); }
@Override public ReferenceCounted touch(Object hint) { for (EncapsulatedRakNetPacket packet : packets) { packet.touch(hint); } return this; }
/** * This method is called by users of the ProxyConnection to send stuff out * over the socket. * * @param msg */ void write(Object msg) { if (msg instanceof ReferenceCounted) { LOG.debug("Retaining reference counted message"); ((ReferenceCounted) msg).retain(); } doWrite(msg); }
@Override void write(Object msg) { LOG.debug("Requested write of {}", msg); if (msg instanceof ReferenceCounted) { LOG.debug("Retaining reference counted message"); ((ReferenceCounted) msg).retain(); } if (is(DISCONNECTED) && msg instanceof HttpRequest) { LOG.debug("Currently disconnected, connect and then write the message"); connectAndWrite((HttpRequest) msg); } else { if (isConnecting()) { synchronized (connectLock) { if (isConnecting()) { LOG.debug("Attempted to write while still in the process of connecting, waiting for connection."); clientConnection.stopReading(); try { connectLock.wait(30000); } catch (InterruptedException ie) { LOG.warn("Interrupted while waiting for connect monitor"); } } } } // only write this message if a connection was established and is not in the process of disconnecting or // already disconnected if (isConnecting() || getCurrentState().isDisconnectingOrDisconnected()) { LOG.debug("Connection failed or timed out while waiting to write message to server. Message will be discarded: {}", msg); return; } LOG.debug("Using existing connection to: {}", remoteAddress); doWrite(msg); } }
/** * This is the method that executing writing to channel. * It will be used both write0 and {@link com.linkedin.mitm.proxy.connectionflow.steps.ConnectionFlowStep} * * @param channel which channel to write to * @param object which object to write to. * * */ private ChannelFuture writeToChannel(final Channel channel, final Object object) { if (channel == null) { throw new IllegalStateException("Failed to write to channel because channel is null"); } if (object instanceof ReferenceCounted) { LOG.debug("Retaining reference counted message"); ((ReferenceCounted) object).retain(); } if (LOG.isDebugEnabled()) { LOG.debug(String.format("Writing in channel [%s]: %s", channel.toString(), object)); } return channel.writeAndFlush(object); }
@Override void write(Object msg) { LOG.debug("Requested write of {}", msg); if (msg instanceof ReferenceCounted) { LOG.debug("Retaining reference counted message"); ((ReferenceCounted) msg).retain(); } if (is(DISCONNECTED) && msg instanceof HttpRequest) { LOG.debug("Currently disconnected, connect and then write the message"); connectAndWrite((HttpRequest) msg); } else { synchronized (connectLock) { if (isConnecting()) { LOG.debug("Attempted to write while still in the process of connecting, waiting for connection."); clientConnection.stopReading(); try { connectLock.wait(30000); } catch (InterruptedException ie) { LOG.warn("Interrupted while waiting for connect monitor"); } if (is(DISCONNECTED)) { LOG.debug("Connection failed while we were waiting for it, don't write"); return; } } } LOG.debug("Using existing connection to: {}", remoteAddress); doWrite(msg); } }
@Override public boolean acceptOutboundMessage(Object msg) throws Exception { final Message message = (Message) msg; final Protocol protocol = this.codecContext.getSession().getProtocol(); final MessageRegistration registration = protocol.outbound().findByMessageType(message.getClass()).orElse(null); if (registration == null) { throw new EncoderException("Message type (" + message.getClass().getName() + ") is not registered in state " + this.codecContext.getSession().getProtocolState().name() + "!"); } final List<Processor> processors = ((MessageRegistration) protocol.outbound() .findByMessageType(message.getClass()).get()).getProcessors(); // Only process if there are processors found if (!processors.isEmpty()) { final List<Object> messages = new ArrayList<>(); for (Processor processor : processors) { // The processor should handle the output messages processor.process(this.codecContext, message, messages); } if (message instanceof ReferenceCounted && !messages.contains(message)) { ((ReferenceCounted) message).release(); } if (!messages.isEmpty()) { this.messages.set(messages); } return true; } return false; }
protected void release() { uploadedFiles.forEach(ReferenceCounted::release); if(this.buffer != null) { this.buffer.release(); } }
@Override public int refCnt() { if (message instanceof ReferenceCounted) { return ((ReferenceCounted) message).refCnt(); } else { return 1; } }
@Test public void rejectReferenceCounted() { AbstractReferenceCounted item = new AbstractReferenceCounted() { @Override protected void deallocate() {} @Override public ReferenceCounted touch(Object hint) { return this; } }; StreamMessageAndWriter<Object> stream = newStreamWriter(ImmutableList.of(item)); assertThatThrownBy(() -> stream.write(item)).isInstanceOf(IllegalArgumentException.class); }
@Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (msg instanceof Message) { Message message = (Message) msg; if (currentPayload == null) { log.error("No payload received for request id '{}': {}", message.getId(), message); ctx.fireExceptionCaught( new RuntimeException("No payload received for request id '" + message.getId() + "'")); reset(); return; } if (error) { log.error("Multiple payloads received for request id '{}': {}", message.getId(), message); ctx.fireExceptionCaught( new RuntimeException( "Multiple payloads received for request id '" + message.getId() + "'")); reset(); return; } ServerRequest request = new ServerRequest(message.getId(), message.expectsResponse(), currentPayload); ctx.fireChannelRead(request); reset(); } else { if (currentPayload != null) { error = true; return; } if (msg instanceof ReferenceCounted) { currentPayload = ((ReferenceCounted) msg).retain(); } else { currentPayload = msg; } } }
@Override public ChannelFuture write(Object msg, ChannelPromise promise) { if (msg instanceof ReferenceCounted) { ((ReferenceCounted) msg).release(); } return null; }
@Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { final FullHttpRequest request = ctx.channel().attr(REQUEST_KEY).getAndRemove(); try { final Consumer<Response> completionCallbackHandler = ctx.channel().attr(ATTRIBUTE_KEY).getAndRemove(); if (completionCallbackHandler == null) { throw new IllegalStateException("Received a response with nothing to handle it."); } final FullHttpResponse response = (FullHttpResponse) msg; if (response.getStatus().equals(HttpResponseStatus.MOVED_PERMANENTLY) || response.getStatus().equals(HttpResponseStatus.TEMPORARY_REDIRECT) ) { final URI locationUri = URI.create(response.headers().get(HttpHeaders.Names.LOCATION)); final URI serverUri; if (locationUri.isAbsolute()) { serverUri = locationUri; request.headers().set(HttpHeaders.Names.HOST, serverUri.getHost()); } else { final InetSocketAddress address = (InetSocketAddress) ctx.channel().remoteAddress(); serverUri = URI.create("http://" + address.getHostString() + ":" + address.getPort()); } request.setUri(locationUri.getPath() + (locationUri.getQuery() == null ? "" : locationUri.getQuery())); final Iterator<ServerList.Server> serverIterator = Collections.singleton(new ServerList.Server(serverUri)).iterator(); request.retain(); send(serverIterator, request, completionCallbackHandler); } else { response.retain(); invokeCompletionHandler(completionCallbackHandler, new Response(response, null)); } } finally { request.release(); if (msg instanceof ReferenceCounted) { ((ReferenceCounted)msg).release(); } } ctx.close(); }
@Override public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception { if (msg instanceof ReferenceCounted) { ((ReferenceCounted) msg).release(); } sendError(ctx); }
@Override protected void encode(ChannelHandlerContext ctx, HttpResponse response, List<Object> out) throws Exception { if (response instanceof ReferenceCounted) { ((ReferenceCounted) response).retain(); } // Add content length if necessary if (!response.headers().contains(CONTENT_LENGTH) && !HttpHeaders.isTransferEncodingChunked(response)) { if (response instanceof DefaultFullHttpResponse) { DefaultFullHttpResponse full = (DefaultFullHttpResponse) response; HttpHeaders.setContentLength(response, full.content().readableBytes()); } } // Set content type if (!response.headers().contains(CONTENT_TYPE)) { response.headers().set(CONTENT_TYPE, MediaType.PLAIN_TEXT_UTF_8.toString()); } // Set protocol version response.setProtocolVersion(version); // Set keep-alive status if (isKeepAlive) { response.headers().set(CONNECTION, "Keep-Alive"); } // Set date if (!response.headers().contains(DATE)) { HttpHeaders.setDate(response, new Date()); } out.add(response); }
@Override protected void decode(ChannelHandlerContext ctx, HttpRequest request, List<Object> out) throws Exception { if (request instanceof ReferenceCounted) { ((ReferenceCounted) request).retain(); } version = request.getProtocolVersion(); isKeepAlive = HttpHeaders.isKeepAlive(request); out.add(request); }
protected void onMessage(Object msg) { if (this.connection == null) return; if (msg instanceof ReferenceCounted) ((ReferenceCounted) msg).retain(); this.connection.onMessage(msg); }
@Override public void close() { if (this.sslContext instanceof ReferenceCounted) { if (this.hasReleasedSslContext.compareAndSet(false, true)) { ((ReferenceCounted) this.sslContext).release(); } } }
@Override protected void encode(ChannelHandlerContext ctx, Object msg, List<Object> out) throws Exception { if (msg instanceof IPacket) { IPacket packet = (IPacket) msg; if (log.isDebugEnabled()) log.debug("Sending packet: {} to channel: {}", msg, ctx.channel()); ByteBuf encodedPacket = encodePacket(packet); if (log.isDebugEnabled()) log.debug("Encoded packet: {}", encodedPacket); TransportType transportType = packet.getTransportType(); if (transportType == TransportType.WEBSOCKET || transportType == TransportType.FLASHSOCKET) { out.add(new TextWebSocketFrame(encodedPacket)); } else if (transportType == TransportType.XHR_POLLING) { out.add(PipelineUtils.createHttpResponse(packet.getOrigin(), encodedPacket, false)); } else if (transportType == TransportType.JSONP_POLLING) { String jsonpIndexParam = (packet.getJsonpIndexParam() != null) ? packet.getJsonpIndexParam() : "0"; String encodedStringPacket = encodedPacket.toString(CharsetUtil.UTF_8); encodedPacket.release(); String encodedJsonpPacket = String.format(JSONP_TEMPLATE, jsonpIndexParam, encodedStringPacket); HttpResponse httpResponse = PipelineUtils.createHttpResponse(packet.getOrigin(), PipelineUtils.copiedBuffer(ctx.alloc(), encodedJsonpPacket), true); httpResponse.headers().add("X-XSS-Protection", "0"); out.add(httpResponse); } else { throw new UnsupportedTransportTypeException(transportType); } } else { if (msg instanceof ReferenceCounted) { ((ReferenceCounted) msg).retain(); } out.add(msg); } }
public static <V> void add(ListenableFuture<V> future, final ReferenceCounted counted) { Futures.addCallback(future, new FutureCallback<V>() { @Override public void onSuccess(V result) { counted.release(); } @Override public void onFailure(Throwable t) { counted.release(); } }); }
public PageRecord(Page page, SpaceMapEntry space, ReferenceCounted ref) { this.page = page; this.space = space; this.ref = ref; if (ref != null) { ref.retain(); } }
@Override public ReferenceCounted retain() { last.retain(); return this; }
@Override public ReferenceCounted retain(int increment) { last.retain(increment); return this; }
@Override public ReferenceCounted touch() { last.touch(); return this; }
@Override public ReferenceCounted touch(Object hint) { last.touch(hint); return this; }