From 155ff1c38760989386ff305ebcd336d8fdd69432 Mon Sep 17 00:00:00 2001 From: Alfie Fresta Date: Sun, 7 Jun 2026 22:09:39 +0100 Subject: [PATCH] fix(cable): send shutdown on close and linger for late linking data close() now asks the connection loop to send a Shutdown control frame over the encrypted channel and waits for it to terminate before the channel is dropped. After a CTAP response the loop lingers briefly to capture a late linking Update before sending Shutdown and tearing down. --- libwebauthn/Cargo.toml | 2 + libwebauthn/src/transport/cable/channel.rs | 10 +- .../src/transport/cable/connection_stages.rs | 3 + .../src/transport/cable/known_devices.rs | 3 + libwebauthn/src/transport/cable/protocol.rs | 371 +++++++++++++++++- .../src/transport/cable/qr_code_device.rs | 3 + 6 files changed, 384 insertions(+), 8 deletions(-) diff --git a/libwebauthn/Cargo.toml b/libwebauthn/Cargo.toml index 80b21c06..251909be 100644 --- a/libwebauthn/Cargo.toml +++ b/libwebauthn/Cargo.toml @@ -118,6 +118,8 @@ reqwest = { version = "0.12", default-features = false, features = [ [dev-dependencies] tracing-subscriber = { version = "0.3.3", features = ["env-filter"] } qrcode = "0.14.1" +# test-util enables paused time for deterministic timeout/linger tests +tokio = { version = "1.45", features = ["test-util"] } # For turning on logging in unittests test-log = { version = "0.2" } diff --git a/libwebauthn/src/transport/cable/channel.rs b/libwebauthn/src/transport/cable/channel.rs index 58176e92..fa39b67a 100644 --- a/libwebauthn/src/transport/cable/channel.rs +++ b/libwebauthn/src/transport/cable/channel.rs @@ -47,6 +47,7 @@ pub struct CableChannel { pub(crate) ux_update_sender: broadcast::Sender, pub(crate) connection_state_receiver: watch::Receiver, pub(crate) persistent_token_store: Option>, + pub(crate) close_sender: Option>, } impl CableChannel { @@ -137,7 +138,14 @@ impl Channel for CableChannel { } async fn close(&mut self) { - // TODO Send CableTunnelMessageType#Shutdown and drop the connection + // Signal the loop to send Shutdown, then wait for it to flush and terminate. + if let Some(close_sender) = self.close_sender.take() { + let _ = close_sender.send(()).await; + } + let mut connection_state = self.connection_state_receiver.clone(); + let _ = connection_state + .wait_for(|state| *state == ConnectionState::Terminated) + .await; } async fn apdu_send(&mut self, _request: &ApduRequest, _timeout: Duration) -> Result<(), Error> { diff --git a/libwebauthn/src/transport/cable/connection_stages.rs b/libwebauthn/src/transport/cable/connection_stages.rs index 24ed19f7..1b9e72e0 100644 --- a/libwebauthn/src/transport/cable/connection_stages.rs +++ b/libwebauthn/src/transport/cable/connection_stages.rs @@ -198,6 +198,7 @@ pub(crate) struct TunnelConnectionInput { pub noise_state: TunnelNoiseState, pub cbor_tx_recv: mpsc::Receiver, pub cbor_rx_send: mpsc::Sender, + pub close_rx: mpsc::Receiver<()>, } impl TunnelConnectionInput { @@ -206,6 +207,7 @@ impl TunnelConnectionInput { known_device_store: Option>, cbor_tx_recv: mpsc::Receiver, cbor_rx_send: mpsc::Sender, + close_rx: mpsc::Receiver<()>, ) -> Self { Self { connection_type: handshake_output.connection_type, @@ -215,6 +217,7 @@ impl TunnelConnectionInput { noise_state: handshake_output.noise_state, cbor_tx_recv, cbor_rx_send, + close_rx, } } } diff --git a/libwebauthn/src/transport/cable/known_devices.rs b/libwebauthn/src/transport/cable/known_devices.rs index 3eadfb03..2ca69dee 100644 --- a/libwebauthn/src/transport/cable/known_devices.rs +++ b/libwebauthn/src/transport/cable/known_devices.rs @@ -198,6 +198,7 @@ impl<'d> Device<'d, Cable, CableChannel> for CableKnownDevice { let (ux_update_sender, _) = broadcast::channel(16); let (cbor_tx_send, cbor_tx_recv) = mpsc::channel(16); let (cbor_rx_send, cbor_rx_recv) = mpsc::channel(16); + let (close_sender, close_rx) = mpsc::channel(1); let (connection_state_sender, connection_state_receiver) = watch::channel(ConnectionState::Connecting); @@ -225,6 +226,7 @@ impl<'d> Device<'d, Cable, CableChannel> for CableKnownDevice { Some(known_device.store), cbor_tx_recv, cbor_rx_send, + close_rx, ); match protocol::connection(tunnel_input).await { @@ -247,6 +249,7 @@ impl<'d> Device<'d, Cable, CableChannel> for CableKnownDevice { ux_update_sender, connection_state_receiver, persistent_token_store: settings.persistent_token_store, + close_sender: Some(close_sender), }) } } diff --git a/libwebauthn/src/transport/cable/protocol.rs b/libwebauthn/src/transport/cable/protocol.rs index ede4795c..4f4c05df 100644 --- a/libwebauthn/src/transport/cable/protocol.rs +++ b/libwebauthn/src/transport/cable/protocol.rs @@ -2,6 +2,7 @@ //! hybrid transport. Runs over any [`CableDataChannel`]. use std::collections::BTreeMap; use std::sync::Arc; +use std::time::Duration; use hmac::{Hmac, Mac}; use p256::{ecdh, NonZeroScalar}; @@ -28,6 +29,9 @@ const P256_X962_LENGTH: usize = 65; const MAX_CBOR_SIZE: usize = 1024 * 1024; const PADDING_GRANULARITY: usize = 32; +/// How long to linger after the CTAP response to capture a late linking Update. +const POST_RESPONSE_LINGER: Duration = Duration::from_secs(120); + const CABLE_PROLOGUE_STATE_ASSISTED: &[u8] = &[0u8]; const CABLE_PROLOGUE_QR_INITIATED: &[u8] = &[1u8]; @@ -118,6 +122,8 @@ enum CableTunnelMessageType { enum RecvOutcome { /// Frame handled, keep the loop running. Continue, + /// A CTAP response was delivered to the client, arming the linger window. + ResponseDelivered, /// Peer sent a `Shutdown` control message; close the channel cleanly. PeerShutdown, } @@ -295,6 +301,10 @@ pub(crate) async fn connection(mut input: TunnelConnectionInput) -> Result<(), T }; debug!(?get_info_response_serialized, "Received initial message"); + let linger = tokio::time::sleep(POST_RESPONSE_LINGER); + tokio::pin!(linger); + let mut lingering = false; + loop { tokio::select! { result = input.data_channel.recv() => { @@ -313,6 +323,12 @@ pub(crate) async fn connection(mut input: TunnelConnectionInput) -> Result<(), T .await { Ok(RecvOutcome::Continue) => {} + Ok(RecvOutcome::ResponseDelivered) => { + linger + .as_mut() + .reset(tokio::time::Instant::now() + POST_RESPONSE_LINGER); + lingering = true; + } Ok(RecvOutcome::PeerShutdown) => return Ok(()), Err(e) => { error!(?e, "Fatal error processing inbound frame"); @@ -330,6 +346,24 @@ pub(crate) async fn connection(mut input: TunnelConnectionInput) -> Result<(), T } } } + _ = input.close_rx.recv() => { + debug!("Channel close requested, sending Shutdown control frame"); + if let Err(e) = + connection_send_shutdown(&mut *input.data_channel, &mut input.noise_state).await + { + warn!(?e, "Failed to send Shutdown control frame on close"); + } + return Ok(()); + } + _ = &mut linger, if lingering => { + debug!("Linger window elapsed, sending Shutdown control frame"); + if let Err(e) = + connection_send_shutdown(&mut *input.data_channel, &mut input.noise_state).await + { + warn!(?e, "Failed to send Shutdown control frame after linger"); + } + return Ok(()); + } Some(request) = input.cbor_tx_recv.recv() => { match request.command { // Optimisation: respond to GetInfo requests immediately with the cached response @@ -342,6 +376,9 @@ pub(crate) async fn connection(mut input: TunnelConnectionInput) -> Result<(), T } } _ => { + // A new request is in flight; stop the linger so a slow + // response (eg. on-device UV) is not torn down early. + lingering = false; debug!(?request.command, "Sending CBOR request"); if let Err(e) = connection_send( request, @@ -380,16 +417,47 @@ async fn connection_send( } trace!(?cbor_request, cbor_request_len = cbor_request.len()); - let extra_bytes = PADDING_GRANULARITY - (cbor_request.len() % PADDING_GRANULARITY); - let padded_len = cbor_request.len() + extra_bytes; + send_tunnel_frame( + CableTunnelMessageType::Ctap, + &cbor_request, + data_channel, + noise_state, + ) + .await +} + +/// Sends an empty `Shutdown` control frame over the encrypted channel. +async fn connection_send_shutdown( + data_channel: &mut dyn CableDataChannel, + noise_state: &mut TunnelNoiseState, +) -> Result<(), TransportError> { + debug!("Sending Shutdown control frame"); + send_tunnel_frame( + CableTunnelMessageType::Shutdown, + &[], + data_channel, + noise_state, + ) + .await +} + +/// Pads `payload`, wraps it in a `CableTunnelMessage`, encrypts it, and sends it. +async fn send_tunnel_frame( + message_type: CableTunnelMessageType, + payload: &[u8], + data_channel: &mut dyn CableDataChannel, + noise_state: &mut TunnelNoiseState, +) -> Result<(), TransportError> { + let extra_bytes = PADDING_GRANULARITY - (payload.len() % PADDING_GRANULARITY); + let padded_len = payload.len() + extra_bytes; - let mut padded_cbor_request = cbor_request.clone(); - padded_cbor_request.resize(padded_len, 0u8); - if let Some(last) = padded_cbor_request.last_mut() { + let mut padded_payload = payload.to_vec(); + padded_payload.resize(padded_len, 0u8); + if let Some(last) = padded_payload.last_mut() { *last = (extra_bytes - 1) as u8; } - let frame = CableTunnelMessage::new(CableTunnelMessageType::Ctap, &padded_cbor_request); + let frame = CableTunnelMessage::new(message_type, &padded_payload); let frame_serialized = frame.to_vec(); trace!(?frame_serialized); @@ -589,7 +657,7 @@ async fn connection_recv( .send(cbor_response) .await .or(Err(TransportError::ConnectionFailed))?; - Ok(RecvOutcome::Continue) + Ok(RecvOutcome::ResponseDelivered) } CableTunnelMessageType::Update => { // Malformed or unsigned update: log, drop the update, keep the channel. @@ -709,4 +777,293 @@ mod tests { let stripped = strip_frame_padding(frame).unwrap(); assert_eq!(stripped, vec![0xAA, 0xBB, 0xCC, 0xDD]); } + + use async_trait::async_trait; + use rand::rngs::OsRng; + use serde_indexed::SerializeIndexed; + use tokio::sync::mpsc; + + /// In-memory data channel: records outbound frames and replays queued inbound ones. + struct TestDataChannel { + inbound: mpsc::UnboundedReceiver>, + outbound: mpsc::UnboundedSender>, + } + + #[async_trait] + impl CableDataChannel for TestDataChannel { + async fn send(&mut self, message: &[u8]) -> Result<(), TransportError> { + let _ = self.outbound.send(message.to_vec()); + Ok(()) + } + + async fn recv(&mut self) -> Result>, TransportError> { + Ok(self.inbound.recv().await) + } + } + + /// Two Noise transport states that can encrypt/decrypt to each other. + fn paired_transport_states() -> (TransportState, TransportState) { + let mut initiator = Builder::new("Noise_NN_P256_AESGCM_SHA256".parse().unwrap()) + .build_initiator() + .unwrap(); + let mut responder = Builder::new("Noise_NN_P256_AESGCM_SHA256".parse().unwrap()) + .build_responder() + .unwrap(); + let mut a = [0u8; 1024]; + let mut b = [0u8; 1024]; + let n = initiator.write_message(&[], &mut a).unwrap(); + responder.read_message(&a[..n], &mut b).unwrap(); + let n = responder.write_message(&[], &mut a).unwrap(); + initiator.read_message(&a[..n], &mut b).unwrap(); + ( + initiator.into_transport_mode().unwrap(), + responder.into_transport_mode().unwrap(), + ) + } + + fn pad(mut payload: Vec) -> Vec { + let extra = PADDING_GRANULARITY - (payload.len() % PADDING_GRANULARITY); + let new_len = payload.len() + extra; + payload.resize(new_len, 0u8); + *payload.last_mut().unwrap() = (extra - 1) as u8; + payload + } + + fn encrypt(state: &mut TransportState, plaintext: &[u8]) -> Vec { + let mut out = vec![0u8; plaintext.len() + 64]; + let n = state.write_message(plaintext, &mut out).unwrap(); + out.truncate(n); + out + } + + fn decrypt(state: &mut TransportState, ciphertext: &[u8]) -> Vec { + let mut out = vec![0u8; ciphertext.len() + 64]; + let n = state.read_message(ciphertext, &mut out).unwrap(); + out.truncate(n); + out + } + + #[derive(SerializeIndexed)] + struct TestInitialMessage { + #[serde(index = 0x01)] + info: ByteBuf, + } + + /// Encrypted initial post-handshake message carrying a minimal GetInfo. + fn encrypted_initial_message(responder: &mut TransportState) -> Vec { + let get_info = Ctap2GetInfoResponse { + versions: vec!["FIDO_2_0".to_string()], + aaguid: ByteBuf::from(vec![0u8; 16]), + ..Default::default() + }; + let initial = TestInitialMessage { + info: ByteBuf::from(cbor::to_vec(&get_info).unwrap()), + }; + encrypt(responder, &pad(cbor::to_vec(&initial).unwrap())) + } + + fn qr_connection_type() -> CableTunnelConnectionType { + CableTunnelConnectionType::QrCode { + routing_id: "000000".to_string(), + tunnel_id: "00000000000000000000000000000000".to_string(), + private_key: NonZeroScalar::random(&mut OsRng), + } + } + + /// Decrypts an outbound frame and returns its tunnel message type byte. + fn outbound_message_type(frame: &[u8], responder: &mut TransportState) -> u8 { + let stripped = strip_frame_padding(decrypt(responder, frame)).unwrap(); + *stripped.first().unwrap() + } + + #[tokio::test] + async fn connection_sends_shutdown_on_close() { + let (initiator, mut responder) = paired_transport_states(); + let (inbound_tx, inbound_rx) = mpsc::unbounded_channel::>(); + let (outbound_tx, mut outbound_rx) = mpsc::unbounded_channel::>(); + + inbound_tx + .send(encrypted_initial_message(&mut responder)) + .unwrap(); + + let (cbor_tx_send, cbor_tx_recv) = mpsc::channel::(4); + let (cbor_rx_send, cbor_rx_recv) = mpsc::channel::(4); + let (close_tx, close_rx) = mpsc::channel::<()>(1); + + let input = TunnelConnectionInput { + connection_type: qr_connection_type(), + tunnel_domain: "cable.example.com".to_string(), + known_device_store: None, + data_channel: Box::new(TestDataChannel { + inbound: inbound_rx, + outbound: outbound_tx, + }), + noise_state: TunnelNoiseState { + transport_state: initiator, + handshake_hash: vec![0u8; 32], + }, + cbor_tx_recv, + cbor_rx_send, + close_rx, + }; + + let handle = tokio::spawn(connection(input)); + + close_tx.send(()).await.unwrap(); + + let frame = outbound_rx + .recv() + .await + .expect("a frame on the outbound path"); + assert_eq!( + outbound_message_type(&frame, &mut responder), + CableTunnelMessageType::Shutdown as u8 + ); + assert!(handle.await.unwrap().is_ok()); + + // Keep the channel ends alive until the loop has shut down. + drop((inbound_tx, cbor_tx_send, cbor_rx_recv, close_tx)); + } + + #[tokio::test(start_paused = true)] + async fn connection_lingers_then_sends_shutdown_after_response() { + let (initiator, mut responder) = paired_transport_states(); + let (inbound_tx, inbound_rx) = mpsc::unbounded_channel::>(); + let (outbound_tx, mut outbound_rx) = mpsc::unbounded_channel::>(); + + inbound_tx + .send(encrypted_initial_message(&mut responder)) + .unwrap(); + // A CTAP response frame: [Ctap type byte][CTAP status OK], padded and encrypted. + let ctap_frame = encrypt( + &mut responder, + &pad(vec![CableTunnelMessageType::Ctap as u8, 0x00]), + ); + inbound_tx.send(ctap_frame).unwrap(); + + let (cbor_tx_send, cbor_tx_recv) = mpsc::channel::(4); + let (cbor_rx_send, mut cbor_rx_recv) = mpsc::channel::(4); + let (close_tx, close_rx) = mpsc::channel::<()>(1); + + let input = TunnelConnectionInput { + connection_type: qr_connection_type(), + tunnel_domain: "cable.example.com".to_string(), + known_device_store: None, + data_channel: Box::new(TestDataChannel { + inbound: inbound_rx, + outbound: outbound_tx, + }), + noise_state: TunnelNoiseState { + transport_state: initiator, + handshake_hash: vec![0u8; 32], + }, + cbor_tx_recv, + cbor_rx_send, + close_rx, + }; + + let handle = tokio::spawn(connection(input)); + + // The delivered CTAP response arms the linger window. + cbor_rx_recv.recv().await.expect("a CTAP response"); + + // With time paused, the runtime advances to the linger deadline, after + // which the loop emits a Shutdown frame instead of running indefinitely. + let frame = outbound_rx + .recv() + .await + .expect("a Shutdown frame after the linger window"); + assert_eq!( + outbound_message_type(&frame, &mut responder), + CableTunnelMessageType::Shutdown as u8 + ); + assert!(handle.await.unwrap().is_ok()); + + // Keep the channel ends alive until the linger has fired. + drop((inbound_tx, cbor_tx_send, close_tx)); + } + + #[tokio::test(start_paused = true)] + async fn linger_does_not_terminate_while_request_in_flight() { + let (initiator, mut responder) = paired_transport_states(); + let (inbound_tx, inbound_rx) = mpsc::unbounded_channel::>(); + let (outbound_tx, mut outbound_rx) = mpsc::unbounded_channel::>(); + + inbound_tx + .send(encrypted_initial_message(&mut responder)) + .unwrap(); + // A first CTAP response arms the linger. + inbound_tx + .send(encrypt( + &mut responder, + &pad(vec![CableTunnelMessageType::Ctap as u8, 0x00]), + )) + .unwrap(); + + let (cbor_tx_send, cbor_tx_recv) = mpsc::channel::(4); + let (cbor_rx_send, mut cbor_rx_recv) = mpsc::channel::(4); + let (close_tx, close_rx) = mpsc::channel::<()>(1); + + let input = TunnelConnectionInput { + connection_type: qr_connection_type(), + tunnel_domain: "cable.example.com".to_string(), + known_device_store: None, + data_channel: Box::new(TestDataChannel { + inbound: inbound_rx, + outbound: outbound_tx, + }), + noise_state: TunnelNoiseState { + transport_state: initiator, + handshake_hash: vec![0u8; 32], + }, + cbor_tx_recv, + cbor_rx_send, + close_rx, + }; + + let mut handle = tokio::spawn(connection(input)); + + // Drain the first response, which armed the linger. + cbor_rx_recv.recv().await.expect("first CTAP response"); + + // The client sends a new request before the linger elapses. Draining its + // outbound frame proves the loop processed it and disarmed the linger. + cbor_tx_send + .send(CborRequest::new(Ctap2CommandCode::AuthenticatorClientPin)) + .await + .unwrap(); + let req_frame = outbound_rx.recv().await.expect("a request frame"); + assert_eq!( + outbound_message_type(&req_frame, &mut responder), + CableTunnelMessageType::Ctap as u8 + ); + + // Advance well past the old linger deadline. A bounded join then drives + // the loop: with the linger disarmed it stays alive (the join times out), + // whereas an armed linger would have sent Shutdown and returned. + tokio::time::advance(POST_RESPONSE_LINGER * 2).await; + let still_running = tokio::time::timeout(Duration::from_secs(1), &mut handle) + .await + .is_err(); + assert!( + still_running, + "connection terminated while a request was in flight" + ); + + // An explicit close still shuts the loop down cleanly. + close_tx.send(()).await.unwrap(); + let frame = outbound_rx.recv().await.expect("a Shutdown frame on close"); + assert_eq!( + outbound_message_type(&frame, &mut responder), + CableTunnelMessageType::Shutdown as u8 + ); + assert!(handle.await.unwrap().is_ok()); + + drop((inbound_tx, cbor_tx_send)); + } + + #[test] + fn post_response_linger_is_two_minutes() { + assert_eq!(POST_RESPONSE_LINGER, Duration::from_secs(120)); + } } diff --git a/libwebauthn/src/transport/cable/qr_code_device.rs b/libwebauthn/src/transport/cable/qr_code_device.rs index d3b065b9..d9cc6ec9 100644 --- a/libwebauthn/src/transport/cable/qr_code_device.rs +++ b/libwebauthn/src/transport/cable/qr_code_device.rs @@ -244,6 +244,7 @@ impl<'d> Device<'d, Cable, CableChannel> for CableQrCodeDevice { let (ux_update_sender, _) = broadcast::channel(16); let (cbor_tx_send, cbor_tx_recv) = mpsc::channel(16); let (cbor_rx_send, cbor_rx_recv) = mpsc::channel(16); + let (close_sender, close_rx) = mpsc::channel(1); let (connection_state_sender, connection_state_receiver) = watch::channel(ConnectionState::Connecting); @@ -271,6 +272,7 @@ impl<'d> Device<'d, Cable, CableChannel> for CableQrCodeDevice { qr_device.store, cbor_tx_recv, cbor_rx_send, + close_rx, ); match protocol::connection(tunnel_input).await { Ok(()) => { @@ -292,6 +294,7 @@ impl<'d> Device<'d, Cable, CableChannel> for CableQrCodeDevice { ux_update_sender, connection_state_receiver, persistent_token_store: settings.persistent_token_store, + close_sender: Some(close_sender), }) }