From f4b6111c867013f96deaf996632d67ebe31a3f96 Mon Sep 17 00:00:00 2001 From: Arturo Bernal Date: Tue, 17 Mar 2026 16:57:46 +0100 Subject: [PATCH] Implement two-phase graceful HTTP/2 shutdown Send an initial GOAWAY, wait one RTT using PING, then send the final GOAWAY with the last processed stream id. --- .../impl/nio/AbstractH2StreamMultiplexer.java | 105 ++++-- .../H2GracefulShutdownDrainExample.java | 336 ++++++++++++++++++ .../nio/TestAbstractH2StreamMultiplexer.java | 173 +++++++++ 3 files changed, 593 insertions(+), 21 deletions(-) create mode 100644 httpcore5-h2/src/test/java/org/apache/hc/core5/http2/examples/H2GracefulShutdownDrainExample.java diff --git a/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/AbstractH2StreamMultiplexer.java b/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/AbstractH2StreamMultiplexer.java index 8e9171b57..6abb9cd66 100644 --- a/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/AbstractH2StreamMultiplexer.java +++ b/httpcore5-h2/src/main/java/org/apache/hc/core5/http2/impl/nio/AbstractH2StreamMultiplexer.java @@ -106,7 +106,7 @@ abstract class AbstractH2StreamMultiplexer implements Identifiable, HttpConnecti private static final long CONNECTION_WINDOW_LOW_MARK = 10 * 1024 * 1024; - enum ConnectionHandshake { READY, ACTIVE, GRACEFUL_SHUTDOWN, SHUTDOWN } + enum ConnectionHandshake { READY, ACTIVE, DRAINING, GRACEFUL_SHUTDOWN, SHUTDOWN } enum SettingsHandshake { READY, TRANSMITTED, ACKED } private final ProtocolIOSession ioSession; @@ -142,6 +142,9 @@ enum SettingsHandshake { READY, TRANSMITTED, ACKED } private EndpointDetails endpointDetails; private boolean goAwayReceived; + private int shutdownLastStreamId; + private int lastProcessedRemoteStreamId; + private boolean drainPingSent; private volatile boolean peerNoRfc7540Priorities; @@ -201,6 +204,9 @@ enum SettingsHandshake { READY, TRANSMITTED, ACKED } this.streamListener = streamListener; this.lastActivityTime = System.currentTimeMillis(); this.validateAfterInactivity = validateAfterInactivity; + this.shutdownLastStreamId = 0; + this.lastProcessedRemoteStreamId = 0; + this.drainPingSent = false; } @Override @@ -506,6 +512,13 @@ public final void onOutput() throws HttpException, IOException { ioSession.getLock().unlock(); } + if (connState == ConnectionHandshake.DRAINING && !drainPingSent && outputBuffer.isEmpty() && outputQueue.isEmpty()) { + drainPingSent = true; + executePing(new PingCommand(createGracefulShutdownPingHandler())); + // Return early so the PING frame is flushed on the next onOutput cycle + return; + } + if (connState.compareTo(ConnectionHandshake.SHUTDOWN) < 0) { if (connOutputWindow.get() > 0 && remoteSettingState == SettingsHandshake.ACKED) { @@ -589,16 +602,16 @@ public final void onOutput() throws HttpException, IOException { streams.dropStreamId(stream.getId()); it.remove(); } else { - if (streams.isSameSide(stream.getId()) || stream.getId() <= streams.getLastRemoteId()) { + if (streams.isSameSide(stream.getId()) || shutdownLastStreamId == 0 || stream.getId() <= shutdownLastStreamId) { liveStreams++; } } } - if (liveStreams == 0) { + if (shutdownLastStreamId != Integer.MAX_VALUE && liveStreams == 0) { connState = ConnectionHandshake.SHUTDOWN; } } - if (connState.compareTo(ConnectionHandshake.GRACEFUL_SHUTDOWN) >= 0) { + if (connState.compareTo(ConnectionHandshake.DRAINING) >= 0) { for (;;) { final Command command = ioSession.poll(); if (command == null) { @@ -628,6 +641,11 @@ public final void onOutput() throws HttpException, IOException { } public final void onTimeout(final Timeout timeout) throws HttpException, IOException { + if (connState == ConnectionHandshake.DRAINING) { + completeGracefulShutdown(); + return; + } + connState = ConnectionHandshake.SHUTDOWN; final RawFrame goAway; @@ -663,13 +681,55 @@ private void executeShutdown(final ShutdownCommand shutdownCommand) throws IOExc if (shutdownCommand.getType() == CloseMode.IMMEDIATE) { streams.shutdownAndReleaseAll(); connState = ConnectionHandshake.SHUTDOWN; - } else { - if (connState.compareTo(ConnectionHandshake.ACTIVE) <= 0) { - final RawFrame goAway = frameFactory.createGoAway(streams.getLastRemoteId(), H2Error.NO_ERROR, "Graceful shutdown"); - commitFrame(goAway); - connState = streams.isEmpty() ? ConnectionHandshake.SHUTDOWN : ConnectionHandshake.GRACEFUL_SHUTDOWN; - } + return; + } + if (connState.compareTo(ConnectionHandshake.ACTIVE) <= 0) { + shutdownLastStreamId = Integer.MAX_VALUE; + drainPingSent = false; + final RawFrame goAway = frameFactory.createGoAway(shutdownLastStreamId, H2Error.NO_ERROR, "Graceful shutdown"); + commitFrame(goAway); + connState = ConnectionHandshake.DRAINING; + requestSessionOutput(); + } + } + + private void completeGracefulShutdown() throws IOException { + if (connState != ConnectionHandshake.DRAINING) { + return; } + shutdownLastStreamId = lastProcessedRemoteStreamId; + final RawFrame goAway = frameFactory.createGoAway(shutdownLastStreamId, H2Error.NO_ERROR, "Graceful shutdown"); + commitFrame(goAway); + connState = ConnectionHandshake.GRACEFUL_SHUTDOWN; + } + + private AsyncPingHandler createGracefulShutdownPingHandler() { + final ByteBuffer data = ByteBuffer.allocate(8); + data.putLong(System.nanoTime()); + data.flip(); + return new AsyncPingHandler() { + + @Override + public ByteBuffer getData() { + return data.asReadOnlyBuffer(); + } + + @Override + public void consumeResponse(final ByteBuffer feedback) throws IOException { + if (connState == ConnectionHandshake.DRAINING) { + completeGracefulShutdown(); + } + } + + @Override + public void failed(final Exception cause) { + } + + @Override + public void cancel() { + } + + }; } private void executePing(final PingCommand pingCommand) throws IOException { @@ -817,8 +877,9 @@ private void consumeFrame(final RawFrame frame) throws HttpException, IOExceptio } final H2StreamChannel channel = createChannel(streamId); - if (connState.compareTo(ConnectionHandshake.ACTIVE) <= 0) { + if (connState.compareTo(ConnectionHandshake.DRAINING) <= 0) { stream = streams.createActive(channel, incomingRequest(channel)); + lastProcessedRemoteStreamId = Math.max(lastProcessedRemoteStreamId, streamId); streams.resetIfExceedsMaxConcurrentLimit(stream, localConfig.getMaxConcurrentStreams()); } else { channel.localReset(H2Error.REFUSED_STREAM); @@ -1026,8 +1087,9 @@ private void consumeFrame(final RawFrame frame) throws HttpException, IOExceptio final H2StreamChannel channel = createChannel(promisedStreamId); final H2Stream promisedStream; - if (connState.compareTo(ConnectionHandshake.ACTIVE) <= 0) { + if (connState.compareTo(ConnectionHandshake.DRAINING) <= 0) { promisedStream = streams.createReserved(channel, incomingPushPromise(channel, stream.getPushHandlerFactory())); + lastProcessedRemoteStreamId = Math.max(lastProcessedRemoteStreamId, promisedStreamId); } else { channel.localReset(H2Error.REFUSED_STREAM); promisedStream = streams.createActive(channel, NoopH2StreamHandler.INSTANCE); @@ -1053,17 +1115,18 @@ private void consumeFrame(final RawFrame frame) throws HttpException, IOExceptio final int errorCode = payload.getInt(); goAwayReceived = true; if (errorCode == H2Error.NO_ERROR.getCode()) { - if (connState.compareTo(ConnectionHandshake.ACTIVE) <= 0) { - for (final Iterator it = streams.iterator(); it.hasNext(); ) { - final H2Stream stream = it.next(); - final int activeStreamId = stream.getId(); - if (!streams.isSameSide(activeStreamId) && activeStreamId > processedLocalStreamId) { - stream.fail(new RequestNotExecutedException()); - it.remove(); - } + for (final Iterator it = streams.iterator(); it.hasNext(); ) { + final H2Stream stream = it.next(); + final int activeStreamId = stream.getId(); + if (!streams.isSameSide(activeStreamId) && activeStreamId > processedLocalStreamId) { + stream.fail(new RequestNotExecutedException()); + it.remove(); } } - connState = streams.isEmpty() ? ConnectionHandshake.SHUTDOWN : ConnectionHandshake.GRACEFUL_SHUTDOWN; + if (connState != ConnectionHandshake.DRAINING) { + shutdownLastStreamId = processedLocalStreamId; + connState = ConnectionHandshake.GRACEFUL_SHUTDOWN; + } } else { for (final Iterator it = streams.iterator(); it.hasNext(); ) { final H2Stream stream = it.next(); diff --git a/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/examples/H2GracefulShutdownDrainExample.java b/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/examples/H2GracefulShutdownDrainExample.java new file mode 100644 index 000000000..da5eaeca4 --- /dev/null +++ b/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/examples/H2GracefulShutdownDrainExample.java @@ -0,0 +1,336 @@ +/* + * ==================================================================== + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * ==================================================================== + * + * This software consists of voluntary contributions made by many + * individuals on behalf of the Apache Software Foundation. For more + * information on the Apache Software Foundation, please see + * . + * + */ +package org.apache.hc.core5.http2.examples; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.hc.core5.annotation.Experimental; +import org.apache.hc.core5.http.ClassicHttpRequest; +import org.apache.hc.core5.http.ClassicHttpResponse; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.EntityDetails; +import org.apache.hc.core5.http.Header; +import org.apache.hc.core5.http.HttpConnection; +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpException; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.HttpHost; +import org.apache.hc.core5.http.HttpRequest; +import org.apache.hc.core5.http.HttpResponse; +import org.apache.hc.core5.http.HttpStatus; +import org.apache.hc.core5.http.URIScheme; +import org.apache.hc.core5.http.impl.bootstrap.HttpAsyncRequester; +import org.apache.hc.core5.http.impl.bootstrap.HttpAsyncServer; +import org.apache.hc.core5.http.io.support.ClassicRequestBuilder; +import org.apache.hc.core5.http.message.BasicHttpResponse; +import org.apache.hc.core5.http.nio.AsyncClientEndpoint; +import org.apache.hc.core5.http.nio.AsyncServerExchangeHandler; +import org.apache.hc.core5.http.nio.CapacityChannel; +import org.apache.hc.core5.http.nio.DataStreamChannel; +import org.apache.hc.core5.http.nio.ResponseChannel; +import org.apache.hc.core5.http.nio.support.classic.ClassicToAsyncRequestProducer; +import org.apache.hc.core5.http.nio.support.classic.ClassicToAsyncResponseConsumer; +import org.apache.hc.core5.http.protocol.HttpContext; +import org.apache.hc.core5.http2.HttpVersionPolicy; +import org.apache.hc.core5.http2.config.H2Config; +import org.apache.hc.core5.http2.frame.FrameFlag; +import org.apache.hc.core5.http2.frame.FrameType; +import org.apache.hc.core5.http2.frame.RawFrame; +import org.apache.hc.core5.http2.impl.nio.H2StreamListener; +import org.apache.hc.core5.http2.impl.nio.bootstrap.H2RequesterBootstrap; +import org.apache.hc.core5.http2.impl.nio.bootstrap.H2ServerBootstrap; +import org.apache.hc.core5.io.CloseMode; +import org.apache.hc.core5.reactor.IOReactorConfig; +import org.apache.hc.core5.reactor.ListenerEndpoint; +import org.apache.hc.core5.util.TimeValue; +import org.apache.hc.core5.util.Timeout; + +/** + * Example that demonstrates graceful HTTP/2 connection drain. + *

+ * This example starts an embedded HTTP/2 server and an HTTP/2 client, executes + * a single request over a persistent connection, and then triggers graceful + * server shutdown. + *

+ * With two-phase GOAWAY drain support in the H2 stream multiplexer, the client + * side frame log should show: + *

+ * << GOAWAY lastStreamId=2147483647 errorCode=0
+ * << PING ack=false
+ * >> PING ack=true
+ * << GOAWAY lastStreamId=1 errorCode=0
+ * 
+ */ +@Experimental +public class H2GracefulShutdownDrainExample { + + private static final int PORT = 8080; + + public static void main(final String[] args) throws Exception { + + final IOReactorConfig ioReactorConfig = IOReactorConfig.custom() + .setSoTimeout(30, TimeUnit.SECONDS) + .setTcpNoDelay(true) + .build(); + + final H2Config h2Config = H2Config.custom() + .setPushEnabled(false) + .setMaxConcurrentStreams(100) + .build(); + + final CountDownLatch finalGoAwayLatch = new CountDownLatch(1); + final AtomicInteger clientGoAwayCount = new AtomicInteger(); + + final HttpAsyncServer server = H2ServerBootstrap.bootstrap() + .setIOReactorConfig(ioReactorConfig) + .setH2Config(h2Config) + .setVersionPolicy(HttpVersionPolicy.FORCE_HTTP_2) + .setStreamListener(new LoggingH2StreamListener("SERVER", null, null)) + .register("/hello", () -> new AsyncServerExchangeHandler() { + + private final ByteBuffer content = StandardCharsets.UTF_8.encode("hello over h2\n"); + private volatile boolean responseSubmitted; + + @Override + public void handleRequest( + final HttpRequest request, + final EntityDetails entityDetails, + final ResponseChannel responseChannel, + final HttpContext context) throws HttpException, IOException { + final HttpResponse response = new BasicHttpResponse(HttpStatus.SC_OK); + response.setHeader(HttpHeaders.CONTENT_TYPE, ContentType.TEXT_PLAIN.toString()); + responseChannel.sendResponse(response, null, context); + responseSubmitted = true; + } + + @Override + public void updateCapacity(final CapacityChannel capacityChannel) throws IOException { + capacityChannel.update(Integer.MAX_VALUE); + } + + @Override + public void consume(final ByteBuffer src) throws IOException { + while (src.hasRemaining()) { + src.get(); + } + } + + @Override + public void streamEnd(final List trailers) throws HttpException, IOException { + } + + @Override + public int available() { + return responseSubmitted ? content.remaining() : 0; + } + + @Override + public void produce(final DataStreamChannel channel) throws IOException { + if (content.hasRemaining()) { + channel.write(content); + } + if (!content.hasRemaining()) { + channel.endStream(); + } + } + + @Override + public void failed(final Exception cause) { + cause.printStackTrace(System.out); + } + + @Override + public void releaseResources() { + } + + }) + .create(); + + final HttpAsyncRequester requester = H2RequesterBootstrap.bootstrap() + .setIOReactorConfig(ioReactorConfig) + .setH2Config(h2Config) + .setVersionPolicy(HttpVersionPolicy.FORCE_HTTP_2) + .setStreamListener(new LoggingH2StreamListener("CLIENT", finalGoAwayLatch, clientGoAwayCount)) + .create(); + + server.start(); + final Future listenerFuture = server.listen(new InetSocketAddress(PORT), URIScheme.HTTP); + final ListenerEndpoint listenerEndpoint = listenerFuture.get(); + System.out.println("Server listening on " + listenerEndpoint.getAddress()); + + requester.start(); + + final HttpHost target = new HttpHost("http", "127.0.0.1", PORT); + final Future endpointFuture = requester.connect(target, Timeout.ofSeconds(30)); + final AsyncClientEndpoint clientEndpoint = endpointFuture.get(); + + final ClassicHttpRequest request = ClassicRequestBuilder.get() + .setHttpHost(target) + .setPath("/hello") + .build(); + + final ClassicToAsyncRequestProducer requestProducer = + new ClassicToAsyncRequestProducer(request, Timeout.ofSeconds(30)); + final ClassicToAsyncResponseConsumer responseConsumer = + new ClassicToAsyncResponseConsumer(Timeout.ofSeconds(30)); + + clientEndpoint.execute(requestProducer, responseConsumer, null); + + requestProducer.blockWaiting().execute(); + try (ClassicHttpResponse response = responseConsumer.blockWaiting()) { + System.out.println("/hello -> " + response.getCode()); + final HttpEntity entity = response.getEntity(); + if (entity != null) { + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(entity.getContent(), StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + System.out.println(line); + } + } + } + } + + System.out.println(); + System.out.println("Triggering graceful server shutdown"); + server.initiateShutdown(); + + final boolean completed = finalGoAwayLatch.await(10, TimeUnit.SECONDS); + System.out.println("Final GOAWAY observed: " + completed); + if (!completed) { + throw new IllegalStateException("Did not observe the final GOAWAY frame"); + } + + Thread.sleep(1000); + + System.out.println(); + System.out.println("Triggering requester shutdown"); + requester.initiateShutdown(); + + requester.awaitShutdown(TimeValue.ofSeconds(5)); + server.awaitShutdown(TimeValue.ofSeconds(5)); + + requester.close(CloseMode.GRACEFUL); + server.close(CloseMode.GRACEFUL); + } + + static final class LoggingH2StreamListener implements H2StreamListener { + + private final String name; + private final CountDownLatch finalGoAwayLatch; + private final AtomicInteger goAwayCount; + + LoggingH2StreamListener( + final String name, + final CountDownLatch finalGoAwayLatch, + final AtomicInteger goAwayCount) { + this.name = name; + this.finalGoAwayLatch = finalGoAwayLatch; + this.goAwayCount = goAwayCount; + } + + @Override + public void onHeaderInput(final HttpConnection connection, final int streamId, final List headers) { + for (int i = 0; i < headers.size(); i++) { + System.out.println(name + " " + connection.getRemoteAddress() + " (" + streamId + ") << " + headers.get(i)); + } + } + + @Override + public void onHeaderOutput(final HttpConnection connection, final int streamId, final List headers) { + for (int i = 0; i < headers.size(); i++) { + System.out.println(name + " " + connection.getRemoteAddress() + " (" + streamId + ") >> " + headers.get(i)); + } + } + + @Override + public void onFrameInput(final HttpConnection connection, final int streamId, final RawFrame frame) { + System.out.println(name + " " + connection.getRemoteAddress() + " (" + streamId + ") << " + formatFrame(frame)); + if (finalGoAwayLatch != null && goAwayCount != null && FrameType.valueOf(frame.getType()) == FrameType.GOAWAY) { + if (goAwayCount.incrementAndGet() == 2) { + finalGoAwayLatch.countDown(); + } + } + } + + @Override + public void onFrameOutput(final HttpConnection connection, final int streamId, final RawFrame frame) { + System.out.println(name + " " + connection.getRemoteAddress() + " (" + streamId + ") >> " + formatFrame(frame)); + } + + @Override + public void onInputFlowControl(final HttpConnection connection, final int streamId, final int delta, final int actualSize) { + } + + @Override + public void onOutputFlowControl(final HttpConnection connection, final int streamId, final int delta, final int actualSize) { + } + + private static String formatFrame(final RawFrame frame) { + final FrameType frameType = FrameType.valueOf(frame.getType()); + if (frameType == null) { + return "UNKNOWN(" + frame.getType() + ")"; + } + switch (frameType) { + case GOAWAY: { + final ByteBuffer payload = frame.getPayload(); + if (payload == null || payload.remaining() < 8) { + return "GOAWAY invalid"; + } + final ByteBuffer dup = payload.asReadOnlyBuffer(); + final int lastStreamId = dup.getInt() & 0x7fffffff; + final int errorCode = dup.getInt(); + return "GOAWAY lastStreamId=" + lastStreamId + " errorCode=" + errorCode; + } + case PING: + return "PING ack=" + frame.isFlagSet(FrameFlag.ACK); + case SETTINGS: + return frame.isFlagSet(FrameFlag.ACK) ? "SETTINGS ack=true" : "SETTINGS ack=false"; + case HEADERS: + return "HEADERS endStream=" + frame.isFlagSet(FrameFlag.END_STREAM) + + " endHeaders=" + frame.isFlagSet(FrameFlag.END_HEADERS); + case DATA: + return "DATA endStream=" + frame.isFlagSet(FrameFlag.END_STREAM) + + " length=" + frame.getLength(); + default: + return frameType.name() + " length=" + frame.getLength(); + } + } + + } + +} diff --git a/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/impl/nio/TestAbstractH2StreamMultiplexer.java b/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/impl/nio/TestAbstractH2StreamMultiplexer.java index 6c9e88e4f..3ce628793 100644 --- a/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/impl/nio/TestAbstractH2StreamMultiplexer.java +++ b/httpcore5-h2/src/test/java/org/apache/hc/core5/http2/impl/nio/TestAbstractH2StreamMultiplexer.java @@ -49,6 +49,7 @@ import org.apache.hc.core5.http.nio.AsyncPushConsumer; import org.apache.hc.core5.http.nio.AsyncPushProducer; import org.apache.hc.core5.http.nio.HandlerFactory; +import org.apache.hc.core5.http.nio.command.ShutdownCommand; import org.apache.hc.core5.http.protocol.HttpContext; import org.apache.hc.core5.http.protocol.HttpProcessor; import org.apache.hc.core5.http2.H2ConnectionException; @@ -68,6 +69,7 @@ import org.apache.hc.core5.http2.frame.StreamIdGenerator; import org.apache.hc.core5.http2.hpack.HPackEncoder; import org.apache.hc.core5.http2.hpack.HPackException; +import org.apache.hc.core5.reactor.Command; import org.apache.hc.core5.reactor.ProtocolIOSession; import org.apache.hc.core5.util.ByteArrayBuffer; import org.apache.hc.core5.util.Timeout; @@ -2017,5 +2019,176 @@ void testHeadersWithPrioritySelfDependencyIsStreamProtocolError() throws Excepti .consumeHeader(ArgumentMatchers.anyList(), ArgumentMatchers.anyBoolean()); } + @Test + void testGracefulShutdownUsesTwoPhaseGoAwayWithPingBarrier() throws Exception { + final List writes = new ArrayList<>(); + Mockito.when(protocolIOSession.write(ArgumentMatchers.any(ByteBuffer.class))) + .thenAnswer(invocation -> { + final ByteBuffer buffer = invocation.getArgument(0, ByteBuffer.class); + final byte[] copy = new byte[buffer.remaining()]; + buffer.get(copy); + writes.add(copy); + return copy.length; + }); + Mockito.doNothing().when(protocolIOSession).setEvent(ArgumentMatchers.anyInt()); + Mockito.doNothing().when(protocolIOSession).clearEvent(ArgumentMatchers.anyInt()); + + Mockito.when(protocolIOSession.poll()).thenReturn(null); + + final H2Config h2Config = H2Config.custom().build(); + + final AbstractH2StreamMultiplexer mux = new H2StreamMultiplexerImpl( + protocolIOSession, + FRAME_FACTORY, + StreamIdGenerator.ODD, + httpProcessor, + CharCodingConfig.DEFAULT, + h2Config, + h2StreamListener, + () -> streamHandler); + + mux.onConnect(); + mux.onOutput(); + completeSettingsHandshake(mux); + mux.onOutput(); + + final ByteArrayBuffer headerBuf = new ByteArrayBuffer(128); + final HPackEncoder encoder = new HPackEncoder( + h2Config.getHeaderTableSize(), + CharCodingSupport.createEncoder(CharCodingConfig.DEFAULT)); + + final List
headers = Arrays.asList( + new BasicHeader(":method", "GET"), + new BasicHeader(":scheme", "https"), + new BasicHeader(":path", "/"), + new BasicHeader(":authority", "example.test")); + + encoder.encodeHeaders(headerBuf, headers, h2Config.isCompressionEnabled()); + + final RawFrame headersFrame = FRAME_FACTORY.createHeaders( + 2, + ByteBuffer.wrap(headerBuf.array(), 0, headerBuf.length()), + true, + true); + feedFrame(mux, headersFrame); + + writes.clear(); + + Mockito.when(protocolIOSession.poll()).thenReturn(ShutdownCommand.GRACEFUL, (Command) null); + + // 1st pass: consume shutdown command, queue initial GOAWAY + mux.onOutput(); + + // 2nd pass: flush initial GOAWAY, queue drain PING + mux.onOutput(); + + // 3rd pass: flush drain PING + mux.onOutput(); + + List frames = parseFrames(concat(writes)); + + final FrameStub initialGoAway = frames.stream() + .filter(FrameStub::isGoAway) + .findFirst() + .orElse(null); + Assertions.assertNotNull(initialGoAway, "Initial GOAWAY not emitted"); + Assertions.assertEquals(Integer.MAX_VALUE, goAwayLastStreamId(initialGoAway)); + + final FrameStub ping = frames.stream() + .filter(f -> f.isPing() && !f.isAck()) + .findFirst() + .orElse(null); + Assertions.assertNotNull(ping, "Drain PING not emitted"); + + final RawFrame pingAck = new RawFrame( + FrameType.PING.getValue(), + FrameFlag.ACK.getValue(), + 0, + ByteBuffer.wrap(ping.payload)); + + writes.clear(); + + feedFrame(mux, pingAck); + + // final GOAWAY gets queued by consumeResponse -> completeGracefulShutdown() + mux.onOutput(); + + frames = parseFrames(concat(writes)); + + final FrameStub finalGoAway = frames.stream() + .filter(FrameStub::isGoAway) + .findFirst() + .orElse(null); + Assertions.assertNotNull(finalGoAway, "Final GOAWAY not emitted"); + Assertions.assertEquals(2, goAwayLastStreamId(finalGoAway)); + } + + @Test + void testPeerInitialGracefulGoAwayDoesNotPreventPingAck() throws Exception { + final List writes = new ArrayList<>(); + Mockito.when(protocolIOSession.write(ArgumentMatchers.any(ByteBuffer.class))) + .thenAnswer(invocation -> { + final ByteBuffer buffer = invocation.getArgument(0, ByteBuffer.class); + final byte[] copy = new byte[buffer.remaining()]; + buffer.get(copy); + writes.add(copy); + return copy.length; + }); + Mockito.doNothing().when(protocolIOSession).setEvent(ArgumentMatchers.anyInt()); + Mockito.doNothing().when(protocolIOSession).clearEvent(ArgumentMatchers.anyInt()); + + final AbstractH2StreamMultiplexer mux = new H2StreamMultiplexerImpl( + protocolIOSession, + FRAME_FACTORY, + StreamIdGenerator.ODD, + httpProcessor, + CharCodingConfig.DEFAULT, + H2Config.custom().build(), + h2StreamListener, + () -> streamHandler); + + mux.onConnect(); + mux.onOutput(); + completeSettingsHandshake(mux); + mux.onOutput(); + writes.clear(); + + final ByteBuffer goAwayPayload = ByteBuffer.allocate(8); + goAwayPayload.putInt(Integer.MAX_VALUE); + goAwayPayload.putInt(H2Error.NO_ERROR.getCode()); + goAwayPayload.flip(); + + final RawFrame goAway = new RawFrame( + FrameType.GOAWAY.getValue(), + 0, + 0, + goAwayPayload); + + feedFrame(mux, goAway); + + final byte[] pingPayload = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }; + final RawFrame ping = new RawFrame( + FrameType.PING.getValue(), + 0, + 0, + ByteBuffer.wrap(pingPayload)); + + feedFrame(mux, ping); + mux.onOutput(); + + final List frames = parseFrames(concat(writes)); + final FrameStub pingAck = frames.stream() + .filter(f -> f.isPing() && f.isAck()) + .findFirst() + .orElse(null); + + Assertions.assertNotNull(pingAck, "PING ACK must still be emitted after initial graceful GOAWAY"); + Assertions.assertArrayEquals(pingPayload, pingAck.payload); + } + + private static int goAwayLastStreamId(final FrameStub frame) { + final ByteBuffer buffer = ByteBuffer.wrap(frame.payload); + return buffer.getInt() & 0x7fffffff; + } } \ No newline at end of file