From 19c080a44ac9ab5ce7bb626a577c48979b4ec94e Mon Sep 17 00:00:00 2001 From: Konstantinos Stefanidis Vozikis Date: Mon, 29 Jun 2026 15:48:36 +0200 Subject: [PATCH] amend CLI UX with all the new storage/tables functionality from Tower Storage --- crates/tower-cmd/src/api.rs | 131 ++++- crates/tower-cmd/src/catalogs.rs | 487 +++++++++++++++++- crates/tower-cmd/src/lib.rs | 3 + crates/tower-cmd/src/mcp.rs | 2 +- .../tower-cmd/src/templates/dbeaver.txt.tmpl | 5 + crates/tower-cmd/src/templates/dbt.yml.tmpl | 5 + .../tower-cmd/src/templates/duckdb.sql.tmpl | 9 + .../tower-cmd/src/templates/pyiceberg.py.tmpl | 10 + crates/tower-cmd/src/templates/spark.py.tmpl | 12 + src/tower/_context.py | 59 ++- src/tower/_storage.py | 214 ++++++++ src/tower/_tables.py | 168 +++++- tests/tower/test_storage.py | 199 +++++++ tests/tower/test_tables.py | 212 ++++++++ 14 files changed, 1484 insertions(+), 32 deletions(-) create mode 100644 crates/tower-cmd/src/templates/dbeaver.txt.tmpl create mode 100644 crates/tower-cmd/src/templates/dbt.yml.tmpl create mode 100644 crates/tower-cmd/src/templates/duckdb.sql.tmpl create mode 100644 crates/tower-cmd/src/templates/pyiceberg.py.tmpl create mode 100644 crates/tower-cmd/src/templates/spark.py.tmpl create mode 100644 src/tower/_storage.py create mode 100644 tests/tower/test_storage.py diff --git a/crates/tower-cmd/src/api.rs b/crates/tower-cmd/src/api.rs index bdc16909..ae6a3507 100644 --- a/crates/tower-cmd/src/api.rs +++ b/crates/tower-cmd/src/api.rs @@ -350,21 +350,24 @@ pub async fn list_catalogs( config: &Config, env: &str, all: bool, + catalog_type: Option<&str>, ) -> Result, Error> { let api_config: configuration::Configuration = config.into(); let env = env.to_string(); + let catalog_type = catalog_type.map(str::to_string); fetch_all_pages(config, |page, page_size| { let api_config = &api_config; let env = &env; + let catalog_type = &catalog_type; async move { let params = tower_api::apis::default_api::ListCatalogsParams { environment: Some(env.to_string()), all: Some(all), page: Some(page), page_size: Some(page_size), - r#type: None, + r#type: catalog_type.clone(), }; unwrap_api_response(tower_api::apis::default_api::list_catalogs( api_config, params, @@ -375,6 +378,33 @@ pub async fn list_catalogs( .await } +pub async fn vend_catalog_credentials( + config: &Config, + name: &str, + env: &str, + mode: tower_api::models::vend_catalog_credentials_body::Mode, +) -> Result< + tower_api::models::VendCatalogCredentialsResponse, + Error, +> { + let api_config = config.into(); + + let params = tower_api::apis::default_api::VendCatalogCredentialsParams { + name: name.to_string(), + environment: Some(env.to_string()), + vend_catalog_credentials_body: tower_api::models::VendCatalogCredentialsBody { + schema: None, + mode: Some(mode), + }, + }; + + unwrap_api_response_redacted(tower_api::apis::default_api::vend_catalog_credentials( + &api_config, + params, + )) + .await +} + pub async fn describe_catalog( config: &Config, name: &str, @@ -719,6 +749,27 @@ pub async fn stream_run_logs( /// Helper function to handle Tower API responses and extract the relevant data async fn unwrap_api_response(api_call: F) -> Result> +where + F: std::future::Future, Error>>, + T: ResponseEntity, + T::Data: serde::de::DeserializeOwned, +{ + unwrap_api_response_inner(api_call, false).await +} + +async fn unwrap_api_response_redacted(api_call: F) -> Result> +where + F: std::future::Future, Error>>, + T: ResponseEntity, + T::Data: serde::de::DeserializeOwned, +{ + unwrap_api_response_inner(api_call, true).await +} + +async fn unwrap_api_response_inner( + api_call: F, + redact_success_content: bool, +) -> Result> where F: std::future::Future, Error>>, T: ResponseEntity, @@ -736,22 +787,32 @@ where } debug!("tower trace ID: {}", response.tower_trace_id); - debug!("Response from server: {}", response.content); + if redact_success_content { + debug!("Response from server: "); + } else { + debug!("Response from server: {}", response.content); + } if let Some(entity) = response.entity { if let Some(data) = entity.extract_data() { Ok(data) } else { - let truncated = if response.content.len() > 500 { + let truncated = if redact_success_content { + "".to_string() + } else if response.content.len() > 500 { format!("{}...(truncated)", &response.content[..500]) } else { response.content.clone() }; // Try explicit deserialization to get the actual error message - let deser_error = serde_json::from_str::(&response.content) - .err() - .map(|e| format!(" Deserialization error: {}", e)) - .unwrap_or_default(); + let deser_error = if redact_success_content { + String::new() + } else { + serde_json::from_str::(&response.content) + .err() + .map(|e| format!(" Deserialization error: {}", e)) + .unwrap_or_default() + }; debug!( "Failed to extract data from API response:{} Content: {}", deser_error, truncated @@ -839,6 +900,17 @@ impl ResponseEntity for tower_api::apis::default_api::DescribeCatalogSuccess { } } +impl ResponseEntity for tower_api::apis::default_api::VendCatalogCredentialsSuccess { + type Data = tower_api::models::VendCatalogCredentialsResponse; + + fn extract_data(self) -> Option { + match self { + Self::Status200(data) => Some(data), + Self::UnknownValue(_) => None, + } + } +} + impl ResponseEntity for tower_api::apis::default_api::CreateSecretSuccess { type Data = tower_api::models::CreateSecretResponse; @@ -1341,3 +1413,48 @@ impl ResponseEntity for tower_api::apis::default_api::CancelRunSuccess { } } } + +#[cfg(test)] +mod tests { + use super::{unwrap_api_response_redacted, ResponseEntity}; + use http::StatusCode; + use tower_api::apis::{Error, ResponseContent}; + + enum SensitiveSuccess { + UnknownValue, + } + + #[derive(serde::Deserialize)] + struct SensitiveData { + _value: String, + } + + impl ResponseEntity for SensitiveSuccess { + type Data = SensitiveData; + + fn extract_data(self) -> Option { + None + } + } + + #[tokio::test] + async fn redacted_unwrap_does_not_return_sensitive_response_body() { + let response = ResponseContent { + tower_trace_id: "trace-id".to_string(), + status: StatusCode::OK, + content: r#"{"credentials":{"oauth_token":"secret-token"}}"#.to_string(), + entity: Some(SensitiveSuccess::UnknownValue), + }; + + let result = + unwrap_api_response_redacted::(async { Ok(response) }).await; + + match result { + Err(Error::ResponseError(response)) => { + assert!(!response.content.contains("secret-token")); + assert!(response.content.contains("")); + } + _ => panic!("expected redacted response error"), + } + } +} diff --git a/crates/tower-cmd/src/catalogs.rs b/crates/tower-cmd/src/catalogs.rs index 06fb87f2..e7223c4c 100644 --- a/crates/tower-cmd/src/catalogs.rs +++ b/crates/tower-cmd/src/catalogs.rs @@ -1,9 +1,12 @@ -use clap::{value_parser, Arg, ArgMatches, Command}; +use clap::{value_parser, Arg, ArgAction, ArgMatches, Command}; use colored::Colorize; use config::Config; +use tower_api::models::{vend_catalog_credentials_body, CatalogCredentials}; use crate::{api, output, util::cmd}; +const STORAGE_CATALOG_TYPE: &str = "tower-catalog"; + pub fn catalogs_cmd() -> Command { Command::new("catalogs") .about("Interact with the catalogs in your Tower account") @@ -26,6 +29,20 @@ pub fn catalogs_cmd() -> Command { .help("List catalogs across all environments") .action(clap::ArgAction::SetTrue), ) + .arg( + Arg::new("type") + .long("type") + .value_parser(value_parser!(String)) + .help("Filter catalogs by type, e.g. tower-catalog") + .action(ArgAction::Set), + ) + .arg( + Arg::new("storage") + .long("storage") + .help("List Tower-managed storage catalogs") + .conflicts_with("type") + .action(ArgAction::SetTrue), + ) .about("List all of your catalogs"), ) .subcommand( @@ -48,14 +65,64 @@ pub fn catalogs_cmd() -> Command { ) .about("Show the details of a catalog, including its property names"), ) + .subcommand( + Command::new("credentials") + .arg( + Arg::new("catalog_name") + .value_parser(value_parser!(String)) + .index(1) + .required(true) + .help("Name of the catalog"), + ) + .arg( + Arg::new("environment") + .short('e') + .long("environment") + .default_value("default") + .value_parser(value_parser!(String)) + .help("Environment the catalog belongs to") + .action(ArgAction::Set), + ) + .arg( + Arg::new("mode") + .long("mode") + .default_value("read") + .value_parser(["read", "read-write"]) + .help("Credential access mode") + .action(ArgAction::Set), + ) + .arg( + Arg::new("format") + .long("format") + .default_value("all") + .value_parser(["all", "pyiceberg", "spark", "duckdb", "dbt", "dbeaver"]) + .help("Snippet format to print") + .action(ArgAction::Set), + ) + .arg( + Arg::new("show_token") + .long("show-token") + .help("Print the vended OAuth token in normal output") + .action(ArgAction::SetTrue), + ) + .about("Vend short-lived catalog credentials for external tools"), + ) } pub async fn do_list(config: Config, args: &ArgMatches) { let all = cmd::get_bool_flag(args, "all"); let env = cmd::get_string_flag(args, "environment"); + let catalog_type = if cmd::get_bool_flag(args, "storage") { + Some(STORAGE_CATALOG_TYPE) + } else { + args.get_one::("type").map(String::as_str) + }; - let catalogs = - output::with_spinner("Listing catalogs", api::list_catalogs(&config, &env, all)).await; + let catalogs = output::with_spinner( + "Listing catalogs", + api::list_catalogs(&config, &env, all, catalog_type), + ) + .await; let headers = vec!["Name", "Type", "Environment"] .into_iter() @@ -74,6 +141,43 @@ pub async fn do_list(config: Config, args: &ArgMatches) { output::table(headers, data, Some(&catalogs)); } +pub async fn do_credentials(config: Config, args: &ArgMatches) { + let name = args + .get_one::("catalog_name") + .expect("catalog_name is required"); + let env = cmd::get_string_flag(args, "environment"); + let mode = args + .get_one::("mode") + .map(String::as_str) + .unwrap_or("read"); + let format = args + .get_one::("format") + .map(String::as_str) + .unwrap_or("all"); + let show_token = cmd::get_bool_flag(args, "show_token"); + + let response = output::with_spinner( + "Vending catalog credentials", + api::vend_catalog_credentials(&config, name, &env, parse_mode(mode)), + ) + .await; + + if output::get_output_mode().is_json() { + output::json(&response); + return; + } + + print_credentials( + name, + &env, + mode, + config.tower_url.as_str(), + &response.credentials, + format, + show_token, + ); +} + pub async fn do_show(config: Config, args: &ArgMatches) { let name = args .get_one::("catalog_name") @@ -119,9 +223,207 @@ pub async fn do_show(config: Config, args: &ArgMatches) { } } +fn parse_mode(mode: &str) -> vend_catalog_credentials_body::Mode { + match mode { + "read-write" => vend_catalog_credentials_body::Mode::ReadWrite, + _ => vend_catalog_credentials_body::Mode::Read, + } +} + +fn shell_quote(value: &str) -> String { + format!("'{}'", value.replace('\'', "'\"'\"'")) +} + +fn quote(value: &str) -> String { + serde_json::to_string(value).expect("serializing a string should not fail") +} + +fn sql_string(value: &str) -> String { + format!("'{}'", value.replace('\'', "''")) +} + +fn sql_ident(value: &str) -> String { + format!("\"{}\"", value.replace('"', "\"\"")) +} + +fn token_export_command(name: &str, environment: &str, mode: &str, tower_url: &str) -> String { + format!( + "export TOWER_CATALOG_TOKEN=\"$(tower --tower-url {tower_url} --json catalogs credentials {name} --environment {environment} --mode {mode} | python3 -c 'import json,sys; print(json.load(sys.stdin)[\"credentials\"][\"oauth_token\"])')\"\n", + tower_url = shell_quote(tower_url), + name = shell_quote(name), + environment = shell_quote(environment), + mode = shell_quote(mode), + ) +} + +fn print_credentials( + name: &str, + environment: &str, + mode: &str, + tower_url: &str, + credentials: &CatalogCredentials, + format: &str, + show_token: bool, +) { + output::detail("Catalog", name); + output::detail("Mode", &credentials.mode); + output::detail("Expires", &credentials.expires_at); + if show_token { + output::detail("Token", &credentials.oauth_token); + } else { + output::detail("Token", "not printed; snippets read $TOWER_CATALOG_TOKEN"); + } + output::write(&output::paragraph( + "These credentials are short-lived and intended for ad-hoc development use.", + )); + output::newline(); + + if !show_token { + output::newline(); + output::header("Shell setup"); + output::write(token_export_command(name, environment, mode, tower_url).as_str()); + } + + for snippet in snippets(name, credentials, format, show_token) { + output::newline(); + output::header(snippet.title); + output::write(snippet.body.as_str()); + output::newline(); + } +} + +const PYICEBERG_TMPL: &str = include_str!("templates/pyiceberg.py.tmpl"); +const SPARK_TMPL: &str = include_str!("templates/spark.py.tmpl"); +const DUCKDB_TMPL: &str = include_str!("templates/duckdb.sql.tmpl"); +const DBT_TMPL: &str = include_str!("templates/dbt.yml.tmpl"); +const DBEAVER_TMPL: &str = include_str!("templates/dbeaver.txt.tmpl"); + +/// Substitute `__TOWER_*__` markers in a connection-snippet template. Values must +/// already be escaped for the target format — the templates under `src/templates/` +/// are inert text and the per-format escaping stays in `snippets`. +fn render(template: &str, vars: &[(&str, String)]) -> String { + let mut out = template.to_string(); + for (marker, value) in vars { + out = out.replace(marker, value); + } + out +} + +struct Snippet { + title: &'static str, + body: String, +} + +fn snippets( + name: &str, + credentials: &CatalogCredentials, + format: &str, + show_token: bool, +) -> Vec { + let all = format == "all"; + let mut snippets = Vec::new(); + + let py_token = if show_token { + quote(&credentials.oauth_token) + } else { + "os.environ[\"TOWER_CATALOG_TOKEN\"]".to_string() + }; + let sql_token = if show_token { + sql_string(&credentials.oauth_token) + } else { + "'${TOWER_CATALOG_TOKEN}'".to_string() + }; + let dbt_token = if show_token { + quote(&credentials.oauth_token) + } else { + "\"{{ env_var('TOWER_CATALOG_TOKEN') }}\"".to_string() + }; + let dbeaver_token = if show_token { + credentials.oauth_token.clone() + } else { + "${TOWER_CATALOG_TOKEN}".to_string() + }; + + if all || format == "pyiceberg" { + snippets.push(Snippet { + title: "PyIceberg", + body: render( + PYICEBERG_TMPL, + &[ + ("__TOWER_NAME__", quote(name)), + ("__TOWER_URI__", quote(&credentials.catalog_uri)), + ("__TOWER_WAREHOUSE__", quote(&credentials.warehouse)), + ("__TOWER_TOKEN__", py_token.clone()), + ], + ), + }); + } + + if all || format == "spark" { + snippets.push(Snippet { + title: "Spark", + body: render( + SPARK_TMPL, + &[ + ("__TOWER_NAME__", name.to_string()), + ("__TOWER_URI__", quote(&credentials.catalog_uri)), + ("__TOWER_WAREHOUSE__", quote(&credentials.warehouse)), + ("__TOWER_TOKEN__", py_token.clone()), + ], + ), + }); + } + + if all || format == "duckdb" { + snippets.push(Snippet { + title: "DuckDB", + body: render( + DUCKDB_TMPL, + &[ + ("__TOWER_NAME__", sql_ident(name)), + ("__TOWER_URI__", sql_string(&credentials.catalog_uri)), + ("__TOWER_WAREHOUSE__", sql_string(&credentials.warehouse)), + ("__TOWER_TOKEN__", sql_token.clone()), + ], + ), + }); + } + + if all || format == "dbt" { + snippets.push(Snippet { + title: "dbt", + body: render( + DBT_TMPL, + &[ + ("__TOWER_URI__", quote(&credentials.catalog_uri)), + ("__TOWER_WAREHOUSE__", quote(&credentials.warehouse)), + ("__TOWER_TOKEN__", dbt_token.clone()), + ], + ), + }); + } + + if all || format == "dbeaver" { + snippets.push(Snippet { + title: "DBeaver", + body: render( + DBEAVER_TMPL, + &[ + ("__TOWER_URI__", credentials.catalog_uri.clone()), + ("__TOWER_WAREHOUSE__", credentials.warehouse.clone()), + ("__TOWER_TOKEN__", dbeaver_token.clone()), + ], + ), + }); + } + + snippets +} + #[cfg(test)] mod tests { - use super::catalogs_cmd; + use super::{catalogs_cmd, parse_mode, snippets, token_export_command}; + use tower_api::models::{vend_catalog_credentials_body, CatalogCredentials}; #[test] fn list_defaults_to_default_environment() { @@ -163,6 +465,38 @@ mod tests { assert_eq!(list_args.get_one::("all").copied(), Some(true)); } + #[test] + fn list_accepts_type_filter() { + let matches = catalogs_cmd() + .try_get_matches_from(["catalogs", "list", "--type", "tower-catalog"]) + .expect("list --type should parse"); + + let (_, list_args) = matches.subcommand().expect("expected list subcommand"); + + assert_eq!( + list_args.get_one::("type").unwrap(), + "tower-catalog" + ); + } + + #[test] + fn list_accepts_storage_alias() { + let matches = catalogs_cmd() + .try_get_matches_from(["catalogs", "list", "--storage"]) + .expect("list --storage should parse"); + + let (_, list_args) = matches.subcommand().expect("expected list subcommand"); + + assert_eq!(list_args.get_one::("storage").copied(), Some(true)); + } + + #[test] + fn list_rejects_type_and_storage_together() { + let result = + catalogs_cmd().try_get_matches_from(["catalogs", "list", "--storage", "--type", "s3"]); + assert!(result.is_err()); + } + #[test] fn show_requires_catalog_name() { let result = catalogs_cmd().try_get_matches_from(["catalogs", "show"]); @@ -204,4 +538,149 @@ mod tests { "production" ); } + + #[test] + fn credentials_accepts_catalog_name() { + let matches = catalogs_cmd() + .try_get_matches_from(["catalogs", "credentials", "default"]) + .expect("credentials with name should parse"); + + let (_, credentials_args) = matches + .subcommand() + .expect("expected credentials subcommand"); + + assert_eq!( + credentials_args.get_one::("catalog_name").unwrap(), + "default" + ); + assert_eq!(credentials_args.get_one::("mode").unwrap(), "read"); + assert_eq!(credentials_args.get_one::("format").unwrap(), "all"); + } + + #[test] + fn credentials_accepts_read_write_mode() { + let matches = catalogs_cmd() + .try_get_matches_from(["catalogs", "credentials", "default", "--mode", "read-write"]) + .expect("credentials --mode read-write should parse"); + + let (_, credentials_args) = matches + .subcommand() + .expect("expected credentials subcommand"); + + assert_eq!( + credentials_args.get_one::("mode").unwrap(), + "read-write" + ); + assert_eq!( + parse_mode(credentials_args.get_one::("mode").unwrap()), + vend_catalog_credentials_body::Mode::ReadWrite + ); + } + + #[test] + fn token_export_command_fetches_token_without_printing_it() { + let credentials = CatalogCredentials::new( + "https://catalog.example.com".to_string(), + "2026-06-26T12:00:00Z".to_string(), + "read".to_string(), + "secret-token".to_string(), + "warehouse-id".to_string(), + ); + + let command = + token_export_command("default", "production", "read", "http://localhost:8000/"); + + assert!(command.contains("export TOWER_CATALOG_TOKEN=")); + assert!(command.contains("tower --tower-url 'http://localhost:8000/' --json")); + assert!(command.contains("catalogs credentials 'default'")); + assert!(command.contains("--environment 'production'")); + assert!(!command.contains(&credentials.oauth_token)); + } + + #[test] + fn pyiceberg_snippet_reads_token_from_environment_by_default() { + let credentials = CatalogCredentials::new( + "https://catalog.example.com".to_string(), + "2026-06-26T12:00:00Z".to_string(), + "read".to_string(), + "secret-token".to_string(), + "warehouse-id".to_string(), + ); + + let snippets = snippets("default", &credentials, "pyiceberg", false); + + assert_eq!(snippets.len(), 1); + assert!(snippets[0].body.contains("load_catalog")); + assert!(snippets[0] + .body + .contains("os.environ[\"TOWER_CATALOG_TOKEN\"]")); + assert!(!snippets[0].body.contains("secret-token")); + } + + #[test] + fn duckdb_snippet_attaches_catalog_with_secret() { + let credentials = CatalogCredentials::new( + "https://catalog.example.com".to_string(), + "2026-06-26T12:00:00Z".to_string(), + "read".to_string(), + "secret-token".to_string(), + "warehouse-id".to_string(), + ); + + let snippets = snippets("default", &credentials, "duckdb", false); + + assert_eq!(snippets.len(), 1); + assert!(snippets[0] + .body + .contains("CREATE OR REPLACE SECRET tower_cat (TYPE iceberg")); + assert!(snippets[0] + .body + .contains("ATTACH 'warehouse-id' AS \"default\"")); + assert!(snippets[0] + .body + .contains("ENDPOINT 'https://catalog.example.com'")); + assert!(snippets[0].body.contains("TOKEN '${TOWER_CATALOG_TOKEN}'")); + assert!(!snippets[0].body.contains("secret-token")); + } + + #[test] + fn dbeaver_snippet_uses_dbeaver_variable_placeholder() { + let credentials = CatalogCredentials::new( + "https://catalog.example.com".to_string(), + "2026-06-26T12:00:00Z".to_string(), + "read".to_string(), + "secret-token".to_string(), + "warehouse-id".to_string(), + ); + + let snippets = snippets("default", &credentials, "dbeaver", false); + + assert_eq!(snippets.len(), 1); + assert!(snippets[0].body.contains("Token: ${TOWER_CATALOG_TOKEN}")); + assert!(!snippets[0].body.contains("secret-token")); + } + + #[test] + fn all_snippet_templates_fully_render() { + let credentials = CatalogCredentials::new( + "https://catalog.example.com".to_string(), + "2026-06-26T12:00:00Z".to_string(), + "read".to_string(), + "secret-token".to_string(), + "warehouse-id".to_string(), + ); + + for show_token in [false, true] { + let rendered = snippets("default", &credentials, "all", show_token); + assert_eq!(rendered.len(), 5); + for snippet in &rendered { + assert!( + !snippet.body.contains("__TOWER_"), + "unsubstituted marker in {} snippet (show_token={})", + snippet.title, + show_token + ); + } + } + } } diff --git a/crates/tower-cmd/src/lib.rs b/crates/tower-cmd/src/lib.rs index 649de82f..9eb84280 100644 --- a/crates/tower-cmd/src/lib.rs +++ b/crates/tower-cmd/src/lib.rs @@ -143,6 +143,9 @@ impl App { match catalogs_command { Some(("list", args)) => catalogs::do_list(sessionized_config, args).await, Some(("show", args)) => catalogs::do_show(sessionized_config, args).await, + Some(("credentials", args)) => { + catalogs::do_credentials(sessionized_config, args).await + } _ => { catalogs::catalogs_cmd().print_help().unwrap(); std::process::exit(2); diff --git a/crates/tower-cmd/src/mcp.rs b/crates/tower-cmd/src/mcp.rs index aeedb285..6d406c22 100644 --- a/crates/tower-cmd/src/mcp.rs +++ b/crates/tower-cmd/src/mcp.rs @@ -707,7 +707,7 @@ impl TowerService { Parameters(request): Parameters, ) -> Result { let environment = request.environment.as_deref().unwrap_or("default"); - match api::list_catalogs(&self.config, environment, false).await { + match api::list_catalogs(&self.config, environment, false, None).await { Ok(catalogs) => { let catalogs: Vec = catalogs .into_iter() diff --git a/crates/tower-cmd/src/templates/dbeaver.txt.tmpl b/crates/tower-cmd/src/templates/dbeaver.txt.tmpl new file mode 100644 index 00000000..ef42023a --- /dev/null +++ b/crates/tower-cmd/src/templates/dbeaver.txt.tmpl @@ -0,0 +1,5 @@ +Catalog type: Iceberg REST +URI: __TOWER_URI__ +Warehouse: __TOWER_WAREHOUSE__ +Authentication: Bearer token +Token: __TOWER_TOKEN__ diff --git a/crates/tower-cmd/src/templates/dbt.yml.tmpl b/crates/tower-cmd/src/templates/dbt.yml.tmpl new file mode 100644 index 00000000..7672fe99 --- /dev/null +++ b/crates/tower-cmd/src/templates/dbt.yml.tmpl @@ -0,0 +1,5 @@ +type: iceberg +catalog_type: rest +uri: __TOWER_URI__ +warehouse: __TOWER_WAREHOUSE__ +token: __TOWER_TOKEN__ diff --git a/crates/tower-cmd/src/templates/duckdb.sql.tmpl b/crates/tower-cmd/src/templates/duckdb.sql.tmpl new file mode 100644 index 00000000..fa792559 --- /dev/null +++ b/crates/tower-cmd/src/templates/duckdb.sql.tmpl @@ -0,0 +1,9 @@ +duckdb < Optional[str]: + current: Any = session + for key in keys: + if not isinstance(current, dict): + return None + current = current.get(key) + return current if isinstance(current, str) and current else None + + +def _read_session() -> dict[str, Any]: + session_path = Path.home() / ".config" / "tower" / "session.json" + try: + with session_path.open() as session_file: + session = json.load(session_file) + except (FileNotFoundError, OSError, json.JSONDecodeError, UnicodeDecodeError): + return {} + + return session if isinstance(session, dict) else {} + + +def _getenv_or_none(name: str) -> Optional[str]: + return os.getenv(name) or None + + +def _session_jwt(session: dict[str, Any]) -> Optional[str]: + return _get_session_value( + session, "active_team", "token", "jwt" + ) or _get_session_value(session, "token", "jwt") class TowerContext: @@ -32,11 +65,27 @@ def is_local(self) -> bool: @classmethod def build(cls): - tower_url = os.getenv("TOWER_URL", "https://api.tower.dev") - tower_environment = os.getenv("TOWER_ENVIRONMENT", "default") - tower_api_key = os.getenv("TOWER_API_KEY") - tower_jwt = os.getenv("TOWER_JWT") - tower_run_id = os.getenv("TOWER__RUNTIME__RUN_ID") + session = {} + tower_url = _getenv_or_none("TOWER_URL") + tower_environment = ( + os.getenv("TOWER__RUNTIME__ENVIRONMENT_NAME") + or os.getenv("TOWER_ENVIRONMENT") + or "default" + ) + tower_api_key = _getenv_or_none("TOWER_API_KEY") + tower_jwt = _getenv_or_none("TOWER_JWT") + tower_run_id = _getenv_or_none("TOWER__RUNTIME__RUN_ID") + + if tower_url is None or (tower_api_key is None and tower_jwt is None): + session = _read_session() + + if tower_url is None: + tower_url = ( + _get_session_value(session, "tower_url") or "https://api.tower.dev" + ) + + if tower_api_key is None and tower_jwt is None: + tower_jwt = _session_jwt(session) # Replaces the deprecated hugging_face_provider and hugging_face_api_key inference_router = os.getenv("TOWER_INFERENCE_ROUTER") diff --git a/src/tower/_storage.py b/src/tower/_storage.py new file mode 100644 index 00000000..ac7db195 --- /dev/null +++ b/src/tower/_storage.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +import hashlib +import time +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from http import HTTPStatus +from typing import Any, Optional + +from ._client import _env_client +from ._context import TowerContext +from .tower_api_client.api.default import ( + describe_default_catalog as describe_default_catalog_api, +) +from .tower_api_client.api.default import ( + vend_catalog_credentials as vend_catalog_credentials_api, +) +from .tower_api_client.models import ( + CatalogCredentials, + ErrorModel, + VendCatalogCredentialsBody, + VendCatalogCredentialsBodyMode, + VendCatalogCredentialsResponse, +) +from .tower_api_client.types import UNSET, Unset + +CREDENTIAL_REFRESH_WINDOW = timedelta(minutes=5) +DEFAULT_CATALOG_PROVISION_RETRY_DELAYS = (0.25, 0.5, 1.0, 2.0) +DEFAULT_CATALOG_NAME = "default" +DEFAULT_ENVIRONMENT_NAME = "default" + + +@dataclass +class _CachedCredentials: + credentials: CatalogCredentials + + def is_usable(self, now: datetime) -> bool: + expires_at = _ensure_aware(self.credentials.expires_at) + return now < expires_at - CREDENTIAL_REFRESH_WINDOW + + +_credential_cache: dict[tuple[str, str, str, str, str], _CachedCredentials] = {} + + +def get_tower_catalog( + name: str = DEFAULT_CATALOG_NAME, + environment: Optional[str] = None, + mode: str = "read", +) -> Any: + """ + Load a PyIceberg REST catalog using short-lived credentials vended by Tower. + """ + credentials = get_tower_catalog_credentials(name, environment, mode) + return load_vended_catalog(name, credentials) + + +def get_tower_catalog_credentials( + name: str = DEFAULT_CATALOG_NAME, + environment: Optional[str] = None, + mode: str = "read", +) -> CatalogCredentials: + ctx = TowerContext.build() + environment = environment or ctx.environment or DEFAULT_ENVIRONMENT_NAME + mode = _normalize_mode(mode) + cache_key = _cache_key(ctx, name, environment, mode) + + now = datetime.now(timezone.utc) + _prune_credential_cache(now) + cached = _credential_cache.get(cache_key) + if cached is not None and cached.is_usable(now): + return cached.credentials + + credentials = _vend_with_default_catalog_fallback(ctx, name, environment, mode) + _credential_cache[cache_key] = _CachedCredentials(credentials) + return credentials + + +def load_vended_catalog(name: str, credentials: CatalogCredentials) -> Any: + from pyiceberg.catalog import load_catalog + + return load_catalog( + name, + type="rest", + uri=credentials.catalog_uri, + warehouse=credentials.warehouse, + token=credentials.oauth_token, + ) + + +def _vend_with_default_catalog_fallback( + ctx: TowerContext, name: str, environment: str, mode: str +) -> CatalogCredentials: + result = _vend_catalog_credentials(ctx, name, environment, mode) + if not _is_not_found(result): + return _unwrap_vend_result(result, name, environment) + + if name == DEFAULT_CATALOG_NAME and environment == DEFAULT_ENVIRONMENT_NAME: + _ensure_legacy_default_catalog(ctx) + for delay in DEFAULT_CATALOG_PROVISION_RETRY_DELAYS: + time.sleep(delay) + result = _vend_catalog_credentials(ctx, name, environment, mode) + if not _is_not_found(result): + return _unwrap_vend_result(result, name, environment) + _ensure_legacy_default_catalog(ctx) + + return _unwrap_vend_result(result, name, environment) + + raise RuntimeError( + f"Tower catalog {name!r} does not exist in environment {environment!r}." + ) + + +def _vend_catalog_credentials( + ctx: TowerContext, name: str, environment: str, mode: str +) -> ErrorModel | VendCatalogCredentialsResponse | None: + _ensure_tower_auth(ctx) + body = VendCatalogCredentialsBody(mode=_vend_mode(mode)) + return vend_catalog_credentials_api.sync( + name=name, + client=_env_client(ctx), + environment=environment, + body=body, + ) + + +def _ensure_legacy_default_catalog(ctx: TowerContext) -> None: + try: + response = describe_default_catalog_api.sync_detailed(client=_env_client(ctx)) + if response.status_code not in (HTTPStatus.OK, HTTPStatus.ACCEPTED): + return + except Exception: + # The following vend retry will surface the actionable backend/auth error. + return + + +def _unwrap_vend_result( + result: ErrorModel | VendCatalogCredentialsResponse | None, + name: str, + environment: str, +) -> CatalogCredentials: + if isinstance(result, VendCatalogCredentialsResponse): + return result.credentials + + if isinstance(result, ErrorModel): + detail = _error_text(result) + raise RuntimeError( + f"Failed to vend credentials for Tower catalog {name!r} " + f"in environment {environment!r}: {detail}" + ) + + raise RuntimeError( + f"Failed to vend credentials for Tower catalog {name!r} " + f"in environment {environment!r}." + ) + + +def _ensure_tower_auth(ctx: TowerContext) -> None: + if ctx.api_key or ctx.jwt: + return + + raise RuntimeError( + "No Tower authentication found. Set TOWER_API_KEY or run `tower login`." + ) + + +def _cache_key( + ctx: TowerContext, name: str, environment: str, mode: str +) -> tuple[str, str, str, str, str]: + token = ctx.api_key or ctx.jwt or "" + principal_hash = hashlib.sha256(token.encode("utf-8")).hexdigest() + return (ctx.tower_url, principal_hash, name, environment, mode) + + +def _prune_credential_cache(now: datetime) -> None: + expired_keys = [ + key for key, cached in _credential_cache.items() if not cached.is_usable(now) + ] + for key in expired_keys: + _credential_cache.pop(key, None) + + +def _normalize_mode(mode: str) -> str: + if mode not in ("read", "read-write"): + raise ValueError("mode must be 'read' or 'read-write'") + return mode + + +def _vend_mode(mode: str) -> VendCatalogCredentialsBodyMode: + return ( + VendCatalogCredentialsBodyMode.READ_WRITE + if mode == "read-write" + else VendCatalogCredentialsBodyMode.READ + ) + + +def _is_not_found(result: ErrorModel | VendCatalogCredentialsResponse | None) -> bool: + return isinstance(result, ErrorModel) and result.status == 404 + + +def _error_text(error: ErrorModel) -> str: + for value in (error.detail, error.title): + if not isinstance(value, Unset) and value: + return str(value) + return f"HTTP {error.status}" if error.status is not UNSET else "unknown error" + + +def _ensure_aware(value: datetime) -> datetime: + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc) + + +def _clear_credential_cache() -> None: + _credential_cache.clear() diff --git a/src/tower/_tables.py b/src/tower/_tables.py index de5bca46..9567c83c 100644 --- a/src/tower/_tables.py +++ b/src/tower/_tables.py @@ -1,4 +1,6 @@ -from typing import Optional, Generic, TypeVar, Union, List +from __future__ import annotations + +from typing import Optional, TypeVar, Union, List from dataclasses import dataclass from pyiceberg.exceptions import NoSuchTableError @@ -17,6 +19,8 @@ ) from ._context import TowerContext +from ._storage import get_tower_catalog_credentials, load_vended_catalog +from .tower_api_client.models import CatalogCredentials from .utils.pyarrow import ( convert_pyarrow_schema, convert_pyarrow_expressions, @@ -36,13 +40,42 @@ class RowsAffectedInformation: updates: int +_VendedCatalogIdentity = tuple[str, str, str] + + +def _vended_catalog_identity( + credentials: CatalogCredentials, +) -> _VendedCatalogIdentity: + return ( + credentials.catalog_uri, + credentials.warehouse, + credentials.oauth_token, + ) + + +def _load_tower_catalog( + name: str, + environment: Optional[str], + mode: str, +) -> tuple[Catalog, _VendedCatalogIdentity]: + credentials = get_tower_catalog_credentials(name, environment, mode) + return load_vended_catalog(name, credentials), _vended_catalog_identity(credentials) + + class Table: """ `Table` is a wrapper around an Iceberg table. It provides methods to read and write data to the table. """ - def __init__(self, context: TowerContext, table: IcebergTable): + def __init__( + self, + context: TowerContext, + table: IcebergTable, + table_reference: Optional[TableReference] = None, + table_identifier: Optional[str] = None, + catalog_mode: str = "read", + ): """ Initialize a new Table instance that wraps an Iceberg table. @@ -70,6 +103,22 @@ def __init__(self, context: TowerContext, table: IcebergTable): self._stats = RowsAffectedInformation(0, 0) self._context = context self._table = table + self._table_reference = table_reference + self._table_identifier = table_identifier + self._catalog_mode = catalog_mode + self._loaded_from = ( + table_reference._catalog if table_reference is not None else None + ) + + def _ensure_read_write_table(self) -> None: + if self._table_reference is None or self._table_identifier is None: + return + + catalog = self._table_reference._ensure_catalog_mode("read-write") + if catalog is not self._loaded_from: + self._table = catalog.load_table(self._table_identifier) + self._loaded_from = catalog + self._catalog_mode = "read-write" def read(self) -> pl.DataFrame: """ @@ -202,6 +251,7 @@ def insert( >>> print(f"Inserted {stats.inserts} rows") """ self._validate_retry_args(max_retries, retry_delay_seconds) + self._ensure_read_write_table() last_exception = None @@ -275,6 +325,7 @@ def upsert( >>> print(f"Inserted {stats.inserts} rows") """ self._validate_retry_args(max_retries, retry_delay_seconds) + self._ensure_read_write_table() last_exception = None @@ -354,6 +405,7 @@ def delete( >>> table.delete("age > 30 AND department = 'IT'") """ self._validate_retry_args(max_retries, retry_delay_seconds) + self._ensure_read_write_table() if isinstance(filters, list): # We need to convert the pc.Expression into PyIceberg @@ -441,11 +493,42 @@ def __init__( catalog: Catalog, name: str, namespace: Optional[str] = None, + catalog_name: Optional[str] = None, + catalog_environment: Optional[str] = None, + tower_vended: bool = False, + catalog_mode: str = "read", + vended_catalog_identity: Optional[_VendedCatalogIdentity] = None, ): self._context = ctx self._catalog = catalog self._name = name self._namespace = namespace + self._catalog_name = catalog_name + self._catalog_environment = catalog_environment + self._tower_vended = tower_vended + self._catalog_mode = catalog_mode + self._vended_catalog_identity = vended_catalog_identity + + def _ensure_catalog_mode(self, mode: str) -> Catalog: + if not self._tower_vended or self._catalog_name is None: + return self._catalog + + credentials = get_tower_catalog_credentials( + self._catalog_name, + environment=self._catalog_environment, + mode=mode, + ) + identity = _vended_catalog_identity(credentials) + + if self._catalog_mode != mode or self._vended_catalog_identity != identity: + self._catalog = load_vended_catalog( + self._catalog_name, + credentials, + ) + self._catalog_mode = mode + self._vended_catalog_identity = identity + + return self._catalog def load(self) -> Table: """ @@ -471,7 +554,13 @@ def load(self) -> Table: namespace = namespace_or_default(self._namespace) table_name = make_table_name(self._name, namespace) table = self._catalog.load_table(table_name) - return Table(self._context, table) + return Table( + self._context, + table, + table_reference=self if self._tower_vended else None, + table_identifier=table_name, + catalog_mode=self._catalog_mode, + ) def create(self, schema: pa.Schema) -> Table: """ @@ -509,20 +598,27 @@ def create(self, schema: pa.Schema) -> Table: namespace = namespace_or_default(self._namespace) table_name = make_table_name(self._name, namespace) + catalog = self._ensure_catalog_mode("read-write") # We need to create the relevant namespace if it's missing from the # resolved namespace. - self._catalog.create_namespace_if_not_exists(namespace) + catalog.create_namespace_if_not_exists(namespace) # Now that we're certain the namespace exists, we can create the # underlying table. This will return an error if something went wrong # along the way. - table = self._catalog.create_table( + table = catalog.create_table( identifier=table_name, schema=convert_pyarrow_schema(schema), ) - return Table(self._context, table) + return Table( + self._context, + table, + table_reference=self if self._tower_vended else None, + table_identifier=table_name, + catalog_mode=self._catalog_mode, + ) def create_if_not_exists(self, schema: pa.Schema) -> Table: """ @@ -564,20 +660,27 @@ def create_if_not_exists(self, schema: pa.Schema) -> Table: namespace = namespace_or_default(self._namespace) table_name = make_table_name(self._name, namespace) + catalog = self._ensure_catalog_mode("read-write") # We need to create the relevant namespace if it's missing from the # resolved namespace. - self._catalog.create_namespace_if_not_exists(namespace) + catalog.create_namespace_if_not_exists(namespace) # We have the catalog, so let's attempt to create the table. It should # not return an error and instead just return the table if it already # exists. - table = self._catalog.create_table_if_not_exists( + table = catalog.create_table_if_not_exists( identifier=table_name, schema=convert_pyarrow_schema(schema), ) - return Table(self._context, table) + return Table( + self._context, + table, + table_reference=self if self._tower_vended else None, + table_identifier=table_name, + catalog_mode=self._catalog_mode, + ) def drop(self) -> bool: """ @@ -605,9 +708,10 @@ def drop(self) -> bool: """ namespace = namespace_or_default(self._namespace) table_name = make_table_name(self._name, namespace) + catalog = self._ensure_catalog_mode("read-write") try: - self._catalog.drop_table(table_name) + catalog.drop_table(table_name) return True except NoSuchTableError: # If the table doesn't exist or there's any other issue, return False @@ -617,7 +721,10 @@ def drop(self) -> bool: def tables( - name: str, catalog: Union[str, Catalog] = "default", namespace: Optional[str] = None + name: str, + catalog: Union[str, Catalog] = "default", + namespace: Optional[str] = None, + tower_credentials: Optional[bool] = None, ) -> TableReference: """ Creates a reference to an Iceberg table that can be used to load or create tables. @@ -636,6 +743,11 @@ def tables( Defaults to "default". namespace (Optional[str], optional): The namespace in which the table exists or should be created. If not provided, a default namespace will be used. + tower_credentials (Optional[bool], optional): Credential resolution for string + catalogs. By default (None) and when True, credentials are vended from + Tower. Set False to fall back to existing PyIceberg configuration (the + legacy ``PYICEBERG_CATALOG__*`` env vars) — a temporary rollback hatch. + Ignored when a Catalog instance is passed. Returns: TableReference: A reference object that can be used to: @@ -671,8 +783,34 @@ def tables( >>> if success: ... print("Table dropped successfully") """ - if isinstance(catalog, str): - catalog = load_catalog(catalog) - ctx = TowerContext.build() - return TableReference(ctx, catalog, name, namespace) + tower_vended = False + catalog_name = catalog if isinstance(catalog, str) else None + vended_catalog_identity = None + + if isinstance(catalog, str): + # Tower-managed catalogs always resolve through credential vending. + # `tower_credentials=False` is a rollback hatch to the legacy + # PYICEBERG_CATALOG__* env-var config (still injected by the runner); + # remove it once that injection is gone (TOW-2316). + if tower_credentials is False: + catalog = load_catalog(catalog) + else: + catalog, vended_catalog_identity = _load_tower_catalog( + catalog, + environment=ctx.environment, + mode="read", + ) + tower_vended = True + + return TableReference( + ctx, + catalog, + name, + namespace, + catalog_name=catalog_name, + catalog_environment=ctx.environment, + tower_vended=tower_vended, + catalog_mode="read", + vended_catalog_identity=vended_catalog_identity, + ) diff --git a/tests/tower/test_storage.py b/tests/tower/test_storage.py new file mode 100644 index 00000000..2a6951c8 --- /dev/null +++ b/tests/tower/test_storage.py @@ -0,0 +1,199 @@ +import json +from datetime import datetime, timedelta, timezone + +from tower._context import TowerContext +from tower import _storage +from tower.tower_api_client.models import ( + CatalogCredentials, + ErrorModel, + VendCatalogCredentialsResponse, +) + + +def clear_tower_env(monkeypatch): + for name in ( + "TOWER_URL", + "TOWER_ENVIRONMENT", + "TOWER_API_KEY", + "TOWER_JWT", + "TOWER__RUNTIME__RUN_ID", + "TOWER__RUNTIME__ENVIRONMENT_NAME", + ): + monkeypatch.delenv(name, raising=False) + + +def test_context_prefers_runtime_environment(monkeypatch, tmp_path): + clear_tower_env(monkeypatch) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setenv("TOWER_ENVIRONMENT", "local-env") + monkeypatch.setenv("TOWER__RUNTIME__ENVIRONMENT_NAME", "run-env") + + ctx = TowerContext.build() + + assert ctx.environment == "run-env" + + +def test_context_reads_session_when_env_auth_is_missing(monkeypatch, tmp_path): + clear_tower_env(monkeypatch) + monkeypatch.setenv("HOME", str(tmp_path)) + session_path = tmp_path / ".config" / "tower" / "session.json" + session_path.parent.mkdir(parents=True) + session_path.write_text( + json.dumps( + { + "tower_url": "https://api.session.example", + "token": {"jwt": "user-jwt"}, + "active_team": {"token": {"jwt": "team-jwt"}}, + } + ) + ) + + ctx = TowerContext.build() + + assert ctx.tower_url == "https://api.session.example" + assert ctx.jwt == "team-jwt" + + +def test_context_treats_blank_auth_env_as_missing(monkeypatch, tmp_path): + clear_tower_env(monkeypatch) + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setenv("TOWER_API_KEY", "") + monkeypatch.setenv("TOWER_JWT", "") + session_path = tmp_path / ".config" / "tower" / "session.json" + session_path.parent.mkdir(parents=True) + session_path.write_text( + json.dumps( + { + "tower_url": "https://api.session.example", + "token": {"jwt": "user-jwt"}, + "active_team": {"token": {"jwt": "team-jwt"}}, + } + ) + ) + + ctx = TowerContext.build() + + assert ctx.api_key is None + assert ctx.jwt == "team-jwt" + + +def test_context_ignores_corrupted_session_file(monkeypatch, tmp_path): + clear_tower_env(monkeypatch) + monkeypatch.setenv("HOME", str(tmp_path)) + session_path = tmp_path / ".config" / "tower" / "session.json" + session_path.parent.mkdir(parents=True) + session_path.write_bytes(b"\xff") + + ctx = TowerContext.build() + + assert ctx.tower_url == "https://api.tower.dev" + assert ctx.jwt is None + + +def test_get_tower_catalog_credentials_caches_vended_credentials(monkeypatch): + _storage._clear_credential_cache() + ctx = TowerContext( + tower_url="https://api.example.com", + environment="production", + api_key="api-key", + ) + expires_at = datetime.now(timezone.utc) + timedelta(hours=1) + credentials = CatalogCredentials( + catalog_uri="https://catalog.example.com", + expires_at=expires_at, + mode="read", + oauth_token="oauth-token", + warehouse="warehouse-id", + ) + calls = [] + + def vend(ctx, name, environment, mode): + calls.append((name, environment, mode)) + return VendCatalogCredentialsResponse(credentials=credentials) + + monkeypatch.setattr(_storage.TowerContext, "build", staticmethod(lambda: ctx)) + monkeypatch.setattr(_storage, "_vend_catalog_credentials", vend) + + first = _storage.get_tower_catalog_credentials("default") + second = _storage.get_tower_catalog_credentials("default") + + assert first is credentials + assert second is credentials + assert calls == [("default", "production", "read")] + + +def test_get_tower_catalog_credentials_prunes_expired_cache_entries(monkeypatch): + _storage._clear_credential_cache() + ctx = TowerContext( + tower_url="https://api.example.com", + environment="production", + api_key="api-key", + ) + expired_credentials = CatalogCredentials( + catalog_uri="https://old-catalog.example.com", + expires_at=datetime.now(timezone.utc) - timedelta(minutes=1), + mode="read", + oauth_token="old-oauth-token", + warehouse="old-warehouse-id", + ) + fresh_credentials = CatalogCredentials( + catalog_uri="https://catalog.example.com", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + mode="read", + oauth_token="oauth-token", + warehouse="warehouse-id", + ) + expired_key = _storage._cache_key(ctx, "stale", "production", "read") + _storage._credential_cache[expired_key] = _storage._CachedCredentials( + expired_credentials + ) + + def vend(ctx, name, environment, mode): + return VendCatalogCredentialsResponse(credentials=fresh_credentials) + + monkeypatch.setattr(_storage.TowerContext, "build", staticmethod(lambda: ctx)) + monkeypatch.setattr(_storage, "_vend_catalog_credentials", vend) + + result = _storage.get_tower_catalog_credentials("default") + + assert result is fresh_credentials + assert expired_key not in _storage._credential_cache + + +def test_default_catalog_vend_retries_after_legacy_provisioning(monkeypatch): + _storage._clear_credential_cache() + ctx = TowerContext( + tower_url="https://api.example.com", + environment="default", + api_key="api-key", + ) + expires_at = datetime.now(timezone.utc) + timedelta(hours=1) + credentials = CatalogCredentials( + catalog_uri="https://catalog.example.com", + expires_at=expires_at, + mode="read", + oauth_token="oauth-token", + warehouse="warehouse-id", + ) + responses = [ + ErrorModel(status=404, detail="not found"), + ErrorModel(status=404, detail="still provisioning"), + VendCatalogCredentialsResponse(credentials=credentials), + ] + legacy_calls = [] + + def vend(ctx, name, environment, mode): + return responses.pop(0) + + def legacy_default(ctx): + legacy_calls.append(ctx) + + monkeypatch.setattr(_storage.TowerContext, "build", staticmethod(lambda: ctx)) + monkeypatch.setattr(_storage, "_vend_catalog_credentials", vend) + monkeypatch.setattr(_storage, "_ensure_legacy_default_catalog", legacy_default) + monkeypatch.setattr(_storage.time, "sleep", lambda delay: None) + + result = _storage.get_tower_catalog_credentials("default") + + assert result is credentials + assert len(legacy_calls) == 2 diff --git a/tests/tower/test_tables.py b/tests/tower/test_tables.py index 931d6afe..38429a7d 100644 --- a/tests/tower/test_tables.py +++ b/tests/tower/test_tables.py @@ -18,6 +18,59 @@ # Imports the library under test import tower +import tower._tables as tables_module +from tower._context import TowerContext +from tower.tower_api_client.models import CatalogCredentials + + +class FakeLoadedTable: + def __init__(self, mode: str, identifier: str): + self.mode = mode + self.identifier = identifier + self.append_calls = [] + + def append(self, data): + self.append_calls.append(data) + + +class FakeCatalog: + def __init__(self, mode: str): + self.mode = mode + self.loaded_identifiers = [] + self.loaded_tables = [] + + def load_table(self, identifier: str): + self.loaded_identifiers.append(identifier) + table = FakeLoadedTable(self.mode, identifier) + self.loaded_tables.append(table) + return table + + +def patch_tower_context( + monkeypatch, + environment: str = "production", + api_key: str | None = "api-key", + run_id: str | None = None, +): + ctx = TowerContext( + tower_url="https://api.example.com", + environment=environment, + api_key=api_key, + run_id=run_id, + ) + monkeypatch.setattr(tables_module.TowerContext, "build", staticmethod(lambda: ctx)) + return ctx + + +def make_catalog_credentials(mode: str, token: str | None = None): + return CatalogCredentials( + catalog_uri="https://catalog.example.com", + expires_at=datetime.datetime.now(datetime.timezone.utc) + + datetime.timedelta(hours=1), + mode=mode, + oauth_token=token or f"{mode}-token", + warehouse="warehouse-id", + ) def get_temp_dir(): @@ -69,6 +122,165 @@ def sql_catalog(): pass +@pytest.mark.parametrize( + ("tower_credentials", "expected_source"), + [ + (None, "vend"), + (True, "vend"), + (False, "load_catalog"), + ], +) +def test_string_catalog_precedence(monkeypatch, tower_credentials, expected_source): + # Tower-managed catalogs vend by default and when forced; `tower_credentials=False` + # is the only path that falls back to existing PyIceberg configuration. + patch_tower_context(monkeypatch) + vend_catalog = FakeCatalog("vend") + configured_catalog = FakeCatalog("configured") + calls = [] + + def get_tower_catalog_credentials(name, environment=None, mode="read"): + calls.append(("vend", name, environment, mode)) + return make_catalog_credentials(mode) + + def load_vended_catalog(name, credentials): + calls.append(("load_vended_catalog", name, credentials.mode)) + return vend_catalog + + def load_catalog(name): + calls.append(("load_catalog", name)) + return configured_catalog + + monkeypatch.setattr( + tables_module, "get_tower_catalog_credentials", get_tower_catalog_credentials + ) + monkeypatch.setattr(tables_module, "load_vended_catalog", load_vended_catalog) + monkeypatch.setattr(tables_module, "load_catalog", load_catalog) + + ref = tables_module.tables( + "events", + catalog="default", + tower_credentials=tower_credentials, + ) + + if expected_source == "vend": + assert ref._catalog is vend_catalog + assert ref._tower_vended is True + assert calls == [ + ("vend", "default", "production", "read"), + ("load_vended_catalog", "default", "read"), + ] + else: + assert ref._catalog is configured_catalog + assert ref._tower_vended is False + assert calls == [("load_catalog", "default")] + + +def test_vended_table_write_lazily_escalates_and_reuses_catalog(monkeypatch): + patch_tower_context(monkeypatch) + catalogs = {} + credential_calls = [] + + def get_tower_catalog_credentials(name, environment=None, mode="read"): + credential_calls.append((name, environment, mode)) + return make_catalog_credentials(mode) + + def load_vended_catalog(name, credentials): + catalog = FakeCatalog(credentials.mode) + catalogs.setdefault(credentials.mode, []).append(catalog) + return catalog + + monkeypatch.setattr( + tables_module, "get_tower_catalog_credentials", get_tower_catalog_credentials + ) + monkeypatch.setattr(tables_module, "load_vended_catalog", load_vended_catalog) + + ref = tables_module.tables("events", catalog="default", namespace="demo") + + assert ref._tower_vended is True + assert credential_calls == [("default", "production", "read")] + + table = ref.load() + + assert table._catalog_mode == "read" + assert catalogs["read"][0].loaded_identifiers == ["demo.events"] + assert "read-write" not in catalogs + + data = pa.table({"id": [1, 2, 3]}) + table.insert(data) + + assert credential_calls == [ + ("default", "production", "read"), + ("default", "production", "read-write"), + ] + assert catalogs["read"][0].loaded_tables[0].append_calls == [] + assert catalogs["read-write"][0].loaded_identifiers == ["demo.events"] + assert catalogs["read-write"][0].loaded_tables[0].append_calls == [data] + assert table._catalog_mode == "read-write" + assert table.rows_affected().inserts == 3 + + second_batch = pa.table({"id": [4]}) + table.insert(second_batch) + + assert credential_calls == [ + ("default", "production", "read"), + ("default", "production", "read-write"), + ("default", "production", "read-write"), + ] + assert len(catalogs["read-write"]) == 1 + assert catalogs["read-write"][0].loaded_identifiers == ["demo.events"] + assert catalogs["read-write"][0].loaded_tables[0].append_calls == [ + data, + second_batch, + ] + assert table.rows_affected().inserts == 4 + + +def test_vended_table_write_reloads_when_credentials_change(monkeypatch): + patch_tower_context(monkeypatch) + catalogs = {} + credential_calls = [] + read_write_credentials = [ + make_catalog_credentials("read-write", token="write-token-1"), + make_catalog_credentials("read-write", token="write-token-2"), + ] + + def get_tower_catalog_credentials(name, environment=None, mode="read"): + credential_calls.append((name, environment, mode)) + if mode == "read-write": + return read_write_credentials.pop(0) + return make_catalog_credentials(mode) + + def load_vended_catalog(name, credentials): + catalog = FakeCatalog(credentials.mode) + catalogs.setdefault(credentials.mode, []).append(catalog) + return catalog + + monkeypatch.setattr( + tables_module, "get_tower_catalog_credentials", get_tower_catalog_credentials + ) + monkeypatch.setattr(tables_module, "load_vended_catalog", load_vended_catalog) + + table = tables_module.tables("events", catalog="default", namespace="demo").load() + + first_batch = pa.table({"id": [1]}) + table.insert(first_batch) + + second_batch = pa.table({"id": [2]}) + table.insert(second_batch) + + assert credential_calls == [ + ("default", "production", "read"), + ("default", "production", "read-write"), + ("default", "production", "read-write"), + ] + assert len(catalogs["read-write"]) == 2 + assert catalogs["read-write"][0].loaded_identifiers == ["demo.events"] + assert catalogs["read-write"][1].loaded_identifiers == ["demo.events"] + assert catalogs["read-write"][0].loaded_tables[0].append_calls == [first_batch] + assert catalogs["read-write"][1].loaded_tables[0].append_calls == [second_batch] + assert table.rows_affected().inserts == 2 + + def test_reading_and_writing_to_tables(in_memory_catalog): schema = pa.schema( [