Skip to content

Commit

Permalink
use descriptor instead of full stream as map key
Browse files Browse the repository at this point in the history
  • Loading branch information
edgao committed Sep 16, 2024
1 parent ce7478f commit 10cea0b
Show file tree
Hide file tree
Showing 12 changed files with 64 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,13 @@ data class DestinationStream(
val minimumGenerationId: Long,
val syncId: Long,
) {
data class Descriptor(val namespace: String, val name: String) {
data class Descriptor(val namespace: String?, val name: String) {
fun asProtocolObject(): StreamDescriptor =
StreamDescriptor().withNamespace(namespace).withName(name)
StreamDescriptor().withName(name).also {
if (namespace != null) {
it.namespace = namespace
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ sealed interface DestinationMessage {

/** Records. */
sealed interface DestinationStreamAffinedMessage : DestinationMessage {
val stream: DestinationStream
val stream: DestinationStream.Descriptor
}

data class DestinationRecord(
override val stream: DestinationStream,
override val stream: DestinationStream.Descriptor,
val data: AirbyteValue,
val emittedAtMs: Long,
val meta: Meta?,
Expand Down Expand Up @@ -70,8 +70,8 @@ data class DestinationRecord(
.withType(AirbyteMessage.Type.RECORD)
.withRecord(
AirbyteRecordMessage()
.withStream(stream.descriptor.name)
.withNamespace(stream.descriptor.namespace)
.withStream(stream.name)
.withNamespace(stream.namespace)
.withEmittedAt(emittedAtMs)
.withData(AirbyteValueToJson().convert(data))
.also {
Expand All @@ -83,7 +83,7 @@ data class DestinationRecord(
}

private fun statusToProtocolMessage(
stream: DestinationStream,
stream: DestinationStream.Descriptor,
emittedAtMs: Long,
status: AirbyteStreamStatus,
): AirbyteMessage =
Expand All @@ -95,21 +95,21 @@ private fun statusToProtocolMessage(
.withEmittedAt(emittedAtMs.toDouble())
.withStreamStatus(
AirbyteStreamStatusTraceMessage()
.withStreamDescriptor(stream.descriptor.asProtocolObject())
.withStreamDescriptor(stream.asProtocolObject())
.withStatus(status)
)
)

data class DestinationStreamComplete(
override val stream: DestinationStream,
override val stream: DestinationStream.Descriptor,
val emittedAtMs: Long,
) : DestinationStreamAffinedMessage {
override fun asProtocolMessage(): AirbyteMessage =
statusToProtocolMessage(stream, emittedAtMs, AirbyteStreamStatus.COMPLETE)
}

data class DestinationStreamIncomplete(
override val stream: DestinationStream,
override val stream: DestinationStream.Descriptor,
val emittedAtMs: Long,
) : DestinationStreamAffinedMessage {
override fun asProtocolMessage(): AirbyteMessage =
Expand All @@ -120,12 +120,12 @@ data class DestinationStreamIncomplete(
sealed interface CheckpointMessage : DestinationMessage {
data class Stats(val recordCount: Long)
data class Checkpoint(
val stream: DestinationStream,
val stream: DestinationStream.Descriptor,
val state: JsonNode,
) {
fun asProtocolObject(): AirbyteStreamState =
AirbyteStreamState()
.withStreamDescriptor(stream.descriptor.asProtocolObject())
.withStreamDescriptor(stream.asProtocolObject())
.withStreamState(state)
}

Expand Down Expand Up @@ -216,7 +216,7 @@ class DestinationMessageFactory(private val catalog: DestinationCatalog) {
name = message.record.stream,
)
DestinationRecord(
stream = stream,
stream = stream.descriptor,
data = JsonToAirbyteValue().convert(message.record.data, stream.schema),
emittedAtMs = message.record.emittedAt,
meta =
Expand Down Expand Up @@ -244,9 +244,15 @@ class DestinationMessageFactory(private val catalog: DestinationCatalog) {
if (message.trace.type == AirbyteTraceMessage.Type.STREAM_STATUS) {
when (status.status) {
AirbyteStreamStatus.COMPLETE ->
DestinationStreamComplete(stream, message.trace.emittedAt.toLong())
DestinationStreamComplete(
stream.descriptor,
message.trace.emittedAt.toLong()
)
AirbyteStreamStatus.INCOMPLETE ->
DestinationStreamIncomplete(stream, message.trace.emittedAt.toLong())
DestinationStreamIncomplete(
stream.descriptor,
message.trace.emittedAt.toLong()
)
else -> Undefined
}
} else {
Expand Down Expand Up @@ -280,7 +286,7 @@ class DestinationMessageFactory(private val catalog: DestinationCatalog) {
private fun fromAirbyteStreamState(streamState: AirbyteStreamState): Checkpoint {
val descriptor = streamState.streamDescriptor
return Checkpoint(
stream = catalog.getStream(namespace = descriptor.namespace, name = descriptor.name),
stream = DestinationStream.Descriptor(descriptor.namespace, descriptor.name),
state = streamState.streamState
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class DestinationMessageQueue(
config: WriteConfiguration,
private val memoryManager: MemoryManager,
private val queueChannelFactory: QueueChannelFactory<DestinationRecordWrapped>
) : MessageQueue<DestinationStream, DestinationRecordWrapped> {
) : MessageQueue<DestinationStream.Descriptor, DestinationRecordWrapped> {
private val channels:
ConcurrentHashMap<DestinationStream.Descriptor, QueueChannel<DestinationRecordWrapped>> =
ConcurrentHashMap()
Expand Down Expand Up @@ -89,12 +89,10 @@ class DestinationMessageQueue(
}

override suspend fun getChannel(
key: DestinationStream,
key: DestinationStream.Descriptor,
): QueueChannel<DestinationRecordWrapped> {
return channels[key.descriptor]
?: throw IllegalArgumentException(
"Reading from non-existent QueueChannel: ${key.descriptor}"
)
return channels[key]
?: throw IllegalArgumentException("Reading from non-existent QueueChannel: ${key}")
}

private val log = KotlinLogging.logger {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class DefaultMessageConverter : MessageConverter<CheckpointMessage, AirbyteMessa
return AirbyteStreamState()
.withStreamDescriptor(
StreamDescriptor()
.withNamespace(checkpoint.stream.descriptor.namespace)
.withName(checkpoint.stream.descriptor.name)
.withNamespace(checkpoint.stream.namespace)
.withName(checkpoint.stream.name)
)
.withStreamState(checkpoint.state)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ class DestinationMessageQueueReader(
var totalBytesRead = 0L
var recordsRead = 0L
while (totalBytesRead < maxBytes) {
when (val wrapped = messageQueue.getChannel(key).receive()) {
when (val wrapped = messageQueue.getChannel(key.descriptor).receive()) {
is StreamRecordWrapped -> {
totalBytesRead += wrapped.sizeBytes
emit(wrapped)
}
is StreamCompleteWrapped -> {
messageQueue.getChannel(key).close()
messageQueue.getChannel(key.descriptor).close()
emit(wrapped)
log.info { "Read end-of-stream for $key" }
return@flow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ interface MessageQueueWriter<T : Any> {
)
class DestinationMessageQueueWriter(
private val catalog: DestinationCatalog,
private val messageQueue: MessageQueue<DestinationStream, DestinationRecordWrapped>,
private val messageQueue: MessageQueue<DestinationStream.Descriptor, DestinationRecordWrapped>,
private val streamsManager: StreamsManager,
private val checkpointManager: CheckpointManager<DestinationStream, CheckpointMessage>
private val checkpointManager:
CheckpointManager<DestinationStream.Descriptor, CheckpointMessage>
) : MessageQueueWriter<DestinationMessage> {
/**
* Deserialize and route the message to the appropriate channel.
Expand Down Expand Up @@ -89,14 +90,15 @@ class DestinationMessageQueueWriter(
is GlobalCheckpoint -> {
val streamWithIndexAndCount =
catalog.streams.map { stream ->
val manager = streamsManager.getManager(stream)
val manager = streamsManager.getManager(stream.descriptor)
val (currentIndex, countSinceLast) = manager.markCheckpoint()
Triple(stream, currentIndex, countSinceLast)
}
val totalCount = streamWithIndexAndCount.sumOf { it.third }
val messageWithCount =
message.withDestinationStats(CheckpointMessage.Stats(totalCount))
val streamIndexes = streamWithIndexAndCount.map { it.first to it.second }
val streamIndexes =
streamWithIndexAndCount.map { it.first.descriptor to it.second }
checkpointManager.addGlobalCheckpoint(streamIndexes, messageWithCount)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ interface CheckpointManager<K, T> {
* TODO: Ensure that checkpoint is flushed at the end, and require that all checkpoints be flushed
* before the destination can succeed.
*/
abstract class StreamsCheckpointManager<T, U>() : CheckpointManager<DestinationStream, T> {
abstract class StreamsCheckpointManager<T, U>() :
CheckpointManager<DestinationStream.Descriptor, T> {
private val log = KotlinLogging.logger {}

abstract val catalog: DestinationCatalog
Expand All @@ -48,18 +49,22 @@ abstract class StreamsCheckpointManager<T, U>() : CheckpointManager<DestinationS
abstract val outputConsumer: Consumer<U>

data class GlobalCheckpoint<T>(
val streamIndexes: List<Pair<DestinationStream, Long>>,
val streamIndexes: List<Pair<DestinationStream.Descriptor, Long>>,
val checkpointMessage: T
)

private val checkpointsAreGlobal: AtomicReference<Boolean?> = AtomicReference(null)
private val streamCheckpoints:
ConcurrentHashMap<DestinationStream, ConcurrentLinkedHashMap<Long, T>> =
ConcurrentHashMap<DestinationStream.Descriptor, ConcurrentLinkedHashMap<Long, T>> =
ConcurrentHashMap()
private val globalCheckpoints: ConcurrentLinkedQueue<GlobalCheckpoint<T>> =
ConcurrentLinkedQueue()

override fun addStreamCheckpoint(key: DestinationStream, index: Long, checkpointMessage: T) {
override fun addStreamCheckpoint(
key: DestinationStream.Descriptor,
index: Long,
checkpointMessage: T
) {
if (checkpointsAreGlobal.updateAndGet { it == true } != false) {
throw IllegalStateException(
"Global checkpoints cannot be mixed with non-global checkpoints"
Expand Down Expand Up @@ -93,7 +98,7 @@ abstract class StreamsCheckpointManager<T, U>() : CheckpointManager<DestinationS

// TODO: Is it an error if we don't get all the streams every time?
override fun addGlobalCheckpoint(
keyIndexes: List<Pair<DestinationStream, Long>>,
keyIndexes: List<Pair<DestinationStream.Descriptor, Long>>,
checkpointMessage: T
) {
if (checkpointsAreGlobal.updateAndGet { it != false } != true) {
Expand Down Expand Up @@ -149,8 +154,8 @@ abstract class StreamsCheckpointManager<T, U>() : CheckpointManager<DestinationS

private fun flushStreamCheckpoints() {
for (stream in catalog.streams) {
val manager = streamsManager.getManager(stream)
val streamCheckpoints = streamCheckpoints[stream] ?: return
val manager = streamsManager.getManager(stream.descriptor)
val streamCheckpoints = streamCheckpoints[stream.descriptor] ?: return
for (index in streamCheckpoints.keys) {
if (manager.areRecordsPersistedUntil(index)) {
val checkpointMessage =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ import kotlinx.coroutines.channels.Channel
/** Manages the state of all streams in the destination. */
interface StreamsManager {
/** Get the manager for the given stream. Throws an exception if the stream is not found. */
fun getManager(stream: DestinationStream): StreamManager
fun getManager(stream: DestinationStream.Descriptor): StreamManager

/** Suspend until all streams are closed. */
suspend fun awaitAllStreamsClosed()
}

class DefaultStreamsManager(
private val streamManagers: ConcurrentHashMap<DestinationStream, StreamManager>
private val streamManagers: ConcurrentHashMap<DestinationStream.Descriptor, StreamManager>
) : StreamsManager {
override fun getManager(stream: DestinationStream): StreamManager {
override fun getManager(stream: DestinationStream.Descriptor): StreamManager {
return streamManagers[stream] ?: throw IllegalArgumentException("Stream not found: $stream")
}

Expand Down Expand Up @@ -191,8 +191,8 @@ class StreamsManagerFactory(
) {
@Singleton
fun make(): StreamsManager {
val hashMap = ConcurrentHashMap<DestinationStream, StreamManager>()
catalog.streams.forEach { hashMap[it] = DefaultStreamManager(it) }
val hashMap = ConcurrentHashMap<DestinationStream.Descriptor, StreamManager>()
catalog.streams.forEach { hashMap[it.descriptor] = DefaultStreamManager(it) }
return DefaultStreamsManager(hashMap)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class CloseStreamTaskFactory(
fun make(taskLauncher: DestinationTaskLauncher, streamLoader: StreamLoader): CloseStreamTask {
return CloseStreamTask(
streamLoader,
streamsManager.getManager(streamLoader.stream),
streamsManager.getManager(streamLoader.stream.descriptor),
taskLauncher
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import jakarta.inject.Singleton
class DestinationTaskLauncher(
private val catalog: DestinationCatalog,
override val taskRunner: TaskRunner,
private val checkpointManager: CheckpointManager<DestinationStream, CheckpointMessage>,
private val checkpointManager:
CheckpointManager<DestinationStream.Descriptor, CheckpointMessage>,
private val setupTaskFactory: SetupTaskFactory,
private val openStreamTaskFactory: OpenStreamTaskFactory,
private val spillToDiskTaskFactory: SpillToDiskTaskFactory,
Expand Down Expand Up @@ -87,7 +88,8 @@ class DestinationTaskLauncher(
class DestinationTaskLauncherFactory(
private val catalog: DestinationCatalog,
private val taskRunner: TaskRunner,
private val checkpointManager: CheckpointManager<DestinationStream, CheckpointMessage>,
private val checkpointManager:
CheckpointManager<DestinationStream.Descriptor, CheckpointMessage>,
private val setupTaskFactory: SetupTaskFactory,
private val openStreamTaskFactory: OpenStreamTaskFactory,
private val spillToDiskTaskFactory: SpillToDiskTaskFactory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class ProcessBatchTaskFactory(
return ProcessBatchTask(
batchEnvelope,
streamLoader,
streamsManager.getManager(streamLoader.stream),
streamsManager.getManager(streamLoader.stream.descriptor),
taskLauncher
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class ProcessRecordsTaskFactory(
): ProcessRecordsTask {
return ProcessRecordsTask(
streamLoader,
streamsManager.getManager(streamLoader.stream),
streamsManager.getManager(streamLoader.stream.descriptor),
taskLauncher,
fileEnvelope,
deserializer,
Expand Down

0 comments on commit 10cea0b

Please sign in to comment.