From d5873fa0d6c8c7a3f357da849f95fba58f019583 Mon Sep 17 00:00:00 2001 From: Ivan Yurchenko Date: Mon, 10 Feb 2025 18:46:15 +0200 Subject: [PATCH] Solve one-produce-request-at-time problem The Kafka network subsystem is organized in a way that allows handling only one request per connection simultaneously. When the request is received, the connection is muted (which means, no further requests are selected by `Selector`). Only when the request is processed and the response is sent to the client, the channel is unmuted to let the following request it. This is fine for normal Kafka, where produce requests are normally quick. However, in Inkless normally we have up up to 250+ ms delay while the active file is filled, uploaded, and committed. This is the major bottleneck for single producer throughput. More details could be found [here](https://aiven-io.slack.com/archives/C06G4TBQ6AW/p1733939393895899). This commit is a proof-of-concept solution to this problem. The core idea is "upgrading" the connection. When the first Inkless Produce request is handled, the connection is "upgraded". After this moment, it's not expecting to receive any other request type nor even non-Inkless produce. Upgraded Inkless connections are handled differently in several key ways: 1. They aren't muted when a request is received. This allows to accept requests in the pipelined fashion, hand them over to the `InklessAppendInterceptor` before the previous ones are handled and responded to. The Inkless produce machinery already has some queueing and support for parallel upload. 2. There's a response queue, `InklessSendQueue`. It enables serializing responses by correlation ID (otherwise, the client will fail) and sending them downstream to the connection only when the connection is ready to send them further to the client. 3. If an unexpected request comes, the connection fails. # TODO 1. The assumption that if there are produce requests in the connection, there will be only produce requests in the future seems correct for any sensibly implemented client. However, there's two exceptions to this: periodic metadata updates and telemetry sending. It seems, the Java client uses the same connection used for produce requests. A way to handle this must be found. Potentially, another unelegant hack will be needed. For example, there requests may be handled in parallel to Inkless produce. 2. Connection muting is a back pressure mechanism. Disabling it, we're opening the broker to all sorts of overload and QoS perils. A correct end-to-end back pressure mechanism must be implemented for Inkless produce. 3. Connection muting is also used for client quotas. This must also be taken into account. --- .../scala/kafka/network/InklessSendQueue.java | 53 ++++++++++++ .../scala/kafka/network/SocketServer.scala | 81 +++++++++++++++---- .../scala/kafka/server/BrokerServer.scala | 36 +++++---- .../main/scala/kafka/server/KafkaApis.scala | 10 ++- .../scala/kafka/server/ReplicaManager.scala | 16 +++- .../io/aiven/inkless/common/SharedState.java | 10 ++- .../InklessConnectionUpgradeTracker.java | 51 ++++++++++++ 7 files changed, 219 insertions(+), 38 deletions(-) create mode 100644 core/src/main/scala/kafka/network/InklessSendQueue.java create mode 100644 storage/inkless/src/main/java/io/aiven/inkless/network/InklessConnectionUpgradeTracker.java diff --git a/core/src/main/scala/kafka/network/InklessSendQueue.java b/core/src/main/scala/kafka/network/InklessSendQueue.java new file mode 100644 index 0000000000..282a8860c1 --- /dev/null +++ b/core/src/main/scala/kafka/network/InklessSendQueue.java @@ -0,0 +1,53 @@ +// Copyright (c) 2025 Aiven, Helsinki, Finland. https://aiven.io/ +package kafka.network; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Comparator; +import java.util.PriorityQueue; + +/** + * The response queue for a connection. + * + *

This queue arranges responses by their correlation ID, expecting no gaps and the strict order. + */ +class InklessSendQueue { + private static final Logger LOGGER = LoggerFactory.getLogger(InklessSendQueue.class); + + private static final Comparator CORRELATION_ID_COMPARATOR = + Comparator.comparing((RequestChannel.SendResponse r) -> r.request().header().correlationId()); + private final PriorityQueue queue = new PriorityQueue<>(CORRELATION_ID_COMPARATOR); + private int nextCorrelationId; + + InklessSendQueue(final int startCorrelationId) { +// LOGGER.info("Starting with correlation ID {}", startCorrelationId); + this.nextCorrelationId = startCorrelationId; + } + + void add(final RequestChannel.SendResponse response) { +// LOGGER.info("Adding response with correlation ID {}", response.request().header().correlationId()); + if (response.request().header().correlationId() < nextCorrelationId) { + throw new IllegalStateException("Expected min correlation ID " + nextCorrelationId); + } + queue.add(response); + } + + boolean nextReady() { + final RequestChannel.SendResponse peeked = queue.peek(); + if (peeked == null) { + return false; + } + final int correlationId = peeked.request().header().correlationId(); +// LOGGER.info("Peeked correlation ID {}, expecting {}", peeked.request().header().correlationId(), nextCorrelationId); + return correlationId == nextCorrelationId; + } + + RequestChannel.SendResponse take() { + if (!nextReady()) { + throw new IllegalStateException(); + } + nextCorrelationId += 1; + return queue.remove(); + } +} diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index 093f5298da..dc65fffc3d 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -17,6 +17,8 @@ package kafka.network +import io.aiven.inkless.network.InklessConnectionUpgradeTracker + import java.io.IOException import java.net._ import java.nio.ByteBuffer @@ -37,7 +39,7 @@ import org.apache.kafka.common.errors.InvalidRequestException import org.apache.kafka.common.memory.{MemoryPool, SimpleMemoryPool} import org.apache.kafka.common.metrics._ import org.apache.kafka.common.metrics.stats.{Avg, CumulativeSum, Meter, Rate} -import org.apache.kafka.common.network.KafkaChannel.ChannelMuteEvent +import org.apache.kafka.common.network.KafkaChannel.{ChannelMuteEvent, ChannelMuteState} import org.apache.kafka.common.network.{ChannelBuilder, ChannelBuilders, ClientInformation, KafkaChannel, ListenerName, ListenerReconfigurable, NetworkSend, Selectable, Send, ServerConnectionId, Selector => KSelector} import org.apache.kafka.common.protocol.ApiKeys import org.apache.kafka.common.requests.{ApiVersionsRequest, RequestContext, RequestHeader} @@ -84,7 +86,8 @@ class SocketServer( val credentialProvider: CredentialProvider, val apiVersionManager: ApiVersionManager, val socketFactory: ServerSocketFactory = ServerSocketFactory.INSTANCE, - val connectionDisconnectListeners: Seq[ConnectionDisconnectListener] = Seq.empty + val connectionDisconnectListeners: Seq[ConnectionDisconnectListener] = Seq.empty, + val inklessConnectionUpgradeTracker: Option[InklessConnectionUpgradeTracker] = None ) extends Logging with BrokerReconfigurable { private val metricsGroup = new KafkaMetricsGroup(this.getClass) @@ -274,11 +277,11 @@ class SocketServer( private def endpoints = config.listeners.map(l => l.listenerName -> l).toMap protected def createDataPlaneAcceptor(endPoint: EndPoint, isPrivilegedListener: Boolean, requestChannel: RequestChannel): DataPlaneAcceptor = { - new DataPlaneAcceptor(this, endPoint, config, nodeId, connectionQuotas, time, isPrivilegedListener, requestChannel, metrics, credentialProvider, logContext, memoryPool, apiVersionManager) + new DataPlaneAcceptor(this, endPoint, config, nodeId, connectionQuotas, time, isPrivilegedListener, requestChannel, metrics, credentialProvider, logContext, memoryPool, apiVersionManager, inklessConnectionUpgradeTracker) } private def createControlPlaneAcceptor(endPoint: EndPoint, requestChannel: RequestChannel): ControlPlaneAcceptor = { - new ControlPlaneAcceptor(this, endPoint, config, nodeId, connectionQuotas, time, requestChannel, metrics, credentialProvider, logContext, memoryPool, apiVersionManager) + new ControlPlaneAcceptor(this, endPoint, config, nodeId, connectionQuotas, time, requestChannel, metrics, credentialProvider, logContext, memoryPool, apiVersionManager, inklessConnectionUpgradeTracker) } /** @@ -442,7 +445,8 @@ class DataPlaneAcceptor(socketServer: SocketServer, credentialProvider: CredentialProvider, logContext: LogContext, memoryPool: MemoryPool, - apiVersionManager: ApiVersionManager) + apiVersionManager: ApiVersionManager, + inklessConnectionUpgradeTracker: Option[InklessConnectionUpgradeTracker] = None) extends Acceptor(socketServer, endPoint, config, @@ -455,7 +459,8 @@ class DataPlaneAcceptor(socketServer: SocketServer, credentialProvider, logContext, memoryPool, - apiVersionManager) with ListenerReconfigurable { + apiVersionManager, + inklessConnectionUpgradeTracker) with ListenerReconfigurable { override def metricPrefix(): String = DataPlaneAcceptor.MetricPrefix override def threadPrefix(): String = DataPlaneAcceptor.ThreadPrefix @@ -544,7 +549,8 @@ class ControlPlaneAcceptor(socketServer: SocketServer, credentialProvider: CredentialProvider, logContext: LogContext, memoryPool: MemoryPool, - apiVersionManager: ApiVersionManager) + apiVersionManager: ApiVersionManager, + inklessConnectionUpgradeTracker: Option[InklessConnectionUpgradeTracker] = None) extends Acceptor(socketServer, endPoint, config, @@ -557,7 +563,8 @@ class ControlPlaneAcceptor(socketServer: SocketServer, credentialProvider, logContext, memoryPool, - apiVersionManager) { + apiVersionManager, + inklessConnectionUpgradeTracker) { override def metricPrefix(): String = ControlPlaneAcceptor.MetricPrefix override def threadPrefix(): String = ControlPlaneAcceptor.ThreadPrefix @@ -579,7 +586,8 @@ private[kafka] abstract class Acceptor(val socketServer: SocketServer, credentialProvider: CredentialProvider, logContext: LogContext, memoryPool: MemoryPool, - apiVersionManager: ApiVersionManager) + apiVersionManager: ApiVersionManager, + inklessConnectionUpgradeTracker: Option[InklessConnectionUpgradeTracker] = None) extends Runnable with Logging { private val metricsGroup = new KafkaMetricsGroup(this.getClass) @@ -879,7 +887,8 @@ private[kafka] abstract class Acceptor(val socketServer: SocketServer, isPrivilegedListener, apiVersionManager, name, - connectionDisconnectListeners) + connectionDisconnectListeners, + inklessConnectionUpgradeTracker) } } @@ -919,7 +928,8 @@ private[kafka] class Processor( isPrivilegedListener: Boolean, apiVersionManager: ApiVersionManager, threadName: String, - connectionDisconnectListeners: Seq[ConnectionDisconnectListener] + connectionDisconnectListeners: Seq[ConnectionDisconnectListener], + inklessConnectionUpgradeTracker: Option[InklessConnectionUpgradeTracker] = None ) extends Runnable with Logging { private val metricsGroup = new KafkaMetricsGroup(this.getClass) @@ -990,6 +1000,8 @@ private[kafka] class Processor( // closed, connection ids are not reused while requests from the closed connection are being processed. private var nextConnectionIndex = 0 + private var inklessSendQueueByChannel: immutable.Map[String, InklessSendQueue] = immutable.Map.empty + override def run(): Unit = { try { while (shouldRun.get()) { @@ -1052,7 +1064,17 @@ private[kafka] class Processor( tryUnmuteChannel(channelId) case response: SendResponse => - sendResponse(response, response.responseSend) + val connectionId = response.request.context.connectionId + val upgradedConnection = inklessConnectionUpgradeTracker.nonEmpty && inklessConnectionUpgradeTracker.get.isConnectionUpgraded(connectionId) + if (upgradedConnection) { + val connectionId = response.request.context.connectionId + if (!inklessSendQueueByChannel.contains(connectionId)) { + inklessSendQueueByChannel += connectionId -> new InklessSendQueue(inklessConnectionUpgradeTracker.get.upgradeCorrelationId(connectionId)) + } + inklessSendQueueByChannel(connectionId).add(response) + } else { + sendResponse(response, response.responseSend) + } case response: CloseConnectionResponse => updateRequestMetrics(response) trace("Closing socket connection actively according to the response code.") @@ -1072,6 +1094,20 @@ private[kafka] class Processor( processChannelException(channelId, s"Exception while processing response for $channelId", e) } } + + // Process responses for Inkless upgraded connections. + for ((connectionId, queue) <- inklessSendQueueByChannel) { + openOrClosingChannel(connectionId) match { + case Some(channel) => + if (queue.nextReady() && !channel.hasSend) { + val response = queue.take() + sendResponse(response, response.responseSend) + } + + case None => + inklessSendQueueByChannel -= connectionId + } + } } // `protected` for test usage @@ -1148,8 +1184,11 @@ private[kafka] class Processor( } } requestChannel.sendRequest(req) - selector.mute(connectionId) - handleChannelMuteEvent(connectionId, ChannelMuteEvent.REQUEST_RECEIVED) + val upgradedConnection = inklessConnectionUpgradeTracker.nonEmpty && inklessConnectionUpgradeTracker.get.isConnectionUpgraded(connectionId) + if (!upgradedConnection) { + selector.mute(connectionId) + handleChannelMuteEvent(connectionId, ChannelMuteEvent.REQUEST_RECEIVED) + } } } case None => @@ -1181,7 +1220,19 @@ private[kafka] class Processor( // Try unmuting the channel. If there was no quota violation and the channel has not been throttled, // it will be unmuted immediately. If the channel has been throttled, it will unmuted only if the throttling // delay has already passed by now. - handleChannelMuteEvent(send.destinationId, ChannelMuteEvent.RESPONSE_SENT) + val connectionId = response.request.context.connectionId() + val upgradedConnection = inklessConnectionUpgradeTracker.nonEmpty && inklessConnectionUpgradeTracker.get.isConnectionUpgraded(connectionId) + if (upgradedConnection) { + openOrClosingChannel(connectionId).foreach{ channel => + // Imitate muting to prevent illegal state errors when `tryUnmuteChannel` is called. + if (channel.muteState() == ChannelMuteState.MUTED_AND_RESPONSE_PENDING) { + handleChannelMuteEvent(send.destinationId, ChannelMuteEvent.RESPONSE_SENT) + } + selector.mute(channel.id) + } + } else { + handleChannelMuteEvent(send.destinationId, ChannelMuteEvent.RESPONSE_SENT) + } tryUnmuteChannel(send.destinationId) } catch { case e: Throwable => processChannelException(send.destinationId, diff --git a/core/src/main/scala/kafka/server/BrokerServer.scala b/core/src/main/scala/kafka/server/BrokerServer.scala index d684041d95..15e67babad 100644 --- a/core/src/main/scala/kafka/server/BrokerServer.scala +++ b/core/src/main/scala/kafka/server/BrokerServer.scala @@ -18,6 +18,7 @@ package kafka.server import io.aiven.inkless.common.SharedState +import io.aiven.inkless.network.InklessConnectionUpgradeTracker import kafka.coordinator.group.{CoordinatorLoaderImpl, CoordinatorPartitionWriter, GroupCoordinatorAdapter} import kafka.coordinator.transaction.TransactionCoordinator import kafka.log.LogManager @@ -262,6 +263,23 @@ class BrokerServer( Some(clientMetricsManager) ) + val inklessMetadataView = new InklessMetadataView(metadataCache) + val inklessConnectionUpgradeTracker = new InklessConnectionUpgradeTracker(inklessMetadataView) + val inklessSharedState = sharedServer.inklessControlPlane.map { controlPlane => + SharedState.initialize( + time, + clusterId, + config.rack.orNull, + config.brokerId, + config.inklessConfig, + inklessMetadataView, + controlPlane, + brokerTopicStats, + () => logManager.currentDefaultConfig, + inklessConnectionUpgradeTracker + ) + } + val connectionDisconnectListeners = Seq(clientMetricsManager.connectionDisconnectListener()) // Create and start the socket server acceptor threads so that the bound port is known. // Delay starting processors until the end of the initialization sequence to ensure @@ -272,7 +290,8 @@ class BrokerServer( credentialProvider, apiVersionManager, sharedServer.socketFactory, - connectionDisconnectListeners) + connectionDisconnectListeners, + Some(inklessConnectionUpgradeTracker)) clientQuotaMetadataManager = new ClientQuotaMetadataManager(quotaManagers, socketServer.connectionQuotas) @@ -335,21 +354,6 @@ class BrokerServer( */ val defaultActionQueue = new DelayedActionQueue - val inklessMetadataView = new InklessMetadataView(metadataCache) - val inklessSharedState = sharedServer.inklessControlPlane.map { controlPlane => - SharedState.initialize( - time, - clusterId, - config.rack.orNull, - config.brokerId, - config.inklessConfig, - inklessMetadataView, - controlPlane, - brokerTopicStats, - () => logManager.currentDefaultConfig - ) - } - this._replicaManager = new ReplicaManager( config = config, metrics = metrics, diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index b0541bc983..9d535d638f 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -189,6 +189,13 @@ class KafkaApis(val requestChannel: RequestChannel, requestHelper.handleError(request, e) } + if (inklessSharedState.forall(st => st.inklessConnectionUpgradeTracker().isConnectionUpgraded(request.context.connectionId())) + && request.header.apiKey != ApiKeys.PRODUCE) { + logger.error("Received unexpected request with API key {} in Inkless-upgraded connection", request.header.apiKey) + handleError(new RuntimeException()) + return + } + try { trace(s"Handling request:${request.requestDesc(true)} from connection ${request.context.connectionId};" + s"securityProtocol:${request.context.securityProtocol},principal:${request.context.principal}") @@ -757,7 +764,8 @@ class KafkaApis(val requestChannel: RequestChannel, responseCallback = sendResponseCallback, recordValidationStatsCallback = processingStatsCallback, requestLocal = requestLocal, - transactionSupportedOperation = transactionSupportedOperation) + transactionSupportedOperation = transactionSupportedOperation, + connectionIdAndCorrelationId = Some((request.context.connectionId(), request.header.correlationId()))) // if the request is put into the purgatory, it will have a held reference and hence cannot be garbage collected; // hence we clear its data here in order to let GC reclaim its memory since it is already appended to log diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala index f8ca7f8ae0..55c4c71729 100644 --- a/core/src/main/scala/kafka/server/ReplicaManager.scala +++ b/core/src/main/scala/kafka/server/ReplicaManager.scala @@ -21,6 +21,7 @@ import io.aiven.inkless.common.SharedState import io.aiven.inkless.consume.{FetchInterceptor, FetchOffsetInterceptor} import io.aiven.inkless.delete.{DeleteRecordsInterceptor, FileCleaner} import io.aiven.inkless.merge.FileMerger +import io.aiven.inkless.network.InklessConnectionUpgradeTracker import io.aiven.inkless.produce.AppendInterceptor import kafka.cluster.{Partition, PartitionListener} import kafka.controller.{KafkaController, StateChangeLogger} @@ -332,6 +333,7 @@ class ReplicaManager(val config: KafkaConfig, private val inklessDeleteRecordsInterceptor: Option[DeleteRecordsInterceptor] = inklessSharedState.map(new DeleteRecordsInterceptor(_)) private val inklessFileCleaner: Option[FileCleaner] = inklessSharedState.map(new FileCleaner(_)) private val inklessFileMerger: Option[FileMerger] = inklessSharedState.map(new FileMerger(_)) + private val inklessConnectionUpgradeTracker: Option[InklessConnectionUpgradeTracker] = inklessSharedState.map(_.inklessConnectionUpgradeTracker()) /* epoch of the controller that last changed the leader */ @volatile private[server] var controllerEpoch: Int = KafkaController.InitialControllerEpoch @@ -828,13 +830,19 @@ class ReplicaManager(val config: KafkaConfig, recordValidationStatsCallback: Map[TopicPartition, RecordValidationStats] => Unit = _ => (), requestLocal: RequestLocal = RequestLocal.noCaching, actionQueue: ActionQueue = this.defaultActionQueue, - verificationGuards: Map[TopicPartition, VerificationGuard] = Map.empty): Unit = { + verificationGuards: Map[TopicPartition, VerificationGuard] = Map.empty, + connectionIdAndCorrelationId: Option[(String, Int)] = None): Unit = { if (!isValidRequiredAcks(requiredAcks)) { sendInvalidRequiredAcksResponse(entriesPerPartition, responseCallback) return } if (inklessAppendInterceptor.exists(_.intercept(entriesPerPartition.asJava, r => responseCallback(r.asScala)))) { + inklessConnectionUpgradeTracker.foreach(tr => + connectionIdAndCorrelationId.foreach { case (connectionId, correlationId) => + tr.upgradeConnection(connectionId, correlationId) + } + ) return } @@ -890,7 +898,8 @@ class ReplicaManager(val config: KafkaConfig, recordValidationStatsCallback: Map[TopicPartition, RecordValidationStats] => Unit = _ => (), requestLocal: RequestLocal = RequestLocal.noCaching, actionQueue: ActionQueue = this.defaultActionQueue, - transactionSupportedOperation: TransactionSupportedOperation): Unit = { + transactionSupportedOperation: TransactionSupportedOperation, + connectionIdAndCorrelationId: Option[(String, Int)] = None): Unit = { val transactionalProducerInfo = mutable.HashSet[(Long, Short)]() val topicPartitionBatchInfo = mutable.Map[TopicPartition, Int]() @@ -949,7 +958,8 @@ class ReplicaManager(val config: KafkaConfig, recordValidationStatsCallback = recordValidationStatsCallback, requestLocal = newRequestLocal, actionQueue = actionQueue, - verificationGuards = verificationGuards + verificationGuards = verificationGuards, + connectionIdAndCorrelationId = connectionIdAndCorrelationId, ) } diff --git a/storage/inkless/src/main/java/io/aiven/inkless/common/SharedState.java b/storage/inkless/src/main/java/io/aiven/inkless/common/SharedState.java index 47af56b99f..d3a8614cd7 100644 --- a/storage/inkless/src/main/java/io/aiven/inkless/common/SharedState.java +++ b/storage/inkless/src/main/java/io/aiven/inkless/common/SharedState.java @@ -16,6 +16,7 @@ import io.aiven.inkless.config.InklessConfig; import io.aiven.inkless.control_plane.ControlPlane; import io.aiven.inkless.control_plane.MetadataView; +import io.aiven.inkless.network.InklessConnectionUpgradeTracker; import io.aiven.inkless.storage_backend.common.StorageBackend; public record SharedState( @@ -29,7 +30,8 @@ public record SharedState( KeyAlignmentStrategy keyAlignmentStrategy, ObjectCache cache, BrokerTopicStats brokerTopicStats, - Supplier defaultTopicConfigs + Supplier defaultTopicConfigs, + InklessConnectionUpgradeTracker inklessConnectionUpgradeTracker ) implements Closeable { public static SharedState initialize( @@ -41,7 +43,8 @@ public static SharedState initialize( MetadataView metadata, ControlPlane controlPlane, BrokerTopicStats brokerTopicStats, - Supplier defaultTopicConfigs + Supplier defaultTopicConfigs, + InklessConnectionUpgradeTracker inklessConnectionUpgradeTracker ) { return new SharedState( time, @@ -54,7 +57,8 @@ public static SharedState initialize( new FixedBlockAlignment(config.fetchCacheBlockBytes()), new InfinispanCache(time, clusterId, rack), brokerTopicStats, - defaultTopicConfigs + defaultTopicConfigs, + inklessConnectionUpgradeTracker ); } diff --git a/storage/inkless/src/main/java/io/aiven/inkless/network/InklessConnectionUpgradeTracker.java b/storage/inkless/src/main/java/io/aiven/inkless/network/InklessConnectionUpgradeTracker.java new file mode 100644 index 0000000000..c3ddd7b8fd --- /dev/null +++ b/storage/inkless/src/main/java/io/aiven/inkless/network/InklessConnectionUpgradeTracker.java @@ -0,0 +1,51 @@ +// Copyright (c) 2025 Aiven, Helsinki, Finland. https://aiven.io/ +package io.aiven.inkless.network; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.ConcurrentHashMap; + +import io.aiven.inkless.control_plane.MetadataView; + +/** + * Tracks whether a connection is Inkless-upgraded. + * + *

For each upgraded connection, remembers the correlation ID at the moment of upgrade (used later for response queueing). + */ +public class InklessConnectionUpgradeTracker { + private static final Logger LOGGER = LoggerFactory.getLogger(InklessConnectionUpgradeTracker.class); + + private final MetadataView metadataView; + // Key: connection ID, value: correlation ID at the moment of upgrade + private final ConcurrentHashMap upgradedConnection = new ConcurrentHashMap<>(); + + public InklessConnectionUpgradeTracker(final MetadataView metadataView) { + this.metadataView = metadataView; + } + + public boolean isConnectionUpgraded(final String connectionId) { + return upgradedConnection.containsKey(connectionId); + } + + public int upgradeCorrelationId(final String connectionId) { + final Integer result = upgradedConnection.get(connectionId); + if (result != null) { + return result; + } else { + throw new IllegalStateException(); + } + } + + public void upgradeConnection(final String connectionId, final int correlationId) { + upgradedConnection.computeIfAbsent(connectionId, ignore -> { + LOGGER.info("Upgrading connection {}", connectionId); + return correlationId; + }); + } + + public void closeConnection(final String connectionId) { + // TODO call + upgradedConnection.remove(connectionId); + } +}