diff --git a/architecture/compute-runtimes.md b/architecture/compute-runtimes.md index 02891c03e..31356a86f 100644 --- a/architecture/compute-runtimes.md +++ b/architecture/compute-runtimes.md @@ -40,6 +40,9 @@ template resource limits. Docker and Podman apply them as runtime limits. Kubernetes mirrors each limit into the matching request. VM accepts the fields but currently ignores them. +GPU requests enter the driver layer through +`SandboxSpec.resource_requirements.gpu`. + VM runtime state paths are derived only from driver-validated sandbox IDs matching `[A-Za-z0-9._-]{1,128}`. The gateway-owned VM driver socket uses a private `run/` directory plus Unix peer UID/PID checks. Standalone @@ -77,7 +80,9 @@ users. Custom sandbox images must include the agent runtime and any system dependencies, but they should not need to include the gateway. GPU-capable images must include the user-space libraries required by the workload. The -runtime still owns GPU device injection. +runtime still owns GPU device injection. GPU requests can include explicit +driver-native device IDs or a requested count; the gateway validates the public +request shape and each runtime enforces the GPU allocation modes it supports. ## Deployment Shape diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 2254f0c89..5e6d1fbdd 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -1215,10 +1215,15 @@ enum SandboxCommands { /// Target a driver-specific GPU device. Docker and Podman use CDI device IDs /// (for example "nvidia.com/gpu=0"); VM uses a PCI BDF or index. - /// Only valid with --gpu. When omitted with --gpu, the driver uses its default GPU selection. - #[arg(long, requires = "gpu")] + /// Specifying --gpu-device also requests GPU resources. + /// When omitted with --gpu, the driver uses its default GPU selection. + #[arg(long, conflicts_with = "gpu_count")] gpu_device: Option, + /// Request a specific number of GPUs. Mutually exclusive with --gpu-device. + #[arg(long, value_parser = clap::value_parser!(u32).range(1..), conflicts_with = "gpu_device")] + gpu_count: Option, + /// CPU limit for the sandbox (for example: 500m, 1, 2.5). #[arg(long)] cpu: Option, @@ -2539,6 +2544,7 @@ async fn main() -> Result<()> { editor, gpu, gpu_device, + gpu_count, cpu, memory, providers, @@ -2608,6 +2614,7 @@ async fn main() -> Result<()> { keep, gpu, gpu_device.as_deref(), + gpu_count, cpu.as_deref(), memory.as_deref(), editor, @@ -4287,6 +4294,78 @@ mod tests { } } + #[test] + fn sandbox_create_gpu_device_parses_without_gpu_flag() { + let cli = Cli::try_parse_from([ + "openshell", + "sandbox", + "create", + "--gpu-device", + "nvidia.com/gpu=0", + ]) + .expect("sandbox create --gpu-device should parse without --gpu"); + + match cli.command { + Some(Commands::Sandbox { + command: + Some(SandboxCommands::Create { + gpu, gpu_device, .. + }), + .. + }) => { + assert!(!gpu); + assert_eq!(gpu_device.as_deref(), Some("nvidia.com/gpu=0")); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn sandbox_create_gpu_count_parses_without_gpu_flag() { + let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu-count", "2"]) + .expect("sandbox create --gpu-count should parse"); + + match cli.command { + Some(Commands::Sandbox { + command: Some(SandboxCommands::Create { gpu, gpu_count, .. }), + .. + }) => { + assert!(!gpu); + assert_eq!(gpu_count, Some(2)); + } + other => panic!("expected SandboxCommands::Create, got: {other:?}"), + } + } + + #[test] + fn sandbox_create_gpu_count_rejects_zero() { + let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu-count", "0"]); + + assert!( + result.is_err(), + "sandbox create --gpu-count 0 should be rejected" + ); + } + + #[test] + fn sandbox_create_gpu_count_conflicts_with_gpu_device() { + let result = Cli::try_parse_from([ + "openshell", + "sandbox", + "create", + "--gpu", + "--gpu-device", + "nvidia.com/gpu=0", + "--gpu-count", + "2", + ]); + + assert!( + result.is_err(), + "sandbox create should reject --gpu-count with --gpu-device" + ); + } + #[test] fn service_expose_accepts_positional_target_port_and_service() { let cli = Cli::try_parse_from([ diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 9988d46db..e5ad32ce2 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -39,17 +39,18 @@ use openshell_core::proto::{ GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, GetProviderProfileRequest, GetProviderRefreshStatusRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, - GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, HealthRequest, - ImportProviderProfilesRequest, LintProviderProfilesRequest, ListProviderProfilesRequest, - ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxProvidersRequest, - ListSandboxesRequest, ListServicesRequest, PlatformEvent, PolicySource, PolicyStatus, Provider, - ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy, ProviderProfile, - ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, - RevokeSshSessionRequest, RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy, - SandboxSpec, SandboxTemplate, ServiceEndpointResponse, SetClusterInferenceRequest, - SettingScope, SettingValue, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, - UpdateConfigRequest, UpdateProviderRequest, WatchSandboxRequest, exec_sandbox_event, - setting_value, tcp_forward_init, + GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, GpuResourceRequirement, + HealthRequest, ImportProviderProfilesRequest, LintProviderProfilesRequest, + ListProviderProfilesRequest, ListProvidersRequest, ListSandboxPoliciesRequest, + ListSandboxProvidersRequest, ListSandboxesRequest, ListServicesRequest, PlatformEvent, + PolicySource, PolicyStatus, Provider, ProviderCredentialRefreshStatus, + ProviderCredentialRefreshStrategy, ProviderProfile, ProviderProfileDiagnostic, + ProviderProfileImportItem, RejectDraftChunkRequest, RevokeSshSessionRequest, + RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy, + SandboxResourceRequirements, SandboxSpec, SandboxTemplate, ServiceEndpointResponse, + SetClusterInferenceRequest, SettingScope, SettingValue, TcpForwardFrame, TcpForwardInit, + TcpRelayTarget, UpdateConfigRequest, UpdateProviderRequest, WatchSandboxRequest, + exec_sandbox_event, setting_value, tcp_forward_init, }; use openshell_core::settings::{self, SettingValueKind}; use openshell_core::{ObjectId, ObjectName}; @@ -1679,6 +1680,7 @@ pub async fn sandbox_create( keep: bool, gpu: bool, gpu_device: Option<&str>, + gpu_count: Option, cpu: Option<&str>, memory: Option<&str>, editor: Option, @@ -1732,8 +1734,6 @@ pub async fn sandbox_create( } None => None, }; - let requested_gpu = gpu || image.as_deref().is_some_and(image_requests_gpu); - let providers_v2_enabled = gateway_providers_v2_enabled(&mut client).await?; let inferred_types: Vec = if providers_v2_enabled { Vec::new() @@ -1750,6 +1750,11 @@ pub async fn sandbox_create( let policy = load_sandbox_policy(policy)?; let resource_limits = build_sandbox_resource_limits(cpu, memory)?; + let resource_requirements = + resource_requirements_from_cli(image.as_deref(), gpu, gpu_device, gpu_count); + let requested_gpu = resource_requirements + .as_ref() + .is_some_and(|requirements| requirements.gpu.is_some()); let template = if image.is_some() || resource_limits.is_some() { Some(SandboxTemplate { @@ -1763,8 +1768,7 @@ pub async fn sandbox_create( let request = CreateSandboxRequest { spec: Some(SandboxSpec { - gpu: requested_gpu, - gpu_device: gpu_device.unwrap_or_default().to_string(), + resource_requirements, policy, providers: configured_providers, template, @@ -2189,6 +2193,29 @@ pub async fn sandbox_create( } } +fn resource_requirements_from_cli( + image: Option<&str>, + gpu: bool, + gpu_device: Option<&str>, + gpu_count: Option, +) -> Option { + let device_ids = gpu_device + .filter(|device_id| !device_id.is_empty()) + .map(|device_id| vec![device_id.to_string()]) + .unwrap_or_default(); + let requested_gpu = gpu + || gpu_count.is_some() + || !device_ids.is_empty() + || image.is_some_and(image_requests_gpu); + + requested_gpu.then_some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids, + count: gpu_count, + }), + }) +} + /// Resolved source for the `--from` flag on `sandbox create`. #[derive(Debug)] enum ResolvedSource { @@ -7444,8 +7471,8 @@ mod tests { plaintext_gateway_is_remote, progress_step_from_metadata, provider_profile_allows_refresh_bootstrap, provisioning_timeout_message, ready_false_condition_message, refresh_status_header, refresh_status_row, resolve_from, - sandbox_should_persist, sandbox_upload_plan, service_expose_status_error, - service_url_for_gateway, + resource_requirements_from_cli, sandbox_should_persist, sandbox_upload_plan, + service_expose_status_error, service_url_for_gateway, }; use crate::TEST_ENV_LOCK; use hyper::StatusCode; @@ -7924,6 +7951,67 @@ mod tests { } } + #[test] + fn resource_requirements_from_cli_uses_presence_for_default_gpu() { + let requirements = resource_requirements_from_cli(None, true, None, None) + .expect("resource requirements should be present"); + let gpu = requirements.gpu.expect("GPU requirement should be present"); + + assert!(gpu.device_ids.is_empty()); + assert_eq!(gpu.count, None); + } + + #[test] + fn resource_requirements_from_cli_maps_gpu_device_to_one_device_id() { + let requirements = resource_requirements_from_cli(None, false, Some("0000:2d:00.0"), None) + .expect("resource requirements should be present"); + let gpu = requirements.gpu.expect("GPU requirement should be present"); + + assert_eq!(gpu.device_ids, vec!["0000:2d:00.0"]); + assert_eq!(gpu.count, None); + } + + #[test] + fn resource_requirements_from_cli_maps_gpu_count() { + let requirements = resource_requirements_from_cli(None, false, None, Some(2)) + .expect("requirements should exist"); + let gpu = requirements.gpu.expect("GPU requirement should be present"); + + assert!(gpu.device_ids.is_empty()); + assert_eq!(gpu.count, Some(2)); + } + + #[test] + fn resource_requirements_from_cli_preserves_device_and_gpu_count_for_gateway_validation() { + let requirements = + resource_requirements_from_cli(None, false, Some("nvidia.com/gpu=0"), Some(2)) + .expect("requirements should exist"); + let gpu = requirements.gpu.expect("GPU requirement should be present"); + + assert_eq!(gpu.device_ids, vec!["nvidia.com/gpu=0"]); + assert_eq!(gpu.count, Some(2)); + } + + #[test] + fn resource_requirements_from_cli_omits_gpu_request_when_not_requested() { + assert!(resource_requirements_from_cli(None, false, None, None).is_none()); + } + + #[test] + fn resource_requirements_from_cli_infers_gpu_from_image() { + let requirements = resource_requirements_from_cli( + Some("ghcr.io/nvidia/openshell-community/sandboxes/nvidia-gpu:latest"), + false, + None, + None, + ) + .expect("resource requirements should be present"); + let gpu = requirements.gpu.expect("GPU requirement should be present"); + + assert!(gpu.device_ids.is_empty()); + assert_eq!(gpu.count, None); + } + #[test] fn resolve_from_classifies_existing_dockerfile_path() { let temp = tempfile::tempdir().expect("failed to create tempdir"); diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index aee91de56..ea08abb4f 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -787,6 +787,7 @@ async fn sandbox_create_keeps_command_sessions_by_default() { None, None, None, + None, &[], None, None, @@ -826,6 +827,7 @@ async fn sandbox_create_sends_cpu_and_memory_limits_only() { true, false, None, + None, Some("500m"), Some("2Gi"), None, @@ -884,6 +886,100 @@ async fn sandbox_create_sends_cpu_and_memory_limits_only() { assert!(!resources.fields.contains_key("requests")); } +#[tokio::test] +async fn sandbox_create_sends_gpu_device_request_without_gpu_flag() { + let server = run_server().await; + let fake_ssh_dir = tempfile::tempdir().unwrap(); + let xdg_dir = tempfile::tempdir().unwrap(); + let _env = test_env(&fake_ssh_dir, &xdg_dir); + let tls = test_tls(&server); + install_fake_ssh(&fake_ssh_dir); + + run::sandbox_create( + &server.endpoint, + Some("gpu-device"), + None, + "openshell", + None, + true, + false, + Some("nvidia.com/gpu=0"), + None, + None, + None, + None, + &[], + None, + None, + &["echo".to_string(), "OK".to_string()], + Some(false), + Some(false), + &HashMap::new(), + "manual", + &tls, + ) + .await + .expect("sandbox create should succeed"); + + let requests = create_requests(&server).await; + let gpu = requests[0] + .spec + .as_ref() + .and_then(|spec| spec.resource_requirements.as_ref()) + .and_then(|requirements| requirements.gpu.as_ref()) + .expect("GPU request should be sent"); + + assert_eq!(gpu.device_ids, vec!["nvidia.com/gpu=0"]); + assert_eq!(gpu.count, None); +} + +#[tokio::test] +async fn sandbox_create_sends_gpu_count_request() { + let server = run_server().await; + let fake_ssh_dir = tempfile::tempdir().unwrap(); + let xdg_dir = tempfile::tempdir().unwrap(); + let _env = test_env(&fake_ssh_dir, &xdg_dir); + let tls = test_tls(&server); + install_fake_ssh(&fake_ssh_dir); + + run::sandbox_create( + &server.endpoint, + Some("gpu-count"), + None, + "openshell", + None, + true, + false, + None, + Some(2), + None, + None, + None, + &[], + None, + None, + &["echo".to_string(), "OK".to_string()], + Some(false), + Some(false), + &HashMap::new(), + "manual", + &tls, + ) + .await + .expect("sandbox create should succeed"); + + let requests = create_requests(&server).await; + let gpu = requests[0] + .spec + .as_ref() + .and_then(|spec| spec.resource_requirements.as_ref()) + .and_then(|requirements| requirements.gpu.as_ref()) + .expect("GPU request should be sent"); + + assert!(gpu.device_ids.is_empty()); + assert_eq!(gpu.count, Some(2)); +} + #[tokio::test] async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() { let server = run_server().await; @@ -906,6 +1002,7 @@ async fn sandbox_create_does_not_infer_command_providers_when_v2_enabled() { None, None, None, + None, &[], None, None, @@ -963,6 +1060,7 @@ async fn sandbox_create_returns_vm_error_without_waiting_for_timeout() { None, None, None, + None, &[], None, None, @@ -1016,6 +1114,7 @@ async fn sandbox_create_keeps_waiting_while_vm_progress_arrives() { None, None, None, + None, &[], None, None, @@ -1061,6 +1160,7 @@ async fn sandbox_create_times_out_when_only_logs_arrive() { None, None, None, + None, &[], None, None, @@ -1102,6 +1202,7 @@ async fn sandbox_create_deletes_command_sessions_with_no_keep() { None, None, None, + None, &[], None, None, @@ -1147,6 +1248,7 @@ async fn sandbox_create_deletes_shell_sessions_with_no_keep() { None, None, None, + None, &[], None, None, @@ -1192,6 +1294,7 @@ async fn sandbox_create_keeps_sandbox_with_hidden_keep_flag() { None, None, None, + None, &[], None, None, @@ -1237,6 +1340,7 @@ async fn sandbox_create_keeps_sandbox_with_forwarding() { None, None, None, + None, &[], None, Some(openshell_core::forward::ForwardSpec::new(forward_port)), diff --git a/crates/openshell-core/src/gpu.rs b/crates/openshell-core/src/gpu.rs index 5df8702ed..59e1ca3f3 100644 --- a/crates/openshell-core/src/gpu.rs +++ b/crates/openshell-core/src/gpu.rs @@ -4,21 +4,28 @@ //! Shared GPU request helpers. use crate::config::CDI_GPU_DEVICE_ALL; +use crate::proto::compute::v1::{DriverGpuResourceRequirement, DriverSandboxSpec}; -/// Resolve the existing GPU request fields into CDI device identifiers. +/// Extract the driver GPU requirement from a sandbox spec. +#[must_use] +pub fn driver_gpu_requirement(spec: &DriverSandboxSpec) -> Option<&DriverGpuResourceRequirement> { + spec.resource_requirements + .as_ref() + .and_then(|requirements| requirements.gpu.as_ref()) +} + +/// Resolve a driver GPU request into CDI device identifiers. /// -/// `None` means no GPU was requested. A GPU request with no explicit device -/// ID uses the CDI all-GPU request; otherwise the driver-native ID passes -/// through unchanged. +/// `None` means no GPU was requested. Presence with no explicit device IDs +/// uses the CDI all-GPU request, preserving the current default GPU behavior; +/// otherwise the driver-native IDs pass through. #[must_use] -pub fn cdi_gpu_device_ids(gpu: bool, gpu_device: &str) -> Option> { - gpu.then(|| { - if gpu_device.is_empty() { - vec![CDI_GPU_DEVICE_ALL.to_string()] - } else { - vec![gpu_device.to_string()] - } - }) +pub fn cdi_gpu_device_ids(gpu: Option<&DriverGpuResourceRequirement>) -> Option> { + match gpu { + Some(gpu) if gpu.device_ids.is_empty() => Some(vec![CDI_GPU_DEVICE_ALL.to_string()]), + Some(gpu) => Some(gpu.device_ids.clone()), + None => None, + } } #[cfg(test)] @@ -27,22 +34,51 @@ mod tests { #[test] fn cdi_gpu_device_ids_returns_none_when_absent() { - assert_eq!(cdi_gpu_device_ids(false, ""), None); + assert_eq!(cdi_gpu_device_ids(None), None); } #[test] fn cdi_gpu_device_ids_defaults_empty_request_to_all_gpus() { + let request = DriverGpuResourceRequirement { + device_ids: vec![], + count: None, + }; + assert_eq!( - cdi_gpu_device_ids(true, ""), + cdi_gpu_device_ids(Some(&request)), Some(vec![CDI_GPU_DEVICE_ALL.to_string()]) ); } #[test] - fn cdi_gpu_device_ids_passes_explicit_device_id_through() { + fn cdi_gpu_device_ids_passes_single_device_id_through() { + let request = DriverGpuResourceRequirement { + device_ids: vec!["nvidia.com/gpu=0".to_string()], + count: None, + }; + assert_eq!( - cdi_gpu_device_ids(true, "nvidia.com/gpu=0"), + cdi_gpu_device_ids(Some(&request)), Some(vec!["nvidia.com/gpu=0".to_string()]) ); } + + #[test] + fn cdi_gpu_device_ids_passes_multiple_device_ids_through() { + let request = DriverGpuResourceRequirement { + device_ids: vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ], + count: None, + }; + + assert_eq!( + cdi_gpu_device_ids(Some(&request)), + Some(vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string() + ]) + ); + } } diff --git a/crates/openshell-driver-docker/README.md b/crates/openshell-driver-docker/README.md index ea57f44e4..b44c7056f 100644 --- a/crates/openshell-driver-docker/README.md +++ b/crates/openshell-driver-docker/README.md @@ -32,7 +32,7 @@ contract: | `apparmor=unconfined` | Avoids Docker's default profile blocking required mount operations. | | `restart_policy = unless-stopped` | Keeps managed sandboxes resumable across daemon or gateway restarts. | | `PidsLimit` | Enforces the sandbox PID budget at the Docker cgroup layer. Set `[openshell.drivers.docker].sandbox_pids_limit = 0` to inherit the Docker/runtime default. | -| CDI GPU request | Uses the sandbox `gpu_device` value when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. | +| CDI GPU request | Uses explicit `resource_requirements.gpu.device_ids` when set; otherwise requests all NVIDIA GPUs when `resource_requirements.gpu` is present and daemon CDI support is detected. Count-based GPU requests are rejected until Docker CDI selection can map counts to concrete devices. | The agent child process does not retain these supervisor privileges. diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index e30ee7754..4f12a3110 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -25,20 +25,20 @@ use openshell_core::driver_utils::{ LABEL_MANAGED_BY, LABEL_MANAGED_BY_VALUE, LABEL_SANDBOX_ID, LABEL_SANDBOX_NAME, LABEL_SANDBOX_NAMESPACE, SUPERVISOR_IMAGE_BINARY_PATH, }; -use openshell_core::gpu::cdi_gpu_device_ids; +use openshell_core::gpu::{cdi_gpu_device_ids, driver_gpu_requirement}; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, }; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, - DriverCondition, DriverPlatformEvent, DriverSandbox, DriverSandboxStatus, - DriverSandboxTemplate, GetCapabilitiesRequest, GetCapabilitiesResponse, GetSandboxRequest, - GetSandboxResponse, ListSandboxesRequest, ListSandboxesResponse, StopSandboxRequest, - StopSandboxResponse, ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, - WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, - WatchSandboxesRequest, WatchSandboxesSandboxEvent, compute_driver_server::ComputeDriver, - watch_sandboxes_event, + DriverCondition, DriverGpuResourceRequirement, DriverPlatformEvent, DriverSandbox, + DriverSandboxStatus, DriverSandboxTemplate, GetCapabilitiesRequest, GetCapabilitiesResponse, + GetSandboxRequest, GetSandboxResponse, ListSandboxesRequest, ListSandboxesResponse, + StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, + ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, + WatchSandboxesPlatformEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent, + compute_driver_server::ComputeDriver, watch_sandboxes_event, }; use openshell_core::{Config, Error, Result as CoreResult}; use std::collections::HashMap; @@ -375,7 +375,7 @@ impl DockerComputeDriver { "docker sandboxes require a template image", )); } - Self::validate_gpu_request(spec.gpu, config.supports_gpu)?; + Self::validate_gpu_request(driver_gpu_requirement(spec), config.supports_gpu)?; if !template.agent_socket_path.trim().is_empty() { return Err(Status::failed_precondition( "docker compute driver does not support template.agent_socket_path", @@ -409,8 +409,16 @@ impl DockerComputeDriver { )) } - fn validate_gpu_request(gpu: bool, supports_gpu: bool) -> Result<(), Status> { - if gpu && !supports_gpu { + fn validate_gpu_request( + gpu: Option<&DriverGpuResourceRequirement>, + supports_gpu: bool, + ) -> Result<(), Status> { + if gpu.is_some_and(|gpu| gpu.count.is_some()) { + return Err(Status::invalid_argument( + "docker compute driver does not support GPU count requests", + )); + } + if gpu.is_some() && !supports_gpu { return Err(Status::failed_precondition( "docker GPU sandboxes require Docker CDI support. Enable CDI on the Docker daemon, then restart the OpenShell gateway/server so GPU capability is detected.", )); @@ -1713,8 +1721,10 @@ fn build_environment(sandbox: &DriverSandbox, config: &DockerDriverRuntimeConfig .collect() } -fn docker_gpu_device_requests(gpu: bool, gpu_device: &str) -> Option> { - cdi_gpu_device_ids(gpu, gpu_device).map(|device_ids| { +fn docker_gpu_device_requests( + gpu: Option<&DriverGpuResourceRequirement>, +) -> Option> { + cdi_gpu_device_ids(gpu).map(|device_ids| { vec![DeviceRequest { driver: Some("cdi".to_string()), device_ids: Some(device_ids), @@ -1765,7 +1775,7 @@ fn build_container_create_body( nano_cpus: resource_limits.nano_cpus, memory: resource_limits.memory_bytes, pids_limit: docker_pids_limit(config.sandbox_pids_limit)?, - device_requests: docker_gpu_device_requests(spec.gpu, &spec.gpu_device), + device_requests: docker_gpu_device_requests(driver_gpu_requirement(spec)), binds: Some(build_binds(sandbox, config)?), restart_policy: Some(RestartPolicy { name: Some(RestartPolicyNameEnum::UNLESS_STOPPED), diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index c9b34ff8f..308605fae 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -13,7 +13,8 @@ use openshell_core::progress::{ PROGRESS_STEP_STARTING_SANDBOX, }; use openshell_core::proto::compute::v1::{ - DriverResourceRequirements, DriverSandboxSpec, DriverSandboxTemplate, + DriverGpuResourceRequirement, DriverResourceRequirements, DriverSandboxResourceRequirements, + DriverSandboxSpec, DriverSandboxTemplate, }; use std::fs; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; @@ -23,6 +24,15 @@ use tempfile::TempDir; const TLS_MOUNT_DIR: &str = "/etc/openshell/tls/client"; static ENV_LOCK: LazyLock> = LazyLock::new(|| Mutex::new(())); +fn gpu_resource_requirements( + device_ids: Vec, + count: Option, +) -> DriverSandboxResourceRequirements { + DriverSandboxResourceRequirements { + gpu: Some(DriverGpuResourceRequirement { device_ids, count }), + } +} + fn test_sandbox() -> DriverSandbox { // Mirrors the gateway-supplied request: the public `Sandbox` API no // longer carries `namespace`, so the gateway elides the field and the @@ -42,9 +52,8 @@ fn test_sandbox() -> DriverSandbox { resources: None, platform_config: None, }), - gpu: false, - gpu_device: String::new(), sandbox_token: String::new(), + resource_requirements: None, }), status: None, } @@ -605,7 +614,8 @@ fn build_container_create_body_clears_inherited_cmd() { fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { let config = runtime_config(); let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = true; + sandbox.spec.as_mut().unwrap().resource_requirements = + Some(gpu_resource_requirements(vec![], None)); let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); @@ -613,6 +623,20 @@ fn validate_sandbox_rejects_gpu_when_cdi_unavailable() { assert!(err.message().contains("Docker CDI")); } +#[test] +fn validate_sandbox_rejects_gpu_count() { + let mut config = runtime_config(); + config.supports_gpu = true; + let mut sandbox = test_sandbox(); + sandbox.spec.as_mut().unwrap().resource_requirements = + Some(gpu_resource_requirements(vec![], Some(2))); + + let err = DockerComputeDriver::validate_sandbox(&sandbox, &config).unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("does not support GPU count")); +} + #[test] fn validate_sandbox_auth_requires_gateway_token() { let mut sandbox = test_sandbox(); @@ -640,7 +664,8 @@ fn build_container_create_body_maps_gpu_to_all_cdi_device() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - sandbox.spec.as_mut().unwrap().gpu = true; + sandbox.spec.as_mut().unwrap().resource_requirements = + Some(gpu_resource_requirements(vec![], None)); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body @@ -658,13 +683,17 @@ fn build_container_create_body_maps_gpu_to_all_cdi_device() { } #[test] -fn build_container_create_body_passes_explicit_cdi_device_id_through() { +fn build_container_create_body_passes_explicit_cdi_device_ids_through() { let mut config = runtime_config(); config.supports_gpu = true; let mut sandbox = test_sandbox(); - let spec = sandbox.spec.as_mut().unwrap(); - spec.gpu = true; - spec.gpu_device = "nvidia.com/gpu=0".to_string(); + sandbox.spec.as_mut().unwrap().resource_requirements = Some(gpu_resource_requirements( + vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string(), + ], + None, + )); let create_body = build_container_create_body(&sandbox, &config).unwrap(); let request = create_body @@ -677,7 +706,10 @@ fn build_container_create_body_passes_explicit_cdi_device_id_through() { assert_eq!(request.driver.as_deref(), Some("cdi")); assert_eq!( request.device_ids.as_ref().unwrap(), - &vec!["nvidia.com/gpu=0".to_string()] + &vec![ + "nvidia.com/gpu=0".to_string(), + "nvidia.com/gpu=1".to_string() + ] ); } diff --git a/crates/openshell-driver-kubernetes/README.md b/crates/openshell-driver-kubernetes/README.md index 1d45a1d83..a57b7453d 100644 --- a/crates/openshell-driver-kubernetes/README.md +++ b/crates/openshell-driver-kubernetes/README.md @@ -49,7 +49,7 @@ pods do not need direct external ingress for SSH. ## GPU Support -When a sandbox requests GPU support, the driver checks node allocatable capacity -for `nvidia.com/gpu` and requests one GPU resource in the workload spec. The -sandbox image must provide the user-space libraries needed by the agent -workload. +When `resource_requirements.gpu` is present, the driver checks node allocatable +capacity for `nvidia.com/gpu` and sets the workload's `nvidia.com/gpu` resource +limit. The default is to request a single GPU. The sandbox image must +provide the user-space libraries needed by the agent workload. diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index 34ab44a2e..35df394a9 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -17,16 +17,18 @@ use kube::{Client, Error as KubeError}; use openshell_core::driver_utils::{ LABEL_MANAGED_BY, LABEL_MANAGED_BY_VALUE, LABEL_SANDBOX_ID, SUPERVISOR_IMAGE_BINARY_PATH, }; +use openshell_core::gpu::driver_gpu_requirement; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, }; use openshell_core::proto::compute::v1::{ - DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, - DriverSandbox as Sandbox, DriverSandboxSpec as SandboxSpec, - DriverSandboxStatus as SandboxStatus, DriverSandboxTemplate as SandboxTemplate, - GetCapabilitiesResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, - WatchSandboxesPlatformEvent, WatchSandboxesSandboxEvent, watch_sandboxes_event, + DriverCondition as SandboxCondition, DriverGpuResourceRequirement, + DriverPlatformEvent as PlatformEvent, DriverSandbox as Sandbox, + DriverSandboxSpec as SandboxSpec, DriverSandboxStatus as SandboxStatus, + DriverSandboxTemplate as SandboxTemplate, GetCapabilitiesResponse, WatchSandboxesDeletedEvent, + WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesSandboxEvent, + watch_sandboxes_event, }; use std::collections::BTreeMap; use std::pin::Pin; @@ -77,7 +79,25 @@ const SANDBOX_VERSION: &str = "v1alpha1"; pub const SANDBOX_KIND: &str = "Sandbox"; const GPU_RESOURCE_NAME: &str = "nvidia.com/gpu"; -const GPU_RESOURCE_QUANTITY: &str = "1"; +const DEFAULT_GPU_COUNT: u32 = 1; +const EXPLICIT_GPU_DEVICE_IDS_UNSUPPORTED_MESSAGE: &str = + "kubernetes compute driver does not support explicit GPU device IDs"; + +fn gpu_has_explicit_device_ids(gpu: Option<&DriverGpuResourceRequirement>) -> bool { + gpu.is_some_and(|gpu| !gpu.device_ids.is_empty()) +} + +#[allow(clippy::result_large_err)] +fn validate_gpu_request_shape( + gpu: Option<&DriverGpuResourceRequirement>, +) -> Result<(), tonic::Status> { + if gpu_has_explicit_device_ids(gpu) { + return Err(tonic::Status::invalid_argument( + EXPLICIT_GPU_DEVICE_IDS_UNSUPPORTED_MESSAGE, + )); + } + Ok(()) +} // --------------------------------------------------------------------------- // Default workspace persistence (temporary — will be replaced by snapshotting) @@ -203,8 +223,16 @@ impl KubernetesComputeDriver { } pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), tonic::Status> { - let gpu_requested = sandbox.spec.as_ref().is_some_and(|spec| spec.gpu); - if gpu_requested + let gpu = sandbox.spec.as_ref().and_then(driver_gpu_requirement); + self.validate_gpu_request(gpu).await + } + + async fn validate_gpu_request( + &self, + gpu: Option<&DriverGpuResourceRequirement>, + ) -> Result<(), tonic::Status> { + validate_gpu_request_shape(gpu)?; + if gpu.is_some() && !self.has_gpu_capacity().await.map_err(|err| { tonic::Status::internal(format!("check GPU node capacity failed: {err}")) })? @@ -296,6 +324,14 @@ impl KubernetesComputeDriver { } pub async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), KubernetesDriverError> { + if let Some(gpu) = sandbox.spec.as_ref().and_then(driver_gpu_requirement) + && gpu_has_explicit_device_ids(Some(gpu)) + { + return Err(KubernetesDriverError::Precondition( + EXPLICIT_GPU_DEVICE_IDS_UNSUPPORTED_MESSAGE.to_string(), + )); + } + let name = sandbox.name.as_str(); info!( sandbox_id = %sandbox.id, @@ -1105,7 +1141,13 @@ fn sandbox_to_k8s_spec( if let Some(template) = spec.template.as_ref() { root.insert( "podTemplate".to_string(), - sandbox_template_to_k8s(template, spec.gpu, &pod_env, inject_workspace, params), + sandbox_template_to_k8s( + template, + driver_gpu_requirement(spec), + &pod_env, + inject_workspace, + params, + ), ); if !template.agent_socket_path.is_empty() { root.insert( @@ -1137,7 +1179,7 @@ fn sandbox_to_k8s_spec( "podTemplate".to_string(), sandbox_template_to_k8s( &SandboxTemplate::default(), - spec.is_some_and(|s| s.gpu), + spec.and_then(driver_gpu_requirement), &pod_env, inject_workspace, params, @@ -1152,7 +1194,7 @@ fn sandbox_to_k8s_spec( fn sandbox_template_to_k8s( template: &SandboxTemplate, - gpu: bool, + gpu: Option<&DriverGpuResourceRequirement>, spec_environment: &std::collections::HashMap, inject_workspace: bool, params: &SandboxPodParams<'_>, @@ -1210,7 +1252,7 @@ fn sandbox_template_to_k8s( if use_user_namespaces { spec.insert("hostUsers".to_string(), serde_json::json!(false)); - if gpu { + if gpu.is_some() { warn!( "GPU sandbox with user namespaces enabled — \ NVIDIA device plugin compatibility is unverified" @@ -1391,7 +1433,10 @@ fn image_pull_secret_refs(secrets: &[String]) -> Vec { .collect() } -fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option { +fn container_resources( + template: &SandboxTemplate, + gpu: Option<&DriverGpuResourceRequirement>, +) -> Option { // Start from the raw resources passthrough in platform_config (preserves // custom resource types like GPU limits that users set via the public API // Struct), then overlay the typed DriverResourceRequirements on top. @@ -1424,8 +1469,8 @@ fn container_resources(template: &SandboxTemplate, gpu: bool) -> Option Option> = std::sync::LazyLock::new(|| std::sync::Mutex::new(())); + fn gpu_request(count: Option) -> DriverGpuResourceRequirement { + DriverGpuResourceRequirement { + device_ids: vec![], + count, + } + } + + fn default_gpu_request() -> DriverGpuResourceRequirement { + gpu_request(None) + } + + #[test] + fn validate_gpu_request_shape_rejects_explicit_device_ids() { + let gpu = DriverGpuResourceRequirement { + device_ids: vec!["nvidia.com/gpu=0".to_string()], + count: None, + }; + + let err = validate_gpu_request_shape(Some(&gpu)) + .expect_err("explicit GPU device IDs should be rejected"); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert_eq!(err.message(), EXPLICIT_GPU_DEVICE_IDS_UNSUPPORTED_MESSAGE); + } + #[test] fn kube_pulling_event_adds_image_progress_metadata() { let mut metadata = std::collections::HashMap::new(); @@ -2001,7 +2071,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &SandboxTemplate::default(), - true, + Some(&default_gpu_request()), &std::collections::HashMap::new(), true, ¶ms, @@ -2014,7 +2084,26 @@ mod tests { ); assert_eq!( pod_template["spec"]["containers"][0]["resources"]["limits"][GPU_RESOURCE_NAME], - serde_json::json!(GPU_RESOURCE_QUANTITY) + serde_json::json!(DEFAULT_GPU_COUNT.to_string()) + ); + } + + #[test] + fn gpu_sandbox_uses_requested_gpu_count() { + let pod_template = { + let params = SandboxPodParams::default(); + sandbox_template_to_k8s( + &SandboxTemplate::default(), + Some(&gpu_request(Some(2))), + &std::collections::HashMap::new(), + true, + ¶ms, + ) + }; + + assert_eq!( + pod_template["spec"]["containers"][0]["resources"]["limits"][GPU_RESOURCE_NAME], + serde_json::json!("2") ); } @@ -2037,7 +2126,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - true, + Some(&default_gpu_request()), &std::collections::HashMap::new(), true, ¶ms, @@ -2069,7 +2158,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2092,7 +2181,7 @@ mod tests { }; sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2127,7 +2216,7 @@ mod tests { }; sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2147,7 +2236,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2175,7 +2264,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - true, + Some(&default_gpu_request()), &std::collections::HashMap::new(), true, ¶ms, @@ -2186,7 +2275,7 @@ mod tests { assert_eq!(limits["cpu"], serde_json::json!("2")); assert_eq!( limits[GPU_RESOURCE_NAME], - serde_json::json!(GPU_RESOURCE_QUANTITY) + serde_json::json!(DEFAULT_GPU_COUNT.to_string()) ); } @@ -2206,7 +2295,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2229,7 +2318,7 @@ mod tests { }; sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2254,7 +2343,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2277,7 +2366,7 @@ mod tests { }; sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2416,7 +2505,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), false, // user provided custom VCTs ¶ms, @@ -2454,7 +2543,7 @@ mod tests { }; sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2519,7 +2608,7 @@ mod tests { let params = SandboxPodParams::default(); // cluster default is off let pod_template = sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2557,7 +2646,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2583,7 +2672,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2605,7 +2694,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2627,7 +2716,7 @@ mod tests { fn sandbox_template_omits_empty_image_pull_secrets() { let pod_template = sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, &SandboxPodParams::default(), @@ -2652,7 +2741,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &SandboxTemplate::default(), - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2681,7 +2770,7 @@ mod tests { }; let pod_template = sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), true, ¶ms, @@ -2809,7 +2898,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), false, ¶ms, @@ -2870,7 +2959,7 @@ mod tests { let params = SandboxPodParams::default(); sandbox_template_to_k8s( &template, - false, + None, &std::collections::HashMap::new(), false, ¶ms, diff --git a/crates/openshell-driver-podman/README.md b/crates/openshell-driver-podman/README.md index 77b42ba37..7bca6e653 100644 --- a/crates/openshell-driver-podman/README.md +++ b/crates/openshell-driver-podman/README.md @@ -46,7 +46,7 @@ The container spec in `container.rs` sets these security-critical fields: | `no_new_privileges` | `true` | Prevents privilege escalation after exec. | | `seccomp_profile_path` | `unconfined` | The supervisor installs its own policy-aware BPF filter. A container-level profile can block Landlock/seccomp syscalls during setup. | | `mounts` | Private tmpfs at `/run/netns` | Lets the supervisor create named network namespaces in rootless Podman. | -| CDI GPU devices | Sandbox `gpu_device` value when set, otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. | +| CDI GPU devices | Explicit `resource_requirements.gpu.device_ids` when set, otherwise all NVIDIA GPUs | Exposes requested GPUs to GPU-enabled sandbox containers. Count-based GPU requests are rejected until Podman CDI selection can map counts to concrete devices. | The restricted agent child does not retain these supervisor privileges. diff --git a/crates/openshell-driver-podman/src/container.rs b/crates/openshell-driver-podman/src/container.rs index 13f053e93..a392fd772 100644 --- a/crates/openshell-driver-podman/src/container.rs +++ b/crates/openshell-driver-podman/src/container.rs @@ -4,7 +4,7 @@ //! Container spec construction for the Podman driver. use crate::config::PodmanComputeConfig; -use openshell_core::gpu::cdi_gpu_device_ids; +use openshell_core::gpu::{cdi_gpu_device_ids, driver_gpu_requirement}; use openshell_core::proto::compute::v1::DriverSandbox; use serde::Serialize; use serde_json::Value; @@ -379,8 +379,8 @@ fn podman_pids_limit(value: i64) -> Option { /// Build CDI GPU device list if GPU is requested. fn build_devices(sandbox: &DriverSandbox) -> Option> { - let spec = sandbox.spec.as_ref()?; - cdi_gpu_device_ids(spec.gpu, &spec.gpu_device).map(|device_ids| { + let gpu = sandbox.spec.as_ref().and_then(driver_gpu_requirement); + cdi_gpu_device_ids(gpu).map(|device_ids| { device_ids .into_iter() .map(|path| LinuxDevice { path }) @@ -699,6 +699,13 @@ mod tests { static ENV_LOCK: std::sync::LazyLock> = std::sync::LazyLock::new(|| std::sync::Mutex::new(())); + fn default_gpu_request() -> openshell_core::proto::compute::v1::DriverGpuResourceRequirement { + openshell_core::proto::compute::v1::DriverGpuResourceRequirement { + device_ids: vec![], + count: None, + } + } + #[test] fn parse_cpu_millicore() { assert_eq!(parse_cpu_to_microseconds("500m"), Some(50_000)); @@ -808,11 +815,15 @@ mod tests { #[test] fn container_spec_maps_empty_gpu_request_to_all_cdi_device() { use openshell_core::config::CDI_GPU_DEVICE_ALL; - use openshell_core::proto::compute::v1::DriverSandboxSpec; + use openshell_core::proto::compute::v1::{ + DriverSandboxResourceRequirements, DriverSandboxSpec, + }; let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { - gpu: true, + resource_requirements: Some(DriverSandboxResourceRequirements { + gpu: Some(default_gpu_request()), + }), ..Default::default() }); let config = test_config(); @@ -826,12 +837,18 @@ mod tests { #[test] fn container_spec_passes_explicit_cdi_device_id_through() { - use openshell_core::proto::compute::v1::DriverSandboxSpec; + use openshell_core::proto::compute::v1::{ + DriverGpuResourceRequirement, DriverSandboxResourceRequirements, DriverSandboxSpec, + }; let mut sandbox = test_sandbox("test-id", "test-name"); sandbox.spec = Some(DriverSandboxSpec { - gpu: true, - gpu_device: "nvidia.com/gpu=0".to_string(), + resource_requirements: Some(DriverSandboxResourceRequirements { + gpu: Some(DriverGpuResourceRequirement { + device_ids: vec!["nvidia.com/gpu=0".to_string()], + count: None, + }), + }), ..Default::default() }); let config = test_config(); diff --git a/crates/openshell-driver-podman/src/driver.rs b/crates/openshell-driver-podman/src/driver.rs index e2deb1c63..c14fade65 100644 --- a/crates/openshell-driver-podman/src/driver.rs +++ b/crates/openshell-driver-podman/src/driver.rs @@ -10,7 +10,10 @@ use crate::watcher::{ self, WatchStream, driver_sandbox_from_inspect, driver_sandbox_from_list_entry, }; use openshell_core::ComputeDriverError; -use openshell_core::proto::compute::v1::{DriverSandbox, GetCapabilitiesResponse}; +use openshell_core::gpu::driver_gpu_requirement; +use openshell_core::proto::compute::v1::{ + DriverGpuResourceRequirement, DriverSandbox, GetCapabilitiesResponse, +}; use std::path::PathBuf; use std::time::Duration; use tracing::{info, warn}; @@ -280,12 +283,19 @@ impl PodmanComputeDriver { &self, sandbox: &DriverSandbox, ) -> Result<(), ComputeDriverError> { - let gpu_requested = sandbox.spec.as_ref().is_some_and(|s| s.gpu); - Self::validate_gpu_request(gpu_requested) + let gpu = sandbox.spec.as_ref().and_then(driver_gpu_requirement); + Self::validate_gpu_request(gpu) } - fn validate_gpu_request(gpu_requested: bool) -> Result<(), ComputeDriverError> { - if gpu_requested && !Self::has_gpu_capacity() { + fn validate_gpu_request( + gpu: Option<&DriverGpuResourceRequirement>, + ) -> Result<(), ComputeDriverError> { + if gpu.is_some_and(|gpu| gpu.count.is_some()) { + return Err(ComputeDriverError::Precondition( + "podman compute driver does not support GPU count requests".to_string(), + )); + } + if gpu.is_some() && !Self::has_gpu_capacity() { return Err(ComputeDriverError::Precondition( "GPU sandbox requested, but no NVIDIA GPU devices are available.".to_string(), )); @@ -305,6 +315,7 @@ impl PodmanComputeDriver { "sandbox id is required".into(), )); } + self.validate_sandbox_create(sandbox)?; // Validate the composed container name early, before creating any // resources (volume), so we don't leave orphans when the name is @@ -667,6 +678,19 @@ mod tests { assert!(matches!(err, ComputeDriverError::Message(_))); } + #[test] + fn validate_gpu_request_rejects_count() { + let err = PodmanComputeDriver::validate_gpu_request(Some(&DriverGpuResourceRequirement { + device_ids: vec![], + count: Some(2), + })) + .expect_err("GPU count should be rejected"); + + assert!( + matches!(err, ComputeDriverError::Precondition(message) if message.contains("does not support GPU count")) + ); + } + // ── grpc_endpoint auto-detection ─────────────────────────────────── // // PodmanComputeDriver::new() fills grpc_endpoint when it is empty. diff --git a/crates/openshell-driver-vm/README.md b/crates/openshell-driver-vm/README.md index 724bde06c..c5860f9cd 100644 --- a/crates/openshell-driver-vm/README.md +++ b/crates/openshell-driver-vm/README.md @@ -52,8 +52,9 @@ sudo -E env "PATH=$PATH" mise run gateway:vm -- --gpu ``` GPU passthrough uses VFIO and requires host support for IOMMU, root privileges -for bind/unbind operations, and a compatible sandbox image. The public GPU -overview lives in the repository `README.md`. +for bind/unbind operations, and a compatible sandbox image. Sandbox GPU requests +arrive as `resource_requirements.gpu`; the VM driver accepts the default request, +one explicit device ID, or a count of one. Point the CLI at the gateway with one of: diff --git a/crates/openshell-driver-vm/src/driver.rs b/crates/openshell-driver-vm/src/driver.rs index 445905a1e..759895dcd 100644 --- a/crates/openshell-driver-vm/src/driver.rs +++ b/crates/openshell-driver-vm/src/driver.rs @@ -24,16 +24,18 @@ use oci_client::manifest::{ }; use oci_client::secrets::RegistryAuth; use oci_client::{Reference, RegistryOperation}; +use openshell_core::gpu::driver_gpu_requirement; use openshell_core::progress::{ PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX, format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail, }; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, - DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, - DriverSandbox as Sandbox, DriverSandboxStatus as SandboxStatus, GetCapabilitiesRequest, - GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, ListSandboxesRequest, - ListSandboxesResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, + DriverCondition as SandboxCondition, DriverGpuResourceRequirement, + DriverPlatformEvent as PlatformEvent, DriverSandbox as Sandbox, + DriverSandboxStatus as SandboxStatus, GetCapabilitiesRequest, GetCapabilitiesResponse, + GetSandboxRequest, GetSandboxResponse, ListSandboxesRequest, ListSandboxesResponse, + StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent, compute_driver_server::ComputeDriver, watch_sandboxes_event, @@ -615,10 +617,12 @@ impl VmDriver { ))); } - let spec = sandbox.spec.as_ref(); - let is_gpu = spec.is_some_and(|s| s.gpu); - let gpu_device = spec.map_or("", |s| s.gpu_device.as_str()); - let gpu_bdf = if is_gpu { + let gpu_device = sandbox + .spec + .as_ref() + .and_then(driver_gpu_requirement) + .and_then(|gpu| requested_gpu_device(Some(gpu))); + let gpu_bdf = if let Some(gpu_device) = gpu_device { Some(self.assign_gpu_to_record(&sandbox.id, gpu_device).await?) } else { None @@ -2577,15 +2581,7 @@ fn validate_vm_sandbox(sandbox: &Sandbox, gpu_enabled: bool) -> Result<(), Statu .as_ref() .ok_or_else(|| Status::invalid_argument("sandbox spec is required"))?; - if spec.gpu && !gpu_enabled { - return Err(Status::failed_precondition( - "GPU support is not enabled on this driver; start with --gpu", - )); - } - - if !spec.gpu && !spec.gpu_device.is_empty() { - return Err(Status::invalid_argument("gpu_device requires gpu=true")); - } + validate_gpu_request(driver_gpu_requirement(spec), gpu_enabled)?; if let Some(template) = spec.template.as_ref() { if !template.agent_socket_path.is_empty() { @@ -2628,6 +2624,40 @@ fn validate_sandbox_id(sandbox_id: &str) -> Result<(), Status> { Ok(()) } +fn requested_gpu_device(gpu: Option<&DriverGpuResourceRequirement>) -> Option<&str> { + let gpu = gpu?; + Some(gpu.device_ids.first().map_or("", String::as_str)) +} + +#[allow(clippy::result_large_err)] +fn validate_gpu_request( + gpu: Option<&DriverGpuResourceRequirement>, + gpu_enabled: bool, +) -> Result<(), Status> { + let Some(gpu) = gpu else { + return Ok(()); + }; + + if !gpu_enabled { + return Err(Status::failed_precondition( + "GPU support is not enabled on this driver; start with --gpu", + )); + } + + if gpu.count.is_some_and(|count| count > 1) { + return Err(Status::invalid_argument( + "vm compute driver supports at most one GPU", + )); + } + + if gpu.device_ids.len() > 1 { + return Err(Status::invalid_argument( + "vm compute driver supports at most one GPU device ID", + )); + } + Ok(()) +} + #[allow(clippy::result_large_err)] fn parse_registry_reference(image_ref: &str) -> Result { Reference::try_from(image_ref).map_err(|err| { @@ -4412,6 +4442,7 @@ mod tests { PROGRESS_COMPLETE_STEP_KEY, }; use openshell_core::proto::compute::v1::{ + DriverGpuResourceRequirement, DriverSandboxResourceRequirements, DriverSandboxSpec as SandboxSpec, DriverSandboxTemplate as SandboxTemplate, }; use prost_types::{Struct, Value, value::Kind}; @@ -4424,6 +4455,15 @@ mod tests { static ENV_LOCK: std::sync::LazyLock> = std::sync::LazyLock::new(|| std::sync::Mutex::new(())); + fn gpu_resource_requirements( + device_ids: Vec, + count: Option, + ) -> DriverSandboxResourceRequirements { + DriverSandboxResourceRequirements { + gpu: Some(DriverGpuResourceRequirement { device_ids, count }), + } + } + #[test] fn vm_pulling_layer_event_adds_progress_detail_metadata() { let mut event = platform_event( @@ -4491,7 +4531,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resource_requirements(vec![], None)), ..Default::default() }), ..Default::default() @@ -4507,7 +4547,7 @@ mod tests { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(gpu_resource_requirements(vec![], None)), ..Default::default() }), ..Default::default() @@ -4516,20 +4556,76 @@ mod tests { } #[test] - fn validate_vm_sandbox_rejects_gpu_device_without_gpu() { + fn validate_vm_sandbox_accepts_gpu_count_one_when_enabled() { let sandbox = Sandbox { id: "sandbox-123".to_string(), spec: Some(SandboxSpec { - gpu: false, - gpu_device: "0000:2d:00.0".to_string(), + resource_requirements: Some(gpu_resource_requirements(vec![], Some(1))), + ..Default::default() + }), + ..Default::default() + }; + validate_vm_sandbox(&sandbox, true).expect("gpu count one should be accepted"); + } + + #[test] + fn validate_vm_sandbox_rejects_gpu_count_greater_than_one() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resource_requirements(vec![], Some(2))), + ..Default::default() + }), + ..Default::default() + }; + let err = + validate_vm_sandbox(&sandbox, true).expect_err("gpu count > 1 should be rejected"); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("at most one GPU")); + } + + #[test] + fn validate_vm_sandbox_rejects_multiple_gpu_device_ids() { + let sandbox = Sandbox { + id: "sandbox-123".to_string(), + spec: Some(SandboxSpec { + resource_requirements: Some(gpu_resource_requirements( + vec!["0000:2d:00.0".to_string(), "0000:3d:00.0".to_string()], + None, + )), ..Default::default() }), ..Default::default() }; let err = validate_vm_sandbox(&sandbox, true) - .expect_err("gpu_device without gpu should be rejected"); + .expect_err("multiple GPU device IDs should be rejected"); assert_eq!(err.code(), Code::InvalidArgument); - assert!(err.message().contains("gpu_device requires gpu=true")); + assert!(err.message().contains("at most one GPU device ID")); + } + + #[test] + fn requested_gpu_device_returns_none_without_gpu_request() { + assert_eq!(requested_gpu_device(None), None); + } + + #[test] + fn requested_gpu_device_defaults_empty_request_to_inventory_choice() { + let gpu = DriverGpuResourceRequirement { + device_ids: vec![], + count: None, + }; + + assert_eq!(requested_gpu_device(Some(&gpu)), Some("")); + } + + #[test] + fn requested_gpu_device_returns_first_explicit_device_id() { + let gpu = DriverGpuResourceRequirement { + device_ids: vec!["0000:2d:00.0".to_string()], + count: None, + }; + + assert_eq!(requested_gpu_device(Some(&gpu)), Some("0000:2d:00.0")); } #[test] diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 0122f9178..666a7174a 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -17,8 +17,9 @@ use crate::tracing_bus::TracingLogBus; use futures::{Stream, StreamExt}; use openshell_core::ComputeDriverKind; use openshell_core::proto::compute::v1::{ - CreateSandboxRequest, DeleteSandboxRequest, DriverCondition, DriverPlatformEvent, - DriverResourceRequirements, DriverSandbox, DriverSandboxSpec, DriverSandboxStatus, + CreateSandboxRequest, DeleteSandboxRequest, DriverCondition, DriverGpuResourceRequirement, + DriverPlatformEvent, DriverResourceRequirements, DriverSandbox, + DriverSandboxResourceRequirements, DriverSandboxSpec, DriverSandboxStatus, DriverSandboxTemplate, GetCapabilitiesRequest, GetSandboxRequest, ListSandboxesRequest, ValidateSandboxCreateRequest, WatchSandboxesEvent, WatchSandboxesRequest, compute_driver_client::ComputeDriverClient, compute_driver_server::ComputeDriver, @@ -1267,12 +1268,28 @@ fn driver_sandbox_spec_from_public(spec: &SandboxSpec) -> DriverSandboxSpec { .template .as_ref() .map(driver_sandbox_template_from_public), - gpu: spec.gpu, - gpu_device: spec.gpu_device.clone(), + resource_requirements: spec + .resource_requirements + .as_ref() + .map(driver_resource_requirements_from_public), sandbox_token: String::new(), } } +fn driver_resource_requirements_from_public( + requirements: &openshell_core::proto::SandboxResourceRequirements, +) -> DriverSandboxResourceRequirements { + DriverSandboxResourceRequirements { + gpu: requirements + .gpu + .as_ref() + .map(|gpu| DriverGpuResourceRequirement { + device_ids: gpu.device_ids.clone(), + count: gpu.count, + }), + } +} + fn driver_sandbox_template_from_public(template: &SandboxTemplate) -> DriverSandboxTemplate { DriverSandboxTemplate { image: template.image.clone(), @@ -1623,7 +1640,12 @@ fn derive_phase(status: Option<&DriverSandboxStatus>) -> SandboxPhase { } fn rewrite_user_facing_conditions(status: &mut Option, spec: Option<&SandboxSpec>) { - let gpu_requested = spec.is_some_and(|sandbox_spec| sandbox_spec.gpu); + let gpu_requested = spec.is_some_and(|sandbox_spec| { + sandbox_spec + .resource_requirements + .as_ref() + .is_some_and(|requirements| requirements.gpu.is_some()) + }); if !gpu_requested { return; } @@ -1785,6 +1807,7 @@ mod tests { CreateSandboxResponse, DeleteSandboxResponse, GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateResponse, }; + use openshell_core::proto::{GpuResourceRequirement, SandboxResourceRequirements}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; @@ -1801,6 +1824,56 @@ mod tests { } } + #[test] + fn driver_sandbox_spec_from_public_preserves_gpu_request_device_ids() { + let public = SandboxSpec { + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: vec!["nvidia.com/gpu=0".to_string()], + count: None, + }), + }), + ..Default::default() + }; + + let driver = driver_sandbox_spec_from_public(&public); + + assert_eq!( + driver + .resource_requirements + .expect("driver resource requirements should be present") + .gpu + .expect("driver GPU requirement should be present") + .device_ids, + vec!["nvidia.com/gpu=0".to_string()] + ); + } + + #[test] + fn driver_sandbox_spec_from_public_preserves_gpu_count() { + let public = SandboxSpec { + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: vec![], + count: Some(2), + }), + }), + ..Default::default() + }; + + let driver = driver_sandbox_spec_from_public(&public); + + assert_eq!( + driver + .resource_requirements + .expect("driver resource requirements should be present") + .gpu + .expect("driver GPU requirement should be present") + .count, + Some(2) + ); + } + fn struct_value( fields: impl IntoIterator, prost_types::Value)>, ) -> prost_types::Value { @@ -2258,7 +2331,12 @@ mod tests { rewrite_user_facing_conditions( &mut status, Some(&SandboxSpec { - gpu: true, + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: vec![], + count: None, + }), + }), ..Default::default() }), ); @@ -2286,13 +2364,7 @@ mod tests { ..Default::default() }); - rewrite_user_facing_conditions( - &mut status, - Some(&SandboxSpec { - gpu: false, - ..Default::default() - }), - ); + rewrite_user_facing_conditions(&mut status, Some(&SandboxSpec::default())); assert_eq!(status.unwrap().conditions[0].message, original); } @@ -2571,7 +2643,12 @@ mod tests { let sandbox = Sandbox { spec: Some(SandboxSpec { - gpu: true, + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: vec![], + count: None, + }), + }), ..Default::default() }), ..sandbox_record("sb-1", "sandbox-a", SandboxPhase::Provisioning) @@ -2594,7 +2671,11 @@ mod tests { SandboxPhase::try_from(stored.phase()).unwrap(), SandboxPhase::Ready ); - assert!(stored.spec.as_ref().is_some_and(|spec| spec.gpu)); + assert!(stored.spec.as_ref().is_some_and(|spec| { + spec.resource_requirements + .as_ref() + .is_some_and(|requirements| requirements.gpu.is_some()) + })); } #[tokio::test] diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index e60ce3995..3e8ed0e1e 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -100,7 +100,9 @@ fn emit_sandbox_create_telemetry( }; openshell_core::telemetry::emit_sandbox_create( outcome, - spec.gpu, + spec.resource_requirements + .as_ref() + .is_some_and(|requirements| requirements.gpu.is_some()), spec.providers.len() as u64, spec.policy.is_some(), template_source, diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 53f292053..0f3b3fd7c 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -131,6 +131,11 @@ pub(super) fn validate_sandbox_spec( validate_sandbox_template(tmpl)?; } + // --- spec.resource_requirements --- + if let Some(ref requirements) = spec.resource_requirements { + validate_resource_requirements(requirements)?; + } + // --- spec.policy serialized size --- if let Some(ref policy) = spec.policy { let size = policy.encoded_len(); @@ -144,6 +149,31 @@ pub(super) fn validate_sandbox_spec( Ok(()) } +fn validate_resource_requirements( + requirements: &openshell_core::proto::SandboxResourceRequirements, +) -> Result<(), Status> { + if let Some(ref gpu) = requirements.gpu { + validate_gpu_requirement(gpu)?; + } + Ok(()) +} + +fn validate_gpu_requirement( + gpu: &openshell_core::proto::GpuResourceRequirement, +) -> Result<(), Status> { + if gpu.count.is_some() && !gpu.device_ids.is_empty() { + return Err(Status::invalid_argument( + "resource_requirements.gpu.count is mutually exclusive with resource_requirements.gpu.device_ids", + )); + } + if gpu.count == Some(0) { + return Err(Status::invalid_argument( + "resource_requirements.gpu.count must be greater than 0", + )); + } + Ok(()) +} + /// Validate template-level field sizes. fn validate_sandbox_template(tmpl: &SandboxTemplate) -> Result<(), Status> { // String fields. @@ -661,7 +691,7 @@ pub(super) fn level_matches(log_level: &str, min_level: &str) -> bool { #[cfg(test)] mod tests { use super::*; - use openshell_core::proto::SandboxSpec; + use openshell_core::proto::{GpuResourceRequirement, SandboxResourceRequirements, SandboxSpec}; use std::collections::HashMap; use tonic::Code; @@ -687,12 +717,67 @@ mod tests { #[test] fn validate_sandbox_spec_accepts_gpu_flag() { let spec = SandboxSpec { - gpu: true, + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: vec![], + count: None, + }), + }), ..Default::default() }; assert!(validate_sandbox_spec("gpu-sandbox", &spec).is_ok()); } + #[test] + fn validate_sandbox_spec_accepts_gpu_count() { + let spec = SandboxSpec { + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: vec![], + count: Some(2), + }), + }), + ..Default::default() + }; + assert!(validate_sandbox_spec("gpu-count-sandbox", &spec).is_ok()); + } + + #[test] + fn validate_sandbox_spec_rejects_zero_gpu_count() { + let spec = SandboxSpec { + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: vec![], + count: Some(0), + }), + }), + ..Default::default() + }; + + let err = validate_sandbox_spec("gpu-count-sandbox", &spec).unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("count must be greater than 0")); + } + + #[test] + fn validate_sandbox_spec_rejects_gpu_count_with_device_id() { + let spec = SandboxSpec { + resource_requirements: Some(SandboxResourceRequirements { + gpu: Some(GpuResourceRequirement { + device_ids: vec!["nvidia.com/gpu=0".to_string()], + count: Some(1), + }), + }), + ..Default::default() + }; + + let err = validate_sandbox_spec("gpu-count-sandbox", &spec).unwrap_err(); + + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("mutually exclusive")); + } + #[test] fn validate_sandbox_spec_accepts_empty_defaults() { assert!(validate_sandbox_spec("", &default_spec()).is_ok()); diff --git a/docs/sandboxes/manage-sandboxes.mdx b/docs/sandboxes/manage-sandboxes.mdx index 512abfd3d..8ec69c709 100644 --- a/docs/sandboxes/manage-sandboxes.mdx +++ b/docs/sandboxes/manage-sandboxes.mdx @@ -51,9 +51,32 @@ To request GPU resources, add `--gpu`: openshell sandbox create --gpu -- claude ``` +Request a specific number of GPUs with `--gpu-count`: + +```shell +openshell sandbox create --gpu-count 2 -- claude +``` + +Request a specific GPU device with `--gpu-device`; this also requests GPU +resources: + +```shell +openshell sandbox create --gpu-device nvidia.com/gpu=0 -- claude +``` + +Support for `--gpu-count` and `--gpu-device` is driver-dependent: + +| Driver | `--gpu` default request | `--gpu-count` | `--gpu-device` | +| ---------- | ----------------------------- | --------------------- | ------------------------------------------ | +| Docker | All CDI GPU devices | Not supported | One CDI device ID, such as `nvidia.com/gpu=0` | +| Podman | All CDI GPU devices | Not supported | One CDI device ID, such as `nvidia.com/gpu=0` | +| Kubernetes | One `nvidia.com/gpu` resource | Supported | Not supported | +| VM | One GPU | Only `1` is supported | One PCI BDF or GPU index | + For Docker-backed sandboxes, GPU injection uses Docker CDI. If you enable Docker CDI after the gateway starts, restart the gateway so OpenShell can detect the -updated Docker daemon capability. +updated Docker daemon capability. In the API, these flags populate +`SandboxSpec.resource_requirements.gpu`. ### Custom Containers diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index 610d491c7..79bff06e2 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -83,12 +83,8 @@ message DriverSandboxSpec { map environment = 5; // Runtime template consumed by the driver during provisioning. DriverSandboxTemplate template = 6; - // Request NVIDIA GPU resources for this sandbox. - bool gpu = 9; - // Optional PCI BDF address (e.g. "0000:2d:00.0") or device index - // (e.g. "0", "1"). When empty with gpu=true, the driver assigns the - // first available GPU. - string gpu_device = 10; + // Portable resource requirements for this sandbox. + DriverSandboxResourceRequirements resource_requirements = 9; // Gateway-minted JWT identifying this sandbox to the gateway. Set by // the gateway on create; the driver materialises it via its native // secret mechanism (Docker/Podman/VM bind-mount a per-sandbox file; @@ -98,6 +94,22 @@ message DriverSandboxSpec { string sandbox_token = 11; } +// Driver-owned resource requirements for the sandbox workload. +message DriverSandboxResourceRequirements { + // GPU requirement for the sandbox. Presence indicates a GPU request. + DriverGpuResourceRequirement gpu = 1; +} + +// Driver-owned GPU resource requirement. Device identifiers are interpreted by +// the selected compute driver and are an interim compatibility surface. +message DriverGpuResourceRequirement { + // Optional number of GPUs requested. Mutually exclusive with device_ids. + optional uint32 count = 1; + // Optional driver-native device identifiers. Mutually exclusive with count. + // Empty means the driver chooses its default GPU assignment behavior. + repeated string device_ids = 2; +} + // Driver-owned runtime template consumed by the compute platform. // // This message describes the sandbox workload in backend-neutral terms. diff --git a/proto/openshell.proto b/proto/openshell.proto index a8ead0d31..6be740409 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -317,18 +317,24 @@ message SandboxSpec { openshell.sandbox.v1.SandboxPolicy policy = 7; // Provider names to attach to this sandbox. repeated string providers = 8; - // Request NVIDIA GPU resources for this sandbox. - bool gpu = 9; - // Optional PCI BDF address (e.g. "0000:2d:00.0") or device index - // (e.g. "0", "1"). When empty with gpu=true, the driver assigns the - // first available GPU. - string gpu_device = 10; - // Field 11 was `proposal_approval_mode`. The approval mode is now a - // runtime setting (gateway or sandbox scope) read via UpdateConfig / - // GetSandboxConfig, so it can be flipped on a running sandbox and - // managed fleet-wide. - reserved 11; - reserved "proposal_approval_mode"; + // Portable resource requirements for this sandbox. + SandboxResourceRequirements resource_requirements = 9; +} + +// Public resource requirements for the sandbox workload. +message SandboxResourceRequirements { + // GPU requirement for the sandbox. Presence indicates a GPU request. + GpuResourceRequirement gpu = 1; +} + +// Public GPU resource requirement. Device identifiers are interpreted by the +// selected compute driver and are an interim compatibility surface. +message GpuResourceRequirement { + // Optional number of GPUs requested. Mutually exclusive with device_ids. + optional uint32 count = 1; + // Optional driver-native device identifiers. Mutually exclusive with count. + // Empty means the driver chooses its default GPU assignment behavior. + repeated string device_ids = 2; } // Public sandbox template mapped onto compute-driver template inputs.