Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 204 additions & 4 deletions crates/socket-patch-cli/src/commands/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@ use socket_patch_core::api::client::{
use socket_patch_core::api::types::{
PatchResponse, PatchSearchResult, SearchResponse, VulnerabilityResponse,
};
use socket_patch_core::crawlers::CrawlerOptions;
use socket_patch_core::crawlers::{CrawlerOptions, Ecosystem};
use socket_patch_core::manifest::operations::{read_manifest, write_manifest};
use socket_patch_core::manifest::schema::{
PatchFileInfo, PatchManifest, PatchRecord, VulnerabilityInfo,
};
use socket_patch_core::patch::apply::select_installed_variant;
use socket_patch_core::utils::fuzzy_match::fuzzy_match_packages;
use socket_patch_core::utils::purl::is_purl;
use socket_patch_core::utils::purl::{is_purl, strip_purl_qualifiers};
use socket_patch_core::utils::telemetry::{track_patch_fetch_failed, track_patch_fetched};
use std::collections::HashMap;
use std::fmt;
use std::path::{Path, PathBuf};

use crate::args::{apply_env_toggles, GlobalArgs};
use crate::ecosystem_dispatch::crawl_all_ecosystems;
use crate::ecosystem_dispatch::{crawl_all_ecosystems, find_packages_for_rollback, partition_purls};
use crate::output::{confirm, select_one, SelectError};

/// Best-effort ecosystem extractor for a `pkg:<eco>/...` PURL. Used as
Expand Down Expand Up @@ -327,6 +328,19 @@ pub struct GetArgs {
/// Apply patch immediately without saving to .socket folder.
#[arg(long = "one-off", env = "SOCKET_ONE_OFF", default_value_t = false)]
pub one_off: bool,

/// Download patches for every release/distribution (artifact_id) of
/// a matched package, not just the one matching the locally-
/// installed distribution. Only affects PyPI today — the only
/// ecosystem with per-release artifact_id variants. Off by default:
/// only the patch for the installed dist is fetched.
#[arg(
long = "all-releases",
env = "SOCKET_ALL_RELEASES",
default_value_t = false,
value_parser = clap::builder::BoolishValueParser::new(),
)]
pub all_releases: bool,
}

#[derive(Debug, Clone, Copy, PartialEq)]
Expand Down Expand Up @@ -508,11 +522,170 @@ pub struct DownloadParams {
/// client constructed here. Without this, `download_and_apply_patches`
/// would only honor env vars and ignore the user's flags.
pub api_overrides: socket_patch_core::api::client::ApiClientEnvOverrides,
/// When `false` (the default — narrow), a PyPI package with multiple
/// release variants (`?artifact_id=...`) is filtered down to the one
/// matching the locally-installed distribution before download. When
/// `true` (`--all-releases`), every variant is downloaded. No effect
/// on ecosystems without per-release artifact_id variants.
pub all_releases: bool,
}

/// Download and apply a set of selected patches.
///
/// Used by both `get` and `scan` commands. Returns (exit_code, json_result).
/// Narrow a selection of patches down to the release variant matching
/// each locally-installed distribution.
///
/// A PyPI `package@version` can resolve to several patch variants — one
/// per `?artifact_id=...` release (wheel/sdist). Only one distribution
/// is ever installed in a given environment, so only one variant can
/// apply. With `--all-releases` off (the default) we keep just the
/// variant whose first patched file's hash matches the on-disk package,
/// dropping the rest so they are never downloaded or written to the
/// manifest. Non-PyPI ecosystems never carry `artifact_id` qualifiers,
/// so they pass through untouched.
///
/// Fallbacks (keep all variants of the base, i.e. behave as broad):
/// * the base package is not installed on disk (nothing to match
/// against — e.g. `get` for an absent package), or
/// * the installed distribution matches none of the variants (a local
/// modification, or no patch exists for the installed release).
///
/// Both fallbacks push a human-readable warning.
///
/// Returns the kept patches plus any warnings to surface to the caller.
async fn filter_to_installed_releases(
selected: &[PatchSearchResult],
params: &DownloadParams,
api_client: &socket_patch_core::api::client::ApiClient,
org: Option<&str>,
) -> (Vec<PatchSearchResult>, Vec<String>) {
// Group the PyPI selections by their base PURL (qualifiers stripped).
// Anything that isn't PyPI, or whose base has a single variant, is
// kept verbatim and needs no installed-dist resolution.
let mut pypi_groups: HashMap<String, Vec<PatchSearchResult>> = HashMap::new();
let mut kept: Vec<PatchSearchResult> = Vec::new();
for sr in selected {
if Ecosystem::from_purl(&sr.purl) == Some(Ecosystem::Pypi) {
pypi_groups
.entry(strip_purl_qualifiers(&sr.purl).to_string())
.or_default()
.push(sr.clone());
} else {
kept.push(sr.clone());
}
}

let mut warnings: Vec<String> = Vec::new();

// Singleton PyPI bases have nothing to disambiguate — keep as-is.
// Collect the multi-variant bases that actually need resolution.
let mut multi: Vec<(String, Vec<PatchSearchResult>)> = Vec::new();
for (base, variants) in pypi_groups {
if variants.len() <= 1 {
kept.extend(variants);
} else {
multi.push((base, variants));
}
}

if multi.is_empty() {
return (kept, warnings);
}

// Discover the on-disk path for each multi-variant base. The pypi
// crawler is queried with base PURLs and the result is fanned back
// out to every qualified variant (all variants of one installed
// package resolve to the same path).
let all_qualified: Vec<String> = multi
.iter()
.flat_map(|(_, variants)| variants.iter().map(|s| s.purl.clone()))
.collect();
// All collected PURLs are PyPI; no ecosystem filter needed.
let partitioned = partition_purls(&all_qualified, None);
let crawler_options = CrawlerOptions {
cwd: params.cwd.clone(),
global: params.global,
global_prefix: params.global_prefix.clone(),
batch_size: 100,
};
let paths = find_packages_for_rollback(&partitioned, &crawler_options, true).await;

for (base, variants) in multi {
// Any variant's resolved path works — they all map to the single
// installed distribution.
let pkg_path = variants.iter().find_map(|s| paths.get(&s.purl)).cloned();
let Some(pkg_path) = pkg_path else {
// Not installed: cannot determine the relevant release. Keep
// every variant so the patch is still obtainable.
warnings.push(format!(
"{base} is not installed locally; keeping all {} release variant(s).",
variants.len()
));
kept.extend(variants);
continue;
};

// Fetch each variant's file hashes (the view carries them) so we
// can hash-match against the installed distribution.
let mut candidates: Vec<(String, HashMap<String, PatchFileInfo>)> = Vec::new();
for s in &variants {
match api_client.fetch_patch(org, &s.uuid).await {
Ok(Some(patch)) => {
candidates.push((s.purl.clone(), files_for_selection(&patch)));
}
// On a fetch error/miss, keep the variant so the main
// download loop can record the failure as it would today.
_ => candidates.push((s.purl.clone(), HashMap::new())),
}
}

let refs: Vec<(&str, &HashMap<String, PatchFileInfo>)> = candidates
.iter()
.map(|(purl, files)| (purl.as_str(), files))
.collect();

match select_installed_variant(&pkg_path, &refs).await {
Some(idx) => {
let winner = candidates[idx].0.clone();
kept.extend(variants.into_iter().filter(|s| s.purl == winner));
}
None => {
// Installed, but no variant matches the on-disk bytes.
// Fall back to broad rather than silently dropping a
// package the user asked about.
warnings.push(format!(
"No release variant of {base} matches the installed distribution; keeping all {} variant(s).",
variants.len()
));
kept.extend(variants);
}
}
}

(kept, warnings)
}

/// Build the before/after-hash map used for installed-distribution
/// matching. Mirrors the download flow's requirement that a patchable
/// file carry both hashes (new files, with an empty `beforeHash`, are
/// still kept so first-file verification can treat them as Ready).
fn files_for_selection(patch: &PatchResponse) -> HashMap<String, PatchFileInfo> {
let mut files = HashMap::new();
for (file_path, file_info) in &patch.files {
if let (Some(before), Some(after)) = (&file_info.before_hash, &file_info.after_hash) {
files.insert(
file_path.clone(),
PatchFileInfo {
before_hash: before.clone(),
after_hash: after.clone(),
},
);
}
}
files
}

pub async fn download_and_apply_patches(
selected: &[PatchSearchResult],
params: &DownloadParams,
Expand Down Expand Up @@ -545,6 +718,26 @@ pub async fn download_and_apply_patches(
_ => PatchManifest::new(),
};

// Narrow PyPI multi-release selections to the installed distribution
// unless --all-releases was passed. `filter_to_installed_releases`
// is a no-op for non-PyPI ecosystems and single-variant packages.
let mut narrow_warnings: Vec<String> = Vec::new();
let selected_owned: Vec<PatchSearchResult>;
let selected: &[PatchSearchResult] = if params.all_releases {
selected
} else {
let (kept, warns) =
filter_to_installed_releases(selected, params, &api_client, effective_org).await;
if !params.json && !params.silent {
for w in &warns {
eprintln!(" [note] {w}");
}
}
narrow_warnings = warns;
selected_owned = kept;
&selected_owned
};

if !params.json && !params.silent {
eprintln!("\nDownloading {} patch(es)...", selected.len());
}
Expand Down Expand Up @@ -735,7 +928,7 @@ pub async fn download_and_apply_patches(
}
}

let result_json = serde_json::json!({
let mut result_json = serde_json::json!({
"status": if patches_failed > 0 { "partial_failure" } else { "success" },
"found": selected.len(),
"downloaded": patches_added,
Expand All @@ -745,6 +938,12 @@ pub async fn download_and_apply_patches(
"updated": updates.len(),
"patches": downloaded_patches,
});
// Surface release-narrowing fallbacks (uninstalled package / no
// matching variant) so JSON consumers can see why all variants were
// kept. Omitted entirely when narrowing was clean.
if !narrow_warnings.is_empty() {
result_json["warnings"] = serde_json::json!(narrow_warnings);
}

let exit_code = if patches_failed > 0 || (!apply_succeeded && patches_added > 0 && !params.save_only) { 1 } else { 0 };
(exit_code, result_json)
Expand Down Expand Up @@ -1127,6 +1326,7 @@ pub async fn run(args: GetArgs) -> i32 {
silent: false,
download_mode: args.common.download_mode.clone(),
api_overrides: args.common.api_client_overrides(),
all_releases: args.all_releases,
};

let (code, result_json) = download_and_apply_patches(&selected, &params).await;
Expand Down
Loading
Loading