From 1d44fd15f8e550f99d1dc86870fd32f76d3f9669 Mon Sep 17 00:00:00 2001 From: James Rich <2199651+jamesarich@users.noreply.github.com> Date: Sun, 28 Jun 2026 10:19:36 -0500 Subject: [PATCH] feat(network): migrate TcpTransport to ktor-network (commonMain) (#5995) Co-authored-by: Claude Sonnet 4.6 --- core/network/build.gradle.kts | 1 + .../core/network/radio/TcpRadioTransport.kt | 0 .../network/transport/StreamFrameCodec.kt | 6 +- .../core/network/transport/TcpTransport.kt | 191 ++++++++++------ .../network/transport/TcpTransportTest.kt | 208 ++++++++++++++++++ 5 files changed, 334 insertions(+), 72 deletions(-) rename core/network/src/{jvmAndroidMain => commonMain}/kotlin/org/meshtastic/core/network/radio/TcpRadioTransport.kt (100%) rename core/network/src/{jvmAndroidMain => commonMain}/kotlin/org/meshtastic/core/network/transport/TcpTransport.kt (66%) create mode 100644 core/network/src/commonTest/kotlin/org/meshtastic/core/network/transport/TcpTransportTest.kt diff --git a/core/network/build.gradle.kts b/core/network/build.gradle.kts index f195951e1..c9baaee85 100644 --- a/core/network/build.gradle.kts +++ b/core/network/build.gradle.kts @@ -47,6 +47,7 @@ kotlin { implementation(libs.kotlinx.serialization.json) implementation(libs.kotlinx.atomicfu) implementation(libs.ktor.client.core) + implementation(libs.ktor.network) // raw TCP sockets for TcpTransport (KMP-common) implementation(libs.ktor.client.content.negotiation) implementation(libs.ktor.client.logging) implementation(libs.ktor.serialization.kotlinx.json) diff --git a/core/network/src/jvmAndroidMain/kotlin/org/meshtastic/core/network/radio/TcpRadioTransport.kt b/core/network/src/commonMain/kotlin/org/meshtastic/core/network/radio/TcpRadioTransport.kt similarity index 100% rename from core/network/src/jvmAndroidMain/kotlin/org/meshtastic/core/network/radio/TcpRadioTransport.kt rename to core/network/src/commonMain/kotlin/org/meshtastic/core/network/radio/TcpRadioTransport.kt diff --git a/core/network/src/commonMain/kotlin/org/meshtastic/core/network/transport/StreamFrameCodec.kt b/core/network/src/commonMain/kotlin/org/meshtastic/core/network/transport/StreamFrameCodec.kt index 31483cb16..43a6810f9 100644 --- a/core/network/src/commonMain/kotlin/org/meshtastic/core/network/transport/StreamFrameCodec.kt +++ b/core/network/src/commonMain/kotlin/org/meshtastic/core/network/transport/StreamFrameCodec.kt @@ -114,7 +114,11 @@ class StreamFrameCodec( * * Thread-safe via an internal mutex — multiple callers can call this concurrently. */ - suspend fun frameAndSend(payload: ByteArray, sendBytes: (ByteArray) -> Unit, flush: () -> Unit = {}) { + suspend fun frameAndSend( + payload: ByteArray, + sendBytes: suspend (ByteArray) -> Unit, + flush: suspend () -> Unit = {}, + ) { writeMutex.withLock { val header = ByteArray(HEADER_SIZE) header[0] = START1 diff --git a/core/network/src/jvmAndroidMain/kotlin/org/meshtastic/core/network/transport/TcpTransport.kt b/core/network/src/commonMain/kotlin/org/meshtastic/core/network/transport/TcpTransport.kt similarity index 66% rename from core/network/src/jvmAndroidMain/kotlin/org/meshtastic/core/network/transport/TcpTransport.kt rename to core/network/src/commonMain/kotlin/org/meshtastic/core/network/transport/TcpTransport.kt index df858ea90..7e36cacbc 100644 --- a/core/network/src/jvmAndroidMain/kotlin/org/meshtastic/core/network/transport/TcpTransport.kt +++ b/core/network/src/commonMain/kotlin/org/meshtastic/core/network/transport/TcpTransport.kt @@ -17,23 +17,31 @@ package org.meshtastic.core.network.transport import co.touchlab.kermit.Logger +import io.ktor.network.selector.SelectorManager +import io.ktor.network.sockets.InetSocketAddress +import io.ktor.network.sockets.Socket +import io.ktor.network.sockets.aSocket +import io.ktor.network.sockets.openReadChannel +import io.ktor.network.sockets.openWriteChannel +import io.ktor.utils.io.ByteReadChannel +import io.ktor.utils.io.ByteWriteChannel +import io.ktor.utils.io.readAvailable +import io.ktor.utils.io.writeFully +import kotlinx.atomicfu.atomic +import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Job +import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.delay import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout +import kotlinx.coroutines.withTimeoutOrNull +import kotlinx.io.IOException import org.meshtastic.core.common.util.handledLaunch import org.meshtastic.core.common.util.nowMillis import org.meshtastic.core.di.CoroutineDispatchers import org.meshtastic.proto.ToRadio -import java.io.BufferedInputStream -import java.io.BufferedOutputStream -import java.io.IOException -import java.io.OutputStream -import java.net.InetAddress -import java.net.Socket -import java.net.SocketTimeoutException -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.atomic.AtomicInteger +import kotlin.concurrent.Volatile /** * Decides whether to reset the reconnect backoff based on session data and uptime. @@ -46,13 +54,14 @@ internal fun shouldResetBackoff(hadData: Boolean, sessionUptimeMs: Long, thresho hadData && sessionUptimeMs >= thresholdMs /** - * Shared JVM TCP transport for Meshtastic radios. + * Shared TCP transport for Meshtastic radios. * * Manages the TCP socket lifecycle (connect, read loop, reconnect with backoff) and uses [StreamFrameCodec] for the * START1/START2 stream framing protocol. [sendHeartbeat] sends a heartbeat with a monotonically-increasing nonce so the * firmware's per-connection duplicate-write filter does not silently drop it. * - * Used by Android and Desktop via the shared `SharedRadioInterfaceService`. + * Uses Ktor raw sockets (`ktor-network`) so the implementation is KMP-common — shared by Android, Desktop, and (once + * wired) iOS via the shared `SharedRadioInterfaceService`. */ @Suppress("TooManyFunctions", "MagicNumber") class TcpTransport( @@ -82,9 +91,15 @@ class TcpTransport( const val MAX_RECONNECT_RETRIES = Int.MAX_VALUE const val MIN_BACKOFF_MILLIS = 1_000L const val MAX_BACKOFF_MILLIS = 5 * 60 * 1_000L - const val SOCKET_TIMEOUT_MS = 5_000 + + /** Per-read inactivity timeout. Combined with [SOCKET_RETRIES] this gives the 90s idle-disconnect window. */ + const val SOCKET_TIMEOUT_MS = 5_000L const val SOCKET_RETRIES = 18 // 18 * 5s = 90s inactivity before disconnect const val TIMEOUT_LOG_INTERVAL = 5 + + /** TCP connect timeout. A failed connect just feeds the reconnect/backoff loop, so it is not fatal. */ + const val CONNECT_TIMEOUT_MS = 30_000L + private const val READ_BUFFER_SIZE = 1024 private const val MILLIS_PER_SECOND = 1_000L /** @@ -106,14 +121,18 @@ class TcpTransport( ) // TCP socket state + @Volatile private var selectorManager: SelectorManager? = null + @Volatile private var socket: Socket? = null - @Volatile private var outStream: OutputStream? = null + @Volatile private var writeChannel: ByteWriteChannel? = null @Volatile private var connectionJob: Job? = null @Volatile private var currentAddress: String? = null + @Volatile private var connected: Boolean = false + // Metrics @Volatile private var connectionStartTime: Long = 0 @@ -127,14 +146,11 @@ class TcpTransport( @Volatile private var timeoutEvents: Int = 0 - private val heartbeatNonce = AtomicInteger(0) + private val heartbeatNonce = atomic(0) /** Whether the transport is currently connected. */ val isConnected: Boolean - get() { - val s = socket ?: return false - return s.isConnected && !s.isClosed - } + get() = connected && socket?.socketContext?.isActive == true /** * Start a TCP connection to the given address with automatic reconnect. @@ -184,10 +200,18 @@ class TcpTransport( val hadData = try { connectAndRead(address) - } catch (ex: IOException) { - Logger.w { "$logTag: [$address] TCP connection error" } + } catch (ex: TimeoutCancellationException) { + Logger.w(ex) { "$logTag: [$address] TCP connect timed out" } disconnectSocket() false + } catch (ex: IOException) { + Logger.w(ex) { "$logTag: [$address] TCP connection error" } + disconnectSocket() + false + } catch (ce: CancellationException) { + // Outer-scope cancellation (stop()) — tear down and let it propagate to end the loop. + disconnectSocket() + throw ce } catch (@Suppress("TooGenericExceptionCaught") ex: Throwable) { Logger.e(ex) { "$logTag: [$address] TCP exception" } disconnectSocket() @@ -203,8 +227,9 @@ class TcpTransport( retryCount = 1 backoff = MIN_BACKOFF_MILLIS } else if (hadData) { + val backoffSec = backoff / MILLIS_PER_SECOND Logger.d { - "$logTag: [$address] Short session (${sessionUptime}ms) — keeping backoff at ${backoff / MILLIS_PER_SECOND}s" + "$logTag: [$address] Short session (${sessionUptime}ms) — keeping backoff at ${backoffSec}s" } } @@ -229,12 +254,18 @@ class TcpTransport( Logger.i { "$logTag: [$address] Connecting to $host:$port" } val attemptStart = nowMillis - Socket(InetAddress.getByName(host), port).use { sock -> - sock.tcpNoDelay = true - sock.keepAlive = true - sock.soTimeout = SOCKET_TIMEOUT_MS - socket = sock + val selector = SelectorManager(dispatchers.io) + selectorManager = selector + val sock = + withTimeout(CONNECT_TIMEOUT_MS) { + aSocket(selector).tcp().connect(InetSocketAddress(host, port)) { + noDelay = true + keepAlive = true + } + } + socket = sock + try { val connectTime = nowMillis - attemptStart connectionStartTime = nowMillis resetMetrics() @@ -242,55 +273,70 @@ class TcpTransport( Logger.i { "$logTag: [$address] Socket connected in ${connectTime}ms" } - BufferedOutputStream(sock.getOutputStream()).use { output -> - outStream = output + val output = sock.openWriteChannel(autoFlush = false) + writeChannel = output + val input: ByteReadChannel = sock.openReadChannel() - BufferedInputStream(sock.getInputStream()).use { input -> - // Send wake bytes and signal connected - sendBytesRaw(StreamFrameCodec.WAKE_BYTES) - listener.onConnected() + // Send wake bytes and signal connected + sendBytesRaw(StreamFrameCodec.WAKE_BYTES) + flushBytes() + connected = true + listener.onConnected() - // Read loop - var timeoutCount = 0 - while (timeoutCount < SOCKET_RETRIES) { - try { - val c = input.read() - if (c == -1) { - Logger.i { "$logTag: [$address] EOF after $packetsReceived packets" } - break - } - timeoutCount = 0 - bytesReceived++ - codec.processInputByte(c.toByte()) - } catch (_: SocketTimeoutException) { - timeoutCount++ - timeoutEvents++ - if (timeoutCount % TIMEOUT_LOG_INTERVAL == 0) { - Logger.d { "$logTag: [$address] Timeout $timeoutCount/$SOCKET_RETRIES" } - } - } - } + readLoop(address, input) - if (timeoutCount >= SOCKET_RETRIES) { - Logger.w { "$logTag: [$address] Closing after $SOCKET_RETRIES consecutive timeouts" } - } - } - } - val hadData = bytesReceived > 0 + bytesReceived > 0 + } finally { disconnectSocket() - hadData } } + /** + * Read until EOF or [SOCKET_RETRIES] consecutive inactivity timeouts. [withTimeoutOrNull] gives a *resumable* + * inactivity timeout: cancelling a parked `readAvailable` leaves the channel usable for the next iteration + * (validated in `TcpTransportTest`). + */ + @Suppress("NestedBlockDepth") + private suspend fun readLoop(address: String, input: ByteReadChannel) { + val buf = ByteArray(READ_BUFFER_SIZE) + var timeoutCount = 0 + while (timeoutCount < SOCKET_RETRIES) { + val read = withTimeoutOrNull(SOCKET_TIMEOUT_MS) { input.readAvailable(buf) } + when { + read == null -> { + timeoutCount++ + timeoutEvents++ + if (timeoutCount % TIMEOUT_LOG_INTERVAL == 0) { + Logger.d { "$logTag: [$address] Timeout $timeoutCount/$SOCKET_RETRIES" } + } + } + + read == -1 -> { + Logger.i { "$logTag: [$address] EOF after $packetsReceived packets" } + return + } + + else -> { + timeoutCount = 0 + bytesReceived += read + for (i in 0 until read) { + codec.processInputByte(buf[i]) + } + } + } + } + Logger.w { "$logTag: [$address] Closing after $SOCKET_RETRIES consecutive timeouts" } + } + // Guards against recursive disconnects triggered by listener callbacks. - private val isDisconnecting = AtomicBoolean(false) + private val isDisconnecting = atomic(false) private fun disconnectSocket() { - if (!isDisconnecting.compareAndSet(false, true)) return + if (!isDisconnecting.compareAndSet(expect = false, update = true)) return try { val s = socket - val hadConnection = s != null || outStream != null + val hadConnection = s != null || writeChannel != null if (s != null) { val uptime = if (connectionStartTime > 0) nowMillis - connectionStartTime else 0 Logger.i { @@ -300,19 +346,22 @@ class TcpTransport( } try { s.close() - } catch (_: IOException) { - // Ignore close errors + } catch (ex: IOException) { + Logger.w(ex) { "$logTag: [$currentAddress] Error closing socket" } } } + selectorManager?.close() socket = null - outStream = null + writeChannel = null + selectorManager = null + connected = false if (hadConnection) { listener.onDisconnected() } } finally { - isDisconnecting.set(false) + isDisconnecting.value = false } } @@ -320,23 +369,23 @@ class TcpTransport( // region Byte I/O - private fun sendBytesRaw(p: ByteArray) { + private suspend fun sendBytesRaw(p: ByteArray) { val stream = - outStream + writeChannel ?: run { Logger.w { "$logTag: [$currentAddress] Cannot send ${p.size} bytes: not connected" } return } try { - stream.write(p) + stream.writeFully(p) } catch (ex: IOException) { Logger.w(ex) { "$logTag: [$currentAddress] TCP write error" } disconnectSocket() } } - private fun flushBytes() { - val stream = outStream ?: return + private suspend fun flushBytes() { + val stream = writeChannel ?: return try { stream.flush() } catch (ex: IOException) { diff --git a/core/network/src/commonTest/kotlin/org/meshtastic/core/network/transport/TcpTransportTest.kt b/core/network/src/commonTest/kotlin/org/meshtastic/core/network/transport/TcpTransportTest.kt new file mode 100644 index 000000000..48fc74d24 --- /dev/null +++ b/core/network/src/commonTest/kotlin/org/meshtastic/core/network/transport/TcpTransportTest.kt @@ -0,0 +1,208 @@ +/* + * Copyright (c) 2026 Meshtastic LLC + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +@file:Suppress("MagicNumber") + +package org.meshtastic.core.network.transport + +import io.ktor.network.selector.SelectorManager +import io.ktor.network.sockets.InetSocketAddress +import io.ktor.network.sockets.ServerSocket +import io.ktor.network.sockets.aSocket +import io.ktor.network.sockets.openReadChannel +import io.ktor.network.sockets.openWriteChannel +import io.ktor.network.sockets.port +import io.ktor.utils.io.ByteWriteChannel +import io.ktor.utils.io.readAvailable +import io.ktor.utils.io.writeFully +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.async +import kotlinx.coroutines.cancel +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout +import kotlinx.coroutines.withTimeoutOrNull +import org.meshtastic.core.di.CoroutineDispatchers +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class TcpTransportTest { + + /** + * THE SPIKE. The whole migration rests on this Ktor assumption: a `withTimeoutOrNull` that cancels a parked + * `readAvailable` must leave the channel usable for the next read — that is what reproduces the old + * `Socket.soTimeout` resumable inactivity timeout. If this fails, switch [TcpTransport]'s read loop to the watchdog + * fallback (see plan). + */ + @Test + fun `read channel survives a withTimeoutOrNull read timeout and resumes`() = runTest { + withContext(Dispatchers.Default) { + val selector = SelectorManager(Dispatchers.Default) + val server = aSocket(selector).tcp().bind(hostname = LOCALHOST, port = 0) + val port = server.localAddress.port() + try { + val acceptJob = async { server.accept() } + val client = aSocket(selector).tcp().connect(InetSocketAddress(LOCALHOST, port)) + val serverConn = acceptJob.await() + + val clientRead = client.openReadChannel() + val serverWrite = serverConn.openWriteChannel(autoFlush = true) + val buf = ByteArray(64) + + // 1st read: server is silent, so this must time out (null), cancelling the parked read. + val firstRead = withTimeoutOrNull(200) { clientRead.readAvailable(buf) } + assertNull(firstRead, "expected the idle read to time out") + + // 2nd read on the SAME channel: server now sends a byte — the channel must still deliver it. + serverWrite.writeFully(byteArrayOf(0x42)) + val secondRead = withTimeout(2_000) { clientRead.readAvailable(buf) } + assertEquals(1, secondRead, "channel was torn down by the previous read-timeout cancellation") + assertEquals(0x42.toByte(), buf[0]) + + client.close() + serverConn.close() + } finally { + server.close() + selector.close() + } + } + } + + /** End-to-end: connect, receive a framed packet from the peer, decode it through [StreamFrameCodec]. */ + @Test + fun `transport decodes a framed packet sent by the peer`() = runTest { + withContext(Dispatchers.Default) { + val server = TestTcpServer.start() + val connected = CompletableDeferred() + val received = CompletableDeferred() + val transport = + TcpTransport( + dispatchers = testDispatchers(), + scope = CoroutineScope(SupervisorJob() + Dispatchers.Default), + listener = + object : TcpTransport.Listener { + override fun onConnected() { + connected.complete(Unit) + } + + override fun onDisconnected() = Unit + + override fun onPacketReceived(bytes: ByteArray) { + received.complete(bytes) + } + }, + ) + + try { + transport.start("$LOCALHOST:${server.port}") + val conn = withTimeout(5_000) { server.awaitConnection() } + withTimeout(5_000) { connected.await() } + + // The transport sends 4 wake bytes (0x94) on connect; drain them so they do not pollute asserts. + conn.drain(4) + + val payload = byteArrayOf(0x10, 0x20, 0x30) + conn.writeFramed(payload) + + val decoded = withTimeout(5_000) { received.await() } + assertContentEquals(payload, decoded) + assertTrue(transport.isConnected) + } finally { + transport.stop() + server.close() + } + } + } + + private fun testDispatchers() = + CoroutineDispatchers(io = Dispatchers.Default, main = Dispatchers.Default, default = Dispatchers.Default) + + private class TestTcpServer + private constructor( + private val selector: SelectorManager, + private val socket: ServerSocket, + ) { + private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + private val accepted = CompletableDeferred() + val port: Int = socket.localAddress.port() + + init { + scope.launch { + runCatching { + val s = socket.accept() + accepted.complete(TestTcpConnection(s.openReadChannel(), s.openWriteChannel(autoFlush = true))) + } + .onFailure { accepted.completeExceptionally(it) } + } + } + + suspend fun awaitConnection(): TestTcpConnection = accepted.await() + + fun close() { + runCatching { socket.close() } + runCatching { selector.close() } + scope.cancel() + } + + companion object { + suspend fun start(): TestTcpServer { + val selector = SelectorManager(Dispatchers.Default) + return TestTcpServer(selector, aSocket(selector).tcp().bind(hostname = LOCALHOST, port = 0)) + } + } + } + + private class TestTcpConnection( + private val read: io.ktor.utils.io.ByteReadChannel, + private val write: ByteWriteChannel, + ) { + /** Reads and discards exactly [count] bytes. */ + suspend fun drain(count: Int) { + val buf = ByteArray(count) + var off = 0 + while (off < count) { + read.awaitContent() + val n = read.readAvailable(buf, off, count - off) + if (n == -1) break + off += n + } + } + + /** Writes a Meshtastic stream frame: [START1][START2][len MSB][len LSB][payload]. */ + suspend fun writeFramed(payload: ByteArray) { + val frame = + byteArrayOf( + StreamFrameCodec.START1, + StreamFrameCodec.START2, + (payload.size shr 8).toByte(), + (payload.size and 0xff).toByte(), + ) + payload + write.writeFully(frame) + write.flush() + } + } + + private companion object { + const val LOCALHOST = "127.0.0.1" + } +}