From 86c288bfb227e135637661d3134e6e1c7145eca2 Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Wed, 3 Jun 2026 11:20:40 +0100 Subject: [PATCH] refactor(server): deduplicate test helpers and grpc utilities Remove three groups of copy-pasted code in openshell-server: 1. grpc/mod.rs had a private current_time_ms() wrapper identical to the one already exported from persistence/mod.rs. Remove the duplicate and update the three grpc sub-modules (policy, sandbox, service) to import directly from crate::persistence. 2. test_store() was repeated verbatim in seven #[cfg(test)] blocks. Promote a single canonical version to persistence/mod.rs (cfg-gated) and replace all copies with crate::persistence::test_store() calls or a thin Arc wrapper in supervisor_session. 3. grpc_client_mtls() and build_tls_root() were copy-pasted across edge_tunnel_auth.rs and multiplex_tls_integration.rs. Move both into the existing tests/common/mod.rs shared module and import from there. --- crates/openshell-server/src/grpc/mod.rs | 4 -- crates/openshell-server/src/grpc/policy.rs | 10 ++--- crates/openshell-server/src/grpc/provider.rs | 7 +--- crates/openshell-server/src/grpc/sandbox.rs | 5 +-- crates/openshell-server/src/grpc/service.rs | 2 +- crates/openshell-server/src/inference.rs | 4 +- .../openshell-server/src/persistence/mod.rs | 7 ++++ .../openshell-server/src/persistence/tests.rs | 8 +--- .../openshell-server/src/provider_refresh.rs | 8 +--- crates/openshell-server/src/ssh_sessions.rs | 4 +- .../src/supervisor_session.rs | 6 +-- crates/openshell-server/tests/common/mod.rs | 39 +++++++++++++++++++ .../tests/edge_tunnel_auth.rs | 37 +----------------- .../tests/multiplex_tls_integration.rs | 39 ++----------------- 14 files changed, 63 insertions(+), 117 deletions(-) diff --git a/crates/openshell-server/src/grpc/mod.rs b/crates/openshell-server/src/grpc/mod.rs index 32885a9a9..5947bb334 100644 --- a/crates/openshell-server/src/grpc/mod.rs +++ b/crates/openshell-server/src/grpc/mod.rs @@ -153,10 +153,6 @@ enum StoredSettingValue { // Utility // --------------------------------------------------------------------------- -fn current_time_ms() -> i64 { - openshell_core::time::now_ms() -} - /// Validate that object metadata is present and contains required fields. /// /// This is a crate-level helper that wraps the validation module's implementation. diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index 2c8b2d336..380671f10 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -72,7 +72,8 @@ use tracing::{debug, info, warn}; use super::validation::{ level_matches, source_matches, validate_policy_safety, validate_static_fields_unchanged, }; -use super::{MAX_PAGE_SIZE, StoredSettingValue, StoredSettings, clamp_limit, current_time_ms}; +use super::{MAX_PAGE_SIZE, StoredSettingValue, StoredSettings, clamp_limit}; +use crate::persistence::current_time_ms; // --------------------------------------------------------------------------- // Constants @@ -3853,16 +3854,11 @@ mod tests { Principal, SandboxIdentitySource, SandboxPrincipal, UserPrincipal, }; use crate::grpc::test_support::test_server_state; + use crate::persistence::test_store; use std::collections::HashMap; use std::sync::Arc; use tonic::Code; - async fn test_store() -> Store { - Store::connect("sqlite::memory:?cache=shared") - .await - .expect("in-memory SQLite store should connect") - } - /// Wrap a request with a user `Principal` so handler scope guards treat /// the test caller as a CLI user. Most handler tests exercise /// user-facing behavior and should not trip sandbox equality checks. diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index 4a869fbcd..7591bdd6b 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -1604,12 +1604,7 @@ mod tests { use super::*; use crate::grpc::MAX_MAP_KEY_LEN; use crate::grpc::test_support::test_server_state; - - async fn test_store() -> Store { - Store::connect("sqlite::memory:?cache=shared") - .await - .expect("in-memory SQLite store should connect") - } + use crate::persistence::test_store; use openshell_core::proto::{ DeleteProviderProfileRequest, GetProviderProfileRequest, ImportProviderProfilesRequest, L7Allow, L7Rule, LintProviderProfilesRequest, ListProviderProfilesRequest, NetworkBinary, diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 198d5f04c..e60ce3995 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -49,7 +49,8 @@ use super::validation::{ level_matches, source_matches, validate_exec_request_fields, validate_policy_safety, validate_sandbox_spec, }; -use super::{MAX_PAGE_SIZE, MAX_PROVIDERS, clamp_limit, current_time_ms}; +use super::{MAX_PAGE_SIZE, MAX_PROVIDERS, clamp_limit}; +use crate::persistence::current_time_ms; const TCP_FORWARD_CHUNK_SIZE: usize = 64 * 1024; @@ -117,8 +118,6 @@ async fn handle_create_sandbox_inner( state: &Arc, request: Request, ) -> Result, Status> { - use crate::persistence::current_time_ms; - let request = request.into_inner(); let spec = request .spec diff --git a/crates/openshell-server/src/grpc/service.rs b/crates/openshell-server/src/grpc/service.rs index a8144b4c7..246d639be 100644 --- a/crates/openshell-server/src/grpc/service.rs +++ b/crates/openshell-server/src/grpc/service.rs @@ -39,7 +39,7 @@ pub(super) async fn handle_expose_service( .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? .ok_or_else(|| Status::not_found("sandbox not found"))?; - let now = super::current_time_ms(); + let now = crate::persistence::current_time_ms(); let key = service_routing::endpoint_key(&req.sandbox, &req.service); // Fetch existing endpoint to determine create vs. update path diff --git a/crates/openshell-server/src/inference.rs b/crates/openshell-server/src/inference.rs index 11547c620..58b5feb2a 100644 --- a/crates/openshell-server/src/inference.rs +++ b/crates/openshell-server/src/inference.rs @@ -934,9 +934,7 @@ mod tests { use wiremock::{Mock, MockServer, ResponseTemplate}; async fn test_store() -> Store { - Store::connect("sqlite::memory:?cache=shared") - .await - .expect("in-memory SQLite store should connect") + crate::persistence::test_store().await } fn test_user_principal() -> Principal { diff --git a/crates/openshell-server/src/persistence/mod.rs b/crates/openshell-server/src/persistence/mod.rs index 6aa2c3bc7..aad8b39c9 100644 --- a/crates/openshell-server/src/persistence/mod.rs +++ b/crates/openshell-server/src/persistence/mod.rs @@ -687,5 +687,12 @@ impl Store { } } +#[cfg(test)] +pub async fn test_store() -> Store { + Store::connect("sqlite::memory:?cache=shared") + .await + .expect("in-memory SQLite store should connect") +} + #[cfg(test)] mod tests; diff --git a/crates/openshell-server/src/persistence/tests.rs b/crates/openshell-server/src/persistence/tests.rs index bd14c39f9..d092b68de 100644 --- a/crates/openshell-server/src/persistence/tests.rs +++ b/crates/openshell-server/src/persistence/tests.rs @@ -1,17 +1,11 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use super::{ObjectType, PersistenceError, Store, generate_name}; +use super::{ObjectType, PersistenceError, Store, generate_name, test_store}; use crate::policy_store::PolicyStoreExt; use openshell_core::proto::{ObjectForTest, SandboxPolicy}; use prost::Message; -async fn test_store() -> Store { - Store::connect("sqlite::memory:?cache=shared") - .await - .expect("in-memory SQLite store should connect") -} - #[tokio::test] async fn sqlite_put_get_round_trip() { let store = test_store().await; diff --git a/crates/openshell-server/src/provider_refresh.rs b/crates/openshell-server/src/provider_refresh.rs index 161daeb7f..b0b9a927c 100644 --- a/crates/openshell-server/src/provider_refresh.rs +++ b/crates/openshell-server/src/provider_refresh.rs @@ -776,7 +776,7 @@ mod tests { refresh_provider_credential, refresh_state_name, refresh_strategy_name, run_refresh_worker_tick, seconds_until_ms, }; - use crate::persistence::Store; + use crate::persistence::test_store; use openshell_core::ObjectId; use openshell_core::proto::datamodel::v1::ObjectMeta; use openshell_core::proto::{ @@ -786,12 +786,6 @@ mod tests { use wiremock::matchers::{body_string_contains, method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; - async fn test_store() -> Store { - Store::connect("sqlite::memory:?cache=shared") - .await - .expect("in-memory SQLite store should connect") - } - #[test] fn refresh_state_name_preserves_distinct_credential_keys() { let provider_id = "provider-id"; diff --git a/crates/openshell-server/src/ssh_sessions.rs b/crates/openshell-server/src/ssh_sessions.rs index 3f1f24a7d..e55c1b6af 100644 --- a/crates/openshell-server/src/ssh_sessions.rs +++ b/crates/openshell-server/src/ssh_sessions.rs @@ -84,9 +84,7 @@ mod tests { use std::collections::HashMap; async fn test_store() -> Store { - Store::connect("sqlite::memory:?cache=shared") - .await - .expect("in-memory SQLite store should connect") + crate::persistence::test_store().await } fn make_session(id: &str, sandbox_id: &str, expires_at_ms: i64, revoked: bool) -> SshSession { diff --git a/crates/openshell-server/src/supervisor_session.rs b/crates/openshell-server/src/supervisor_session.rs index 77734f929..4adf9e8b6 100644 --- a/crates/openshell-server/src/supervisor_session.rs +++ b/crates/openshell-server/src/supervisor_session.rs @@ -823,11 +823,7 @@ mod tests { use tokio::io::{AsyncReadExt, AsyncWriteExt}; async fn test_store() -> Arc { - Arc::new( - Store::connect("sqlite::memory:?cache=shared") - .await - .expect("in-memory SQLite store should connect"), - ) + Arc::new(crate::persistence::test_store().await) } /// Returns a shutdown sender with its receiver immediately dropped. Tests diff --git a/crates/openshell-server/tests/common/mod.rs b/crates/openshell-server/tests/common/mod.rs index 563eaa238..00228b043 100644 --- a/crates/openshell-server/tests/common/mod.rs +++ b/crates/openshell-server/tests/common/mod.rs @@ -25,10 +25,14 @@ use openshell_core::proto::{ RefreshSandboxTokenRequest, RefreshSandboxTokenResponse, RelayFrame, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, TcpForwardFrame, UpdateProviderRequest, WatchSandboxRequest, + open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; use openshell_server::{MultiplexedService, Store, TlsAcceptor, health_router}; use rcgen::{CertificateParams, IsCa, KeyPair}; +use rustls::RootCertStore; +use rustls::pki_types::CertificateDer; +use rustls_pemfile::certs; use std::io::Write; use std::net::SocketAddr; use std::sync::Arc; @@ -36,6 +40,7 @@ use tempfile::tempdir; use tokio::net::TcpListener; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; +use tonic::transport::{Channel, ClientTlsConfig, Endpoint}; use tonic::{Response, Status}; // --------------------------------------------------------------------------- @@ -620,3 +625,37 @@ pub async fn test_health_store() -> Arc { .expect("connect in-memory sqlite store for tests"), ) } + +/// Parse PEM cert bytes into a `RootCertStore`. +pub fn build_tls_root(cert_pem: &[u8]) -> RootCertStore { + let mut roots = RootCertStore::empty(); + let mut cursor = std::io::Cursor::new(cert_pem); + let parsed = certs(&mut cursor) + .collect::>, _>>() + .expect("failed to parse cert pem"); + for cert in parsed { + roots.add(cert).expect("failed to add cert"); + } + roots +} + +/// Build a gRPC client with mTLS (CA + client cert). +pub async fn grpc_client_mtls( + addr: SocketAddr, + ca_pem: Vec, + client_cert_pem: Vec, + client_key_pem: Vec, +) -> OpenShellClient { + let ca_cert = tonic::transport::Certificate::from_pem(ca_pem); + let identity = tonic::transport::Identity::from_pem(client_cert_pem, client_key_pem); + let tls = ClientTlsConfig::new() + .ca_certificate(ca_cert) + .identity(identity) + .domain_name("localhost"); + let endpoint = Endpoint::from_shared(format!("https://localhost:{}", addr.port())) + .expect("invalid endpoint") + .tls_config(tls) + .expect("failed to set tls"); + let channel = endpoint.connect().await.expect("failed to connect"); + OpenShellClient::new(channel) +} diff --git a/crates/openshell-server/tests/edge_tunnel_auth.rs b/crates/openshell-server/tests/edge_tunnel_auth.rs index c49953b23..4df221f11 100644 --- a/crates/openshell-server/tests/edge_tunnel_auth.rs +++ b/crates/openshell-server/tests/edge_tunnel_auth.rs @@ -28,7 +28,8 @@ mod common; use bytes::Bytes; use common::{ - PkiBundle, generate_pki, generate_rogue_pki, install_rustls_provider, start_test_server, + PkiBundle, build_tls_root, generate_pki, generate_rogue_pki, grpc_client_mtls, + install_rustls_provider, start_test_server, }; use http_body_util::Empty; use hyper::{Request, StatusCode}; @@ -36,7 +37,6 @@ use hyper_rustls::HttpsConnectorBuilder; use hyper_util::{client::legacy::Client, rt::TokioExecutor}; use openshell_core::proto::{HealthRequest, ServiceStatus, open_shell_client::OpenShellClient}; use openshell_server::TlsAcceptor; -use rustls::RootCertStore; use rustls::pki_types::CertificateDer; use rustls_pemfile::certs; use tonic::Status; @@ -46,27 +46,6 @@ use tonic::transport::{Channel, ClientTlsConfig, Endpoint}; // Client helpers // --------------------------------------------------------------------------- -/// Build a gRPC client with mTLS (CA + client cert). -async fn grpc_client_mtls( - addr: std::net::SocketAddr, - ca_pem: Vec, - client_cert_pem: Vec, - client_key_pem: Vec, -) -> OpenShellClient { - let ca_cert = tonic::transport::Certificate::from_pem(ca_pem); - let identity = tonic::transport::Identity::from_pem(client_cert_pem, client_key_pem); - let tls = ClientTlsConfig::new() - .ca_certificate(ca_cert) - .identity(identity) - .domain_name("localhost"); - let endpoint = Endpoint::from_shared(format!("https://localhost:{}", addr.port())) - .expect("invalid endpoint") - .tls_config(tls) - .expect("failed to set tls"); - let channel = endpoint.connect().await.expect("failed to connect"); - OpenShellClient::new(channel) -} - /// Build a gRPC client *without* a client cert (simulates Cloudflare tunnel). async fn grpc_client_no_cert( addr: std::net::SocketAddr, @@ -123,18 +102,6 @@ impl tonic::service::Interceptor for CfInterceptor { } } -fn build_tls_root(cert_pem: &[u8]) -> RootCertStore { - let mut roots = RootCertStore::empty(); - let mut cursor = std::io::Cursor::new(cert_pem); - let parsed = certs(&mut cursor) - .collect::>, _>>() - .expect("failed to parse cert pem"); - for cert in parsed { - roots.add(cert).expect("failed to add cert"); - } - roots -} - /// Build an HTTPS client with mTLS. fn https_client_mtls( pki: &PkiBundle, diff --git a/crates/openshell-server/tests/multiplex_tls_integration.rs b/crates/openshell-server/tests/multiplex_tls_integration.rs index 31ece9ed6..4e17fdef9 100644 --- a/crates/openshell-server/tests/multiplex_tls_integration.rs +++ b/crates/openshell-server/tests/multiplex_tls_integration.rs @@ -5,7 +5,8 @@ mod common; use bytes::Bytes; use common::{ - PkiBundle, generate_pki, generate_rogue_pki, install_rustls_provider, start_test_server, + PkiBundle, build_tls_root, generate_pki, generate_rogue_pki, grpc_client_mtls, + install_rustls_provider, start_test_server, }; use http_body_util::Empty; use hyper::Request; @@ -13,43 +14,9 @@ use hyper::StatusCode; use hyper_rustls::HttpsConnectorBuilder; use hyper_util::{client::legacy::Client, rt::TokioExecutor}; use openshell_core::proto::{HealthRequest, ServiceStatus, open_shell_client::OpenShellClient}; -use rustls::RootCertStore; use rustls::pki_types::CertificateDer; use rustls_pemfile::certs; -use tonic::transport::{Channel, ClientTlsConfig, Endpoint}; - -fn build_tls_root(cert_pem: &[u8]) -> RootCertStore { - let mut roots = RootCertStore::empty(); - let mut cursor = std::io::Cursor::new(cert_pem); - let parsed = certs(&mut cursor) - .collect::>, _>>() - .expect("failed to parse cert pem"); - for cert in parsed { - roots.add(cert).expect("failed to add cert"); - } - roots -} - -/// Build a gRPC client with mTLS (CA + client cert). -async fn grpc_client_mtls( - addr: std::net::SocketAddr, - ca_pem: Vec, - client_cert_pem: Vec, - client_key_pem: Vec, -) -> OpenShellClient { - let ca_cert = tonic::transport::Certificate::from_pem(ca_pem); - let identity = tonic::transport::Identity::from_pem(client_cert_pem, client_key_pem); - let tls = ClientTlsConfig::new() - .ca_certificate(ca_cert) - .identity(identity) - .domain_name("localhost"); - let endpoint = Endpoint::from_shared(format!("https://localhost:{}", addr.port())) - .expect("invalid endpoint") - .tls_config(tls) - .expect("failed to set tls"); - let channel = endpoint.connect().await.expect("failed to connect"); - OpenShellClient::new(channel) -} +use tonic::transport::{ClientTlsConfig, Endpoint}; /// Build an HTTPS client with mTLS (CA trust + client cert/key). fn https_client_mtls(