diff --git a/peeroxide-dht/src/secret_stream.rs b/peeroxide-dht/src/secret_stream.rs index 7856784..8c7001a 100644 --- a/peeroxide-dht/src/secret_stream.rs +++ b/peeroxide-dht/src/secret_stream.rs @@ -23,7 +23,7 @@ use blake2::digest::{KeyInit, Mac}; use blake2::Blake2bMac; use blake2::digest::consts::U32; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; use tracing::debug; use crate::crypto; @@ -340,6 +340,106 @@ impl SecretStream { pub fn into_inner(self) -> T { self.raw } + + /// Split the stream into owned read and write halves. + /// + /// The underlying transport is split with [`tokio::io::split`]; the + /// decryptor (`Pull`) moves into the [`SecretStreamReadHalf`] and the + /// encryptor (`Push`) moves into the [`SecretStreamWriteHalf`]. Because the + /// two halves own disjoint cipher state, they can be driven from separate + /// tasks concurrently without either half cancelling the other's in-flight + /// framed read/write. Wire framing and cipher behavior are identical to the + /// unsplit [`SecretStream::read`]/[`SecretStream::write`]. + pub fn into_split( + self, + ) -> ( + SecretStreamReadHalf>, + SecretStreamWriteHalf>, + ) { + let (read_raw, write_raw) = tokio::io::split(self.raw); + ( + SecretStreamReadHalf { + raw: read_raw, + decrypt: self.decrypt, + }, + SecretStreamWriteHalf { + raw: write_raw, + encrypt: self.encrypt, + }, + ) + } +} + +/// Owned read half of a [`SecretStream`], produced by [`SecretStream::into_split`]. +/// +/// Holds the decryptor and the read side of the underlying transport. Its +/// [`read`](SecretStreamReadHalf::read) mirrors [`SecretStream::read`] exactly +/// (same length-prefixed framing, same empty-keepalive handling, same AEAD +/// decryption). +pub struct SecretStreamReadHalf { + raw: R, + decrypt: secretstream::Pull, +} + +impl SecretStreamReadHalf { + /// Read and decrypt the next framed message. + /// + /// Returns `Ok(None)` on clean EOF, `Ok(Some(plaintext))` for data + /// messages. Empty messages (keepalives) are silently consumed and the next + /// message is read. Identical framing/cipher semantics to + /// [`SecretStream::read`]. + pub async fn read(&mut self) -> Result>, SecretStreamError> { + loop { + let msg = match read_frame(&mut self.raw).await? { + Some(m) => m, + None => return Ok(None), + }; + + if msg.len() < ABYTES { + return Err(SecretStreamError::Decrypt( + secretstream::SecretstreamError::CiphertextTooShort, + )); + } + + let (plaintext, _tag) = self.decrypt.next(&msg)?; + + // Drop empty keepalive messages (match Node.js behavior). + if plaintext.is_empty() { + continue; + } + + return Ok(Some(plaintext)); + } + } +} + +/// Owned write half of a [`SecretStream`], produced by [`SecretStream::into_split`]. +/// +/// Holds the encryptor and the write side of the underlying transport. Its +/// [`write`](SecretStreamWriteHalf::write) mirrors [`SecretStream::write`] +/// exactly. +pub struct SecretStreamWriteHalf { + raw: W, + encrypt: secretstream::Push, +} + +impl SecretStreamWriteHalf { + /// Encrypt and send `data` as a single framed message. + /// + /// Identical framing/cipher semantics to [`SecretStream::write`]. + pub async fn write(&mut self, data: &[u8]) -> Result<(), SecretStreamError> { + let encrypted = self.encrypt.next(data); + // encrypted = [enc_tag(1)][ciphertext][mac(16)] = data.len() + ABYTES bytes + debug_assert_eq!(encrypted.len(), data.len() + ABYTES); + write_frame(&mut self.raw, &encrypted).await + } + + /// Gracefully close the write half, sending a FIN to the remote peer. + /// + /// Identical to [`SecretStream::shutdown`]. + pub async fn shutdown(&mut self) -> Result<(), SecretStreamError> { + self.raw.shutdown().await.map_err(SecretStreamError::Io) + } } // ── FramedStream adapter ───────────────────────────────────────────────────── @@ -577,4 +677,72 @@ mod tests { assert!(client.is_initiator()); assert!(!server.is_initiator()); } + + #[tokio::test] + async fn into_split_roundtrip() { + // Split both peers into read/write halves and exchange in both + // directions concurrently. Proves the split preserves wire framing and + // AEAD cipher behavior identically to the unsplit read()/write(). + let (client_stream, server_stream) = tokio::io::duplex(8192); + + let (client, server) = tokio::try_join!( + SecretStream::new(true, client_stream, noise::generate_keypair()), + SecretStream::new(false, server_stream, noise::generate_keypair()), + ) + .unwrap(); + + let (mut client_rd, mut client_wr) = client.into_split(); + let (mut server_rd, mut server_wr) = server.into_split(); + + // client → server + client_wr.write(b"hello from client").await.unwrap(); + let msg = server_rd.read().await.unwrap().expect("expected message"); + assert_eq!(msg, b"hello from client"); + + // server → client + server_wr.write(b"hello from server").await.unwrap(); + let msg = client_rd.read().await.unwrap().expect("expected message"); + assert_eq!(msg, b"hello from server"); + + // Concurrent bidirectional bulk: writer and reader in separate tasks, + // read_exact is never cancelled — the exact property the split provides. + let writer = tokio::spawn(async move { + for i in 0..100 { + let payload = format!("bulk {i}"); + client_wr.write(payload.as_bytes()).await.unwrap(); + } + client_wr.shutdown().await.unwrap(); + }); + let reader = tokio::spawn(async move { + let mut count = 0; + while let Some(msg) = server_rd.read().await.unwrap() { + assert_eq!(msg, format!("bulk {count}").as_bytes()); + count += 1; + } + count + }); + writer.await.unwrap(); + assert_eq!(reader.await.unwrap(), 100); + } + + #[tokio::test] + async fn into_split_empty_keepalive() { + // Empty messages are consumed as keepalives by the split read half, + // matching SecretStream::read. + let (client_stream, server_stream) = tokio::io::duplex(8192); + + let (client, server) = tokio::try_join!( + SecretStream::new(true, client_stream, noise::generate_keypair()), + SecretStream::new(false, server_stream, noise::generate_keypair()), + ) + .unwrap(); + + let (_client_rd, mut client_wr) = client.into_split(); + let (mut server_rd, _server_wr) = server.into_split(); + + client_wr.write(b"").await.unwrap(); + client_wr.write(b"after empty").await.unwrap(); + let msg = server_rd.read().await.unwrap().expect("expected message"); + assert_eq!(msg, b"after empty"); + } }