Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions crates/openshell-server/src/grpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,6 @@ enum StoredSettingValue {
// Utility
// ---------------------------------------------------------------------------

fn current_time_ms() -> i64 {
openshell_core::time::now_ms()
}

/// Validate that object metadata is present and contains required fields.
///
/// This is a crate-level helper that wraps the validation module's implementation.
Expand Down
10 changes: 3 additions & 7 deletions crates/openshell-server/src/grpc/policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ use tracing::{debug, info, warn};
use super::validation::{
level_matches, source_matches, validate_policy_safety, validate_static_fields_unchanged,
};
use super::{MAX_PAGE_SIZE, StoredSettingValue, StoredSettings, clamp_limit, current_time_ms};
use super::{MAX_PAGE_SIZE, StoredSettingValue, StoredSettings, clamp_limit};
use crate::persistence::current_time_ms;

// ---------------------------------------------------------------------------
// Constants
Expand Down Expand Up @@ -3853,16 +3854,11 @@ mod tests {
Principal, SandboxIdentitySource, SandboxPrincipal, UserPrincipal,
};
use crate::grpc::test_support::test_server_state;
use crate::persistence::test_store;
use std::collections::HashMap;
use std::sync::Arc;
use tonic::Code;

async fn test_store() -> Store {
Store::connect("sqlite::memory:?cache=shared")
.await
.expect("in-memory SQLite store should connect")
}

/// Wrap a request with a user `Principal` so handler scope guards treat
/// the test caller as a CLI user. Most handler tests exercise
/// user-facing behavior and should not trip sandbox equality checks.
Expand Down
7 changes: 1 addition & 6 deletions crates/openshell-server/src/grpc/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1604,12 +1604,7 @@ mod tests {
use super::*;
use crate::grpc::MAX_MAP_KEY_LEN;
use crate::grpc::test_support::test_server_state;

async fn test_store() -> Store {
Store::connect("sqlite::memory:?cache=shared")
.await
.expect("in-memory SQLite store should connect")
}
use crate::persistence::test_store;
use openshell_core::proto::{
DeleteProviderProfileRequest, GetProviderProfileRequest, ImportProviderProfilesRequest,
L7Allow, L7Rule, LintProviderProfilesRequest, ListProviderProfilesRequest, NetworkBinary,
Expand Down
5 changes: 2 additions & 3 deletions crates/openshell-server/src/grpc/sandbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ use super::validation::{
level_matches, source_matches, validate_exec_request_fields, validate_policy_safety,
validate_sandbox_spec,
};
use super::{MAX_PAGE_SIZE, MAX_PROVIDERS, clamp_limit, current_time_ms};
use super::{MAX_PAGE_SIZE, MAX_PROVIDERS, clamp_limit};
use crate::persistence::current_time_ms;

const TCP_FORWARD_CHUNK_SIZE: usize = 64 * 1024;

Expand Down Expand Up @@ -117,8 +118,6 @@ async fn handle_create_sandbox_inner(
state: &Arc<ServerState>,
request: Request<CreateSandboxRequest>,
) -> Result<Response<SandboxResponse>, Status> {
use crate::persistence::current_time_ms;

let request = request.into_inner();
let spec = request
.spec
Expand Down
2 changes: 1 addition & 1 deletion crates/openshell-server/src/grpc/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub(super) async fn handle_expose_service(
.map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))?
.ok_or_else(|| Status::not_found("sandbox not found"))?;

let now = super::current_time_ms();
let now = crate::persistence::current_time_ms();
let key = service_routing::endpoint_key(&req.sandbox, &req.service);

// Fetch existing endpoint to determine create vs. update path
Expand Down
4 changes: 1 addition & 3 deletions crates/openshell-server/src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -934,9 +934,7 @@ mod tests {
use wiremock::{Mock, MockServer, ResponseTemplate};

async fn test_store() -> Store {
Store::connect("sqlite::memory:?cache=shared")
.await
.expect("in-memory SQLite store should connect")
crate::persistence::test_store().await
}

fn test_user_principal() -> Principal {
Expand Down
7 changes: 7 additions & 0 deletions crates/openshell-server/src/persistence/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -687,5 +687,12 @@ impl Store {
}
}

#[cfg(test)]
pub async fn test_store() -> Store {
Store::connect("sqlite::memory:?cache=shared")
.await
.expect("in-memory SQLite store should connect")
}

#[cfg(test)]
mod tests;
8 changes: 1 addition & 7 deletions crates/openshell-server/src/persistence/tests.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use super::{ObjectType, PersistenceError, Store, generate_name};
use super::{ObjectType, PersistenceError, Store, generate_name, test_store};
use crate::policy_store::PolicyStoreExt;
use openshell_core::proto::{ObjectForTest, SandboxPolicy};
use prost::Message;

async fn test_store() -> Store {
Store::connect("sqlite::memory:?cache=shared")
.await
.expect("in-memory SQLite store should connect")
}

#[tokio::test]
async fn sqlite_put_get_round_trip() {
let store = test_store().await;
Expand Down
8 changes: 1 addition & 7 deletions crates/openshell-server/src/provider_refresh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ mod tests {
refresh_provider_credential, refresh_state_name, refresh_strategy_name,
run_refresh_worker_tick, seconds_until_ms,
};
use crate::persistence::Store;
use crate::persistence::test_store;
use openshell_core::ObjectId;
use openshell_core::proto::datamodel::v1::ObjectMeta;
use openshell_core::proto::{
Expand All @@ -786,12 +786,6 @@ mod tests {
use wiremock::matchers::{body_string_contains, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};

async fn test_store() -> Store {
Store::connect("sqlite::memory:?cache=shared")
.await
.expect("in-memory SQLite store should connect")
}

#[test]
fn refresh_state_name_preserves_distinct_credential_keys() {
let provider_id = "provider-id";
Expand Down
4 changes: 1 addition & 3 deletions crates/openshell-server/src/ssh_sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ mod tests {
use std::collections::HashMap;

async fn test_store() -> Store {
Store::connect("sqlite::memory:?cache=shared")
.await
.expect("in-memory SQLite store should connect")
crate::persistence::test_store().await
}

fn make_session(id: &str, sandbox_id: &str, expires_at_ms: i64, revoked: bool) -> SshSession {
Expand Down
6 changes: 1 addition & 5 deletions crates/openshell-server/src/supervisor_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -823,11 +823,7 @@ mod tests {
use tokio::io::{AsyncReadExt, AsyncWriteExt};

async fn test_store() -> Arc<Store> {
Arc::new(
Store::connect("sqlite::memory:?cache=shared")
.await
.expect("in-memory SQLite store should connect"),
)
Arc::new(crate::persistence::test_store().await)
}

/// Returns a shutdown sender with its receiver immediately dropped. Tests
Expand Down
39 changes: 39 additions & 0 deletions crates/openshell-server/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,22 @@ use openshell_core::proto::{
RefreshSandboxTokenRequest, RefreshSandboxTokenResponse, RelayFrame, RevokeSshSessionRequest,
RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus,
SupervisorMessage, TcpForwardFrame, UpdateProviderRequest, WatchSandboxRequest,
open_shell_client::OpenShellClient,
open_shell_server::{OpenShell, OpenShellServer},
};
use openshell_server::{MultiplexedService, Store, TlsAcceptor, health_router};
use rcgen::{CertificateParams, IsCa, KeyPair};
use rustls::RootCertStore;
use rustls::pki_types::CertificateDer;
use rustls_pemfile::certs;
use std::io::Write;
use std::net::SocketAddr;
use std::sync::Arc;
use tempfile::tempdir;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
use tonic::{Response, Status};

// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -620,3 +625,37 @@ pub async fn test_health_store() -> Arc<Store> {
.expect("connect in-memory sqlite store for tests"),
)
}

/// Parse PEM cert bytes into a `RootCertStore`.
pub fn build_tls_root(cert_pem: &[u8]) -> RootCertStore {
let mut roots = RootCertStore::empty();
let mut cursor = std::io::Cursor::new(cert_pem);
let parsed = certs(&mut cursor)
.collect::<Result<Vec<CertificateDer<'static>>, _>>()
.expect("failed to parse cert pem");
for cert in parsed {
roots.add(cert).expect("failed to add cert");
}
roots
}

/// Build a gRPC client with mTLS (CA + client cert).
pub async fn grpc_client_mtls(
addr: SocketAddr,
ca_pem: Vec<u8>,
client_cert_pem: Vec<u8>,
client_key_pem: Vec<u8>,
) -> OpenShellClient<Channel> {
let ca_cert = tonic::transport::Certificate::from_pem(ca_pem);
let identity = tonic::transport::Identity::from_pem(client_cert_pem, client_key_pem);
let tls = ClientTlsConfig::new()
.ca_certificate(ca_cert)
.identity(identity)
.domain_name("localhost");
let endpoint = Endpoint::from_shared(format!("https://localhost:{}", addr.port()))
.expect("invalid endpoint")
.tls_config(tls)
.expect("failed to set tls");
let channel = endpoint.connect().await.expect("failed to connect");
OpenShellClient::new(channel)
}
37 changes: 2 additions & 35 deletions crates/openshell-server/tests/edge_tunnel_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ mod common;

use bytes::Bytes;
use common::{
PkiBundle, generate_pki, generate_rogue_pki, install_rustls_provider, start_test_server,
PkiBundle, build_tls_root, generate_pki, generate_rogue_pki, grpc_client_mtls,
install_rustls_provider, start_test_server,
};
use http_body_util::Empty;
use hyper::{Request, StatusCode};
use hyper_rustls::HttpsConnectorBuilder;
use hyper_util::{client::legacy::Client, rt::TokioExecutor};
use openshell_core::proto::{HealthRequest, ServiceStatus, open_shell_client::OpenShellClient};
use openshell_server::TlsAcceptor;
use rustls::RootCertStore;
use rustls::pki_types::CertificateDer;
use rustls_pemfile::certs;
use tonic::Status;
Expand All @@ -46,27 +46,6 @@ use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
// Client helpers
// ---------------------------------------------------------------------------

/// Build a gRPC client with mTLS (CA + client cert).
async fn grpc_client_mtls(
addr: std::net::SocketAddr,
ca_pem: Vec<u8>,
client_cert_pem: Vec<u8>,
client_key_pem: Vec<u8>,
) -> OpenShellClient<Channel> {
let ca_cert = tonic::transport::Certificate::from_pem(ca_pem);
let identity = tonic::transport::Identity::from_pem(client_cert_pem, client_key_pem);
let tls = ClientTlsConfig::new()
.ca_certificate(ca_cert)
.identity(identity)
.domain_name("localhost");
let endpoint = Endpoint::from_shared(format!("https://localhost:{}", addr.port()))
.expect("invalid endpoint")
.tls_config(tls)
.expect("failed to set tls");
let channel = endpoint.connect().await.expect("failed to connect");
OpenShellClient::new(channel)
}

/// Build a gRPC client *without* a client cert (simulates Cloudflare tunnel).
async fn grpc_client_no_cert(
addr: std::net::SocketAddr,
Expand Down Expand Up @@ -123,18 +102,6 @@ impl tonic::service::Interceptor for CfInterceptor {
}
}

fn build_tls_root(cert_pem: &[u8]) -> RootCertStore {
let mut roots = RootCertStore::empty();
let mut cursor = std::io::Cursor::new(cert_pem);
let parsed = certs(&mut cursor)
.collect::<Result<Vec<CertificateDer<'static>>, _>>()
.expect("failed to parse cert pem");
for cert in parsed {
roots.add(cert).expect("failed to add cert");
}
roots
}

/// Build an HTTPS client with mTLS.
fn https_client_mtls(
pki: &PkiBundle,
Expand Down
39 changes: 3 additions & 36 deletions crates/openshell-server/tests/multiplex_tls_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,51 +5,18 @@ mod common;

use bytes::Bytes;
use common::{
PkiBundle, generate_pki, generate_rogue_pki, install_rustls_provider, start_test_server,
PkiBundle, build_tls_root, generate_pki, generate_rogue_pki, grpc_client_mtls,
install_rustls_provider, start_test_server,
};
use http_body_util::Empty;
use hyper::Request;
use hyper::StatusCode;
use hyper_rustls::HttpsConnectorBuilder;
use hyper_util::{client::legacy::Client, rt::TokioExecutor};
use openshell_core::proto::{HealthRequest, ServiceStatus, open_shell_client::OpenShellClient};
use rustls::RootCertStore;
use rustls::pki_types::CertificateDer;
use rustls_pemfile::certs;
use tonic::transport::{Channel, ClientTlsConfig, Endpoint};

fn build_tls_root(cert_pem: &[u8]) -> RootCertStore {
let mut roots = RootCertStore::empty();
let mut cursor = std::io::Cursor::new(cert_pem);
let parsed = certs(&mut cursor)
.collect::<Result<Vec<CertificateDer<'static>>, _>>()
.expect("failed to parse cert pem");
for cert in parsed {
roots.add(cert).expect("failed to add cert");
}
roots
}

/// Build a gRPC client with mTLS (CA + client cert).
async fn grpc_client_mtls(
addr: std::net::SocketAddr,
ca_pem: Vec<u8>,
client_cert_pem: Vec<u8>,
client_key_pem: Vec<u8>,
) -> OpenShellClient<Channel> {
let ca_cert = tonic::transport::Certificate::from_pem(ca_pem);
let identity = tonic::transport::Identity::from_pem(client_cert_pem, client_key_pem);
let tls = ClientTlsConfig::new()
.ca_certificate(ca_cert)
.identity(identity)
.domain_name("localhost");
let endpoint = Endpoint::from_shared(format!("https://localhost:{}", addr.port()))
.expect("invalid endpoint")
.tls_config(tls)
.expect("failed to set tls");
let channel = endpoint.connect().await.expect("failed to connect");
OpenShellClient::new(channel)
}
use tonic::transport::{ClientTlsConfig, Endpoint};

/// Build an HTTPS client with mTLS (CA trust + client cert/key).
fn https_client_mtls(
Expand Down
Loading