diff --git a/core/src/main/scala/kafka/network/InklessSendQueue.java b/core/src/main/scala/kafka/network/InklessSendQueue.java new file mode 100644 index 0000000000..282a8860c1 --- /dev/null +++ b/core/src/main/scala/kafka/network/InklessSendQueue.java @@ -0,0 +1,53 @@ +// Copyright (c) 2025 Aiven, Helsinki, Finland. https://aiven.io/ +package kafka.network; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Comparator; +import java.util.PriorityQueue; + +/** + * The response queue for a connection. + * + *

This queue arranges responses by their correlation ID, expecting no gaps and the strict order. + */ +class InklessSendQueue { + private static final Logger LOGGER = LoggerFactory.getLogger(InklessSendQueue.class); + + private static final Comparator CORRELATION_ID_COMPARATOR = + Comparator.comparing((RequestChannel.SendResponse r) -> r.request().header().correlationId()); + private final PriorityQueue queue = new PriorityQueue<>(CORRELATION_ID_COMPARATOR); + private int nextCorrelationId; + + InklessSendQueue(final int startCorrelationId) { +// LOGGER.info("Starting with correlation ID {}", startCorrelationId); + this.nextCorrelationId = startCorrelationId; + } + + void add(final RequestChannel.SendResponse response) { +// LOGGER.info("Adding response with correlation ID {}", response.request().header().correlationId()); + if (response.request().header().correlationId() < nextCorrelationId) { + throw new IllegalStateException("Expected min correlation ID " + nextCorrelationId); + } + queue.add(response); + } + + boolean nextReady() { + final RequestChannel.SendResponse peeked = queue.peek(); + if (peeked == null) { + return false; + } + final int correlationId = peeked.request().header().correlationId(); +// LOGGER.info("Peeked correlation ID {}, expecting {}", peeked.request().header().correlationId(), nextCorrelationId); + return correlationId == nextCorrelationId; + } + + RequestChannel.SendResponse take() { + if (!nextReady()) { + throw new IllegalStateException(); + } + nextCorrelationId += 1; + return queue.remove(); + } +} diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index 093f5298da..dc65fffc3d 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -17,6 +17,8 @@ package kafka.network +import io.aiven.inkless.network.InklessConnectionUpgradeTracker + import java.io.IOException import java.net._ import java.nio.ByteBuffer @@ -37,7 +39,7 @@ import org.apache.kafka.common.errors.InvalidRequestException import org.apache.kafka.common.memory.{MemoryPool, SimpleMemoryPool} import org.apache.kafka.common.metrics._ import org.apache.kafka.common.metrics.stats.{Avg, CumulativeSum, Meter, Rate} -import org.apache.kafka.common.network.KafkaChannel.ChannelMuteEvent +import org.apache.kafka.common.network.KafkaChannel.{ChannelMuteEvent, ChannelMuteState} import org.apache.kafka.common.network.{ChannelBuilder, ChannelBuilders, ClientInformation, KafkaChannel, ListenerName, ListenerReconfigurable, NetworkSend, Selectable, Send, ServerConnectionId, Selector => KSelector} import org.apache.kafka.common.protocol.ApiKeys import org.apache.kafka.common.requests.{ApiVersionsRequest, RequestContext, RequestHeader} @@ -84,7 +86,8 @@ class SocketServer( val credentialProvider: CredentialProvider, val apiVersionManager: ApiVersionManager, val socketFactory: ServerSocketFactory = ServerSocketFactory.INSTANCE, - val connectionDisconnectListeners: Seq[ConnectionDisconnectListener] = Seq.empty + val connectionDisconnectListeners: Seq[ConnectionDisconnectListener] = Seq.empty, + val inklessConnectionUpgradeTracker: Option[InklessConnectionUpgradeTracker] = None ) extends Logging with BrokerReconfigurable { private val metricsGroup = new KafkaMetricsGroup(this.getClass) @@ -274,11 +277,11 @@ class SocketServer( private def endpoints = config.listeners.map(l => l.listenerName -> l).toMap protected def createDataPlaneAcceptor(endPoint: EndPoint, isPrivilegedListener: Boolean, requestChannel: RequestChannel): DataPlaneAcceptor = { - new DataPlaneAcceptor(this, endPoint, config, nodeId, connectionQuotas, time, isPrivilegedListener, requestChannel, metrics, credentialProvider, logContext, memoryPool, apiVersionManager) + new DataPlaneAcceptor(this, endPoint, config, nodeId, connectionQuotas, time, isPrivilegedListener, requestChannel, metrics, credentialProvider, logContext, memoryPool, apiVersionManager, inklessConnectionUpgradeTracker) } private def createControlPlaneAcceptor(endPoint: EndPoint, requestChannel: RequestChannel): ControlPlaneAcceptor = { - new ControlPlaneAcceptor(this, endPoint, config, nodeId, connectionQuotas, time, requestChannel, metrics, credentialProvider, logContext, memoryPool, apiVersionManager) + new ControlPlaneAcceptor(this, endPoint, config, nodeId, connectionQuotas, time, requestChannel, metrics, credentialProvider, logContext, memoryPool, apiVersionManager, inklessConnectionUpgradeTracker) } /** @@ -442,7 +445,8 @@ class DataPlaneAcceptor(socketServer: SocketServer, credentialProvider: CredentialProvider, logContext: LogContext, memoryPool: MemoryPool, - apiVersionManager: ApiVersionManager) + apiVersionManager: ApiVersionManager, + inklessConnectionUpgradeTracker: Option[InklessConnectionUpgradeTracker] = None) extends Acceptor(socketServer, endPoint, config, @@ -455,7 +459,8 @@ class DataPlaneAcceptor(socketServer: SocketServer, credentialProvider, logContext, memoryPool, - apiVersionManager) with ListenerReconfigurable { + apiVersionManager, + inklessConnectionUpgradeTracker) with ListenerReconfigurable { override def metricPrefix(): String = DataPlaneAcceptor.MetricPrefix override def threadPrefix(): String = DataPlaneAcceptor.ThreadPrefix @@ -544,7 +549,8 @@ class ControlPlaneAcceptor(socketServer: SocketServer, credentialProvider: CredentialProvider, logContext: LogContext, memoryPool: MemoryPool, - apiVersionManager: ApiVersionManager) + apiVersionManager: ApiVersionManager, + inklessConnectionUpgradeTracker: Option[InklessConnectionUpgradeTracker] = None) extends Acceptor(socketServer, endPoint, config, @@ -557,7 +563,8 @@ class ControlPlaneAcceptor(socketServer: SocketServer, credentialProvider, logContext, memoryPool, - apiVersionManager) { + apiVersionManager, + inklessConnectionUpgradeTracker) { override def metricPrefix(): String = ControlPlaneAcceptor.MetricPrefix override def threadPrefix(): String = ControlPlaneAcceptor.ThreadPrefix @@ -579,7 +586,8 @@ private[kafka] abstract class Acceptor(val socketServer: SocketServer, credentialProvider: CredentialProvider, logContext: LogContext, memoryPool: MemoryPool, - apiVersionManager: ApiVersionManager) + apiVersionManager: ApiVersionManager, + inklessConnectionUpgradeTracker: Option[InklessConnectionUpgradeTracker] = None) extends Runnable with Logging { private val metricsGroup = new KafkaMetricsGroup(this.getClass) @@ -879,7 +887,8 @@ private[kafka] abstract class Acceptor(val socketServer: SocketServer, isPrivilegedListener, apiVersionManager, name, - connectionDisconnectListeners) + connectionDisconnectListeners, + inklessConnectionUpgradeTracker) } } @@ -919,7 +928,8 @@ private[kafka] class Processor( isPrivilegedListener: Boolean, apiVersionManager: ApiVersionManager, threadName: String, - connectionDisconnectListeners: Seq[ConnectionDisconnectListener] + connectionDisconnectListeners: Seq[ConnectionDisconnectListener], + inklessConnectionUpgradeTracker: Option[InklessConnectionUpgradeTracker] = None ) extends Runnable with Logging { private val metricsGroup = new KafkaMetricsGroup(this.getClass) @@ -990,6 +1000,8 @@ private[kafka] class Processor( // closed, connection ids are not reused while requests from the closed connection are being processed. private var nextConnectionIndex = 0 + private var inklessSendQueueByChannel: immutable.Map[String, InklessSendQueue] = immutable.Map.empty + override def run(): Unit = { try { while (shouldRun.get()) { @@ -1052,7 +1064,17 @@ private[kafka] class Processor( tryUnmuteChannel(channelId) case response: SendResponse => - sendResponse(response, response.responseSend) + val connectionId = response.request.context.connectionId + val upgradedConnection = inklessConnectionUpgradeTracker.nonEmpty && inklessConnectionUpgradeTracker.get.isConnectionUpgraded(connectionId) + if (upgradedConnection) { + val connectionId = response.request.context.connectionId + if (!inklessSendQueueByChannel.contains(connectionId)) { + inklessSendQueueByChannel += connectionId -> new InklessSendQueue(inklessConnectionUpgradeTracker.get.upgradeCorrelationId(connectionId)) + } + inklessSendQueueByChannel(connectionId).add(response) + } else { + sendResponse(response, response.responseSend) + } case response: CloseConnectionResponse => updateRequestMetrics(response) trace("Closing socket connection actively according to the response code.") @@ -1072,6 +1094,20 @@ private[kafka] class Processor( processChannelException(channelId, s"Exception while processing response for $channelId", e) } } + + // Process responses for Inkless upgraded connections. + for ((connectionId, queue) <- inklessSendQueueByChannel) { + openOrClosingChannel(connectionId) match { + case Some(channel) => + if (queue.nextReady() && !channel.hasSend) { + val response = queue.take() + sendResponse(response, response.responseSend) + } + + case None => + inklessSendQueueByChannel -= connectionId + } + } } // `protected` for test usage @@ -1148,8 +1184,11 @@ private[kafka] class Processor( } } requestChannel.sendRequest(req) - selector.mute(connectionId) - handleChannelMuteEvent(connectionId, ChannelMuteEvent.REQUEST_RECEIVED) + val upgradedConnection = inklessConnectionUpgradeTracker.nonEmpty && inklessConnectionUpgradeTracker.get.isConnectionUpgraded(connectionId) + if (!upgradedConnection) { + selector.mute(connectionId) + handleChannelMuteEvent(connectionId, ChannelMuteEvent.REQUEST_RECEIVED) + } } } case None => @@ -1181,7 +1220,19 @@ private[kafka] class Processor( // Try unmuting the channel. If there was no quota violation and the channel has not been throttled, // it will be unmuted immediately. If the channel has been throttled, it will unmuted only if the throttling // delay has already passed by now. - handleChannelMuteEvent(send.destinationId, ChannelMuteEvent.RESPONSE_SENT) + val connectionId = response.request.context.connectionId() + val upgradedConnection = inklessConnectionUpgradeTracker.nonEmpty && inklessConnectionUpgradeTracker.get.isConnectionUpgraded(connectionId) + if (upgradedConnection) { + openOrClosingChannel(connectionId).foreach{ channel => + // Imitate muting to prevent illegal state errors when `tryUnmuteChannel` is called. + if (channel.muteState() == ChannelMuteState.MUTED_AND_RESPONSE_PENDING) { + handleChannelMuteEvent(send.destinationId, ChannelMuteEvent.RESPONSE_SENT) + } + selector.mute(channel.id) + } + } else { + handleChannelMuteEvent(send.destinationId, ChannelMuteEvent.RESPONSE_SENT) + } tryUnmuteChannel(send.destinationId) } catch { case e: Throwable => processChannelException(send.destinationId, diff --git a/core/src/main/scala/kafka/server/BrokerServer.scala b/core/src/main/scala/kafka/server/BrokerServer.scala index d684041d95..15e67babad 100644 --- a/core/src/main/scala/kafka/server/BrokerServer.scala +++ b/core/src/main/scala/kafka/server/BrokerServer.scala @@ -18,6 +18,7 @@ package kafka.server import io.aiven.inkless.common.SharedState +import io.aiven.inkless.network.InklessConnectionUpgradeTracker import kafka.coordinator.group.{CoordinatorLoaderImpl, CoordinatorPartitionWriter, GroupCoordinatorAdapter} import kafka.coordinator.transaction.TransactionCoordinator import kafka.log.LogManager @@ -262,6 +263,23 @@ class BrokerServer( Some(clientMetricsManager) ) + val inklessMetadataView = new InklessMetadataView(metadataCache) + val inklessConnectionUpgradeTracker = new InklessConnectionUpgradeTracker(inklessMetadataView) + val inklessSharedState = sharedServer.inklessControlPlane.map { controlPlane => + SharedState.initialize( + time, + clusterId, + config.rack.orNull, + config.brokerId, + config.inklessConfig, + inklessMetadataView, + controlPlane, + brokerTopicStats, + () => logManager.currentDefaultConfig, + inklessConnectionUpgradeTracker + ) + } + val connectionDisconnectListeners = Seq(clientMetricsManager.connectionDisconnectListener()) // Create and start the socket server acceptor threads so that the bound port is known. // Delay starting processors until the end of the initialization sequence to ensure @@ -272,7 +290,8 @@ class BrokerServer( credentialProvider, apiVersionManager, sharedServer.socketFactory, - connectionDisconnectListeners) + connectionDisconnectListeners, + Some(inklessConnectionUpgradeTracker)) clientQuotaMetadataManager = new ClientQuotaMetadataManager(quotaManagers, socketServer.connectionQuotas) @@ -335,21 +354,6 @@ class BrokerServer( */ val defaultActionQueue = new DelayedActionQueue - val inklessMetadataView = new InklessMetadataView(metadataCache) - val inklessSharedState = sharedServer.inklessControlPlane.map { controlPlane => - SharedState.initialize( - time, - clusterId, - config.rack.orNull, - config.brokerId, - config.inklessConfig, - inklessMetadataView, - controlPlane, - brokerTopicStats, - () => logManager.currentDefaultConfig - ) - } - this._replicaManager = new ReplicaManager( config = config, metrics = metrics, diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index b0541bc983..9d535d638f 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -189,6 +189,13 @@ class KafkaApis(val requestChannel: RequestChannel, requestHelper.handleError(request, e) } + if (inklessSharedState.forall(st => st.inklessConnectionUpgradeTracker().isConnectionUpgraded(request.context.connectionId())) + && request.header.apiKey != ApiKeys.PRODUCE) { + logger.error("Received unexpected request with API key {} in Inkless-upgraded connection", request.header.apiKey) + handleError(new RuntimeException()) + return + } + try { trace(s"Handling request:${request.requestDesc(true)} from connection ${request.context.connectionId};" + s"securityProtocol:${request.context.securityProtocol},principal:${request.context.principal}") @@ -757,7 +764,8 @@ class KafkaApis(val requestChannel: RequestChannel, responseCallback = sendResponseCallback, recordValidationStatsCallback = processingStatsCallback, requestLocal = requestLocal, - transactionSupportedOperation = transactionSupportedOperation) + transactionSupportedOperation = transactionSupportedOperation, + connectionIdAndCorrelationId = Some((request.context.connectionId(), request.header.correlationId()))) // if the request is put into the purgatory, it will have a held reference and hence cannot be garbage collected; // hence we clear its data here in order to let GC reclaim its memory since it is already appended to log diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala index f8ca7f8ae0..55c4c71729 100644 --- a/core/src/main/scala/kafka/server/ReplicaManager.scala +++ b/core/src/main/scala/kafka/server/ReplicaManager.scala @@ -21,6 +21,7 @@ import io.aiven.inkless.common.SharedState import io.aiven.inkless.consume.{FetchInterceptor, FetchOffsetInterceptor} import io.aiven.inkless.delete.{DeleteRecordsInterceptor, FileCleaner} import io.aiven.inkless.merge.FileMerger +import io.aiven.inkless.network.InklessConnectionUpgradeTracker import io.aiven.inkless.produce.AppendInterceptor import kafka.cluster.{Partition, PartitionListener} import kafka.controller.{KafkaController, StateChangeLogger} @@ -332,6 +333,7 @@ class ReplicaManager(val config: KafkaConfig, private val inklessDeleteRecordsInterceptor: Option[DeleteRecordsInterceptor] = inklessSharedState.map(new DeleteRecordsInterceptor(_)) private val inklessFileCleaner: Option[FileCleaner] = inklessSharedState.map(new FileCleaner(_)) private val inklessFileMerger: Option[FileMerger] = inklessSharedState.map(new FileMerger(_)) + private val inklessConnectionUpgradeTracker: Option[InklessConnectionUpgradeTracker] = inklessSharedState.map(_.inklessConnectionUpgradeTracker()) /* epoch of the controller that last changed the leader */ @volatile private[server] var controllerEpoch: Int = KafkaController.InitialControllerEpoch @@ -828,13 +830,19 @@ class ReplicaManager(val config: KafkaConfig, recordValidationStatsCallback: Map[TopicPartition, RecordValidationStats] => Unit = _ => (), requestLocal: RequestLocal = RequestLocal.noCaching, actionQueue: ActionQueue = this.defaultActionQueue, - verificationGuards: Map[TopicPartition, VerificationGuard] = Map.empty): Unit = { + verificationGuards: Map[TopicPartition, VerificationGuard] = Map.empty, + connectionIdAndCorrelationId: Option[(String, Int)] = None): Unit = { if (!isValidRequiredAcks(requiredAcks)) { sendInvalidRequiredAcksResponse(entriesPerPartition, responseCallback) return } if (inklessAppendInterceptor.exists(_.intercept(entriesPerPartition.asJava, r => responseCallback(r.asScala)))) { + inklessConnectionUpgradeTracker.foreach(tr => + connectionIdAndCorrelationId.foreach { case (connectionId, correlationId) => + tr.upgradeConnection(connectionId, correlationId) + } + ) return } @@ -890,7 +898,8 @@ class ReplicaManager(val config: KafkaConfig, recordValidationStatsCallback: Map[TopicPartition, RecordValidationStats] => Unit = _ => (), requestLocal: RequestLocal = RequestLocal.noCaching, actionQueue: ActionQueue = this.defaultActionQueue, - transactionSupportedOperation: TransactionSupportedOperation): Unit = { + transactionSupportedOperation: TransactionSupportedOperation, + connectionIdAndCorrelationId: Option[(String, Int)] = None): Unit = { val transactionalProducerInfo = mutable.HashSet[(Long, Short)]() val topicPartitionBatchInfo = mutable.Map[TopicPartition, Int]() @@ -949,7 +958,8 @@ class ReplicaManager(val config: KafkaConfig, recordValidationStatsCallback = recordValidationStatsCallback, requestLocal = newRequestLocal, actionQueue = actionQueue, - verificationGuards = verificationGuards + verificationGuards = verificationGuards, + connectionIdAndCorrelationId = connectionIdAndCorrelationId, ) } diff --git a/storage/inkless/src/main/java/io/aiven/inkless/common/SharedState.java b/storage/inkless/src/main/java/io/aiven/inkless/common/SharedState.java index 47af56b99f..d3a8614cd7 100644 --- a/storage/inkless/src/main/java/io/aiven/inkless/common/SharedState.java +++ b/storage/inkless/src/main/java/io/aiven/inkless/common/SharedState.java @@ -16,6 +16,7 @@ import io.aiven.inkless.config.InklessConfig; import io.aiven.inkless.control_plane.ControlPlane; import io.aiven.inkless.control_plane.MetadataView; +import io.aiven.inkless.network.InklessConnectionUpgradeTracker; import io.aiven.inkless.storage_backend.common.StorageBackend; public record SharedState( @@ -29,7 +30,8 @@ public record SharedState( KeyAlignmentStrategy keyAlignmentStrategy, ObjectCache cache, BrokerTopicStats brokerTopicStats, - Supplier defaultTopicConfigs + Supplier defaultTopicConfigs, + InklessConnectionUpgradeTracker inklessConnectionUpgradeTracker ) implements Closeable { public static SharedState initialize( @@ -41,7 +43,8 @@ public static SharedState initialize( MetadataView metadata, ControlPlane controlPlane, BrokerTopicStats brokerTopicStats, - Supplier defaultTopicConfigs + Supplier defaultTopicConfigs, + InklessConnectionUpgradeTracker inklessConnectionUpgradeTracker ) { return new SharedState( time, @@ -54,7 +57,8 @@ public static SharedState initialize( new FixedBlockAlignment(config.fetchCacheBlockBytes()), new InfinispanCache(time, clusterId, rack), brokerTopicStats, - defaultTopicConfigs + defaultTopicConfigs, + inklessConnectionUpgradeTracker ); } diff --git a/storage/inkless/src/main/java/io/aiven/inkless/network/InklessConnectionUpgradeTracker.java b/storage/inkless/src/main/java/io/aiven/inkless/network/InklessConnectionUpgradeTracker.java new file mode 100644 index 0000000000..c3ddd7b8fd --- /dev/null +++ b/storage/inkless/src/main/java/io/aiven/inkless/network/InklessConnectionUpgradeTracker.java @@ -0,0 +1,51 @@ +// Copyright (c) 2025 Aiven, Helsinki, Finland. https://aiven.io/ +package io.aiven.inkless.network; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.ConcurrentHashMap; + +import io.aiven.inkless.control_plane.MetadataView; + +/** + * Tracks whether a connection is Inkless-upgraded. + * + *

For each upgraded connection, remembers the correlation ID at the moment of upgrade (used later for response queueing). + */ +public class InklessConnectionUpgradeTracker { + private static final Logger LOGGER = LoggerFactory.getLogger(InklessConnectionUpgradeTracker.class); + + private final MetadataView metadataView; + // Key: connection ID, value: correlation ID at the moment of upgrade + private final ConcurrentHashMap upgradedConnection = new ConcurrentHashMap<>(); + + public InklessConnectionUpgradeTracker(final MetadataView metadataView) { + this.metadataView = metadataView; + } + + public boolean isConnectionUpgraded(final String connectionId) { + return upgradedConnection.containsKey(connectionId); + } + + public int upgradeCorrelationId(final String connectionId) { + final Integer result = upgradedConnection.get(connectionId); + if (result != null) { + return result; + } else { + throw new IllegalStateException(); + } + } + + public void upgradeConnection(final String connectionId, final int correlationId) { + upgradedConnection.computeIfAbsent(connectionId, ignore -> { + LOGGER.info("Upgrading connection {}", connectionId); + return correlationId; + }); + } + + public void closeConnection(final String connectionId) { + // TODO call + upgradedConnection.remove(connectionId); + } +}