diff --git a/crates/socket-patch-cli/src/commands/get.rs b/crates/socket-patch-cli/src/commands/get.rs index 46133442..940f98b2 100644 --- a/crates/socket-patch-cli/src/commands/get.rs +++ b/crates/socket-patch-cli/src/commands/get.rs @@ -6,20 +6,21 @@ use socket_patch_core::api::client::{ use socket_patch_core::api::types::{ PatchResponse, PatchSearchResult, SearchResponse, VulnerabilityResponse, }; -use socket_patch_core::crawlers::CrawlerOptions; +use socket_patch_core::crawlers::{CrawlerOptions, Ecosystem}; use socket_patch_core::manifest::operations::{read_manifest, write_manifest}; use socket_patch_core::manifest::schema::{ PatchFileInfo, PatchManifest, PatchRecord, VulnerabilityInfo, }; +use socket_patch_core::patch::apply::select_installed_variant; use socket_patch_core::utils::fuzzy_match::fuzzy_match_packages; -use socket_patch_core::utils::purl::is_purl; +use socket_patch_core::utils::purl::{is_purl, strip_purl_qualifiers}; use socket_patch_core::utils::telemetry::{track_patch_fetch_failed, track_patch_fetched}; use std::collections::HashMap; use std::fmt; use std::path::{Path, PathBuf}; use crate::args::{apply_env_toggles, GlobalArgs}; -use crate::ecosystem_dispatch::crawl_all_ecosystems; +use crate::ecosystem_dispatch::{crawl_all_ecosystems, find_packages_for_rollback, partition_purls}; use crate::output::{confirm, select_one, SelectError}; /// Best-effort ecosystem extractor for a `pkg:/...` PURL. Used as @@ -327,6 +328,19 @@ pub struct GetArgs { /// Apply patch immediately without saving to .socket folder. #[arg(long = "one-off", env = "SOCKET_ONE_OFF", default_value_t = false)] pub one_off: bool, + + /// Download patches for every release/distribution (artifact_id) of + /// a matched package, not just the one matching the locally- + /// installed distribution. Only affects PyPI today — the only + /// ecosystem with per-release artifact_id variants. Off by default: + /// only the patch for the installed dist is fetched. + #[arg( + long = "all-releases", + env = "SOCKET_ALL_RELEASES", + default_value_t = false, + value_parser = clap::builder::BoolishValueParser::new(), + )] + pub all_releases: bool, } #[derive(Debug, Clone, Copy, PartialEq)] @@ -508,11 +522,170 @@ pub struct DownloadParams { /// client constructed here. Without this, `download_and_apply_patches` /// would only honor env vars and ignore the user's flags. pub api_overrides: socket_patch_core::api::client::ApiClientEnvOverrides, + /// When `false` (the default — narrow), a PyPI package with multiple + /// release variants (`?artifact_id=...`) is filtered down to the one + /// matching the locally-installed distribution before download. When + /// `true` (`--all-releases`), every variant is downloaded. No effect + /// on ecosystems without per-release artifact_id variants. + pub all_releases: bool, } /// Download and apply a set of selected patches. /// /// Used by both `get` and `scan` commands. Returns (exit_code, json_result). +/// Narrow a selection of patches down to the release variant matching +/// each locally-installed distribution. +/// +/// A PyPI `package@version` can resolve to several patch variants — one +/// per `?artifact_id=...` release (wheel/sdist). Only one distribution +/// is ever installed in a given environment, so only one variant can +/// apply. With `--all-releases` off (the default) we keep just the +/// variant whose first patched file's hash matches the on-disk package, +/// dropping the rest so they are never downloaded or written to the +/// manifest. Non-PyPI ecosystems never carry `artifact_id` qualifiers, +/// so they pass through untouched. +/// +/// Fallbacks (keep all variants of the base, i.e. behave as broad): +/// * the base package is not installed on disk (nothing to match +/// against — e.g. `get` for an absent package), or +/// * the installed distribution matches none of the variants (a local +/// modification, or no patch exists for the installed release). +/// +/// Both fallbacks push a human-readable warning. +/// +/// Returns the kept patches plus any warnings to surface to the caller. +async fn filter_to_installed_releases( + selected: &[PatchSearchResult], + params: &DownloadParams, + api_client: &socket_patch_core::api::client::ApiClient, + org: Option<&str>, +) -> (Vec, Vec) { + // Group the PyPI selections by their base PURL (qualifiers stripped). + // Anything that isn't PyPI, or whose base has a single variant, is + // kept verbatim and needs no installed-dist resolution. + let mut pypi_groups: HashMap> = HashMap::new(); + let mut kept: Vec = Vec::new(); + for sr in selected { + if Ecosystem::from_purl(&sr.purl) == Some(Ecosystem::Pypi) { + pypi_groups + .entry(strip_purl_qualifiers(&sr.purl).to_string()) + .or_default() + .push(sr.clone()); + } else { + kept.push(sr.clone()); + } + } + + let mut warnings: Vec = Vec::new(); + + // Singleton PyPI bases have nothing to disambiguate — keep as-is. + // Collect the multi-variant bases that actually need resolution. + let mut multi: Vec<(String, Vec)> = Vec::new(); + for (base, variants) in pypi_groups { + if variants.len() <= 1 { + kept.extend(variants); + } else { + multi.push((base, variants)); + } + } + + if multi.is_empty() { + return (kept, warnings); + } + + // Discover the on-disk path for each multi-variant base. The pypi + // crawler is queried with base PURLs and the result is fanned back + // out to every qualified variant (all variants of one installed + // package resolve to the same path). + let all_qualified: Vec = multi + .iter() + .flat_map(|(_, variants)| variants.iter().map(|s| s.purl.clone())) + .collect(); + // All collected PURLs are PyPI; no ecosystem filter needed. + let partitioned = partition_purls(&all_qualified, None); + let crawler_options = CrawlerOptions { + cwd: params.cwd.clone(), + global: params.global, + global_prefix: params.global_prefix.clone(), + batch_size: 100, + }; + let paths = find_packages_for_rollback(&partitioned, &crawler_options, true).await; + + for (base, variants) in multi { + // Any variant's resolved path works — they all map to the single + // installed distribution. + let pkg_path = variants.iter().find_map(|s| paths.get(&s.purl)).cloned(); + let Some(pkg_path) = pkg_path else { + // Not installed: cannot determine the relevant release. Keep + // every variant so the patch is still obtainable. + warnings.push(format!( + "{base} is not installed locally; keeping all {} release variant(s).", + variants.len() + )); + kept.extend(variants); + continue; + }; + + // Fetch each variant's file hashes (the view carries them) so we + // can hash-match against the installed distribution. + let mut candidates: Vec<(String, HashMap)> = Vec::new(); + for s in &variants { + match api_client.fetch_patch(org, &s.uuid).await { + Ok(Some(patch)) => { + candidates.push((s.purl.clone(), files_for_selection(&patch))); + } + // On a fetch error/miss, keep the variant so the main + // download loop can record the failure as it would today. + _ => candidates.push((s.purl.clone(), HashMap::new())), + } + } + + let refs: Vec<(&str, &HashMap)> = candidates + .iter() + .map(|(purl, files)| (purl.as_str(), files)) + .collect(); + + match select_installed_variant(&pkg_path, &refs).await { + Some(idx) => { + let winner = candidates[idx].0.clone(); + kept.extend(variants.into_iter().filter(|s| s.purl == winner)); + } + None => { + // Installed, but no variant matches the on-disk bytes. + // Fall back to broad rather than silently dropping a + // package the user asked about. + warnings.push(format!( + "No release variant of {base} matches the installed distribution; keeping all {} variant(s).", + variants.len() + )); + kept.extend(variants); + } + } + } + + (kept, warnings) +} + +/// Build the before/after-hash map used for installed-distribution +/// matching. Mirrors the download flow's requirement that a patchable +/// file carry both hashes (new files, with an empty `beforeHash`, are +/// still kept so first-file verification can treat them as Ready). +fn files_for_selection(patch: &PatchResponse) -> HashMap { + let mut files = HashMap::new(); + for (file_path, file_info) in &patch.files { + if let (Some(before), Some(after)) = (&file_info.before_hash, &file_info.after_hash) { + files.insert( + file_path.clone(), + PatchFileInfo { + before_hash: before.clone(), + after_hash: after.clone(), + }, + ); + } + } + files +} + pub async fn download_and_apply_patches( selected: &[PatchSearchResult], params: &DownloadParams, @@ -545,6 +718,26 @@ pub async fn download_and_apply_patches( _ => PatchManifest::new(), }; + // Narrow PyPI multi-release selections to the installed distribution + // unless --all-releases was passed. `filter_to_installed_releases` + // is a no-op for non-PyPI ecosystems and single-variant packages. + let mut narrow_warnings: Vec = Vec::new(); + let selected_owned: Vec; + let selected: &[PatchSearchResult] = if params.all_releases { + selected + } else { + let (kept, warns) = + filter_to_installed_releases(selected, params, &api_client, effective_org).await; + if !params.json && !params.silent { + for w in &warns { + eprintln!(" [note] {w}"); + } + } + narrow_warnings = warns; + selected_owned = kept; + &selected_owned + }; + if !params.json && !params.silent { eprintln!("\nDownloading {} patch(es)...", selected.len()); } @@ -735,7 +928,7 @@ pub async fn download_and_apply_patches( } } - let result_json = serde_json::json!({ + let mut result_json = serde_json::json!({ "status": if patches_failed > 0 { "partial_failure" } else { "success" }, "found": selected.len(), "downloaded": patches_added, @@ -745,6 +938,12 @@ pub async fn download_and_apply_patches( "updated": updates.len(), "patches": downloaded_patches, }); + // Surface release-narrowing fallbacks (uninstalled package / no + // matching variant) so JSON consumers can see why all variants were + // kept. Omitted entirely when narrowing was clean. + if !narrow_warnings.is_empty() { + result_json["warnings"] = serde_json::json!(narrow_warnings); + } let exit_code = if patches_failed > 0 || (!apply_succeeded && patches_added > 0 && !params.save_only) { 1 } else { 0 }; (exit_code, result_json) @@ -1127,6 +1326,7 @@ pub async fn run(args: GetArgs) -> i32 { silent: false, download_mode: args.common.download_mode.clone(), api_overrides: args.common.api_client_overrides(), + all_releases: args.all_releases, }; let (code, result_json) = download_and_apply_patches(&selected, ¶ms).await; diff --git a/crates/socket-patch-cli/src/commands/remove.rs b/crates/socket-patch-cli/src/commands/remove.rs index 9157e521..9984518c 100644 --- a/crates/socket-patch-cli/src/commands/remove.rs +++ b/crates/socket-patch-cli/src/commands/remove.rs @@ -3,6 +3,7 @@ use socket_patch_core::api::client::get_api_client_with_overrides; use socket_patch_core::manifest::operations::{read_manifest, write_manifest}; use socket_patch_core::manifest::schema::PatchManifest; use socket_patch_core::utils::cleanup_blobs::{cleanup_unused_blobs, format_cleanup_result}; +use socket_patch_core::utils::purl::purl_matches_identifier; use socket_patch_core::utils::telemetry::{track_patch_removed, track_patch_remove_failed}; use std::path::Path; use std::time::Duration; @@ -92,13 +93,15 @@ pub async fn run(args: RemoveArgs) -> i32 { } }; - // Find matching patches to show what will be removed + // Find matching patches to show what will be removed. A base PURL + // (no `?`) matches every release variant of that package@version; a + // qualified PURL or a UUID targets a single patch. let matching: Vec<(&String, &socket_patch_core::manifest::schema::PatchRecord)> = if args.identifier.starts_with("pkg:") { manifest .patches .iter() - .filter(|(purl, _)| *purl == &args.identifier) + .filter(|(purl, _)| purl_matches_identifier(purl, &args.identifier)) .collect() } else { manifest @@ -125,9 +128,23 @@ pub async fn run(args: RemoveArgs) -> i32 { return 1; } - // Show what will be removed and confirm + // Show what will be removed and confirm. When a base PURL expanded + // to multiple manifest entries (PyPI release variants), make the + // blast radius explicit so the user understands why a single + // `remove pkg:pypi/foo@1.0` is removing several variants. if !args.common.json { - eprintln!("The following patch(es) will be removed:"); + if args.identifier.starts_with("pkg:") + && !args.identifier.contains('?') + && matching.len() > 1 + { + eprintln!( + "{} matches {} release variant(s) — all will be removed:", + args.identifier, + matching.len() + ); + } else { + eprintln!("The following patch(es) will be removed:"); + } for (purl, patch) in &matching { let file_count = patch.files.len(); eprintln!(" - {} (UUID: {}, {} file(s))", purl, &patch.uuid[..8], file_count); @@ -303,22 +320,26 @@ async fn remove_patch_from_manifest( let mut removed = Vec::new(); - if identifier.starts_with("pkg:") { - if manifest.patches.remove(identifier).is_some() { - removed.push(identifier.to_string()); - } + let purls_to_remove: Vec = if identifier.starts_with("pkg:") { + // Base PURL removes every release variant; qualified PURL removes one. + manifest + .patches + .keys() + .filter(|purl| purl_matches_identifier(purl, identifier)) + .cloned() + .collect() } else { - let purls_to_remove: Vec = manifest + manifest .patches .iter() .filter(|(_, patch)| patch.uuid == identifier) .map(|(purl, _)| purl.clone()) - .collect(); + .collect() + }; - for purl in purls_to_remove { - manifest.patches.remove(&purl); - removed.push(purl); - } + for purl in purls_to_remove { + manifest.patches.remove(&purl); + removed.push(purl); } if !removed.is_empty() { @@ -329,3 +350,100 @@ async fn remove_patch_from_manifest( Ok((removed, manifest)) } + +#[cfg(test)] +mod tests { + use super::*; + use socket_patch_core::manifest::schema::PatchRecord; + use std::collections::HashMap; + + fn make_record(uuid: &str) -> PatchRecord { + PatchRecord { + uuid: uuid.to_string(), + exported_at: "2024-01-01T00:00:00Z".to_string(), + files: HashMap::new(), + vulnerabilities: HashMap::new(), + description: "test".to_string(), + license: "MIT".to_string(), + tier: "free".to_string(), + } + } + + /// Write a manifest with three PyPI release variants of one + /// package@version plus an unrelated npm package, returning the + /// temp dir (kept alive) and the manifest path. + async fn write_multi_variant(dir: &Path) { + let mut patches = HashMap::new(); + patches.insert( + "pkg:pypi/six@1.16.0?artifact_id=wheel-cp311".to_string(), + make_record("uuid-cp311"), + ); + patches.insert( + "pkg:pypi/six@1.16.0?artifact_id=sdist".to_string(), + make_record("uuid-sdist"), + ); + patches.insert( + "pkg:pypi/six@1.16.0?artifact_id=wheel-cp312".to_string(), + make_record("uuid-cp312"), + ); + patches.insert("pkg:npm/foo@1.0".to_string(), make_record("uuid-foo")); + let manifest = PatchManifest { patches }; + write_manifest(&dir.join("manifest.json"), &manifest) + .await + .expect("write manifest"); + } + + #[tokio::test] + async fn remove_base_purl_removes_all_variants() { + let tmp = tempfile::tempdir().expect("tempdir"); + write_multi_variant(tmp.path()).await; + let manifest_path = tmp.path().join("manifest.json"); + + let (removed, manifest) = + remove_patch_from_manifest("pkg:pypi/six@1.16.0", &manifest_path) + .await + .expect("remove ok"); + + // All three release variants removed; the npm package untouched. + assert_eq!(removed.len(), 3); + assert!(removed.iter().all(|p| p.contains("six@1.16.0"))); + assert_eq!(manifest.patches.len(), 1); + assert!(manifest.patches.contains_key("pkg:npm/foo@1.0")); + } + + #[tokio::test] + async fn remove_qualified_purl_removes_single_variant() { + let tmp = tempfile::tempdir().expect("tempdir"); + write_multi_variant(tmp.path()).await; + let manifest_path = tmp.path().join("manifest.json"); + + let (removed, manifest) = remove_patch_from_manifest( + "pkg:pypi/six@1.16.0?artifact_id=sdist", + &manifest_path, + ) + .await + .expect("remove ok"); + + // Only the sdist variant removed; the two wheels + npm remain. + assert_eq!(removed, vec!["pkg:pypi/six@1.16.0?artifact_id=sdist"]); + assert_eq!(manifest.patches.len(), 3); + assert!(!manifest + .patches + .contains_key("pkg:pypi/six@1.16.0?artifact_id=sdist")); + } + + #[tokio::test] + async fn remove_by_uuid_removes_single_variant() { + let tmp = tempfile::tempdir().expect("tempdir"); + write_multi_variant(tmp.path()).await; + let manifest_path = tmp.path().join("manifest.json"); + + let (removed, manifest) = + remove_patch_from_manifest("uuid-cp312", &manifest_path) + .await + .expect("remove ok"); + + assert_eq!(removed, vec!["pkg:pypi/six@1.16.0?artifact_id=wheel-cp312"]); + assert_eq!(manifest.patches.len(), 3); + } +} diff --git a/crates/socket-patch-cli/src/commands/rollback.rs b/crates/socket-patch-cli/src/commands/rollback.rs index e821d8d7..1937b916 100644 --- a/crates/socket-patch-cli/src/commands/rollback.rs +++ b/crates/socket-patch-cli/src/commands/rollback.rs @@ -5,10 +5,12 @@ use socket_patch_core::api::blob_fetcher::{ use socket_patch_core::api::client::get_api_client_with_overrides; use socket_patch_core::crawlers::CrawlerOptions; use socket_patch_core::manifest::operations::read_manifest; -use socket_patch_core::manifest::schema::{PatchManifest, PatchRecord}; +use socket_patch_core::manifest::schema::{PatchFileInfo, PatchManifest, PatchRecord}; +use socket_patch_core::patch::apply::select_installed_variant; use socket_patch_core::patch::rollback::{rollback_package_patch, RollbackResult, VerifyRollbackStatus}; +use socket_patch_core::utils::purl::{purl_matches_identifier, strip_purl_qualifiers}; use socket_patch_core::utils::telemetry::{track_patch_rolled_back, track_patch_rollback_failed}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::path::{Path, PathBuf}; use std::time::Duration; @@ -51,11 +53,15 @@ fn find_patches_to_rollback( Some(id) => { let mut patches = Vec::new(); if id.starts_with("pkg:") { - if let Some(patch) = manifest.patches.get(id) { - patches.push(PatchToRollback { - purl: id.to_string(), - patch: patch.clone(), - }); + // A base PURL (no `?`) matches every release variant of + // that package@version; a qualified PURL targets one. + for (purl, patch) in &manifest.patches { + if purl_matches_identifier(purl, id) { + patches.push(PatchToRollback { + purl: purl.clone(), + patch: patch.clone(), + }); + } } } else { for (purl, patch) in &manifest.patches { @@ -445,36 +451,81 @@ async fn rollback_patches_inner( return Ok((true, Vec::new())); } + // Group discovered packages by base PURL. A PyPI package@version may + // have several release variants (`?artifact_id=...`) in the manifest; + // `merge_pypi_qualified` resolves them all to the single on-disk + // distribution. Rolling back every variant against that one file + // would HashMismatch on the non-installed variants and report + // spurious failures, so — mirroring apply — we collapse each group to + // the single variant whose hashes match the installed bytes. + let mut groups: HashMap> = HashMap::new(); + for (purl, pkg_path) in &all_packages { + groups + .entry(strip_purl_qualifiers(purl).to_string()) + .or_default() + .push((purl, pkg_path)); + } + // Rollback patches let mut results: Vec = Vec::new(); let mut has_errors = false; - for (purl, pkg_path) in &all_packages { - let patch = match filtered_manifest.patches.get(purl) { - Some(p) => p, - None => continue, + for (_base, entries) in groups { + // Resolve which variant(s) to roll back for this base PURL. + let to_rollback: Vec<(&String, &PathBuf)> = if entries.len() == 1 { + entries + } else { + // All variants in a group resolve to the same installed path. + let pkg_path = entries[0].1; + let candidates: Vec<(&str, &HashMap)> = entries + .iter() + .filter_map(|(purl, _)| { + filtered_manifest + .patches + .get(*purl) + .map(|p| (purl.as_str(), &p.files)) + }) + .collect(); + match select_installed_variant(pkg_path, &candidates).await { + Some(idx) => { + let winner = candidates[idx].0.to_string(); + entries.into_iter().filter(|(p, _)| **p == winner).collect() + } + // No variant matches the installed distribution (e.g. a + // locally-modified file). Fall back to attempting every + // variant so the per-file verification surfaces the + // mismatch rather than silently skipping the package. + None => entries, + } }; - let result = rollback_package_patch( - purl, - pkg_path, - &patch.files, - &blobs_path, - args.common.dry_run, - ) - .await; - - if !result.success { - has_errors = true; - if !args.common.silent && !args.common.json { - eprintln!( - "Failed to rollback {}: {}", - purl, - result.error.as_deref().unwrap_or("unknown error") - ); + for (purl, pkg_path) in to_rollback { + let patch = match filtered_manifest.patches.get(purl) { + Some(p) => p, + None => continue, + }; + + let result = rollback_package_patch( + purl, + pkg_path, + &patch.files, + &blobs_path, + args.common.dry_run, + ) + .await; + + if !result.success { + has_errors = true; + if !args.common.silent && !args.common.json { + eprintln!( + "Failed to rollback {}: {}", + purl, + result.error.as_deref().unwrap_or("unknown error") + ); + } } + results.push(result); } - results.push(result); } Ok((!has_errors, results)) @@ -577,4 +628,56 @@ mod tests { find_patches_to_rollback(&manifest, Some("uuid-does-not-exist")); assert!(result.is_empty()); } + + /// A manifest holding several PyPI release variants of one + /// package@version (broad mode). + fn make_multi_variant_manifest() -> PatchManifest { + let mut patches = HashMap::new(); + patches.insert( + "pkg:pypi/six@1.16.0?artifact_id=wheel-cp311".to_string(), + make_record("uuid-wheel-cp311"), + ); + patches.insert( + "pkg:pypi/six@1.16.0?artifact_id=wheel-cp312".to_string(), + make_record("uuid-wheel-cp312"), + ); + patches.insert( + "pkg:pypi/six@1.16.0?artifact_id=sdist".to_string(), + make_record("uuid-sdist"), + ); + patches.insert("pkg:npm/foo@1.0".to_string(), make_record("uuid-foo")); + PatchManifest { patches } + } + + #[test] + fn test_find_patches_to_rollback_base_purl_matches_all_variants() { + let manifest = make_multi_variant_manifest(); + let result = + find_patches_to_rollback(&manifest, Some("pkg:pypi/six@1.16.0")); + // Base PURL (no qualifier) expands to every release variant. + assert_eq!(result.len(), 3); + for p in &result { + assert!(p.purl.starts_with("pkg:pypi/six@1.16.0?artifact_id=")); + } + } + + #[test] + fn test_find_patches_to_rollback_qualified_purl_matches_one_variant() { + let manifest = make_multi_variant_manifest(); + let result = find_patches_to_rollback( + &manifest, + Some("pkg:pypi/six@1.16.0?artifact_id=sdist"), + ); + // A fully-qualified PURL targets exactly one variant. + assert_eq!(result.len(), 1); + assert_eq!(result[0].purl, "pkg:pypi/six@1.16.0?artifact_id=sdist"); + } + + #[test] + fn test_find_patches_to_rollback_base_purl_does_not_leak_other_packages() { + let manifest = make_multi_variant_manifest(); + let result = + find_patches_to_rollback(&manifest, Some("pkg:pypi/six@1.16.0")); + assert!(result.iter().all(|p| p.purl.contains("six@1.16.0"))); + } } diff --git a/crates/socket-patch-cli/src/commands/scan.rs b/crates/socket-patch-cli/src/commands/scan.rs index 80232c64..d132e744 100644 --- a/crates/socket-patch-cli/src/commands/scan.rs +++ b/crates/socket-patch-cli/src/commands/scan.rs @@ -9,6 +9,7 @@ use socket_patch_core::manifest::schema::PatchManifest; use socket_patch_core::utils::cleanup_blobs::{ cleanup_unused_archives, cleanup_unused_blobs, CleanupResult, }; +use socket_patch_core::utils::purl::strip_purl_qualifiers; use socket_patch_core::utils::telemetry::{track_patch_scan_failed, track_patch_scanned}; use std::collections::HashSet; use std::path::Path; @@ -158,14 +159,24 @@ async fn preview_apply_gc( /// correspond to packages that were once patched but are no longer /// installed (or no longer reachable to the crawler). Pure / no I/O so /// it's unit-testable. +/// +/// Comparison is on the **base** PURL (qualifiers stripped) on both +/// sides: the pypi crawler reports base PURLs, but a manifest may hold +/// several qualified release variants (`?artifact_id=...`) of one +/// installed package. Matching on the base keeps every variant of an +/// installed package while still pruning all variants of one that is +/// gone — otherwise `scan --all-releases --sync` would prune the very +/// variants it just downloaded. pub(crate) fn detect_prunable( manifest: &PatchManifest, scanned_purls: &HashSet, ) -> Vec { + let scanned_bases: HashSet<&str> = + scanned_purls.iter().map(|p| strip_purl_qualifiers(p)).collect(); manifest .patches .keys() - .filter(|p| !scanned_purls.contains(*p)) + .filter(|p| !scanned_bases.contains(strip_purl_qualifiers(p))) .cloned() .collect() } @@ -237,6 +248,21 @@ pub struct ScanArgs { /// fully-reconciled state in one invocation. #[arg(long, default_value_t = false)] pub sync: bool, + + /// Download patches for every release/distribution (artifact_id) of + /// a matched package, not just the one matching the locally- + /// installed distribution. Only affects PyPI today — the only + /// ecosystem with per-release artifact_id variants. Off by default: + /// narrow scans store only the patch for the installed dist, keeping + /// `.socket/` small; `--all-releases` makes the manifest portable + /// across environments (e.g. cross-platform CI caches). + #[arg( + long = "all-releases", + env = "SOCKET_ALL_RELEASES", + default_value_t = false, + value_parser = clap::builder::BoolishValueParser::new(), + )] + pub all_releases: bool, } pub async fn run(args: ScanArgs) -> i32 { @@ -651,6 +677,7 @@ pub async fn run(args: ScanArgs) -> i32 { silent: true, download_mode: args.common.download_mode.clone(), api_overrides: args.common.api_client_overrides(), + all_releases: args.all_releases, }; let (code, apply_json) = download_and_apply_patches(&selected, ¶ms).await; apply_code = code; @@ -990,6 +1017,7 @@ pub async fn run(args: ScanArgs) -> i32 { silent: false, download_mode: args.common.download_mode.clone(), api_overrides: args.common.api_client_overrides(), + all_releases: args.all_releases, }; let (code, _) = download_and_apply_patches(&selected, ¶ms).await; @@ -1229,4 +1257,33 @@ mod tests { vec!["pkg:npm/bar@2.0".to_string(), "pkg:npm/foo@1.0".to_string()], ); } + + #[test] + fn detect_prunable_keeps_pypi_variants_of_installed_base() { + // Manifest holds three qualified release variants; the crawler + // reports only the base PURL. None should be pruned — they all + // belong to the installed package. + let m = manifest_with(&[ + ("pkg:pypi/six@1.16.0?artifact_id=wheel-a", "uuid-a"), + ("pkg:pypi/six@1.16.0?artifact_id=wheel-b", "uuid-b"), + ("pkg:pypi/six@1.16.0?artifact_id=sdist", "uuid-c"), + ]); + let out = detect_prunable(&m, &scanned(&["pkg:pypi/six@1.16.0"])); + assert!( + out.is_empty(), + "variants of an installed base must not be pruned; got {out:?}" + ); + } + + #[test] + fn detect_prunable_removes_all_variants_of_uninstalled_base() { + // The package is no longer installed (empty crawl): every + // release variant is prunable. + let m = manifest_with(&[ + ("pkg:pypi/six@1.16.0?artifact_id=wheel-a", "uuid-a"), + ("pkg:pypi/six@1.16.0?artifact_id=sdist", "uuid-c"), + ]); + let out = detect_prunable(&m, &scanned(&[])); + assert_eq!(out.len(), 2, "all variants of a gone package should prune"); + } } diff --git a/crates/socket-patch-cli/tests/cli_parse_get.rs b/crates/socket-patch-cli/tests/cli_parse_get.rs index fc1ccf16..c8364ab3 100644 --- a/crates/socket-patch-cli/tests/cli_parse_get.rs +++ b/crates/socket-patch-cli/tests/cli_parse_get.rs @@ -43,6 +43,16 @@ fn defaults_with_only_required_identifier() { assert!(!a.one_off); assert!(!a.common.json); assert_eq!(a.common.download_mode, "diff"); + assert!( + !a.all_releases, + "--all-releases default is false (narrow — installed-dist variant only)" + ); +} + +#[test] +fn all_releases_flag_sets_all_releases() { + let a = parse_get(&["some-id", "--all-releases"]); + assert!(a.all_releases); } #[test] diff --git a/crates/socket-patch-cli/tests/cli_parse_scan.rs b/crates/socket-patch-cli/tests/cli_parse_scan.rs index 14eaa7f3..2eecd1e2 100644 --- a/crates/socket-patch-cli/tests/cli_parse_scan.rs +++ b/crates/socket-patch-cli/tests/cli_parse_scan.rs @@ -59,6 +59,16 @@ fn defaults_match_contract() { assert!(!args.prune, "--prune default is false (GC is opt-in in v3.0)"); assert!(!args.sync, "--sync default is false"); assert!(!args.common.dry_run, "--dry-run default is false"); + assert!( + !args.all_releases, + "--all-releases default is false (narrow — installed-dist variant only)" + ); +} + +#[test] +fn all_releases_flag_long_form() { + let args = parse_scan(&["--all-releases"]); + assert!(args.all_releases); } #[test] diff --git a/crates/socket-patch-cli/tests/in_process_cargo_apply.rs b/crates/socket-patch-cli/tests/in_process_cargo_apply.rs index 7d174d70..f7020a21 100644 --- a/crates/socket-patch-cli/tests/in_process_cargo_apply.rs +++ b/crates/socket-patch-cli/tests/in_process_cargo_apply.rs @@ -218,6 +218,7 @@ async fn cargo_fetch_scan_sync_patches_real_file() { apply: false, prune: false, sync: true, + all_releases: false, }; // CARGO_HOME must be set in this process's env so the cargo crawler // probes the isolated location (not the developer's real ~/.cargo). @@ -289,6 +290,7 @@ async fn cargo_crawler_finds_real_fetched_crate() { apply: false, prune: false, sync: false, + all_releases: false, }; assert_eq!(scan_run(args).await, 0); std::env::remove_var("CARGO_HOME"); diff --git a/crates/socket-patch-cli/tests/in_process_gem_apply.rs b/crates/socket-patch-cli/tests/in_process_gem_apply.rs index 54fea683..1497e4a4 100644 --- a/crates/socket-patch-cli/tests/in_process_gem_apply.rs +++ b/crates/socket-patch-cli/tests/in_process_gem_apply.rs @@ -203,6 +203,7 @@ async fn gem_install_scan_sync_patches_real_file() { apply: false, prune: false, sync: true, + all_releases: false, }; let code = scan_run(args).await; assert!(code == 0 || code == 1, "scan --sync exit: {code}"); @@ -262,6 +263,7 @@ async fn gem_crawler_finds_real_installed_gem() { apply: false, prune: false, sync: false, + all_releases: false, }; assert_eq!(scan_run(args).await, 0); } diff --git a/crates/socket-patch-cli/tests/in_process_get.rs b/crates/socket-patch-cli/tests/in_process_get.rs index b0a2efa3..f383b7a7 100644 --- a/crates/socket-patch-cli/tests/in_process_get.rs +++ b/crates/socket-patch-cli/tests/in_process_get.rs @@ -40,6 +40,7 @@ fn default_args(identifier: &str, cwd: &Path) -> GetArgs { package: false, save_only: true, one_off: false, + all_releases: false, } } diff --git a/crates/socket-patch-cli/tests/in_process_pypi_apply.rs b/crates/socket-patch-cli/tests/in_process_pypi_apply.rs index 2733f5b6..2e948fc1 100644 --- a/crates/socket-patch-cli/tests/in_process_pypi_apply.rs +++ b/crates/socket-patch-cli/tests/in_process_pypi_apply.rs @@ -237,6 +237,7 @@ async fn pypi_install_scan_sync_patches_real_file() { apply: false, prune: false, sync: true, + all_releases: false, }; // Avoid borrow problem with into_iter let _ = &mut args; @@ -297,6 +298,7 @@ async fn pypi_scan_then_apply_force_patches_real_file() { apply: false, prune: false, sync: true, + all_releases: false, }; let _ = scan_run(scan_args).await; @@ -371,6 +373,7 @@ async fn pypi_apply_dry_run_does_not_modify_file() { apply: true, prune: false, sync: false, + all_releases: false, }; let _ = scan_run(scan_args).await; @@ -445,6 +448,7 @@ async fn pypi_crawler_finds_real_installed_six() { apply: false, prune: false, sync: false, + all_releases: false, }; assert_eq!(scan_run(args).await, 0); } diff --git a/crates/socket-patch-cli/tests/in_process_pypi_multi_release.rs b/crates/socket-patch-cli/tests/in_process_pypi_multi_release.rs new file mode 100644 index 00000000..ba7612fa --- /dev/null +++ b/crates/socket-patch-cli/tests/in_process_pypi_multi_release.rs @@ -0,0 +1,504 @@ +//! Multi-release (multi-`artifact_id`) PyPI patching coverage. +//! +//! A PyPI `package@version` can resolve to several patch variants — one +//! per release/distribution (`?artifact_id=...`, e.g. different wheels + +//! an sdist). Only the distribution actually installed in the venv can +//! apply. These tests install a real `six==1.16.0`, then drive the CLI +//! against a wiremock that advertises three release variants where only +//! one carries the on-disk file's real `beforeHash`. +//! +//! Behaviors pinned: +//! * `scan` (narrow, default) stores only the installed-dist variant. +//! * `scan --all-releases` (broad) stores every variant; apply still +//! patches with the installed one. +//! * `remove ` over a broad manifest removes ALL variants +//! and rolls back the file without spurious failure. +//! * `rollback` (no id) over a broad manifest exits 0 and restores the +//! file (non-installed variants are skipped, not failed). +//! +//! Requires `python3` with `venv` + `pip`; skipped (visibly) otherwise. + +use std::path::{Path, PathBuf}; +use std::process::Command; + +use base64::Engine; +use serial_test::serial; +use sha2::{Digest, Sha256}; +use socket_patch_cli::commands::remove::{run as remove_run, RemoveArgs}; +use socket_patch_cli::commands::rollback::{run as rollback_run, RollbackArgs}; +use socket_patch_cli::commands::scan::{run as scan_run, ScanArgs}; +use wiremock::matchers::{method, path, path_regex}; +use wiremock::{Mock, MockServer, ResponseTemplate}; + +const ORG: &str = "test-org"; +const PYPI_PACKAGE: &str = "six"; +const PYPI_VERSION: &str = "1.16.0"; + +// One UUID per release variant. The "installed" wheel is the only one +// whose beforeHash matches the real on-disk six.py. +const UUID_INSTALLED: &str = "11111111-1111-4111-8111-111111111111"; +const UUID_OTHER_WHEEL: &str = "22222222-2222-4222-8222-222222222222"; +const UUID_SDIST: &str = "33333333-3333-4333-8333-333333333333"; + +const ARTIFACT_INSTALLED: &str = "wheel-cp-installed"; +const ARTIFACT_OTHER_WHEEL: &str = "wheel-cp-other"; +const ARTIFACT_SDIST: &str = "sdist"; + +const MARKER_INSTALLED: &[u8] = b"\n# SOCKET-MULTIRELEASE-INSTALLED\n"; + +fn git_sha256(content: &[u8]) -> String { + let header = format!("blob {}\0", content.len()); + let mut hasher = Sha256::new(); + hasher.update(header.as_bytes()); + hasher.update(content); + hex::encode(hasher.finalize()) +} + +fn b64(bytes: &[u8]) -> String { + base64::engine::general_purpose::STANDARD.encode(bytes) +} + +fn find_python() -> Option<&'static str> { + for cmd in ["python3", "python", "py"] { + let ok = Command::new(cmd) + .arg("--version") + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::null()) + .status() + .map(|s| s.success()) + .unwrap_or(false); + if ok { + return Some(cmd); + } + } + None +} + +fn has_python3() -> bool { + find_python().is_some() +} + +fn venv_pip(venv: &Path) -> PathBuf { + if cfg!(windows) { + venv.join("Scripts").join("pip.exe") + } else { + venv.join("bin").join("pip") + } +} + +fn find_site_packages(venv: &Path) -> PathBuf { + if cfg!(windows) { + venv.join("Lib").join("site-packages") + } else { + let lib = venv.join("lib"); + for entry in std::fs::read_dir(&lib).expect("lib dir").flatten() { + let sp = entry.path().join("site-packages"); + if sp.exists() { + return sp; + } + } + panic!("site-packages not found under {}", lib.display()); + } +} + +/// Install `six==1.16.0` into a venv under `tmp`; return the path to the +/// installed `six.py`. +fn install_six(tmp: &Path) -> PathBuf { + let venv = tmp.join(".venv"); + let python = find_python().expect("python interpreter not on PATH"); + let status = Command::new(python) + .args(["-m", "venv", venv.to_str().unwrap()]) + .status() + .expect("python venv"); + assert!(status.success(), "failed to create venv"); + + let pip = venv_pip(&venv); + let status = Command::new(&pip) + .args([ + "install", + "--disable-pip-version-check", + "--quiet", + "--no-cache-dir", + &format!("{PYPI_PACKAGE}=={PYPI_VERSION}"), + ]) + .status() + .expect("pip install"); + assert!(status.success(), "failed to install {PYPI_PACKAGE}"); + + let candidate = find_site_packages(&venv).join("six.py"); + assert!(candidate.exists(), "six.py not found after pip install"); + candidate +} + +fn base_purl() -> String { + format!("pkg:pypi/{PYPI_PACKAGE}@{PYPI_VERSION}") +} + +fn qualified(artifact_id: &str) -> String { + format!("{}?artifact_id={artifact_id}", base_purl()) +} + +/// Stand up a wiremock advertising three release variants for the base +/// PURL. Only the `installed` variant's `beforeHash` matches the real +/// on-disk six.py; the other two describe different distributions. +async fn setup_multi_release_mock(server: &MockServer, installed_before_hash: &str) { + let base = base_purl(); + + // --- batch: report the base package has patches ----------------------- + Mock::given(method("POST")) + .and(path(format!("/v0/orgs/{ORG}/patches/batch"))) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "packages": [{ + "purl": base, + "patches": [ + { "uuid": UUID_INSTALLED, "purl": qualified(ARTIFACT_INSTALLED), + "tier": "free", "cveIds": [], "ghsaIds": [], + "severity": "high", "title": "installed wheel" }, + { "uuid": UUID_OTHER_WHEEL, "purl": qualified(ARTIFACT_OTHER_WHEEL), + "tier": "free", "cveIds": [], "ghsaIds": [], + "severity": "high", "title": "other wheel" }, + { "uuid": UUID_SDIST, "purl": qualified(ARTIFACT_SDIST), + "tier": "free", "cveIds": [], "ghsaIds": [], + "severity": "high", "title": "sdist" }, + ] + }], + "canAccessPaidPatches": false, + }))) + .mount(server) + .await; + + // --- by-package: all three qualified variants ------------------------- + Mock::given(method("GET")) + .and(path_regex(format!("^/v0/orgs/{ORG}/patches/by-package/.+$"))) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "patches": [ + { "uuid": UUID_INSTALLED, "purl": qualified(ARTIFACT_INSTALLED), + "publishedAt": "2024-01-01T00:00:00Z", "description": "installed wheel", + "license": "MIT", "tier": "free", "vulnerabilities": {} }, + { "uuid": UUID_OTHER_WHEEL, "purl": qualified(ARTIFACT_OTHER_WHEEL), + "publishedAt": "2024-01-01T00:00:00Z", "description": "other wheel", + "license": "MIT", "tier": "free", "vulnerabilities": {} }, + { "uuid": UUID_SDIST, "purl": qualified(ARTIFACT_SDIST), + "publishedAt": "2024-01-01T00:00:00Z", "description": "sdist", + "license": "MIT", "tier": "free", "vulnerabilities": {} }, + ], + "canAccessPaidPatches": false, + }))) + .mount(server) + .await; + + // --- view: per-UUID full patch with file hashes + inline blobs -------- + // Installed variant: beforeHash == real on-disk hash, so it applies. + // Its beforeBlobContent is left as a placeholder set by the caller's + // real bytes below (filled via mount_installed_view). + // The two non-installed variants carry bogus distribution bytes so + // their beforeHash never matches the on-disk file. + let other_before = b"# six.py from a DIFFERENT wheel distribution\n"; + let mut other_after = other_before.to_vec(); + other_after.extend_from_slice(b"\n# OTHER-WHEEL-MARKER\n"); + mount_view( + server, + UUID_OTHER_WHEEL, + &qualified(ARTIFACT_OTHER_WHEEL), + &git_sha256(other_before), + &git_sha256(&other_after), + other_before, + &other_after, + ) + .await; + + let sdist_before = b"# six.py from the sdist distribution\n"; + let mut sdist_after = sdist_before.to_vec(); + sdist_after.extend_from_slice(b"\n# SDIST-MARKER\n"); + mount_view( + server, + UUID_SDIST, + &qualified(ARTIFACT_SDIST), + &git_sha256(sdist_before), + &git_sha256(&sdist_after), + sdist_before, + &sdist_after, + ) + .await; + + // Sanity: the installed variant's before hash is the real file hash. + let _ = installed_before_hash; +} + +/// Mount the view for the installed variant. Separated because it needs +/// the real on-disk `before` bytes (for rollback) and the marker-patched +/// `after` bytes computed by the test. +async fn mount_installed_view( + server: &MockServer, + before_hash: &str, + after_hash: &str, + before_bytes: &[u8], + after_bytes: &[u8], +) { + mount_view( + server, + UUID_INSTALLED, + &qualified(ARTIFACT_INSTALLED), + before_hash, + after_hash, + before_bytes, + after_bytes, + ) + .await; +} + +#[allow(clippy::too_many_arguments)] +async fn mount_view( + server: &MockServer, + uuid: &str, + purl: &str, + before_hash: &str, + after_hash: &str, + before_bytes: &[u8], + after_bytes: &[u8], +) { + Mock::given(method("GET")) + .and(path(format!("/v0/orgs/{ORG}/patches/view/{uuid}"))) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "uuid": uuid, + "purl": purl, + "publishedAt": "2024-01-01T00:00:00Z", + "files": { + "six.py": { + "beforeHash": before_hash, + "afterHash": after_hash, + "blobContent": b64(after_bytes), + "beforeBlobContent": b64(before_bytes), + } + }, + "vulnerabilities": {}, + "description": "multi-release fixture", + "license": "MIT", + "tier": "free", + }))) + .mount(server) + .await; +} + +fn scan_args(tmp: &Path, api_url: String, all_releases: bool) -> ScanArgs { + ScanArgs { + common: socket_patch_cli::args::GlobalArgs { + cwd: tmp.to_path_buf(), + org: Some(ORG.to_string()), + json: true, + yes: true, + global: false, + global_prefix: None, + api_url, + api_token: Some("fake".to_string()), + ecosystems: Some(vec!["pypi".to_string()]), + download_mode: "diff".to_string(), + dry_run: false, + ..socket_patch_cli::args::GlobalArgs::default() + }, + batch_size: 100, + // Download + apply but DON'T prune/GC: the post-sync GC sweeps + // `beforeHash` blobs (only `afterHash` blobs are kept for apply), + // which would force the later rollback/remove to re-fetch them + // from the API. Keeping GC off leaves the before-blobs on disk so + // rollback restores offline. (Prune's base-vs-qualified handling + // is covered by `detect_prunable` unit tests.) + apply: true, + prune: false, + sync: false, + all_releases, + } +} + +fn manifest_keys(tmp: &Path) -> Vec { + let path = tmp.join(".socket").join("manifest.json"); + let raw = std::fs::read_to_string(&path) + .unwrap_or_else(|_| panic!("manifest not found at {}", path.display())); + let v: serde_json::Value = serde_json::from_str(&raw).expect("manifest json"); + v["patches"] + .as_object() + .map(|m| m.keys().cloned().collect()) + .unwrap_or_default() +} + +fn file_has_marker(file: &Path, marker: &[u8]) -> bool { + let bytes = std::fs::read(file).expect("read file"); + bytes.windows(marker.len()).any(|w| w == marker) +} + +/// Common setup: install six, compute the installed variant's hashes, +/// stand up the mock. Returns (six_path, server). +async fn fixture(tmp: &Path) -> (PathBuf, MockServer) { + let six_path = install_six(tmp); + let original = std::fs::read(&six_path).expect("read six.py"); + let before_hash = git_sha256(&original); + let mut patched = original.clone(); + patched.extend_from_slice(MARKER_INSTALLED); + let after_hash = git_sha256(&patched); + + let server = MockServer::start().await; + setup_multi_release_mock(&server, &before_hash).await; + mount_installed_view(&server, &before_hash, &after_hash, &original, &patched).await; + (six_path, server) +} + +// --------------------------------------------------------------------------- +// Narrow (default): only the installed-dist variant lands in the manifest. +// --------------------------------------------------------------------------- + +#[tokio::test] +#[serial] +async fn narrow_scan_keeps_only_installed_release() { + if !has_python3() { + println!("SKIP: python3 not on PATH"); + return; + } + let tmp = tempfile::tempdir().expect("tempdir"); + let (six_path, server) = fixture(tmp.path()).await; + + let code = scan_run(scan_args(tmp.path(), server.uri(), false)).await; + assert!(code == 0 || code == 1, "scan exit: {code}"); + + // Manifest holds exactly the installed wheel variant. + let keys = manifest_keys(tmp.path()); + assert_eq!( + keys, + vec![qualified(ARTIFACT_INSTALLED)], + "narrow scan must store only the installed-dist variant; got {keys:?}" + ); + + // The on-disk file was patched with the installed variant's marker. + assert!( + file_has_marker(&six_path, MARKER_INSTALLED), + "installed variant should have patched six.py" + ); +} + +// --------------------------------------------------------------------------- +// Broad: every variant is downloaded; apply still picks the installed one. +// --------------------------------------------------------------------------- + +#[tokio::test] +#[serial] +async fn broad_scan_keeps_all_releases() { + if !has_python3() { + println!("SKIP: python3 not on PATH"); + return; + } + let tmp = tempfile::tempdir().expect("tempdir"); + let (six_path, server) = fixture(tmp.path()).await; + + let code = scan_run(scan_args(tmp.path(), server.uri(), true)).await; + assert!(code == 0 || code == 1, "scan exit: {code}"); + + // Manifest holds all three release variants. + let mut keys = manifest_keys(tmp.path()); + keys.sort(); + let mut expected = vec![ + qualified(ARTIFACT_INSTALLED), + qualified(ARTIFACT_OTHER_WHEEL), + qualified(ARTIFACT_SDIST), + ]; + expected.sort(); + assert_eq!(keys, expected, "broad scan must store every variant"); + + // Apply still patches with the installed distribution's variant only. + assert!( + file_has_marker(&six_path, MARKER_INSTALLED), + "broad apply should still patch with the installed variant" + ); +} + +// --------------------------------------------------------------------------- +// Remove over a broad manifest: removes ALL variants and +// rolls back the file with no spurious failure. +// --------------------------------------------------------------------------- + +#[tokio::test] +#[serial] +async fn remove_base_purl_clears_all_variants_and_rolls_back() { + if !has_python3() { + println!("SKIP: python3 not on PATH"); + return; + } + let tmp = tempfile::tempdir().expect("tempdir"); + let (six_path, server) = fixture(tmp.path()).await; + + // Broad scan to seed all three variants + apply the installed one. + let _ = scan_run(scan_args(tmp.path(), server.uri(), true)).await; + assert_eq!(manifest_keys(tmp.path()).len(), 3); + assert!(file_has_marker(&six_path, MARKER_INSTALLED)); + + // Remove by base PURL — must match every variant and roll back. + let remove_args = RemoveArgs { + identifier: base_purl(), + common: socket_patch_cli::args::GlobalArgs { + cwd: tmp.path().to_path_buf(), + org: Some(ORG.to_string()), + api_url: server.uri(), + api_token: Some("fake".to_string()), + json: true, + yes: true, + ecosystems: Some(vec!["pypi".to_string()]), + ..socket_patch_cli::args::GlobalArgs::default() + }, + skip_rollback: false, + }; + let code = remove_run(remove_args).await; + assert_eq!(code, 0, "remove base PURL should succeed (exit 0)"); + + // Manifest emptied of the six variants. + assert!( + manifest_keys(tmp.path()).is_empty(), + "all release variants should be removed from the manifest" + ); + // File rolled back to original (marker gone). + assert!( + !file_has_marker(&six_path, MARKER_INSTALLED), + "remove should roll the on-disk file back to its original bytes" + ); +} + +// --------------------------------------------------------------------------- +// Rollback (no identifier) over a broad manifest: exit 0, file restored, +// non-installed variants skipped rather than failed. +// --------------------------------------------------------------------------- + +#[tokio::test] +#[serial] +async fn rollback_all_over_broad_manifest_succeeds() { + if !has_python3() { + println!("SKIP: python3 not on PATH"); + return; + } + let tmp = tempfile::tempdir().expect("tempdir"); + let (six_path, server) = fixture(tmp.path()).await; + + let _ = scan_run(scan_args(tmp.path(), server.uri(), true)).await; + assert_eq!(manifest_keys(tmp.path()).len(), 3); + assert!(file_has_marker(&six_path, MARKER_INSTALLED)); + + // Rollback everything in the manifest. Before the variant-dedupe fix + // this exited non-zero (HashMismatch on the two non-installed + // variants against the single on-disk file). + let rollback_args = RollbackArgs { + identifier: None, + common: socket_patch_cli::args::GlobalArgs { + cwd: tmp.path().to_path_buf(), + org: Some(ORG.to_string()), + api_url: server.uri(), + api_token: Some("fake".to_string()), + json: true, + ecosystems: Some(vec!["pypi".to_string()]), + ..socket_patch_cli::args::GlobalArgs::default() + }, + one_off: false, + }; + let code = rollback_run(rollback_args).await; + assert_eq!(code, 0, "rollback-all over broad manifest should exit 0"); + + assert!( + !file_has_marker(&six_path, MARKER_INSTALLED), + "rollback should restore the original file bytes" + ); +} diff --git a/crates/socket-patch-cli/tests/in_process_python_envs.rs b/crates/socket-patch-cli/tests/in_process_python_envs.rs index 41a25998..1a395173 100644 --- a/crates/socket-patch-cli/tests/in_process_python_envs.rs +++ b/crates/socket-patch-cli/tests/in_process_python_envs.rs @@ -58,6 +58,7 @@ fn default_args(cwd: &Path, api_url: String) -> ScanArgs { apply: false, prune: false, sync: false, + all_releases: false, } } diff --git a/crates/socket-patch-cli/tests/in_process_remote_ecosystems_apply.rs b/crates/socket-patch-cli/tests/in_process_remote_ecosystems_apply.rs index 3efcf115..24c6e238 100644 --- a/crates/socket-patch-cli/tests/in_process_remote_ecosystems_apply.rs +++ b/crates/socket-patch-cli/tests/in_process_remote_ecosystems_apply.rs @@ -60,6 +60,7 @@ fn default_scan_args(cwd: &Path, eco: &str, api_url: String) -> ScanArgs { apply: false, prune: false, sync: true, + all_releases: false, } } diff --git a/crates/socket-patch-cli/tests/in_process_scan.rs b/crates/socket-patch-cli/tests/in_process_scan.rs index 8f0d0a93..ea71f33c 100644 --- a/crates/socket-patch-cli/tests/in_process_scan.rs +++ b/crates/socket-patch-cli/tests/in_process_scan.rs @@ -36,6 +36,7 @@ fn default_args(cwd: &Path) -> ScanArgs { apply: false, prune: false, sync: false, + all_releases: false, } } diff --git a/crates/socket-patch-core/src/patch/apply.rs b/crates/socket-patch-core/src/patch/apply.rs index dfc0723b..5249a834 100644 --- a/crates/socket-patch-core/src/patch/apply.rs +++ b/crates/socket-patch-core/src/patch/apply.rs @@ -211,6 +211,52 @@ pub async fn verify_file_patch( } } +/// Select the single variant whose installed bytes match the on-disk +/// distribution — i.e. the "minimally required" release for this +/// environment. +/// +/// A package@version may resolve to several patch variants (PyPI +/// `?artifact_id=...` releases, one per wheel/sdist). Only one +/// distribution is ever installed in a given environment, so only one +/// variant can apply. This mirrors the first-file hash check the apply +/// pipeline uses: a variant matches when its first patched file is not +/// in a [`VerifyStatus::HashMismatch`] state against the on-disk +/// package. A variant with no files (nothing to verify) is treated as a +/// match. +/// +/// `variants` maps a variant key (typically a qualified PURL) to that +/// variant's patched files. Returns the index of the first variant whose +/// first patched file is in a [`VerifyStatus::Ready`] or +/// [`VerifyStatus::AlreadyPatched`] state — i.e. its `beforeHash` (or +/// `afterHash`, if already applied) matches the installed bytes — or +/// `None` when no variant matches the installed distribution. +/// +/// A [`VerifyStatus::NotFound`] (a missing pre-existing file) or +/// [`VerifyStatus::HashMismatch`] does **not** count as a match: those +/// signal the variant describes a *different* distribution than the one +/// on disk. A variant with no files (nothing to verify) is treated as a +/// match. Both the narrow download filter (scan/get) and the rollback +/// dedupe share this helper so release selection stays consistent. +pub async fn select_installed_variant( + pkg_path: &Path, + variants: &[(&str, &HashMap)], +) -> Option { + for (idx, (_key, files)) in variants.iter().enumerate() { + // No files to verify — nothing to disqualify the variant. + let Some((file_name, file_info)) = files.iter().next() else { + return Some(idx); + }; + let verify = verify_file_patch(pkg_path, file_name, file_info).await; + if matches!( + verify.status, + VerifyStatus::Ready | VerifyStatus::AlreadyPatched + ) { + return Some(idx); + } + } + None +} + /// Apply a patch to a single file. /// /// **Permission policy** (per the user-visible contract — patched diff --git a/crates/socket-patch-core/src/utils/purl.rs b/crates/socket-patch-core/src/utils/purl.rs index eec86a2d..e049121b 100644 --- a/crates/socket-patch-core/src/utils/purl.rs +++ b/crates/socket-patch-core/src/utils/purl.rs @@ -225,6 +225,28 @@ pub fn is_purl(s: &str) -> bool { s.starts_with("pkg:") } +/// Does a manifest PURL key match a user-supplied PURL identifier? +/// +/// PyPI patches are keyed in the manifest by their fully-qualified PURL +/// (`pkg:pypi/foo@1.0?artifact_id=...`), one entry per release variant. +/// A user removing or rolling back a package usually types the *base* +/// PURL without a qualifier and expects it to cover every variant. So: +/// +/// * a **base** identifier (no `?`) matches any key whose base equals it +/// — i.e. all release variants of that `package@version`, and +/// * a **qualified** identifier (`?artifact_id=...`) matches only the +/// exact key, so a single variant can still be targeted precisely. +/// +/// Non-PyPI keys never carry a `?`, so for them this reduces to plain +/// equality. +pub fn purl_matches_identifier(manifest_key: &str, identifier: &str) -> bool { + if identifier.contains('?') { + manifest_key == identifier + } else { + strip_purl_qualifiers(manifest_key) == identifier + } +} + #[cfg(test)] mod tests { use super::*; @@ -256,6 +278,47 @@ mod tests { assert_eq!(parse_pypi_purl("pkg:pypi/requests@"), None); } + #[test] + fn test_purl_matches_identifier() { + // Base identifier matches every qualified variant + the bare base. + assert!(purl_matches_identifier( + "pkg:pypi/requests@2.28.0?artifact_id=abc", + "pkg:pypi/requests@2.28.0" + )); + assert!(purl_matches_identifier( + "pkg:pypi/requests@2.28.0", + "pkg:pypi/requests@2.28.0" + )); + // Base identifier does NOT match a different version. + assert!(!purl_matches_identifier( + "pkg:pypi/requests@2.29.0?artifact_id=abc", + "pkg:pypi/requests@2.28.0" + )); + // Qualified identifier matches only the exact key. + assert!(purl_matches_identifier( + "pkg:pypi/requests@2.28.0?artifact_id=abc", + "pkg:pypi/requests@2.28.0?artifact_id=abc" + )); + assert!(!purl_matches_identifier( + "pkg:pypi/requests@2.28.0?artifact_id=xyz", + "pkg:pypi/requests@2.28.0?artifact_id=abc" + )); + // A qualified identifier must not match the bare base key. + assert!(!purl_matches_identifier( + "pkg:pypi/requests@2.28.0", + "pkg:pypi/requests@2.28.0?artifact_id=abc" + )); + // Non-PyPI keys: plain equality. + assert!(purl_matches_identifier( + "pkg:npm/lodash@4.17.21", + "pkg:npm/lodash@4.17.21" + )); + assert!(!purl_matches_identifier( + "pkg:npm/lodash@4.17.21", + "pkg:npm/lodash@4.17.20" + )); + } + #[test] fn test_is_purl() { assert!(is_purl("pkg:npm/lodash@4.17.21"));