diff --git a/core/src/commonMain/kotlin/dev/kdriver/core/connection/DefaultConnection.kt b/core/src/commonMain/kotlin/dev/kdriver/core/connection/DefaultConnection.kt index f0fbdf763..42a8bdbfd 100644 --- a/core/src/commonMain/kotlin/dev/kdriver/core/connection/DefaultConnection.kt +++ b/core/src/commonMain/kotlin/dev/kdriver/core/connection/DefaultConnection.kt @@ -4,12 +4,7 @@ import dev.kdriver.cdp.* import dev.kdriver.cdp.domain.* import dev.kdriver.core.browser.Browser import dev.kdriver.core.browser.Config.Defaults -import dev.kdriver.core.browser.WebSocketInfo -import io.ktor.client.* -import io.ktor.client.plugins.websocket.* -import io.ktor.http.* import io.ktor.util.logging.* -import io.ktor.websocket.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.flow.Flow @@ -35,17 +30,23 @@ open class DefaultConnection( private val logger = KtorSimpleLogger("Connection") - private val client = HttpClient(getWebSocketClientEngine()) { - install(WebSockets) - } - - private var wsSession: ClientWebSocketSession? = null + private val transport: WebSocketTransport by lazy { createTransport() } private var socketSubscription: Job? = null private val currentIdMutex = Mutex() private var currentId = 0L + private val pendingRequestsMutex = Mutex() + private val pendingRequests = mutableMapOf>() + + /** + * Creates the [WebSocketTransport] used to talk to the browser. + * + * Overridable so tests can inject a fake transport without a real browser. + */ + protected open fun createTransport(): WebSocketTransport = KtorWebSocketTransport(websocketUrl) + private var prepareHeadlessDone = false private var prepareExpertDone = false @@ -61,16 +62,8 @@ open class DefaultConnection( override val generatedDomains: MutableMap, Domain> = mutableMapOf() private suspend fun connect() { - if (wsSession != null && wsSession?.isActive == true) return - wsSession = client.webSocketSession { - url { - val parsed = parseWebSocketUrl(websocketUrl) - this.protocol = URLProtocol.WS - this.host = parsed.host - this.port = parsed.port - this.path(parsed.path) - } - } + if (transport.isActive) return + transport.connect() startListening() } @@ -78,12 +71,14 @@ open class DefaultConnection( socketSubscription?.cancel() socketSubscription = messageListeningScope.launch { try { - for (frame in wsSession?.incoming ?: return@launch) { + transport.incoming().collect { text -> try { - frame as? Frame.Text ?: continue - val text = frame.readText() logger.debug("WS < CDP: ${text.take(owner?.config?.debugStringLimit ?: Defaults.DEBUG_STRING_LIMIT)}") val received = Serialization.json.decodeFromString(text) + if (received is Message.Response) { + pendingRequestsMutex.withLock { pendingRequests.remove(received.id) } + ?.complete(received) + } allMessages.emit(received) } catch (e: CancellationException) { throw e @@ -110,19 +105,27 @@ open class DefaultConnection( } val requestId = currentIdMutex.withLock { currentId++ } - val jsonString = Serialization.json.encodeToString(Request(requestId, method, parameter)) - wsSession?.send(jsonString) - logger.debug("WS > CDP: ${jsonString.take(owner?.config?.debugStringLimit ?: Defaults.DEBUG_STRING_LIMIT)}") - - val result = responses.first { it.id == requestId } - result.error?.throwAsException(method) - return result.result + // Register the response waiter *before* sending, so a reply that arrives before we start + // awaiting is still captured (the receive loop completes this deferred). Awaiting the + // response via a replay-0 shared flow after sending could miss it and hang (ISSUE-1). + val deferred = CompletableDeferred() + pendingRequestsMutex.withLock { pendingRequests[requestId] = deferred } + try { + val jsonString = Serialization.json.encodeToString(Request(requestId, method, parameter)) + transport.send(jsonString) + logger.debug("WS > CDP: ${jsonString.take(owner?.config?.debugStringLimit ?: Defaults.DEBUG_STRING_LIMIT)}") + + val result = deferred.await() + result.error?.throwAsException(method) + return result.result + } finally { + pendingRequestsMutex.withLock { pendingRequests.remove(requestId) } + } } @InternalCdpApi override suspend fun close() { - wsSession?.close() - wsSession = null + transport.close() socketSubscription?.cancel() socketSubscription = null } @@ -219,20 +222,6 @@ open class DefaultConnection( } } - private fun parseWebSocketUrl(url: String): WebSocketInfo { - val uri = Url(url) - - val host = uri.host - val port = if (uri.port != -1) uri.port else when (uri.protocol) { - URLProtocol.WS -> 80 - URLProtocol.WSS -> 443 - else -> throw IllegalArgumentException("Unsupported scheme: ${uri.protocol}") - } - val path = uri.encodedPath - - return WebSocketInfo(host, port, path) - } - override fun toString(): String { return "Connection: ${targetInfo?.toString() ?: "no target"}" } diff --git a/core/src/commonMain/kotlin/dev/kdriver/core/connection/KtorWebSocketTransport.kt b/core/src/commonMain/kotlin/dev/kdriver/core/connection/KtorWebSocketTransport.kt new file mode 100644 index 000000000..2c87facb0 --- /dev/null +++ b/core/src/commonMain/kotlin/dev/kdriver/core/connection/KtorWebSocketTransport.kt @@ -0,0 +1,72 @@ +package dev.kdriver.core.connection + +import dev.kdriver.core.browser.WebSocketInfo +import io.ktor.client.* +import io.ktor.client.plugins.websocket.* +import io.ktor.http.* +import io.ktor.websocket.* +import kotlinx.coroutines.isActive +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow + +/** + * Default [WebSocketTransport] backed by a Ktor WebSocket session. + */ +class KtorWebSocketTransport( + private val websocketUrl: String, +) : WebSocketTransport { + + private val client = HttpClient(getWebSocketClientEngine()) { + install(WebSockets) + } + + private var session: ClientWebSocketSession? = null + + override val isActive: Boolean + get() = session?.isActive == true + + override suspend fun connect() { + if (isActive) return + session = client.webSocketSession { + url { + val parsed = parseWebSocketUrl(websocketUrl) + this.protocol = URLProtocol.WS + this.host = parsed.host + this.port = parsed.port + this.path(parsed.path) + } + } + } + + override suspend fun send(message: String) { + session?.send(message) + } + + override fun incoming(): Flow = flow { + val session = session ?: return@flow + for (frame in session.incoming) { + val text = (frame as? Frame.Text)?.readText() ?: continue + emit(text) + } + } + + override suspend fun close() { + session?.close() + session = null + } + + private fun parseWebSocketUrl(url: String): WebSocketInfo { + val uri = Url(url) + + val host = uri.host + val port = if (uri.port != -1) uri.port else when (uri.protocol) { + URLProtocol.WS -> 80 + URLProtocol.WSS -> 443 + else -> throw IllegalArgumentException("Unsupported scheme: ${uri.protocol}") + } + val path = uri.encodedPath + + return WebSocketInfo(host, port, path) + } + +} diff --git a/core/src/commonMain/kotlin/dev/kdriver/core/connection/WebSocketTransport.kt b/core/src/commonMain/kotlin/dev/kdriver/core/connection/WebSocketTransport.kt new file mode 100644 index 000000000..4a7e120b6 --- /dev/null +++ b/core/src/commonMain/kotlin/dev/kdriver/core/connection/WebSocketTransport.kt @@ -0,0 +1,39 @@ +package dev.kdriver.core.connection + +import kotlinx.coroutines.flow.Flow + +/** + * Abstraction over the raw WebSocket connection used to talk to the browser. + * + * Extracting this lets [DefaultConnection]'s message plumbing (request/response correlation, + * event dispatch) be exercised without a real browser, by injecting a fake transport in tests. + */ +interface WebSocketTransport { + + /** + * Whether the underlying connection is currently open. + */ + val isActive: Boolean + + /** + * Opens the connection. Must be called before [send] or [incoming]. No-op if already open. + */ + suspend fun connect() + + /** + * Sends a raw text payload to the browser. + */ + suspend fun send(message: String) + + /** + * Cold stream of raw text payloads received from the browser. Collecting it starts consuming + * frames; the flow completes when the connection is closed. + */ + fun incoming(): Flow + + /** + * Closes the connection. + */ + suspend fun close() + +} diff --git a/core/src/commonMain/kotlin/dev/kdriver/core/network/BaseFetchInterception.kt b/core/src/commonMain/kotlin/dev/kdriver/core/network/BaseFetchInterception.kt index b4e27b996..0ce0bfee7 100644 --- a/core/src/commonMain/kotlin/dev/kdriver/core/network/BaseFetchInterception.kt +++ b/core/src/commonMain/kotlin/dev/kdriver/core/network/BaseFetchInterception.kt @@ -5,11 +5,7 @@ import dev.kdriver.cdp.domain.Fetch.HeaderEntry import dev.kdriver.cdp.domain.Network import dev.kdriver.cdp.domain.fetch import dev.kdriver.core.tab.Tab -import kotlinx.coroutines.CompletableDeferred -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Job -import kotlinx.coroutines.launch -import kotlin.coroutines.coroutineContext +import kotlinx.coroutines.* /** * Default implementation of [FetchInterception]. @@ -30,7 +26,12 @@ open class BaseFetchInterception( } private suspend fun setup() { - val coroutineScope = CoroutineScope(coroutineContext) + val coroutineScope = CoroutineScope(currentCoroutineContext()) + // Subscribe before enabling fetch, so a requestPaused fired after enable() can't be missed. + // UNDISPATCHED guarantees the collector is subscribed before launch returns (ISSUE-2). + job = coroutineScope.launch(start = CoroutineStart.UNDISPATCHED) { + tab.fetch.requestPaused.collect { handler(it) } + } tab.fetch.enable( listOf( Fetch.RequestPattern( @@ -40,9 +41,6 @@ open class BaseFetchInterception( ) ) ) - job = coroutineScope.launch { - tab.fetch.requestPaused.collect { handler(it) } - } } private suspend fun teardown() { diff --git a/core/src/commonMain/kotlin/dev/kdriver/core/network/BaseRequestExpectation.kt b/core/src/commonMain/kotlin/dev/kdriver/core/network/BaseRequestExpectation.kt index 5e9e1e9a9..de244a18f 100644 --- a/core/src/commonMain/kotlin/dev/kdriver/core/network/BaseRequestExpectation.kt +++ b/core/src/commonMain/kotlin/dev/kdriver/core/network/BaseRequestExpectation.kt @@ -3,11 +3,7 @@ package dev.kdriver.core.network import dev.kdriver.cdp.domain.Network import dev.kdriver.cdp.domain.network import dev.kdriver.core.tab.Tab -import kotlinx.coroutines.CompletableDeferred -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Job -import kotlinx.coroutines.launch -import kotlin.coroutines.coroutineContext +import kotlinx.coroutines.* /** * Default implementation of [RequestExpectation]. @@ -53,17 +49,20 @@ open class BaseRequestExpectation( } private suspend fun setup() { - val coroutineScope = CoroutineScope(coroutineContext) - tab.network.enable() - requestJob = coroutineScope.launch { + val coroutineScope = CoroutineScope(currentCoroutineContext()) + // Subscribe to the event flows before enabling the domain, so no event fired after + // enable() can be missed. UNDISPATCHED guarantees each collector is actually subscribed + // before launch returns (ISSUE-2). + requestJob = coroutineScope.launch(start = CoroutineStart.UNDISPATCHED) { tab.network.requestWillBeSent.collect { requestHandler(it) } } - responseJob = coroutineScope.launch { + responseJob = coroutineScope.launch(start = CoroutineStart.UNDISPATCHED) { tab.network.responseReceived.collect { responseHandler(it) } } - loadingFinishedJob = coroutineScope.launch { + loadingFinishedJob = coroutineScope.launch(start = CoroutineStart.UNDISPATCHED) { tab.network.loadingFinished.collect { loadingFinishedHandler(it) } } + tab.network.enable() } private fun teardown() { diff --git a/core/src/jvmTest/kotlin/dev/kdriver/core/connection/CallCommandResponseRaceTest.kt b/core/src/jvmTest/kotlin/dev/kdriver/core/connection/CallCommandResponseRaceTest.kt new file mode 100644 index 000000000..cb7d29c5b --- /dev/null +++ b/core/src/jvmTest/kotlin/dev/kdriver/core/connection/CallCommandResponseRaceTest.kt @@ -0,0 +1,91 @@ +package dev.kdriver.core.connection + +import dev.kdriver.cdp.CommandMode +import dev.kdriver.cdp.Serialization +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.receiveAsFlow +import kotlinx.coroutines.test.UnconfinedTestDispatcher +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withTimeout +import kotlinx.serialization.json.int +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive +import kotlinx.serialization.json.long +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull + +/** + * Reproduces the `callCommand` response race (audit ISSUE-1). + * + * The browser's reply can arrive between `send` and the point where `callCommand` starts listening + * for it. Because responses were delivered through a `replay = 0` shared flow, a reply emitted + * while no collector is subscribed is dropped and `callCommand` waits forever. + * + * [FakeTransport] delivers the response **synchronously inside `send()`** — i.e. exactly in that + * window, before `callCommand` resumes to await it. On [UnconfinedTestDispatcher] the receive loop + * processes the delivered frame eagerly during `send`, making the interleaving deterministic. + */ +class CallCommandResponseRaceTest { + + /** + * Fake transport whose [send] hook lets the test inject a frame at send time. Incoming frames + * are delivered through a rendezvous channel so the connection's receive loop consumes them + * cooperatively. + */ + private class FakeTransport : WebSocketTransport { + private val channel = Channel(Channel.RENDEZVOUS) + override var isActive: Boolean = false + private set + var onSend: (suspend (String) -> Unit)? = null + + override suspend fun connect() { + isActive = true + } + + override suspend fun send(message: String) { + onSend?.invoke(message) + } + + override fun incoming(): Flow = channel.receiveAsFlow() + + suspend fun deliver(frame: String) = channel.send(frame) + + override suspend fun close() { + isActive = false + channel.close() + } + } + + private class TestConnection( + scope: CoroutineScope, + private val transport: FakeTransport, + ) : DefaultConnection("ws://stub/devtools/page/stub", scope) { + override fun createTransport(): WebSocketTransport = transport + } + + @Test + fun callCommand_receivesResponseDeliveredAtSendTime() = runTest(UnconfinedTestDispatcher()) { + val transport = FakeTransport() + val connection = TestConnection(this, transport) + + // The browser "replies" the instant the request is sent — the worst-case window. + transport.onSend = { sent -> + val id = Serialization.json.parseToJsonElement(sent).jsonObject["id"]!!.jsonPrimitive.long + transport.deliver("""{"id":$id,"result":{"value":42}}""") + } + + try { + val result = withTimeout(2_000) { + connection.callCommand("Some.method", null, CommandMode.ONE_SHOT) + } + + assertNotNull(result, "callCommand must return the response that arrived at send time") + assertEquals(42, result.jsonObject["value"]!!.jsonPrimitive.int) + } finally { + connection.close() + } + } +} diff --git a/core/src/jvmTest/kotlin/dev/kdriver/core/network/ExpectationSubscribeRaceTest.kt b/core/src/jvmTest/kotlin/dev/kdriver/core/network/ExpectationSubscribeRaceTest.kt new file mode 100644 index 000000000..4f70d66a6 --- /dev/null +++ b/core/src/jvmTest/kotlin/dev/kdriver/core/network/ExpectationSubscribeRaceTest.kt @@ -0,0 +1,125 @@ +package dev.kdriver.core.network + +import dev.kdriver.cdp.CommandMode +import dev.kdriver.cdp.Serialization +import dev.kdriver.cdp.domain.Network +import dev.kdriver.cdp.domain.Target +import dev.kdriver.core.connection.WebSocketTransport +import dev.kdriver.core.tab.DefaultTab +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.receiveAsFlow +import kotlinx.coroutines.test.UnconfinedTestDispatcher +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withTimeout +import kotlinx.serialization.json.encodeToJsonElement +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive +import kotlinx.serialization.json.long +import kotlin.test.Test +import kotlin.test.assertEquals + +/** + * Reproduces the expectation subscribe-after-enable race (audit ISSUE-2). + * + * `BaseRequestExpectation.setup()` must subscribe to the network event flows *before* the events + * it cares about can fire. The buggy version enabled the Network domain first and only then + * launched the collectors, so an event arriving in that window was dropped and `getRequestEvent()` + * waited forever. + * + * The fake transport delivers a `Network.requestWillBeSent` event **while handling the + * `Network.enable` command** — i.e. exactly while the domain is being enabled. If the collectors + * were subscribed before enabling (the fix), the event is captured; otherwise it is lost. + */ +class ExpectationSubscribeRaceTest { + + private class FakeTransport : WebSocketTransport { + private val channel = Channel(Channel.RENDEZVOUS) + override var isActive: Boolean = false + private set + var onSend: (suspend (String) -> Unit)? = null + + override suspend fun connect() { + isActive = true + } + + override suspend fun send(message: String) { + onSend?.invoke(message) + } + + override fun incoming(): Flow = channel.receiveAsFlow() + + suspend fun deliver(frame: String) = channel.send(frame) + + override suspend fun close() { + isActive = false + channel.close() + } + } + + private class TestTab( + scope: CoroutineScope, + private val transport: FakeTransport, + ) : DefaultTab("ws://stub/devtools/page/stub", scope, STUB_TARGET) { + override fun createTransport(): WebSocketTransport = transport + } + + private fun requestWillBeSentFrame(requestId: String): String { + val params = Network.RequestWillBeSentParameter( + requestId = requestId, + loaderId = "loader-1", + documentURL = "https://example.com", + request = Network.Request( + url = "https://example.com", + method = "GET", + headers = emptyMap(), + initialPriority = Network.ResourcePriority.MEDIUM, + referrerPolicy = "no-referrer", + ), + timestamp = 0.0, + wallTime = 0.0, + initiator = Network.Initiator(type = "other"), + redirectHasExtraInfo = false, + ) + val paramsJson = Serialization.json.encodeToJsonElement(params) + return """{"method":"Network.requestWillBeSent","params":$paramsJson}""" + } + + @Test + fun expect_capturesEventFiredWhileEnabling() = runTest(UnconfinedTestDispatcher()) { + val transport = FakeTransport() + val tab = TestTab(this, transport) + + transport.onSend = { sent -> + val obj = Serialization.json.parseToJsonElement(sent).jsonObject + val id = obj["id"]!!.jsonPrimitive.long + val method = obj["method"]!!.jsonPrimitive.content + if (method == "Network.enable") { + // Fire the event in the exact window where the domain is being enabled. + transport.deliver(requestWillBeSentFrame("req-1")) + } + transport.deliver("""{"id":$id,"result":{}}""") + } + + try { + val event = tab.expect(urlPattern = null) { + withTimeout(2_000) { getRequestEvent() } + } + assertEquals("req-1", event.requestId) + } finally { + tab.close() + } + } + + private companion object { + val STUB_TARGET = Target.TargetInfo( + targetId = "stub", + type = "page", + title = "", + url = "about:blank", + attached = true, + canAccessOpener = false, + ) + } +}