diff --git a/Cargo.lock b/Cargo.lock index 4bc657be3..eee476aaf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3481,6 +3481,7 @@ dependencies = [ "bytes", "futures", "openshell-core", + "prost-types", "serde", "tar", "temp-env", @@ -3528,6 +3529,7 @@ dependencies = [ "miette", "nix", "openshell-core", + "prost-types", "rustix 1.1.4", "serde", "serde_json", diff --git a/architecture/compute-runtimes.md b/architecture/compute-runtimes.md index 02891c03e..d1f91156d 100644 --- a/architecture/compute-runtimes.md +++ b/architecture/compute-runtimes.md @@ -40,6 +40,14 @@ template resource limits. Docker and Podman apply them as runtime limits. Kubernetes mirrors each limit into the matching request. VM accepts the fields but currently ignores them. +GPU requests enter the driver layer through +`SandboxSpec.resource_requirements.gpu`. The compact interim shape supports a +default GPU request and GPU count. Exact driver-native device selection is +passed through the selected runtime's `driver_config` block; the gateway +selects that block but does not interpret the nested driver schema. Drivers +that support exact selection validate that the unique `gpu_device_ids` entry +count matches the portable GPU count. + VM runtime state paths are derived only from driver-validated sandbox IDs matching `[A-Za-z0-9._-]{1,128}`. The gateway-owned VM driver socket uses a private `run/` directory plus Unix peer UID/PID checks. Standalone diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 2254f0c89..490a4cd2c 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -1215,10 +1215,14 @@ enum SandboxCommands { /// Target a driver-specific GPU device. Docker and Podman use CDI device IDs /// (for example "nvidia.com/gpu=0"); VM uses a PCI BDF or index. - /// Only valid with --gpu. When omitted with --gpu, the driver uses its default GPU selection. - #[arg(long, requires = "gpu")] + /// When omitted with --gpu, the driver uses its default GPU selection. + #[arg(long, conflicts_with = "gpu_count")] gpu_device: Option, + /// Request a specific number of GPUs. Mutually exclusive with --gpu-device. + #[arg(long, value_parser = clap::value_parser!(u32).range(1..), conflicts_with = "gpu_device")] + gpu_count: Option, + /// CPU limit for the sandbox (for example: 500m, 1, 2.5). #[arg(long)] cpu: Option, @@ -2539,6 +2543,7 @@ async fn main() -> Result<()> { editor, gpu, gpu_device, + gpu_count, cpu, memory, providers, @@ -2608,6 +2613,7 @@ async fn main() -> Result<()> { keep, gpu, gpu_device.as_deref(), + gpu_count, cpu.as_deref(), memory.as_deref(), editor, @@ -4287,6 +4293,78 @@ mod tests { } } + #[test] + fn sandbox_create_gpu_count_parses_without_gpu_flag() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu-count", "2"]) + .expect("sandbox create --gpu-count should parse"); + + match cli.command { + Some(Commands::Sandbox { + command: Some(SandboxCommands::Create { gpu, gpu_count, .. }), + .. + }) => { + assert!(!gpu); + assert_eq!(gpu_count, Some(2)); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn sandbox_create_gpu_count_rejects_zero() { + let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu-count", "0"]); + + assert!( + result.is_err(), + "sandbox create --gpu-count 0 should be rejected" + ); + } + + #[test] + fn sandbox_create_gpu_device_parses_without_gpu_flag() { + let cli = Cli::try_parse_from([ + "openshell", + "sandbox", + "create", + "--gpu-device", + "nvidia.com/gpu=0", + ]) + .expect("sandbox create --gpu-device should parse without --gpu"); + + match cli.command { + Some(Commands::Sandbox { + command: + Some(SandboxCommands::Create { + gpu, gpu_device, .. + }), + .. + }) => { + assert!(!gpu); + assert_eq!(gpu_device.as_deref(), Some("nvidia.com/gpu=0")); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn sandbox_create_gpu_count_conflicts_with_gpu_device() { + let result = Cli::try_parse_from([ + "openshell", + "sandbox", + "create", + "--gpu", + "--gpu-device", + "nvidia.com/gpu=0", + "--gpu-count", + "2", + ]); + + assert!( + result.is_err(), + "sandbox create should reject --gpu-count with --gpu-device" + ); + } + #[test] fn service_expose_accepts_positional_target_port_and_service() { let cli = Cli::try_parse_from([ diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 9988d46db..3e3ce1e08 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -39,17 +39,18 @@ use openshell_core::proto::{ GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, GetProviderProfileRequest, GetProviderRefreshStatusRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, - GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, HealthRequest, - ImportProviderProfilesRequest, LintProviderProfilesRequest, ListProviderProfilesRequest, - ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxProvidersRequest, - ListSandboxesRequest, ListServicesRequest, PlatformEvent, PolicySource, PolicyStatus, Provider, - ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy, ProviderProfile, - ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, - RevokeSshSessionRequest, RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy, - SandboxSpec, SandboxTemplate, ServiceEndpointResponse, SetClusterInferenceRequest, - SettingScope, SettingValue, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, - UpdateConfigRequest, UpdateProviderRequest, WatchSandboxRequest, exec_sandbox_event, - setting_value, tcp_forward_init, + GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, GpuResourceRequirement, + HealthRequest, ImportProviderProfilesRequest, LintProviderProfilesRequest, + ListProviderProfilesRequest, ListProvidersRequest, ListSandboxPoliciesRequest, + ListSandboxProvidersRequest, ListSandboxesRequest, ListServicesRequest, PlatformEvent, + PolicySource, PolicyStatus, Provider, ProviderCredentialRefreshStatus, + ProviderCredentialRefreshStrategy, ProviderProfile, ProviderProfileDiagnostic, + ProviderProfileImportItem, RejectDraftChunkRequest, RevokeSshSessionRequest, + RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy, + SandboxResourceRequirements, SandboxSpec, SandboxTemplate, ServiceEndpointResponse, + SetClusterInferenceRequest, SettingScope, SettingValue, TcpForwardFrame, TcpForwardInit, + TcpRelayTarget, UpdateConfigRequest, UpdateProviderRequest, WatchSandboxRequest, + exec_sandbox_event, setting_value, tcp_forward_init, }; use openshell_core::settings::{self, SettingValueKind}; use openshell_core::{ObjectId, ObjectName}; @@ -1679,6 +1680,7 @@ pub async fn sandbox_create( keep: bool, gpu: bool, gpu_device: Option<&str>, + gpu_count: Option, cpu: Option<&str>, memory: Option<&str>, editor: Option, @@ -1732,7 +1734,10 @@ pub async fn sandbox_create( } None => None, }; - let requested_gpu = gpu || image.as_deref().is_some_and(image_requests_gpu); + let gpu_device_ids = gpu_device_ids_from_cli(gpu_device); + let effective_gpu_count = gpu_count_from_cli(gpu_count, &gpu_device_ids); + let requested_gpu = + gpu || effective_gpu_count.is_some() || image.as_deref().is_some_and(image_requests_gpu); let providers_v2_enabled = gateway_providers_v2_enabled(&mut client).await?; let inferred_types: Vec = if providers_v2_enabled { @@ -1750,11 +1755,13 @@ pub async fn sandbox_create( let policy = load_sandbox_policy(policy)?; let resource_limits = build_sandbox_resource_limits(cpu, memory)?; + let driver_config = gpu_driver_config_from_cli(&gpu_device_ids); - let template = if image.is_some() || resource_limits.is_some() { + let template = if image.is_some() || resource_limits.is_some() || driver_config.is_some() { Some(SandboxTemplate { image: image.unwrap_or_default(), resources: resource_limits, + driver_config, ..SandboxTemplate::default() }) } else { @@ -1763,8 +1770,10 @@ pub async fn sandbox_create( let request = CreateSandboxRequest { spec: Some(SandboxSpec { - gpu: requested_gpu, - gpu_device: gpu_device.unwrap_or_default().to_string(), + resource_requirements: resource_requirements_from_cli( + requested_gpu, + effective_gpu_count, + ), policy, providers: configured_providers, template, @@ -2189,6 +2198,74 @@ pub async fn sandbox_create( } } +fn resource_requirements_from_cli( + requested_gpu: bool, + gpu_count: Option, +) -> Option { + requested_gpu.then_some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { count: gpu_count }), + }) +} + +fn gpu_device_ids_from_cli(gpu_device: Option<&str>) -> Vec { + gpu_device + .map(str::trim) + .filter(|device_id| !device_id.is_empty()) + .map(|device_id| vec![device_id.to_string()]) + .unwrap_or_default() +} + +fn gpu_count_from_cli(gpu_count: Option, gpu_device_ids: &[String]) -> Option { + if gpu_device_ids.is_empty() { + gpu_count + } else { + u32::try_from(gpu_device_ids.len()).ok() + } +} + +fn gpu_driver_config_from_cli(gpu_device_ids: &[String]) -> Option { + use prost_types::{ListValue, Struct, Value, value::Kind}; + + fn string_value(value: &str) -> Value { + Value { + kind: Some(Kind::StringValue(value.to_string())), + } + } + + fn driver_block(gpu_device_ids: &[String]) -> Value { + Value { + kind: Some(Kind::StructValue(Struct { + fields: std::iter::once(( + "gpu_device_ids".to_string(), + Value { + kind: Some(Kind::ListValue(ListValue { + values: gpu_device_ids + .iter() + .map(|device_id| string_value(device_id)) + .collect(), + })), + }, + )) + .collect(), + })), + } + } + + if gpu_device_ids.is_empty() { + return None; + } + + Some(Struct { + fields: [ + ("docker".to_string(), driver_block(gpu_device_ids)), + ("podman".to_string(), driver_block(gpu_device_ids)), + ("vm".to_string(), driver_block(gpu_device_ids)), + ] + .into_iter() + .collect(), + }) +} + /// Resolved source for the `--from` flag on `sandbox create`. #[derive(Debug)] enum ResolvedSource { @@ -7438,14 +7515,15 @@ mod tests { dockerfile_sources_supported_for_gateway, format_endpoint, format_gateway_select_header, format_gateway_select_items, format_provider_attachment_table, gateway_add, gateway_auth_label, gateway_env_override_warning, gateway_select_with, gateway_type_label, - git_sync_files, http_health_check, image_requests_gpu, import_local_package_mtls_bundle, + git_sync_files, gpu_count_from_cli, gpu_device_ids_from_cli, gpu_driver_config_from_cli, + http_health_check, image_requests_gpu, import_local_package_mtls_bundle, inferred_provider_type, package_managed_tls_dirs, parse_cli_setting_value, parse_credential_expiry_cli_value, parse_credential_expiry_pairs, parse_credential_pairs, plaintext_gateway_is_remote, progress_step_from_metadata, provider_profile_allows_refresh_bootstrap, provisioning_timeout_message, ready_false_condition_message, refresh_status_header, refresh_status_row, resolve_from, - sandbox_should_persist, sandbox_upload_plan, service_expose_status_error, - service_url_for_gateway, + resource_requirements_from_cli, sandbox_should_persist, sandbox_upload_plan, + service_expose_status_error, service_url_for_gateway, }; use crate::TEST_ENV_LOCK; use hyper::StatusCode; @@ -7924,6 +8002,67 @@ mod tests { } } + #[test] + fn gpu_device_ids_from_cli_trims_gpu_device() { + assert_eq!( + gpu_device_ids_from_cli(Some(" nvidia.com/gpu=0 ")), + vec!["nvidia.com/gpu=0".to_string()] + ); + } + + #[test] + fn gpu_device_ids_from_cli_omits_empty_device() { + assert!(gpu_device_ids_from_cli(Some(" ")).is_empty()); + assert!(gpu_device_ids_from_cli(None).is_empty()); + } + + #[test] + fn gpu_count_from_cli_uses_gpu_device_id_count() { + let device_ids = gpu_device_ids_from_cli(Some("nvidia.com/gpu=0")); + + assert_eq!(gpu_count_from_cli(None, &device_ids), Some(1)); + assert_eq!(gpu_count_from_cli(Some(2), &device_ids), Some(1)); + } + + #[test] + fn resource_requirements_from_cli_uses_presence_for_default_gpu() { + let requirements = resource_requirements_from_cli(true, None) + .expect("resource requirements should be present"); + let gpu = requirements.gpu.expect("GPU requirement should be present"); + + assert_eq!(gpu.count, None); + } + + #[test] + fn gpu_driver_config_from_cli_maps_gpu_device_to_driver_blocks() { + let device_ids = gpu_device_ids_from_cli(Some("nvidia.com/gpu=0")); + let config = + gpu_driver_config_from_cli(&device_ids).expect("driver config should be present"); + + assert!(config.fields.contains_key("docker")); + assert!(config.fields.contains_key("podman")); + assert!(config.fields.contains_key("vm")); + } + + #[test] + fn resource_requirements_from_cli_maps_gpu_count() { + let requirements = + resource_requirements_from_cli(true, Some(2)).expect("requirements should exist"); + let gpu = requirements.gpu.expect("GPU requirement should be present"); + + assert_eq!(gpu.count, Some(2)); + } + + #[test] + fn gpu_driver_config_from_cli_omits_empty_device() { + assert!(gpu_driver_config_from_cli(&[]).is_none()); + } + + #[test] + fn resource_requirements_from_cli_omits_gpu_request_when_not_requested() { + assert!(resource_requirements_from_cli(false, None).is_none()); + } + #[test] fn resolve_from_classifies_existing_dockerfile_path() { let temp = tempfile::tempdir().expect("failed to create tempdir"); diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index aee91de56..37a5a682c 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -787,6 +787,7 @@ async fn sandbox_create_keeps_command_sessions_by_default() { None, None, None, + None, &[], None, None, @@ -826,6 +827,7 @@ async fn sandbox_create_sends_cpu_and_memory_limits_only() { true, false, None, + None, Some("500m"), Some("2Gi"), None, @@ -884,6 +886,61 @@ async fn sandbox_create_sends_cpu_and_memory_limits_only() { assert!(!resources.fields.contains_key("requests")); } +#[tokio::test] +async fn sandbox_create_sends_gpu_count_request() { + let server = run_server().await; + let fake_ssh_dir = tempfile::tempdir().unwrap(); + let xdg_dir = tempfile::tempdir().unwrap(); + let _env = test_env(&fake_ssh_dir, &xdg_dir); + let tls = test_tls(&server); + install_fake_ssh(&fake_ssh_dir); + + run::sandbox_create( + &server.endpoint, + Some("gpu-count"), + None, + "openshell", + None, + true, + false, + None, + Some(2), + None, + None, + None, + &[], + None, + None, + &["echo".to_string(), "OK".to_string()], + Some(false), + Some(false), + &HashMap::new(), + "manual", + &tls, + ) + .await + .expect("sandbox create should succeed"); + + let requests = create_requests(&server).await; + let spec = requests[0] + .spec + .as_ref() + .expect("sandbox spec should be sent"); + let gpu = spec + .resource_requirements + .as_ref() + .and_then(|requirements| requirements.gpu.as_ref()) + .expect("GPU request should be sent"); + + assert_eq!(gpu.count, Some(2)); + assert!( + spec.template + .as_ref() + .and_then(|template| template.driver_config.as_ref()) + .is_none() + ); +} + #[tokio::test] async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() { let server = run_server().await; @@ -906,6 +963,7 @@ async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() { None, None, None, + None, &[], None, None, @@ -963,6 +1021,7 @@ async fn sandbox_create_returns_vm_error_without_waiting_for_timeout() { None, None, None, + None, &[], None, None, @@ -1016,6 +1075,7 @@ async fn sandbox_create_keeps_waiting_while_vm_progress_arrives() { None, None, None, + None, &[], None, None, @@ -1061,6 +1121,7 @@ async fn sandbox_create_times_out_when_only_logs_arrive() { None, None, None, + None, &[], None, None, @@ -1102,6 +1163,7 @@ async fn sandbox_create_deletes_command_sessions_with_no_keep() { None, None, None, + None, &[], None, None, @@ -1147,6 +1209,7 @@ async fn sandbox_create_deletes_shell_sessions_with_no_keep() { None, None, None, + None, &[], None, None, @@ -1192,6 +1255,7 @@ async fn sandbox_create_keeps_sandbox_with_hidden_keep_flag() { None, None, None, + None, &[], None, None, @@ -1237,6 +1301,7 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { None, None, None, + None, &[], None, Some(openshell_core::forward::ForwardSpec::new(forward_port)), diff --git a/crates/openshell-core/src/gpu.rs b/crates/openshell-core/src/gpu.rs index 5df8702ed..9f5e24adf 100644 --- a/crates/openshell-core/src/gpu.rs +++ b/crates/openshell-core/src/gpu.rs @@ -4,21 +4,66 @@ //! Shared GPU request helpers. use crate::config::CDI_GPU_DEVICE_ALL; +use crate::proto::compute::v1::DriverGpuResourceRequirement; +use std::collections::HashSet; -/// Resolve the existing GPU request fields into CDI device identifiers. +/// Resolve a driver GPU request into CDI device identifiers. /// -/// `None` means no GPU was requested. A GPU request with no explicit device -/// ID uses the CDI all-GPU request; otherwise the driver-native ID passes -/// through unchanged. +/// `None` means no GPU was requested. Presence with a positive count and +/// explicit device IDs passes those IDs through. Other present GPU requests use +/// the CDI all-GPU request. #[must_use] -pub fn cdi_gpu_device_ids(gpu: bool, gpu_device: &str) -> Option> { - gpu.then(|| { - if gpu_device.is_empty() { - vec![CDI_GPU_DEVICE_ALL.to_string()] - } else { - vec![gpu_device.to_string()] +pub fn cdi_gpu_device_ids( + gpu: Option<&DriverGpuResourceRequirement>, + driver_config_device_ids: &[String], +) -> Option> { + match gpu { + Some(gpu) + if gpu.count.is_some_and(|count| count > 0) && !driver_config_device_ids.is_empty() => + { + Some(driver_config_device_ids.to_vec()) } - }) + Some(_) if driver_config_device_ids.is_empty() => { + Some(vec![CDI_GPU_DEVICE_ALL.to_string()]) + } + Some(_) => Some(vec![CDI_GPU_DEVICE_ALL.to_string()]), + None => None, + } +} + +/// Validate that explicit driver GPU device IDs line up with the portable GPU count. +pub fn validate_gpu_device_ids_count( + gpu: Option<&DriverGpuResourceRequirement>, + gpu_device_ids: &[String], +) -> Result<(), String> { + if gpu_device_ids.is_empty() { + return Ok(()); + } + + let Some(count) = gpu.and_then(|gpu| gpu.count) else { + return Err( + "template.driver_config.gpu_device_ids requires resource_requirements.gpu.count" + .to_string(), + ); + }; + if count == 0 { + return Err("resource_requirements.gpu.count must be greater than 0".to_string()); + } + + let unique = gpu_device_ids.iter().collect::>().len(); + if unique != gpu_device_ids.len() { + return Err( + "template.driver_config.gpu_device_ids must not contain duplicates".to_string(), + ); + } + if unique != count as usize { + return Err( + "template.driver_config.gpu_device_ids unique entry count must equal resource_requirements.gpu.count" + .to_string(), + ); + } + + Ok(()) } #[cfg(test)] @@ -27,22 +72,112 @@ mod tests { #[test] fn cdi_gpu_device_ids_returns_none_when_absent() { - assert_eq!(cdi_gpu_device_ids(false, ""), None); + assert_eq!(cdi_gpu_device_ids(None, &[]), None); } #[test] fn cdi_gpu_device_ids_defaults_empty_request_to_all_gpus() { + let request = DriverGpuResourceRequirement { count: None }; + assert_eq!( - cdi_gpu_device_ids(true, ""), + cdi_gpu_device_ids(Some(&request), &[]), Some(vec![CDI_GPU_DEVICE_ALL.to_string()]) ); } #[test] - fn cdi_gpu_device_ids_passes_explicit_device_id_through() { + fn cdi_gpu_device_ids_passes_single_device_id_through() { + let request = DriverGpuResourceRequirement { count: Some(1) }; + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + assert_eq!( - cdi_gpu_device_ids(true, "nvidia.com/gpu=0"), + cdi_gpu_device_ids(Some(&request), &device_ids), Some(vec!["nvidia.com/gpu=0".to_string()]) ); } + + #[test] + fn cdi_gpu_device_ids_passes_multiple_device_ids_through() { + let request = DriverGpuResourceRequirement { count: Some(2) }; + let device_ids = vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ]; + + assert_eq!( + cdi_gpu_device_ids(Some(&request), &device_ids), + Some(vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string() + ]) + ); + } + + #[test] + fn cdi_gpu_device_ids_ignores_device_ids_without_count() { + let request = DriverGpuResourceRequirement { count: None }; + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + + assert_eq!( + cdi_gpu_device_ids(Some(&request), &device_ids), + Some(vec![CDI_GPU_DEVICE_ALL.to_string()]) + ); + } + + #[test] + fn cdi_gpu_device_ids_ignores_device_ids_with_zero_count() { + let request = DriverGpuResourceRequirement { count: Some(0) }; + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + + assert_eq!( + cdi_gpu_device_ids(Some(&request), &device_ids), + Some(vec![CDI_GPU_DEVICE_ALL.to_string()]) + ); + } + + #[test] + fn validate_gpu_device_ids_count_requires_gpu_count() { + let request = DriverGpuResourceRequirement { count: None }; + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + + assert!(validate_gpu_device_ids_count(Some(&request), &device_ids).is_err()); + } + + #[test] + fn validate_gpu_device_ids_count_rejects_zero_count() { + let request = DriverGpuResourceRequirement { count: Some(0) }; + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + + assert!(validate_gpu_device_ids_count(Some(&request), &device_ids).is_err()); + } + + #[test] + fn validate_gpu_device_ids_count_accepts_matching_unique_ids() { + let request = DriverGpuResourceRequirement { count: Some(2) }; + let device_ids = vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ]; + + validate_gpu_device_ids_count(Some(&request), &device_ids).unwrap(); + } + + #[test] + fn validate_gpu_device_ids_count_rejects_duplicate_ids() { + let request = DriverGpuResourceRequirement { count: Some(1) }; + let device_ids = vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=0".to_string(), + ]; + + assert!(validate_gpu_device_ids_count(Some(&request), &device_ids).is_err()); + } + + #[test] + fn validate_gpu_device_ids_count_rejects_count_mismatch() { + let request = DriverGpuResourceRequirement { count: Some(2) }; + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + + assert!(validate_gpu_device_ids_count(Some(&request), &device_ids).is_err()); + } } diff --git a/crates/openshell-driver-docker/Cargo.toml b/crates/openshell-driver-docker/Cargo.toml index 0cdc205ed..1f5eabaca 100644 --- a/crates/openshell-driver-docker/Cargo.toml +++ b/crates/openshell-driver-docker/Cargo.toml @@ -19,6 +19,7 @@ futures = { workspace = true } tokio-stream = { workspace = true } tracing = { workspace = true } bytes = { workspace = true } +prost-types = { workspace = true } serde = { workspace = true } bollard = { version = "0.20" } tar = "0.4" diff --git a/crates/openshell-driver-docker/README.md b/crates/openshell-driver-docker/README.md index ea57f44e4..c20658d07 100644 --- a/crates/openshell-driver-docker/README.md +++ b/crates/openshell-driver-docker/README.md @@ -32,7 +32,7 @@ contract: | `apparmor=unconfined` | Avoids Docker's default profile blocking required mount operations. | | `restart_policy = unless-stopped` | Keeps managed sandboxes resumable across daemon or gateway restarts. | | `PidsLimit` | Enforces the sandbox PID budget at the Docker cgroup layer. Set `[openshell.drivers.docker].sandbox_pids_limit = 0` to inherit the Docker/runtime default. | -| CDI GPU request | Uses the sandbox `gpu_device` value when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. | +| CDI GPU request | Uses explicit `template.driver_config.gpu_device_ids` when set and its unique entry count equals `resource_requirements.gpu.count`; otherwise requests all NVIDIA GPUs when `resource_requirements.gpu` is present and daemon CDI support is detected. Count-only GPU requests are rejected until Docker CDI selection can map counts to concrete devices. | The agent child process does not retain these supervisor privileges. diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index e30ee7754..ff5159491 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -25,20 +25,20 @@ use openshell_core::driver_utils::{ LABEL_MANAGED_BY, LABEL_MANAGED_BY_VALUE, LABEL_SANDBOX_ID, LABEL_SANDBOX_NAME, LABEL_SANDBOX_NAMESPACE, SUPERVISOR_IMAGE_BINARY_PATH, }; -use openshell_core::gpu::cdi_gpu_device_ids; +use openshell_core::gpu::{cdi_gpu_device_ids, validate_gpu_device_ids_count}; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, }; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, - DriverCondition, DriverPlatformEvent, DriverSandbox, DriverSandboxStatus, - DriverSandboxTemplate, GetCapabilitiesRequest, GetCapabilitiesResponse, GetSandboxRequest, - GetSandboxResponse, ListSandboxesRequest, ListSandboxesResponse, StopSandboxRequest, - StopSandboxResponse, ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, - WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, - WatchSandboxesRequest, WatchSandboxesSandboxEvent, compute_driver_server::ComputeDriver, - watch_sandboxes_event, + DriverCondition, DriverGpuResourceRequirement, DriverPlatformEvent, DriverSandbox, + DriverSandboxStatus, DriverSandboxTemplate, GetCapabilitiesRequest, GetCapabilitiesResponse, + GetSandboxRequest, GetSandboxResponse, ListSandboxesRequest, ListSandboxesResponse, + StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, + ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, + WatchSandboxesPlatformEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent, + compute_driver_server::ComputeDriver, watch_sandboxes_event, }; use openshell_core::{Config, Error, Result as CoreResult}; use std::collections::HashMap; @@ -375,7 +375,14 @@ impl DockerComputeDriver { "docker sandboxes require a template image", )); } - Self::validate_gpu_request(spec.gpu, config.supports_gpu)?; + let gpu_device_ids = + docker_gpu_device_ids_from_driver_config(template.driver_config.as_ref()) + .map_err(Status::invalid_argument)?; + Self::validate_gpu_request( + driver_gpu_requirement(spec), + &gpu_device_ids, + config.supports_gpu, + )?; if !template.agent_socket_path.trim().is_empty() { return Err(Status::failed_precondition( "docker compute driver does not support template.agent_socket_path", @@ -409,8 +416,31 @@ impl DockerComputeDriver { )) } - fn validate_gpu_request(gpu: bool, supports_gpu: bool) -> Result<(), Status> { - if gpu && !supports_gpu { + fn validate_gpu_request( + gpu: Option<&DriverGpuResourceRequirement>, + gpu_device_ids: &[String], + supports_gpu: bool, + ) -> Result<(), Status> { + if gpu.is_none() && !gpu_device_ids.is_empty() { + return Err(Status::invalid_argument( + "template.driver_config.gpu_device_ids requires resource_requirements.gpu.count", + )); + } + if let Some(gpu) = gpu + && gpu.count == Some(0) + { + return Err(Status::invalid_argument( + "resource_requirements.gpu.count must be greater than 0", + )); + } + if !gpu_device_ids.is_empty() { + validate_gpu_device_ids_count(gpu, gpu_device_ids).map_err(Status::invalid_argument)?; + } else if gpu.is_some_and(|gpu| gpu.count.is_some()) { + return Err(Status::invalid_argument( + "docker compute driver does not support GPU count requests", + )); + } + if gpu.is_some() && !supports_gpu { return Err(Status::failed_precondition( "docker GPU sandboxes require Docker CDI support. Enable CDI on the Docker daemon, then restart the OpenShell gateway/server so GPU capability is detected.", )); @@ -1713,8 +1743,52 @@ fn build_environment(sandbox: &DriverSandbox, config: &DockerDriverRuntimeConfig .collect() } -fn docker_gpu_device_requests(gpu: bool, gpu_device: &str) -> Option> { - cdi_gpu_device_ids(gpu, gpu_device).map(|device_ids| { +fn driver_gpu_requirement( + spec: &openshell_core::proto::compute::v1::DriverSandboxSpec, +) -> Option<&DriverGpuResourceRequirement> { + spec.resource_requirements + .as_ref() + .and_then(|requirements| requirements.gpu.as_ref()) +} + +fn docker_gpu_device_ids_from_driver_config( + driver_config: Option<&prost_types::Struct>, +) -> Result, String> { + use prost_types::value::Kind; + + let Some(config) = driver_config else { + return Ok(Vec::new()); + }; + if config.fields.is_empty() { + return Ok(Vec::new()); + } + + let Some(value) = config.fields.get("gpu_device_ids") else { + return Ok(Vec::new()); + }; + let Some(Kind::ListValue(list)) = value.kind.as_ref() else { + return Err("driver_config.gpu_device_ids must be a list of strings".to_string()); + }; + + list.values + .iter() + .enumerate() + .map(|(idx, value)| match value.kind.as_ref() { + Some(Kind::StringValue(device_id)) if !device_id.trim().is_empty() => { + Ok(device_id.clone()) + } + _ => Err(format!( + "driver_config.gpu_device_ids[{idx}] must be a non-empty string" + )), + }) + .collect() +} + +fn docker_gpu_device_requests( + gpu: Option<&DriverGpuResourceRequirement>, + gpu_device_ids: &[String], +) -> Option> { + cdi_gpu_device_ids(gpu, gpu_device_ids).map(|device_ids| { vec![DeviceRequest { driver: Some("cdi".to_string()), device_ids: Some(device_ids), @@ -1736,6 +1810,8 @@ fn build_container_create_body( .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox.spec.template is required"))?; let resource_limits = docker_resource_limits(template)?; + let gpu_device_ids = docker_gpu_device_ids_from_driver_config(template.driver_config.as_ref()) + .map_err(Status::invalid_argument)?; let mut labels = template.labels.clone(); labels.insert( LABEL_MANAGED_BY.to_string(), @@ -1765,7 +1841,10 @@ fn build_container_create_body( nano_cpus: resource_limits.nano_cpus, memory: resource_limits.memory_bytes, pids_limit: docker_pids_limit(config.sandbox_pids_limit)?, - device_requests: docker_gpu_device_requests(spec.gpu, &spec.gpu_device), + device_requests: docker_gpu_device_requests( + driver_gpu_requirement(spec), + &gpu_device_ids, + ), binds: Some(build_binds(sandbox, config)?), restart_policy: Some(RestartPolicy { name: Some(RestartPolicyNameEnum::UNLESS_STOPPED), diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index c9b34ff8f..449b9476f 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -13,7 +13,8 @@ use openshell_core::progress::{ PROGRESS_STEP_STARTING_SANDBOX, }; use openshell_core::proto::compute::v1::{ - DriverResourceRequirements, DriverSandboxSpec, DriverSandboxTemplate, + DriverGpuResourceRequirement, DriverResourceRequirements, DriverSandboxResourceRequirements, + DriverSandboxSpec, DriverSandboxTemplate, }; use std::fs; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; @@ -23,6 +24,66 @@ use tempfile::TempDir; const TLS_MOUNT_DIR: &str = "/etc/openshell/tls/client"; static ENV_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); +fn gpu_resource_requirements(count: Option) -> DriverSandboxResourceRequirements { + DriverSandboxResourceRequirements { + gpu: Some(DriverGpuResourceRequirement { count }), + } +} + +fn gpu_device_ids_driver_config(device_ids: &[&str]) -> prost_types::Struct { + use prost_types::{ListValue, Struct, Value, value::Kind}; + + Struct { + fields: std::iter::once(( + "gpu_device_ids".to_string(), + Value { + kind: Some(Kind::ListValue(ListValue { + values: device_ids + .iter() + .map(|device_id| Value { + kind: Some(Kind::StringValue((*device_id).to_string())), + }) + .collect(), + })), + }, + )) + .collect(), + } +} + +#[test] +fn docker_gpu_device_ids_from_driver_config_ignores_unrelated_fields() { + use prost_types::{ListValue, Struct, Value, value::Kind}; + + let config = Struct { + fields: [ + ( + "gpu_device_ids".to_string(), + Value { + kind: Some(Kind::ListValue(ListValue { + values: vec![Value { + kind: Some(Kind::StringValue("nvidia.com/gpu=0".to_string())), + }], + })), + }, + ), + ( + "future_field".to_string(), + Value { + kind: Some(Kind::StringValue("ignored".to_string())), + }, + ), + ] + .into_iter() + .collect(), + }; + + assert_eq!( + docker_gpu_device_ids_from_driver_config(Some(&config)).unwrap(), + vec!["nvidia.com/gpu=0".to_string()] + ); +} + fn test_sandbox() -> DriverSandbox { // Mirrors the gateway-supplied request: the public `Sandbox` API no // longer carries `namespace`, so the gateway elides the field and the @@ -41,10 +102,10 @@ fn test_sandbox() -> DriverSandbox { environment: HashMap::from([("TEMPLATE_ENV".to_string(), "template".to_string())]), resources: None, platform_config: None, + driver_config: None, }), - gpu: false, - gpu_device: String::new(), sandbox_token: String::new(), + resource_requirements: None, }), status: None, } @@ -392,6 +453,7 @@ fn docker_resource_limits_rejects_requests() { memory_limit: String::new(), }), platform_config: None, + driver_config: None, }; let err = docker_resource_limits(&template).unwrap_err(); @@ -412,6 +474,7 @@ fn docker_resource_limits_applies_cpu_and_memory_limits() { ..Default::default() }), platform_config: None, + driver_config: None, }; let limits = docker_resource_limits(&template).unwrap(); @@ -605,7 +668,7 @@ fn build_container_create_body_clears_inherited_cmd() { fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { let config = runtime_config(); let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = true; + sandbox.spec.as_mut().unwrap().resource_requirements = Some(gpu_resource_requirements(None)); let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); @@ -613,6 +676,85 @@ fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { assert!(err.message().contains("Docker CDI")); } +#[test] +fn validate_sandbox_rejects_gpu_count() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + sandbox.spec.as_mut().unwrap().resource_requirements = Some(gpu_resource_requirements(Some(2))); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("does not support GPU count")); +} + +#[test] +fn validate_sandbox_accepts_gpu_count_with_matching_device_ids() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.resource_requirements = Some(gpu_resource_requirements(Some(2))); + spec.template.as_mut().unwrap().driver_config = Some(gpu_device_ids_driver_config(&[ + "nvidia.com/gpu=0", + "nvidia.com/gpu=1", + ])); + + DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap(); +} + +#[test] +fn validate_sandbox_rejects_gpu_device_ids_without_count() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.resource_requirements = Some(gpu_resource_requirements(None)); + spec.template.as_mut().unwrap().driver_config = + Some(gpu_device_ids_driver_config(&["nvidia.com/gpu=0"])); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!( + err.message() + .contains("requires resource_requirements.gpu.count") + ); +} + +#[test] +fn validate_sandbox_rejects_gpu_device_ids_with_zero_count() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.resource_requirements = Some(gpu_resource_requirements(Some(0))); + spec.template.as_mut().unwrap().driver_config = + Some(gpu_device_ids_driver_config(&["nvidia.com/gpu=0"])); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("must be greater than 0")); +} + +#[test] +fn validate_sandbox_rejects_gpu_device_ids_count_mismatch() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.resource_requirements = Some(gpu_resource_requirements(Some(2))); + spec.template.as_mut().unwrap().driver_config = + Some(gpu_device_ids_driver_config(&["nvidia.com/gpu=0"])); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("unique entry count")); +} + #[test] fn validate_sandbox_auth_requires_gateway_token() { let mut sandbox = test_sandbox(); @@ -640,7 +782,7 @@ fn build_container_create_body_maps_gpu_to_all_cdi_device() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = true; + sandbox.spec.as_mut().unwrap().resource_requirements = Some(gpu_resource_requirements(None)); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body @@ -658,13 +800,16 @@ fn build_container_create_body_maps_gpu_to_all_cdi_device() { } #[test] -fn build_container_create_body_passes_explicit_cdi_device_id_through() { +fn build_container_create_body_passes_explicit_gpu_device_ids_through() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); let spec = sandbox.spec.as_mut().unwrap(); - spec.gpu = true; - spec.gpu_device = "nvidia.com/gpu=0".to_string(); + spec.resource_requirements = Some(gpu_resource_requirements(Some(2))); + spec.template.as_mut().unwrap().driver_config = Some(gpu_device_ids_driver_config(&[ + "nvidia.com/gpu=0", + "nvidia.com/gpu=1", + ])); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body @@ -677,7 +822,10 @@ fn build_container_create_body_passes_explicit_cdi_device_id_through() { assert_eq!(request.driver.as_deref(), Some("cdi")); assert_eq!( request.device_ids.as_ref().unwrap(), - &vec!["nvidia.com/gpu=0".to_string()] + &vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string() + ] ); } diff --git a/crates/openshell-driver-kubernetes/README.md b/crates/openshell-driver-kubernetes/README.md index 1d45a1d83..0ddbc9e2d 100644 --- a/crates/openshell-driver-kubernetes/README.md +++ b/crates/openshell-driver-kubernetes/README.md @@ -49,7 +49,7 @@ pods do not need direct external ingress for SSH. ## GPU Support -When a sandbox requests GPU support, the driver checks node allocatable capacity -for `nvidia.com/gpu` and requests one GPU resource in the workload spec. The -sandbox image must provide the user-space libraries needed by the agent -workload. +When `resource_requirements.gpu` is present, the driver checks node allocatable +capacity for `nvidia.com/gpu` and sets the workload's `nvidia.com/gpu` resource +limit. Requests without an explicit count use one GPU. The sandbox image must +provide the user-space libraries needed by the agent workload. diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index 5a43eb980..87e658fe8 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -22,11 +22,12 @@ use openshell_core::progress::{ format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, }; use openshell_core::proto::compute::v1::{ - DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, - DriverSandbox as Sandbox, DriverSandboxSpec as SandboxSpec, - DriverSandboxStatus as SandboxStatus, DriverSandboxTemplate as SandboxTemplate, - GetCapabilitiesResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, - WatchSandboxesPlatformEvent, WatchSandboxesSandboxEvent, watch_sandboxes_event, + DriverCondition as SandboxCondition, DriverGpuResourceRequirement, + DriverPlatformEvent as PlatformEvent, DriverSandbox as Sandbox, + DriverSandboxSpec as SandboxSpec, DriverSandboxStatus as SandboxStatus, + DriverSandboxTemplate as SandboxTemplate, GetCapabilitiesResponse, WatchSandboxesDeletedEvent, + WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesSandboxEvent, + watch_sandboxes_event, }; use std::collections::BTreeMap; use std::pin::Pin; @@ -77,7 +78,13 @@ const SANDBOX_VERSION: &str = "v1alpha1"; pub const SANDBOX_KIND: &str = "Sandbox"; const GPU_RESOURCE_NAME: &str = "nvidia.com/gpu"; -const GPU_RESOURCE_QUANTITY: &str = "1"; +const DEFAULT_GPU_COUNT: u32 = 1; + +fn driver_gpu_requirement(spec: &SandboxSpec) -> Option<&DriverGpuResourceRequirement> { + spec.resource_requirements + .as_ref() + .and_then(|requirements| requirements.gpu.as_ref()) +} // --------------------------------------------------------------------------- // Default workspace persistence (temporary — will be replaced by snapshotting) @@ -203,8 +210,15 @@ impl KubernetesComputeDriver { } pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), tonic::Status> { - let gpu_requested = sandbox.spec.as_ref().is_some_and(|spec| spec.gpu); - if gpu_requested + let gpu = sandbox.spec.as_ref().and_then(driver_gpu_requirement); + self.validate_gpu_request(gpu).await + } + + async fn validate_gpu_request( + &self, + gpu: Option<&DriverGpuResourceRequirement>, + ) -> Result<(), tonic::Status> { + if gpu.is_some() && !self.has_gpu_capacity().await.map_err(|err| { tonic::Status::internal(format!("check GPU node capacity failed: {err}")) })? @@ -1102,7 +1116,13 @@ fn sandbox_to_k8s_spec( if let Some(template) = spec.template.as_ref() { root.insert( "podTemplate".to_string(), - sandbox_template_to_k8s(template, spec.gpu, &pod_env, inject_workspace, params), + sandbox_template_to_k8s( + template, + driver_gpu_requirement(spec), + &pod_env, + inject_workspace, + params, + ), ); if !template.agent_socket_path.is_empty() { root.insert( @@ -1134,7 +1154,7 @@ fn sandbox_to_k8s_spec( "podTemplate".to_string(), sandbox_template_to_k8s( &SandboxTemplate::default(), - spec.is_some_and(|s| s.gpu), + spec.and_then(driver_gpu_requirement), &pod_env, inject_workspace, params, @@ -1149,7 +1169,7 @@ fn sandbox_to_k8s_spec( fn sandbox_template_to_k8s( template: &SandboxTemplate, - gpu: bool, + gpu: Option<&DriverGpuResourceRequirement>, spec_environment: &std::collections::HashMap, inject_workspace: bool, params: &SandboxPodParams<'_>, @@ -1203,7 +1223,7 @@ fn sandbox_template_to_k8s( if use_user_namespaces { spec.insert("hostUsers".to_string(), serde_json::json!(false)); - if gpu { + if gpu.is_some() { warn!( "GPU sandbox with user namespaces enabled — \ NVIDIA device plugin compatibility is unverified" @@ -1384,7 +1404,10 @@ fn image_pull_secret_refs(secrets: &[String]) -> Vec { .collect() } -fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option { +fn container_resources( + template: &SandboxTemplate, + gpu: Option<&DriverGpuResourceRequirement>, +) -> Option { // Start from the raw resources passthrough in platform_config (preserves // custom resource types like GPU limits that users set via the public API // Struct), then overlay the typed DriverResourceRequirements on top. @@ -1417,8 +1440,8 @@ fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option Option> = std::sync::LazyLock::new(|| std::sync::Mutex::new(())); + fn gpu_request(count: Option) -> DriverGpuResourceRequirement { + DriverGpuResourceRequirement { count } + } + #[test] fn kube_pulling_event_adds_image_progress_metadata() { let mut metadata = std::collections::HashMap::new(); @@ -1994,7 +2021,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &SandboxTemplate::default(), - true, + Some(&gpu_request(None)), &std::collections::HashMap::new(), true, ¶ms, @@ -2007,7 +2034,26 @@ mod tests { ); assert_eq!( pod_template["spec"]["containers"][0]["resources"]["limits"][GPU_RESOURCE_NAME], - serde_json::json!(GPU_RESOURCE_QUANTITY) + serde_json::json!(DEFAULT_GPU_COUNT.to_string()) + ); + } + + #[test] + fn gpu_sandbox_uses_requested_gpu_count() { + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &SandboxTemplate::default(), + Some(&gpu_request(Some(2))), + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; + + assert_eq!( + pod_template["spec"]["containers"][0]["resources"]["limits"][GPU_RESOURCE_NAME], + serde_json::json!("2") ); } @@ -2030,7 +2076,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - true, + Some(&gpu_request(None)), &std::collections::HashMap::new(), true, ¶ms, @@ -2062,7 +2108,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2090,7 +2136,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - true, + Some(&gpu_request(None)), &std::collections::HashMap::new(), true, ¶ms, @@ -2101,7 +2147,7 @@ mod tests { assert_eq!(limits["cpu"], serde_json::json!("2")); assert_eq!( limits[GPU_RESOURCE_NAME], - serde_json::json!(GPU_RESOURCE_QUANTITY) + serde_json::json!(DEFAULT_GPU_COUNT.to_string()) ); } @@ -2121,7 +2167,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2144,7 +2190,7 @@ mod tests { }; sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2169,7 +2215,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2192,7 +2238,7 @@ mod tests { }; sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2331,7 +2377,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), false, // user provided custom VCTs ¶ms, @@ -2369,7 +2415,7 @@ mod tests { }; sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2434,7 +2480,7 @@ mod tests { let params = SandboxPodParams::default(); // cluster default is off let pod_template = sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2472,7 +2518,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2498,7 +2544,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2520,7 +2566,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2542,7 +2588,7 @@ mod tests { fn sandbox_template_omits_empty_image_pull_secrets() { let pod_template = sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, &SandboxPodParams::default(), @@ -2567,7 +2613,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2596,7 +2642,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2724,7 +2770,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), false, ¶ms, @@ -2785,7 +2831,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), false, ¶ms, diff --git a/crates/openshell-driver-podman/Cargo.toml b/crates/openshell-driver-podman/Cargo.toml index 6f2963d92..0ccff99f6 100644 --- a/crates/openshell-driver-podman/Cargo.toml +++ b/crates/openshell-driver-podman/Cargo.toml @@ -24,6 +24,7 @@ tokio-stream = { workspace = true } hyper = { workspace = true } hyper-util = { workspace = true } http-body-util = { workspace = true } +prost-types = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } clap = { workspace = true } diff --git a/crates/openshell-driver-podman/README.md b/crates/openshell-driver-podman/README.md index 77b42ba37..7bd8e58e2 100644 --- a/crates/openshell-driver-podman/README.md +++ b/crates/openshell-driver-podman/README.md @@ -46,7 +46,7 @@ The container spec in `container.rs` sets these security-critical fields: | `no_new_privileges` | `true` | Prevents privilege escalation after exec. | | `seccomp_profile_path` | `unconfined` | The supervisor installs its own policy-aware BPF filter. A container-level profile can block Landlock/seccomp syscalls during setup. | | `mounts` | Private tmpfs at `/run/netns` | Lets the supervisor create named network namespaces in rootless Podman. | -| CDI GPU devices | Sandbox `gpu_device` value when set, otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. | +| CDI GPU devices | Explicit `template.driver_config.gpu_device_ids` when set and its unique entry count equals `resource_requirements.gpu.count`; otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. Count-only GPU requests are rejected until Podman CDI selection can map counts to concrete devices. | The restricted agent child does not retain these supervisor privileges. diff --git a/crates/openshell-driver-podman/src/container.rs b/crates/openshell-driver-podman/src/container.rs index 13f053e93..af47082af 100644 --- a/crates/openshell-driver-podman/src/container.rs +++ b/crates/openshell-driver-podman/src/container.rs @@ -379,8 +379,19 @@ fn podman_pids_limit(value: i64) -> Option { /// Build CDI GPU device list if GPU is requested. fn build_devices(sandbox: &DriverSandbox) -> Option> { - let spec = sandbox.spec.as_ref()?; - cdi_gpu_device_ids(spec.gpu, &spec.gpu_device).map(|device_ids| { + let spec = sandbox.spec.as_ref(); + let gpu = spec.and_then(|spec| { + spec.resource_requirements + .as_ref() + .and_then(|requirements| requirements.gpu.as_ref()) + }); + let gpu_device_ids = spec + .and_then(|spec| spec.template.as_ref()) + .and_then(|template| { + gpu_device_ids_from_driver_config(template.driver_config.as_ref()).ok() + }) + .unwrap_or_default(); + cdi_gpu_device_ids(gpu, &gpu_device_ids).map(|device_ids| { device_ids .into_iter() .map(|path| LinuxDevice { path }) @@ -388,6 +399,39 @@ fn build_devices(sandbox: &DriverSandbox) -> Option> { }) } +pub fn gpu_device_ids_from_driver_config( + driver_config: Option<&prost_types::Struct>, +) -> Result, String> { + use prost_types::value::Kind; + + let Some(config) = driver_config else { + return Ok(Vec::new()); + }; + if config.fields.is_empty() { + return Ok(Vec::new()); + } + + let Some(value) = config.fields.get("gpu_device_ids") else { + return Ok(Vec::new()); + }; + let Some(Kind::ListValue(list)) = value.kind.as_ref() else { + return Err("driver_config.gpu_device_ids must be a list of strings".to_string()); + }; + + list.values + .iter() + .enumerate() + .map(|(idx, value)| match value.kind.as_ref() { + Some(Kind::StringValue(device_id)) if !device_id.trim().is_empty() => { + Ok(device_id.clone()) + } + _ => Err(format!( + "driver_config.gpu_device_ids[{idx}] must be a non-empty string" + )), + }) + .collect() +} + /// Build the Podman container creation JSON spec. #[cfg(test)] #[must_use] @@ -699,6 +743,60 @@ mod tests { static ENV_LOCK: std::sync::LazyLock> = std::sync::LazyLock::new(|| std::sync::Mutex::new(())); + fn gpu_device_ids_driver_config(device_ids: &[&str]) -> prost_types::Struct { + use prost_types::{ListValue, Struct, Value, value::Kind}; + + Struct { + fields: std::iter::once(( + "gpu_device_ids".to_string(), + Value { + kind: Some(Kind::ListValue(ListValue { + values: device_ids + .iter() + .map(|device_id| Value { + kind: Some(Kind::StringValue((*device_id).to_string())), + }) + .collect(), + })), + }, + )) + .collect(), + } + } + + #[test] + fn gpu_device_ids_from_driver_config_ignores_unrelated_fields() { + use prost_types::{ListValue, Struct, Value, value::Kind}; + + let config = Struct { + fields: [ + ( + "gpu_device_ids".to_string(), + Value { + kind: Some(Kind::ListValue(ListValue { + values: vec![Value { + kind: Some(Kind::StringValue("nvidia.com/gpu=0".to_string())), + }], + })), + }, + ), + ( + "future_field".to_string(), + Value { + kind: Some(Kind::StringValue("ignored".to_string())), + }, + ), + ] + .into_iter() + .collect(), + }; + + assert_eq!( + gpu_device_ids_from_driver_config(Some(&config)).unwrap(), + vec!["nvidia.com/gpu=0".to_string()] + ); + } + #[test] fn parse_cpu_millicore() { assert_eq!(parse_cpu_to_microseconds("500m"), Some(50_000)); @@ -808,11 +906,15 @@ mod tests { #[test] fn container_spec_maps_empty_gpu_request_to_all_cdi_device() { use openshell_core::config::CDI_GPU_DEVICE_ALL; - use openshell_core::proto::compute::v1::DriverSandboxSpec; + use openshell_core::proto::compute::v1::{ + DriverGpuResourceRequirement, DriverSandboxResourceRequirements, DriverSandboxSpec, + }; let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { - gpu: true, + resource_requirements: Some(DriverSandboxResourceRequirements { + gpu: Some(DriverGpuResourceRequirement { count: None }), + }), ..Default::default() }); let config = test_config(); @@ -826,12 +928,20 @@ mod tests { #[test] fn container_spec_passes_explicit_cdi_device_id_through() { - use openshell_core::proto::compute::v1::DriverSandboxSpec; + use openshell_core::proto::compute::v1::{ + DriverGpuResourceRequirement, DriverSandboxResourceRequirements, DriverSandboxSpec, + DriverSandboxTemplate, + }; let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { - gpu: true, - gpu_device: "nvidia.com/gpu=0".to_string(), + template: Some(DriverSandboxTemplate { + driver_config: Some(gpu_device_ids_driver_config(&["nvidia.com/gpu=0"])), + ..Default::default() + }), + resource_requirements: Some(DriverSandboxResourceRequirements { + gpu: Some(DriverGpuResourceRequirement { count: Some(1) }), + }), ..Default::default() }); let config = test_config(); diff --git a/crates/openshell-driver-podman/src/driver.rs b/crates/openshell-driver-podman/src/driver.rs index e2deb1c63..1f6dcf3e5 100644 --- a/crates/openshell-driver-podman/src/driver.rs +++ b/crates/openshell-driver-podman/src/driver.rs @@ -10,7 +10,10 @@ use crate::watcher::{ self, WatchStream, driver_sandbox_from_inspect, driver_sandbox_from_list_entry, }; use openshell_core::ComputeDriverError; -use openshell_core::proto::compute::v1::{DriverSandbox, GetCapabilitiesResponse}; +use openshell_core::gpu::validate_gpu_device_ids_count; +use openshell_core::proto::compute::v1::{ + DriverGpuResourceRequirement, DriverSandbox, GetCapabilitiesResponse, +}; use std::path::PathBuf; use std::time::Duration; use tracing::{info, warn}; @@ -280,12 +283,45 @@ impl PodmanComputeDriver { &self, sandbox: &DriverSandbox, ) -> Result<(), ComputeDriverError> { - let gpu_requested = sandbox.spec.as_ref().is_some_and(|s| s.gpu); - Self::validate_gpu_request(gpu_requested) - } - - fn validate_gpu_request(gpu_requested: bool) -> Result<(), ComputeDriverError> { - if gpu_requested && !Self::has_gpu_capacity() { + let spec = sandbox.spec.as_ref(); + let gpu = spec.and_then(driver_gpu_requirement); + let gpu_device_ids = spec + .and_then(|spec| spec.template.as_ref()) + .map(|template| { + container::gpu_device_ids_from_driver_config(template.driver_config.as_ref()) + .map_err(ComputeDriverError::Precondition) + }) + .transpose()? + .unwrap_or_default(); + Self::validate_gpu_request(gpu, &gpu_device_ids) + } + + fn validate_gpu_request( + gpu: Option<&DriverGpuResourceRequirement>, + gpu_device_ids: &[String], + ) -> Result<(), ComputeDriverError> { + if gpu.is_none() && !gpu_device_ids.is_empty() { + return Err(ComputeDriverError::Precondition( + "template.driver_config.gpu_device_ids requires resource_requirements.gpu.count" + .to_string(), + )); + } + if let Some(gpu) = gpu + && gpu.count == Some(0) + { + return Err(ComputeDriverError::Precondition( + "resource_requirements.gpu.count must be greater than 0".to_string(), + )); + } + if !gpu_device_ids.is_empty() { + validate_gpu_device_ids_count(gpu, gpu_device_ids) + .map_err(ComputeDriverError::Precondition)?; + } else if gpu.is_some_and(|gpu| gpu.count.is_some()) { + return Err(ComputeDriverError::Precondition( + "podman compute driver does not support GPU count requests".to_string(), + )); + } + if gpu.is_some() && !Self::has_gpu_capacity() { return Err(ComputeDriverError::Precondition( "GPU sandbox requested, but no NVIDIA GPU devices are available.".to_string(), )); @@ -305,6 +341,7 @@ impl PodmanComputeDriver { "sandbox id is required".into(), )); } + self.validate_sandbox_create(sandbox)?; // Validate the composed container name early, before creating any // resources (volume), so we don't leave orphans when the name is @@ -572,6 +609,14 @@ impl PodmanComputeDriver { } } +fn driver_gpu_requirement( + spec: &openshell_core::proto::compute::v1::DriverSandboxSpec, +) -> Option<&DriverGpuResourceRequirement> { + spec.resource_requirements + .as_ref() + .and_then(|requirements| requirements.gpu.as_ref()) +} + #[cfg(test)] impl PodmanComputeDriver { pub(crate) fn for_tests(config: PodmanComputeConfig) -> Self { @@ -667,6 +712,64 @@ mod tests { assert!(matches!(err, ComputeDriverError::Message(_))); } + #[test] + fn validate_gpu_request_rejects_count() { + let err = PodmanComputeDriver::validate_gpu_request( + Some(&DriverGpuResourceRequirement { count: Some(2) }), + &[], + ) + .expect_err("GPU count should be rejected"); + + assert!( + matches!(err, ComputeDriverError::Precondition(message) if message.contains("does not support GPU count")) + ); + } + + #[test] + fn validate_gpu_request_rejects_device_ids_without_count() { + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + + let err = PodmanComputeDriver::validate_gpu_request( + Some(&DriverGpuResourceRequirement { count: None }), + &device_ids, + ) + .expect_err("device IDs without count should be rejected"); + + assert!( + matches!(err, ComputeDriverError::Precondition(message) if message.contains("requires resource_requirements.gpu.count")) + ); + } + + #[test] + fn validate_gpu_request_rejects_device_ids_with_zero_count() { + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + + let err = PodmanComputeDriver::validate_gpu_request( + Some(&DriverGpuResourceRequirement { count: Some(0) }), + &device_ids, + ) + .expect_err("device IDs with zero count should be rejected"); + + assert!( + matches!(err, ComputeDriverError::Precondition(message) if message.contains("must be greater than 0")) + ); + } + + #[test] + fn validate_gpu_request_rejects_device_id_count_mismatch() { + let device_ids = vec!["nvidia.com/gpu=0".to_string()]; + + let err = PodmanComputeDriver::validate_gpu_request( + Some(&DriverGpuResourceRequirement { count: Some(2) }), + &device_ids, + ) + .expect_err("device ID count mismatch should be rejected"); + + assert!( + matches!(err, ComputeDriverError::Precondition(message) if message.contains("unique entry count")) + ); + } + // ── grpc_endpoint auto-detection ─────────────────────────────────── // // PodmanComputeDriver::new() fills grpc_endpoint when it is empty. diff --git a/crates/openshell-driver-vm/README.md b/crates/openshell-driver-vm/README.md index 724bde06c..d5b0982c3 100644 --- a/crates/openshell-driver-vm/README.md +++ b/crates/openshell-driver-vm/README.md @@ -52,8 +52,10 @@ sudo -E env "PATH=$PATH" mise run gateway:vm -- --gpu ``` GPU passthrough uses VFIO and requires host support for IOMMU, root privileges -for bind/unbind operations, and a compatible sandbox image. The public GPU -overview lives in the repository `README.md`. +for bind/unbind operations, and a compatible sandbox image. Sandbox GPU requests +arrive as `resource_requirements.gpu`; the VM driver accepts the default request, +one driver-configured `gpu_device_ids` entry with a matching count of one, or a +count of one. Point the CLI at the gateway with one of: diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index 445905a1e..ad52b7101 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -24,16 +24,18 @@ use oci_client::manifest::{ }; use oci_client::secrets::RegistryAuth; use oci_client::{Reference, RegistryOperation}; +use openshell_core::gpu::validate_gpu_device_ids_count; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, }; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, - DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, - DriverSandbox as Sandbox, DriverSandboxStatus as SandboxStatus, GetCapabilitiesRequest, - GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, ListSandboxesRequest, - ListSandboxesResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, + DriverCondition as SandboxCondition, DriverGpuResourceRequirement, + DriverPlatformEvent as PlatformEvent, DriverSandbox as Sandbox, + DriverSandboxStatus as SandboxStatus, GetCapabilitiesRequest, GetCapabilitiesResponse, + GetSandboxRequest, GetSandboxResponse, ListSandboxesRequest, ListSandboxesResponse, + StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent, compute_driver_server::ComputeDriver, watch_sandboxes_event, @@ -615,11 +617,14 @@ impl VmDriver { ))); } - let spec = sandbox.spec.as_ref(); - let is_gpu = spec.is_some_and(|s| s.gpu); - let gpu_device = spec.map_or("", |s| s.gpu_device.as_str()); - let gpu_bdf = if is_gpu { - Some(self.assign_gpu_to_record(&sandbox.id, gpu_device).await?) + let gpu_device = sandbox + .spec + .as_ref() + .map(|spec| requested_gpu_device(driver_gpu_requirement(spec), spec.template.as_ref())) + .transpose()? + .flatten(); + let gpu_bdf = if let Some(gpu_device) = gpu_device { + Some(self.assign_gpu_to_record(&sandbox.id, &gpu_device).await?) } else { None }; @@ -2577,15 +2582,11 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox spec is required"))?; - if spec.gpu && !gpu_enabled { - return Err(Status::failed_precondition( - "GPU support is not enabled on this driver; start with --gpu", - )); - } - - if !spec.gpu && !spec.gpu_device.is_empty() { - return Err(Status::invalid_argument("gpu_device requires gpu=true")); - } + validate_gpu_request( + driver_gpu_requirement(spec), + spec.template.as_ref(), + gpu_enabled, + )?; if let Some(template) = spec.template.as_ref() { if !template.agent_socket_path.is_empty() { @@ -2628,6 +2629,112 @@ fn validate_sandbox_id(sandbox_id: &str) -> Result<(), Status> { Ok(()) } +fn driver_gpu_requirement( + spec: &openshell_core::proto::compute::v1::DriverSandboxSpec, +) -> Option<&DriverGpuResourceRequirement> { + spec.resource_requirements + .as_ref() + .and_then(|requirements| requirements.gpu.as_ref()) +} + +#[allow(clippy::result_large_err)] +fn requested_gpu_device( + gpu: Option<&DriverGpuResourceRequirement>, + template: Option<&openshell_core::proto::compute::v1::DriverSandboxTemplate>, +) -> Result, Status> { + let Some(_) = gpu else { + return Ok(None); + }; + let should_use_device_ids = gpu.is_some_and(|gpu| gpu.count.is_some_and(|count| count > 0)); + let configured = template + .and_then(|template| template.driver_config.as_ref()) + .map(vm_gpu_device_ids_from_driver_config) + .transpose()? + .unwrap_or_default(); + if should_use_device_ids { + Ok(Some(configured.first().cloned().unwrap_or_default())) + } else { + Ok(Some(String::new())) + } +} + +#[allow(clippy::result_large_err)] +fn validate_gpu_request( + gpu: Option<&DriverGpuResourceRequirement>, + template: Option<&openshell_core::proto::compute::v1::DriverSandboxTemplate>, + gpu_enabled: bool, +) -> Result<(), Status> { + if gpu.is_some() && !gpu_enabled { + return Err(Status::failed_precondition( + "GPU support is not enabled on this driver; start with --gpu", + )); + } + + let gpu_device_ids = template + .and_then(|template| template.driver_config.as_ref()) + .map(vm_gpu_device_ids_from_driver_config) + .transpose()? + .unwrap_or_default(); + if gpu.is_none() && !gpu_device_ids.is_empty() { + return Err(Status::invalid_argument( + "template.driver_config.gpu_device_ids requires resource_requirements.gpu.count", + )); + } + if let Some(gpu) = gpu + && gpu.count == Some(0) + { + return Err(Status::invalid_argument( + "resource_requirements.gpu.count must be greater than 0", + )); + } + if !gpu_device_ids.is_empty() { + validate_gpu_device_ids_count(gpu, &gpu_device_ids).map_err(Status::invalid_argument)?; + } + if gpu.is_some_and(|gpu| gpu.count.is_some_and(|count| count > 1)) { + return Err(Status::invalid_argument( + "vm compute driver supports at most one GPU", + )); + } + if gpu_device_ids.len() > 1 { + return Err(Status::invalid_argument( + "vm compute driver supports at most one GPU device ID", + )); + } + Ok(()) +} + +#[allow(clippy::result_large_err)] +fn vm_gpu_device_ids_from_driver_config( + driver_config: &prost_types::Struct, +) -> Result, Status> { + use prost_types::value::Kind; + + if driver_config.fields.is_empty() { + return Ok(Vec::new()); + } + let Some(value) = driver_config.fields.get("gpu_device_ids") else { + return Ok(Vec::new()); + }; + let Some(Kind::ListValue(list)) = value.kind.as_ref() else { + return Err(Status::invalid_argument( + "driver_config.gpu_device_ids must be a list of strings", + )); + }; + + list.values + .iter() + .enumerate() + .map(|(idx, value)| match value.kind.as_ref() { + Some(Kind::StringValue(gpu_device)) if !gpu_device.trim().is_empty() => { + Ok(gpu_device.clone()) + } + _ => Err(Status::invalid_argument(format!( + "driver_config.gpu_device_ids[{idx}] must be a non-empty string" + ))), + }) + .collect() +} + #[allow(clippy::result_large_err)] fn parse_registry_reference(image_ref: &str) -> Result { Reference::try_from(image_ref).map_err(|err| { @@ -4412,6 +4519,7 @@ mod tests { PROGRESS_COMPLETE_STEP_KEY, }; use openshell_core::proto::compute::v1::{ + DriverGpuResourceRequirement, DriverSandboxResourceRequirements, DriverSandboxSpec as SandboxSpec, DriverSandboxTemplate as SandboxTemplate, }; use prost_types::{Struct, Value, value::Kind}; @@ -4424,6 +4532,31 @@ mod tests { static ENV_LOCK: std::sync::LazyLock> = std::sync::LazyLock::new(|| std::sync::Mutex::new(())); + fn gpu_resource_requirements(count: Option) -> DriverSandboxResourceRequirements { + DriverSandboxResourceRequirements { + gpu: Some(DriverGpuResourceRequirement { count }), + } + } + + fn vm_gpu_device_ids_config(gpu_devices: &[&str]) -> Struct { + Struct { + fields: std::iter::once(( + "gpu_device_ids".to_string(), + Value { + kind: Some(Kind::ListValue(prost_types::ListValue { + values: gpu_devices + .iter() + .map(|gpu_device| Value { + kind: Some(Kind::StringValue((*gpu_device).to_string())), + }) + .collect(), + })), + }, + )) + .collect(), + } + } + #[test] fn vm_pulling_layer_event_adds_progress_detail_metadata() { let mut event = platform_event( @@ -4491,7 +4624,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resource_requirements(None)), ..Default::default() }), ..Default::default() @@ -4507,7 +4640,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resource_requirements(None)), ..Default::default() }), ..Default::default() @@ -4516,20 +4649,181 @@ mod tests { } #[test] - fn validate_vm_sandbox_rejects_gpu_device_without_gpu() { + fn validate_vm_sandbox_accepts_gpu_count_one_when_enabled() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resource_requirements(Some(1))), + ..Default::default() + }), + ..Default::default() + }; + validate_vm_sandbox(&sandbox, true).expect("gpu count one should be accepted"); + } + + #[test] + fn validate_vm_sandbox_rejects_gpu_count_greater_than_one() { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: false, - gpu_device: "0000:2d:00.0".to_string(), + resource_requirements: Some(gpu_resource_requirements(Some(2))), + ..Default::default() + }), + ..Default::default() + }; + let err = + validate_vm_sandbox(&sandbox, true).expect_err("gpu count > 1 should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("at most one GPU")); + } + + #[test] + fn validate_vm_sandbox_accepts_gpu_count_with_matching_driver_config_gpu_device_ids() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resource_requirements(Some(1))), + template: Some(SandboxTemplate { + driver_config: Some(vm_gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + + validate_vm_sandbox(&sandbox, true) + .expect("gpu count with matching explicit device should be accepted"); + } + + #[test] + fn validate_vm_sandbox_rejects_driver_config_gpu_device_ids_without_gpu_request() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + template: Some(SandboxTemplate { + driver_config: Some(vm_gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, true) + .expect_err("gpu device without gpu request should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!( + err.message() + .contains("requires resource_requirements.gpu.count") + ); + } + + #[test] + fn validate_vm_sandbox_rejects_driver_config_gpu_device_id_count_mismatch() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resource_requirements(Some(2))), + template: Some(SandboxTemplate { + driver_config: Some(vm_gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), ..Default::default() }), ..Default::default() }; let err = validate_vm_sandbox(&sandbox, true) - .expect_err("gpu_device without gpu should be rejected"); + .expect_err("GPU device ID count mismatch should be rejected"); assert_eq!(err.code(), Code::InvalidArgument); - assert!(err.message().contains("gpu_device requires gpu=true")); + assert!(err.message().contains("unique entry count")); + } + + #[test] + fn validate_vm_sandbox_rejects_driver_config_gpu_device_ids_with_zero_count() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resource_requirements(Some(0))), + template: Some(SandboxTemplate { + driver_config: Some(vm_gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, true) + .expect_err("GPU device IDs with zero count should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("must be greater than 0")); + } + + #[test] + fn validate_vm_sandbox_rejects_multiple_driver_config_gpu_device_ids() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resource_requirements(Some(2))), + template: Some(SandboxTemplate { + driver_config: Some(vm_gpu_device_ids_config(&[ + "0000:2d:00.0", + "0000:3d:00.0", + ])), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + let err = validate_vm_sandbox(&sandbox, true) + .expect_err("multiple GPU device IDs should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("at most one GPU")); + } + + #[test] + fn vm_gpu_device_ids_from_driver_config_ignores_unrelated_fields() { + let mut config = vm_gpu_device_ids_config(&["0000:2d:00.0"]); + config.fields.insert( + "future_field".to_string(), + Value { + kind: Some(Kind::StringValue("ignored".to_string())), + }, + ); + + assert_eq!( + vm_gpu_device_ids_from_driver_config(&config).unwrap(), + vec!["0000:2d:00.0".to_string()] + ); + } + + #[test] + fn requested_gpu_device_returns_none_without_gpu_request() { + assert_eq!(requested_gpu_device(None, None).unwrap(), None); + } + + #[test] + fn requested_gpu_device_defaults_empty_request_to_inventory_choice() { + let gpu = DriverGpuResourceRequirement { count: None }; + + assert_eq!( + requested_gpu_device(Some(&gpu), None).unwrap(), + Some(String::new()) + ); + } + + #[test] + fn requested_gpu_device_returns_driver_config_gpu_device_id() { + let gpu = DriverGpuResourceRequirement { count: Some(1) }; + let template = SandboxTemplate { + driver_config: Some(vm_gpu_device_ids_config(&["0000:2d:00.0"])), + ..Default::default() + }; + + assert_eq!( + requested_gpu_device(Some(&gpu), Some(&template)).unwrap(), + Some("0000:2d:00.0".to_string()) + ); } #[test] diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 0122f9178..8413914bf 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -17,8 +17,9 @@ use crate::tracing_bus::TracingLogBus; use futures::{Stream, StreamExt}; use openshell_core::ComputeDriverKind; use openshell_core::proto::compute::v1::{ - CreateSandboxRequest, DeleteSandboxRequest, DriverCondition, DriverPlatformEvent, - DriverResourceRequirements, DriverSandbox, DriverSandboxSpec, DriverSandboxStatus, + CreateSandboxRequest, DeleteSandboxRequest, DriverCondition, DriverGpuResourceRequirement, + DriverPlatformEvent, DriverResourceRequirements, DriverSandbox, + DriverSandboxResourceRequirements, DriverSandboxSpec, DriverSandboxStatus, DriverSandboxTemplate, GetCapabilitiesRequest, GetSandboxRequest, ListSandboxesRequest, ValidateSandboxCreateRequest, WatchSandboxesEvent, WatchSandboxesRequest, compute_driver_client::ComputeDriverClient, compute_driver_server::ComputeDriver, @@ -425,7 +426,7 @@ impl ComputeRuntime { } pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), Status> { - let driver_sandbox = driver_sandbox_from_public(sandbox); + let driver_sandbox = driver_sandbox_from_public(sandbox, self.driver_kind)?; self.driver .validate_sandbox_create(Request::new(ValidateSandboxCreateRequest { sandbox: Some(driver_sandbox), @@ -469,7 +470,7 @@ impl ComputeRuntime { } })?; - let mut driver_sandbox = driver_sandbox_from_public(&sandbox); + let mut driver_sandbox = driver_sandbox_from_public(&sandbox, self.driver_kind)?; if let Some(token) = sandbox_token && let Some(spec) = driver_sandbox.spec.as_mut() { @@ -551,12 +552,11 @@ impl ComputeRuntime { self.sandbox_watch_bus.notify(&id); self.cleanup_sandbox_owned_records(&sandbox).await; - let driver_sandbox = driver_sandbox_from_public(&sandbox); let deleted = self .driver .delete_sandbox(Request::new(DeleteSandboxRequest { - sandbox_id: driver_sandbox.id, - sandbox_name: driver_sandbox.name, + sandbox_id: sandbox.object_id().to_string(), + sandbox_name: sandbox.object_name().to_string(), })) .await .map(|response| response.into_inner().deleted) @@ -1249,38 +1249,92 @@ impl ComputeRuntime { } } -fn driver_sandbox_from_public(sandbox: &Sandbox) -> DriverSandbox { - DriverSandbox { +#[allow(clippy::result_large_err)] +fn driver_sandbox_from_public( + sandbox: &Sandbox, + driver_kind: Option, +) -> Result { + Ok(DriverSandbox { id: sandbox.object_id().to_string(), name: sandbox.object_name().to_string(), namespace: String::new(), // Namespace is set by the driver based on its config - spec: sandbox.spec.as_ref().map(driver_sandbox_spec_from_public), + spec: sandbox + .spec + .as_ref() + .map(|spec| driver_sandbox_spec_from_public(spec, driver_kind)) + .transpose()?, status: sandbox.status.as_ref().map(driver_status_from_public), - } + }) } -fn driver_sandbox_spec_from_public(spec: &SandboxSpec) -> DriverSandboxSpec { - DriverSandboxSpec { +#[allow(clippy::result_large_err)] +fn driver_sandbox_spec_from_public( + spec: &SandboxSpec, + driver_kind: Option, +) -> Result { + Ok(DriverSandboxSpec { log_level: spec.log_level.clone(), environment: spec.environment.clone(), template: spec .template .as_ref() - .map(driver_sandbox_template_from_public), - gpu: spec.gpu, - gpu_device: spec.gpu_device.clone(), + .map(|template| driver_sandbox_template_from_public(template, driver_kind)) + .transpose()?, + resource_requirements: spec + .resource_requirements + .as_ref() + .map(|requirements| driver_resource_requirements_from_public(*requirements)), sandbox_token: String::new(), + }) +} + +fn driver_resource_requirements_from_public( + requirements: openshell_core::proto::SandboxResourceRequirements, +) -> DriverSandboxResourceRequirements { + DriverSandboxResourceRequirements { + gpu: requirements + .gpu + .as_ref() + .map(|gpu| DriverGpuResourceRequirement { count: gpu.count }), } } -fn driver_sandbox_template_from_public(template: &SandboxTemplate) -> DriverSandboxTemplate { - DriverSandboxTemplate { +#[allow(clippy::result_large_err)] +fn driver_sandbox_template_from_public( + template: &SandboxTemplate, + driver_kind: Option, +) -> Result { + Ok(DriverSandboxTemplate { image: template.image.clone(), agent_socket_path: template.agent_socket.clone(), labels: template.labels.clone(), environment: template.environment.clone(), resources: extract_typed_resources(&template.resources), platform_config: build_platform_config(template), + driver_config: select_driver_config(&template.driver_config, driver_kind)?, + }) +} + +#[allow(clippy::result_large_err)] +fn select_driver_config( + driver_config: &Option, + driver_kind: Option, +) -> Result, Status> { + let Some(driver_kind) = driver_kind else { + return Ok(None); + }; + let Some(config) = driver_config.as_ref() else { + return Ok(None); + }; + let Some(value) = config.fields.get(driver_kind.as_str()) else { + return Ok(None); + }; + match value.kind.as_ref() { + Some(prost_types::value::Kind::StructValue(inner)) => Ok(Some(inner.clone())), + _ => Err(Status::invalid_argument(format!( + "template.driver_config.{} must be an object", + driver_kind.as_str() + ))), } } @@ -1623,7 +1677,12 @@ fn derive_phase(status: Option<&DriverSandboxStatus>) -> SandboxPhase { } fn rewrite_user_facing_conditions(status: &mut Option, spec: Option<&SandboxSpec>) { - let gpu_requested = spec.is_some_and(|sandbox_spec| sandbox_spec.gpu); + let gpu_requested = spec.is_some_and(|sandbox_spec| { + sandbox_spec + .resource_requirements + .as_ref() + .is_some_and(|requirements| requirements.gpu.is_some()) + }); if !gpu_requested { return; } @@ -1785,6 +1844,7 @@ mod tests { CreateSandboxResponse, DeleteSandboxResponse, GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateResponse, }; + use openshell_core::proto::{GpuResourceRequirement, SandboxResourceRequirements}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; @@ -1801,6 +1861,91 @@ mod tests { } } + #[test] + fn driver_sandbox_spec_from_public_selects_matching_driver_config_block() { + let public = SandboxSpec { + template: Some(SandboxTemplate { + driver_config: Some(prost_types::Struct { + fields: [ + ( + "docker".to_string(), + struct_value([( + "gpu_device_ids", + prost_types::Value { + kind: Some(prost_types::value::Kind::ListValue( + prost_types::ListValue { + values: vec![string_value("nvidia.com/gpu=0")], + }, + )), + }, + )]), + ), + ( + "vm".to_string(), + struct_value([( + "gpu_device_ids", + prost_types::Value { + kind: Some(prost_types::value::Kind::ListValue( + prost_types::ListValue { + values: vec![string_value("0")], + }, + )), + }, + )]), + ), + ] + .into_iter() + .collect(), + }), + ..Default::default() + }), + ..Default::default() + }; + + let driver = + driver_sandbox_spec_from_public(&public, Some(ComputeDriverKind::Docker)).unwrap(); + + let config = driver + .template + .expect("driver template should be present") + .driver_config + .expect("driver config should be selected"); + let device_ids = config + .fields + .get("gpu_device_ids") + .and_then(|value| match value.kind.as_ref() { + Some(prost_types::value::Kind::ListValue(list)) => list.values.first(), + _ => None, + }) + .and_then(|value| match value.kind.as_ref() { + Some(prost_types::value::Kind::StringValue(value)) => Some(value.as_str()), + _ => None, + }); + assert_eq!(device_ids, Some("nvidia.com/gpu=0")); + } + + #[test] + fn driver_sandbox_spec_from_public_preserves_gpu_count() { + let public = SandboxSpec { + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { count: Some(2) }), + }), + ..Default::default() + }; + + let driver = driver_sandbox_spec_from_public(&public, None).unwrap(); + + assert_eq!( + driver + .resource_requirements + .expect("driver resource requirements should be present") + .gpu + .expect("driver GPU requirement should be present") + .count, + Some(2) + ); + } + fn struct_value( fields: impl IntoIterator, prost_types::Value)>, ) -> prost_types::Value { @@ -2258,7 +2403,9 @@ mod tests { rewrite_user_facing_conditions( &mut status, Some(&SandboxSpec { - gpu: true, + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { count: None }), + }), ..Default::default() }), ); @@ -2286,13 +2433,7 @@ mod tests { ..Default::default() }); - rewrite_user_facing_conditions( - &mut status, - Some(&SandboxSpec { - gpu: false, - ..Default::default() - }), - ); + rewrite_user_facing_conditions(&mut status, Some(&SandboxSpec::default())); assert_eq!(status.unwrap().conditions[0].message, original); } @@ -2571,7 +2712,9 @@ mod tests { let sandbox = Sandbox { spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { count: None }), + }), ..Default::default() }), ..sandbox_record("sb-1", "sandbox-a", SandboxPhase::Provisioning) @@ -2594,7 +2737,11 @@ mod tests { SandboxPhase::try_from(stored.phase()).unwrap(), SandboxPhase::Ready ); - assert!(stored.spec.as_ref().is_some_and(|spec| spec.gpu)); + assert!(stored.spec.as_ref().is_some_and(|spec| { + spec.resource_requirements + .as_ref() + .is_some_and(|requirements| requirements.gpu.is_some()) + })); } #[tokio::test] diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 198d5f04c..dec84c4e9 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -99,7 +99,9 @@ fn emit_sandbox_create_telemetry( }; openshell_core::telemetry::emit_sandbox_create( outcome, - spec.gpu, + spec.resource_requirements + .as_ref() + .is_some_and(|requirements| requirements.gpu.is_some()), spec.providers.len() as u64, spec.policy.is_some(), template_source, diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 53f292053..d6afec1b6 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -131,6 +131,11 @@ pub(super) fn validate_sandbox_spec( validate_sandbox_template(tmpl)?; } + // --- spec.resource_requirements --- + if let Some(ref requirements) = spec.resource_requirements { + validate_resource_requirements(*requirements)?; + } + // --- spec.policy serialized size --- if let Some(ref policy) = spec.policy { let size = policy.encoded_len(); @@ -144,6 +149,26 @@ pub(super) fn validate_sandbox_spec( Ok(()) } +fn validate_resource_requirements( + requirements: openshell_core::proto::SandboxResourceRequirements, +) -> Result<(), Status> { + if let Some(gpu) = requirements.gpu { + validate_gpu_requirement(gpu)?; + } + Ok(()) +} + +fn validate_gpu_requirement( + gpu: openshell_core::proto::GpuResourceRequirement, +) -> Result<(), Status> { + if gpu.count == Some(0) { + return Err(Status::invalid_argument( + "resource_requirements.gpu.count must be greater than 0", + )); + } + Ok(()) +} + /// Validate template-level field sizes. fn validate_sandbox_template(tmpl: &SandboxTemplate) -> Result<(), Status> { // String fields. @@ -200,6 +225,14 @@ fn validate_sandbox_template(tmpl: &SandboxTemplate) -> Result<(), Status> { ))); } } + if let Some(ref s) = tmpl.driver_config { + let size = s.encoded_len(); + if size > MAX_TEMPLATE_STRUCT_SIZE { + return Err(Status::invalid_argument(format!( + "template.driver_config serialized size exceeds maximum ({size} > {MAX_TEMPLATE_STRUCT_SIZE})" + ))); + } + } Ok(()) } @@ -661,7 +694,10 @@ pub(super) fn level_matches(log_level: &str, min_level: &str) -> bool { #[cfg(test)] mod tests { use super::*; - use openshell_core::proto::SandboxSpec; + use openshell_core::proto::{ + GpuResourceRequirement, SandboxResourceRequirements, SandboxSpec, SandboxTemplate, + }; + use prost_types::{Struct, Value, value::Kind}; use std::collections::HashMap; use tonic::Code; @@ -687,12 +723,61 @@ mod tests { #[test] fn validate_sandbox_spec_accepts_gpu_flag() { let spec = SandboxSpec { - gpu: true, + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { count: None }), + }), ..Default::default() }; assert!(validate_sandbox_spec("gpu-sandbox", &spec).is_ok()); } + #[test] + fn validate_sandbox_spec_accepts_gpu_count() { + let spec = SandboxSpec { + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { count: Some(2) }), + }), + ..Default::default() + }; + assert!(validate_sandbox_spec("gpu-count-sandbox", &spec).is_ok()); + } + + #[test] + fn validate_sandbox_spec_rejects_zero_gpu_count() { + let spec = SandboxSpec { + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { count: Some(0) }), + }), + ..Default::default() + }; + + let err = validate_sandbox_spec("gpu-count-sandbox", &spec).unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("count must be greater than 0")); + } + + #[test] + fn validate_sandbox_spec_accepts_driver_config() { + let spec = SandboxSpec { + template: Some(SandboxTemplate { + driver_config: Some(Struct { + fields: std::iter::once(( + "docker".to_string(), + Value { + kind: Some(Kind::StructValue(Struct::default())), + }, + )) + .collect(), + }), + ..Default::default() + }), + ..Default::default() + }; + + assert!(validate_sandbox_spec("driver-config-sandbox", &spec).is_ok()); + } + #[test] fn validate_sandbox_spec_accepts_empty_defaults() { assert!(validate_sandbox_spec("", &default_spec()).is_ok()); diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index 512abfd3d..fb530ecfa 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -51,10 +51,32 @@ To request GPU resources, add `--gpu`: openshell sandbox create --gpu -- claude ``` +Request a specific number of GPUs with `--gpu-count`: + +```shell +openshell sandbox create --gpu-count 2 -- claude +``` + +Request a specific driver-native device with `--gpu-device`: + +```shell +openshell sandbox create --gpu-device nvidia.com/gpu=0 -- claude +``` + For Docker-backed sandboxes, GPU injection uses Docker CDI. If you enable Docker CDI after the gateway starts, restart the gateway so OpenShell can detect the updated Docker daemon capability. +Kubernetes gateways honor `--gpu-count` by setting the `nvidia.com/gpu` resource +limit. Docker and Podman support explicit CDI device IDs through `--gpu-device` +but do not support count-based selection yet. The CLI sets the portable GPU +count to match the requested device ID. VM gateways accept only one GPU. In the +API, portable GPU presence and count populate +`SandboxSpec.resource_requirements.gpu`. Exact device selection is passed as +driver-owned `template.driver_config.gpu_device_ids`. Drivers that support +exact selection require the number of unique `gpu_device_ids` entries to match +`resource_requirements.gpu.count`. + ### Custom Containers Use `--from` to create a sandbox from the base image, another pre-built sandbox name, a local directory, or a container image: diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index 610d491c7..c062d52a4 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -83,12 +83,8 @@ message DriverSandboxSpec { map environment = 5; // Runtime template consumed by the driver during provisioning. DriverSandboxTemplate template = 6; - // Request NVIDIA GPU resources for this sandbox. - bool gpu = 9; - // Optional PCI BDF address (e.g. "0000:2d:00.0") or device index - // (e.g. "0", "1"). When empty with gpu=true, the driver assigns the - // first available GPU. - string gpu_device = 10; + // Portable resource requirements for this sandbox. + DriverSandboxResourceRequirements resource_requirements = 9; // Gateway-minted JWT identifying this sandbox to the gateway. Set by // the gateway on create; the driver materialises it via its native // secret mechanism (Docker/Podman/VM bind-mount a per-sandbox file; @@ -98,6 +94,22 @@ message DriverSandboxSpec { string sandbox_token = 11; } +// Driver-owned resource requirements for the sandbox workload. +message DriverSandboxResourceRequirements { + // GPU requirement for the sandbox. Presence indicates a GPU request. + DriverGpuResourceRequirement gpu = 1; +} + +// Driver-owned GPU resource requirement. +message DriverGpuResourceRequirement { + reserved 2; + reserved "device_ids"; + + // Optional number of GPUs requested. When unset, presence means the driver + // chooses its default GPU assignment behavior. + optional uint32 count = 1; +} + // Driver-owned runtime template consumed by the compute platform. // // This message describes the sandbox workload in backend-neutral terms. @@ -121,6 +133,9 @@ message DriverSandboxTemplate { // For the Kubernetes driver this carries fields such as runtimeClassName, // annotations, and volumeClaimTemplates. google.protobuf.Struct platform_config = 11; + // Caller-provided config for the selected driver only. This is the inner + // block from public SandboxTemplate.driver_config after gateway selection. + google.protobuf.Struct driver_config = 12; } // Typed compute-resource requirements. diff --git a/proto/openshell.proto b/proto/openshell.proto index f9b64618b..e11891774 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -315,18 +315,24 @@ message SandboxSpec { openshell.sandbox.v1.SandboxPolicy policy = 7; // Provider names to attach to this sandbox. repeated string providers = 8; - // Request NVIDIA GPU resources for this sandbox. - bool gpu = 9; - // Optional PCI BDF address (e.g. "0000:2d:00.0") or device index - // (e.g. "0", "1"). When empty with gpu=true, the driver assigns the - // first available GPU. - string gpu_device = 10; - // Field 11 was `proposal_approval_mode`. The approval mode is now a - // runtime setting (gateway or sandbox scope) read via UpdateConfig / - // GetSandboxConfig, so it can be flipped on a running sandbox and - // managed fleet-wide. - reserved 11; - reserved "proposal_approval_mode"; + // Portable resource requirements for this sandbox. + SandboxResourceRequirements resource_requirements = 9; +} + +// Public resource requirements for the sandbox workload. +message SandboxResourceRequirements { + // GPU requirement for the sandbox. Presence indicates a GPU request. + GpuResourceRequirement gpu = 1; +} + +// Public GPU resource requirement. +message GpuResourceRequirement { + reserved 2; + reserved "device_ids"; + + // Optional number of GPUs requested. When unset, presence means the driver + // chooses its default GPU assignment behavior. + optional uint32 count = 1; } // Public sandbox template mapped onto compute-driver template inputs. @@ -353,6 +359,10 @@ message SandboxTemplate { // available (beta through 1.35, GA in 1.36+) and a supporting runtime. // When unset, the cluster-wide default is used. optional bool user_namespaces = 10; + // Opaque driver-specific configuration provided by the caller. The gateway + // selects the block matching the active driver name and forwards only that + // inner block to the selected compute driver. + google.protobuf.Struct driver_config = 11; } // User-facing sandbox status derived by the gateway from compute-driver observations.