Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/openshell-bootstrap/src/oidc_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ pub struct OidcTokenBundle {
/// `OAuth2` access token (JWT).
pub access_token: String,

/// Optional OIDC ID token returned by the provider.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id_token: Option<String>,

/// `OAuth2` refresh token. `None` for `client_credentials` grants.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
Expand Down
19 changes: 11 additions & 8 deletions crates/openshell-cli/src/completers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,19 @@ async fn completion_grpc_client(
Some("oidc") => {
if let Some(bundle) = load_oidc_token(gateway_name) {
if is_token_expired(&bundle) {
match oidc_refresh_token(&bundle, tls_opts.gateway_insecure).await {
Ok(refreshed) => {
let _ = store_oidc_token(gateway_name, &refreshed);
tls_opts.oidc_token = Some(refreshed.access_token);
}
Err(_) => {
tls_opts.oidc_token = Some(bundle.access_token);
}
if let Ok(refreshed) =
oidc_refresh_token(&bundle, tls_opts.gateway_insecure).await
{
let _ = store_oidc_token(gateway_name, &refreshed);
tls_opts.oidc_token = Some(refreshed.access_token);
tls_opts.oidc_id_token = refreshed.id_token;
} else {
tls_opts.oidc_token = Some(bundle.access_token);
tls_opts.oidc_id_token = bundle.id_token;
}
} else {
tls_opts.oidc_token = Some(bundle.access_token);
tls_opts.oidc_id_token = bundle.id_token;
}
}
}
Expand All @@ -124,6 +126,7 @@ async fn completion_grpc_client(
let channel = build_channel(server, &tls_opts).await.ok()?;
let interceptor = EdgeAuthInterceptor::new(
tls_opts.oidc_token.as_deref(),
tls_opts.oidc_id_token.as_deref(),
tls_opts.edge_token.as_deref(),
)
.ok()?;
Expand Down
32 changes: 29 additions & 3 deletions crates/openshell-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ fn apply_auth(tls: &mut TlsOptions, gateway_name: &str) {
else {
return;
};
let bearer_for_gateway = |access_token: &str, id_token: Option<&String>| {
if access_token.matches('.').count() == 2 {
access_token.to_string()
} else {
id_token
.cloned()
.unwrap_or_else(|| access_token.to_string())
}
};
if openshell_bootstrap::oidc_token::is_token_expired(&bundle) {
let insecure = std::env::var("OPENSHELL_GATEWAY_INSECURE")
.is_ok_and(|v| !v.is_empty() && v != "0" && v != "false");
Expand All @@ -155,17 +164,29 @@ fn apply_auth(tls: &mut TlsOptions, gateway_name: &str) {
gateway_name,
&refreshed,
);
tls.oidc_token = Some(refreshed.access_token);
tls.oidc_token = Some(bearer_for_gateway(
&refreshed.access_token,
refreshed.id_token.as_ref(),
));
tls.oidc_id_token = refreshed.id_token;
}
Err(e) => {
tracing::warn!("OIDC token refresh failed: {e}");
// Use the expired token anyway — server will reject it
// with a clear error prompting re-login.
tls.oidc_token = Some(bundle.access_token);
tls.oidc_token = Some(bearer_for_gateway(
&bundle.access_token,
bundle.id_token.as_ref(),
));
tls.oidc_id_token = bundle.id_token;
}
}
} else {
tls.oidc_token = Some(bundle.access_token);
tls.oidc_token = Some(bearer_for_gateway(
&bundle.access_token,
bundle.id_token.as_ref(),
));
tls.oidc_id_token = bundle.id_token;
}
}
_ => {}
Expand Down Expand Up @@ -660,6 +681,8 @@ enum OutputFormat {
enum CliProviderRefreshStrategy {
Oauth2RefreshToken,
Oauth2ClientCredentials,
Oauth2TokenExchange,
OktaXaa,
GoogleServiceAccountJwt,
}

Expand All @@ -668,6 +691,8 @@ impl CliProviderRefreshStrategy {
match self {
Self::Oauth2RefreshToken => "oauth2_refresh_token",
Self::Oauth2ClientCredentials => "oauth2_client_credentials",
Self::Oauth2TokenExchange => "oauth2_token_exchange",
Self::OktaXaa => "okta_xaa",
Self::GoogleServiceAccountJwt => "google_service_account_jwt",
}
}
Expand Down Expand Up @@ -2925,6 +2950,7 @@ async fn main() -> Result<()> {
let channel = openshell_cli::tls::build_channel(&ctx.endpoint, &tls).await?;
let interceptor = openshell_core::auth::EdgeAuthInterceptor::new(
tls.oidc_token.as_deref(),
tls.oidc_id_token.as_deref(),
tls.edge_token.as_deref(),
)?;
openshell_tui::run(channel, interceptor, &ctx.name, &ctx.endpoint, theme).await?;
Expand Down
97 changes: 79 additions & 18 deletions crates/openshell-cli/src/oidc_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@ use hyper::{Method, Response, StatusCode};
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder;
use miette::{IntoDiagnostic, Result};
use oauth2::basic::BasicClient;
use oauth2::{
AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge,
RedirectUrl, RefreshToken, Scope, TokenResponse, TokenUrl,
AuthType, AuthUrl, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken,
EndpointNotSet, ExtraTokenFields, PkceCodeChallenge, RedirectUrl, RefreshToken, Scope,
StandardRevocableToken, StandardTokenResponse, TokenResponse, TokenUrl,
basic::{
BasicErrorResponse, BasicRevocationErrorResponse, BasicTokenIntrospectionResponse,
BasicTokenType,
},
};
use openshell_bootstrap::oidc_token::OidcTokenBundle;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::sync::{Arc, Mutex};
use std::time::Duration;
Expand All @@ -29,6 +33,37 @@ use tokio::sync::oneshot;
use tracing::debug;

const AUTH_TIMEOUT: Duration = Duration::from_secs(120);
const DEFAULT_OIDC_CALLBACK_BIND: &str = "127.0.0.1:0";
const OIDC_CALLBACK_PORT_ENV: &str = "OPENSHELL_OIDC_CALLBACK_PORT";
const OIDC_CLIENT_SECRET_ENV: &str = "OPENSHELL_OIDC_CLIENT_SECRET";

#[derive(Clone, Debug, Default, Deserialize, Serialize)]
struct OidcExtraTokenFields {
#[serde(default, skip_serializing_if = "Option::is_none")]
id_token: Option<String>,
}

impl ExtraTokenFields for OidcExtraTokenFields {}

type OidcTokenResponse = StandardTokenResponse<OidcExtraTokenFields, BasicTokenType>;
type OidcClient<
HasAuthUrl = EndpointNotSet,
HasDeviceAuthUrl = EndpointNotSet,
HasIntrospectionUrl = EndpointNotSet,
HasRevocationUrl = EndpointNotSet,
HasTokenUrl = EndpointNotSet,
> = Client<
BasicErrorResponse,
OidcTokenResponse,
BasicTokenIntrospectionResponse,
StandardRevocableToken,
BasicRevocationErrorResponse,
HasAuthUrl,
HasDeviceAuthUrl,
HasIntrospectionUrl,
HasRevocationUrl,
HasTokenUrl,
>;

/// OIDC discovery document (subset of fields we need).
#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -95,6 +130,25 @@ fn build_ci_scopes(scopes: Option<&str>) -> Vec<Scope> {
.collect()
}

fn oidc_callback_bind_address() -> Result<String> {
match std::env::var(OIDC_CALLBACK_PORT_ENV) {
Ok(raw) => {
let port = raw.parse::<u16>().map_err(|_| {
miette::miette!(
"{OIDC_CALLBACK_PORT_ENV} must be a valid TCP port number, got '{raw}'"
)
})?;
if port == 0 {
return Err(miette::miette!(
"{OIDC_CALLBACK_PORT_ENV} must be greater than 0"
));
}
Ok(format!("127.0.0.1:{port}"))
}
Err(_) => Ok(DEFAULT_OIDC_CALLBACK_BIND.to_string()),
}
}

/// Run the OIDC Authorization Code + PKCE browser flow.
///
/// Opens the user's browser to the Keycloak login page and waits for
Expand All @@ -108,14 +162,21 @@ pub async fn oidc_browser_auth_flow(
) -> Result<OidcTokenBundle> {
let discovery = discover(issuer, insecure).await?;

let listener = TcpListener::bind("127.0.0.1:0").await.into_diagnostic()?;
let listener = TcpListener::bind(oidc_callback_bind_address()?)
.await
.into_diagnostic()?;
let port = listener.local_addr().into_diagnostic()?.port();
let redirect_uri = format!("http://127.0.0.1:{port}/callback");

let client = BasicClient::new(ClientId::new(client_id.to_string()))
let mut client = OidcClient::new(ClientId::new(client_id.to_string()))
.set_auth_uri(AuthUrl::new(discovery.authorization_endpoint).into_diagnostic()?)
.set_token_uri(TokenUrl::new(discovery.token_endpoint).into_diagnostic()?)
.set_redirect_uri(RedirectUrl::new(redirect_uri).into_diagnostic()?);
if let Ok(client_secret) = std::env::var(OIDC_CLIENT_SECRET_ENV) {
client = client
.set_client_secret(ClientSecret::new(client_secret))
.set_auth_type(AuthType::RequestBody);
}

let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();

Expand Down Expand Up @@ -167,7 +228,7 @@ pub async fn oidc_browser_auth_flow(
server_handle.abort();

let http = http_client(insecure);
let token_response = client
let token_response: OidcTokenResponse = client
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(pkce_verifier)
.request_async(&http)
Expand All @@ -191,15 +252,15 @@ pub async fn oidc_client_credentials_flow(
scopes: Option<&str>,
insecure: bool,
) -> Result<OidcTokenBundle> {
let client_secret = std::env::var("OPENSHELL_OIDC_CLIENT_SECRET").map_err(|_| {
let client_secret = std::env::var(OIDC_CLIENT_SECRET_ENV).map_err(|_| {
miette::miette!(
"OPENSHELL_OIDC_CLIENT_SECRET environment variable is required for client credentials flow"
"{OIDC_CLIENT_SECRET_ENV} environment variable is required for client credentials flow"
)
})?;

let discovery = discover(issuer, insecure).await?;

let client = BasicClient::new(ClientId::new(client_id.to_string()))
let client = OidcClient::new(ClientId::new(client_id.to_string()))
.set_client_secret(ClientSecret::new(client_secret))
.set_token_uri(TokenUrl::new(discovery.token_endpoint).into_diagnostic()?)
.set_auth_type(AuthType::RequestBody);
Expand All @@ -213,7 +274,7 @@ pub async fn oidc_client_credentials_flow(
}

let http = http_client(insecure);
let token_response = request
let token_response: OidcTokenResponse = request
.request_async(&http)
.await
.map_err(|e| miette::miette!("client credentials token exchange failed: {e}"))?;
Expand Down Expand Up @@ -241,11 +302,11 @@ pub async fn oidc_refresh_token(

let discovery = discover(&bundle.issuer, insecure).await?;

let client = BasicClient::new(ClientId::new(bundle.client_id.clone()))
let client = OidcClient::new(ClientId::new(bundle.client_id.clone()))
.set_token_uri(TokenUrl::new(discovery.token_endpoint).into_diagnostic()?);

let http = http_client(insecure);
let token_response = client
let token_response: OidcTokenResponse = client
.exchange_refresh_token(&RefreshToken::new(refresh_token.to_string()))
.request_async(&http)
.await
Expand Down Expand Up @@ -287,7 +348,7 @@ pub async fn ensure_valid_oidc_token(gateway_name: &str, insecure: bool) -> Resu
// ── Helpers ──────────────────────────────────────────────────────────

fn bundle_from_oauth2_response(
resp: &oauth2::basic::BasicTokenResponse,
resp: &OidcTokenResponse,
issuer: &str,
client_id: &str,
) -> OidcTokenBundle {
Expand All @@ -298,6 +359,7 @@ fn bundle_from_oauth2_response(

OidcTokenBundle {
access_token: resp.access_token().secret().clone(),
id_token: resp.extra_fields().id_token.clone(),
refresh_token: resp.refresh_token().map(|rt| rt.secret().clone()),
expires_at: resp.expires_in().map(|ei| now + ei.as_secs()),
issuer: issuer.to_string(),
Expand Down Expand Up @@ -518,14 +580,13 @@ mod tests {

#[test]
fn bundle_from_response_sets_fields() {
use oauth2::basic::BasicTokenResponse;

let token_response: BasicTokenResponse = serde_json::from_str(
r#"{"access_token":"test-access","token_type":"bearer","expires_in":300,"refresh_token":"test-refresh"}"#,
let token_response: OidcTokenResponse = serde_json::from_str(
r#"{"access_token":"test-access","token_type":"bearer","expires_in":300,"refresh_token":"test-refresh","id_token":"test-id"}"#,
)
.unwrap();
let bundle = bundle_from_oauth2_response(&token_response, "https://issuer", "my-client");
assert_eq!(bundle.access_token, "test-access");
assert_eq!(bundle.id_token.as_deref(), Some("test-id"));
assert_eq!(bundle.refresh_token.as_deref(), Some("test-refresh"));
assert_eq!(bundle.issuer, "https://issuer");
assert_eq!(bundle.client_id, "my-client");
Expand Down
4 changes: 4 additions & 0 deletions crates/openshell-cli/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5004,6 +5004,8 @@ fn provider_refresh_strategy(strategy: &str) -> Result<ProviderCredentialRefresh
"oauth2_client_credentials" => {
Ok(ProviderCredentialRefreshStrategy::Oauth2ClientCredentials)
}
"oauth2_token_exchange" => Ok(ProviderCredentialRefreshStrategy::Oauth2TokenExchange),
"okta_xaa" => Ok(ProviderCredentialRefreshStrategy::OktaXaa),
"google_service_account_jwt" => {
Ok(ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt)
}
Expand Down Expand Up @@ -5058,6 +5060,8 @@ fn provider_refresh_strategy_name(strategy: ProviderCredentialRefreshStrategy) -
ProviderCredentialRefreshStrategy::External => "external",
ProviderCredentialRefreshStrategy::Oauth2RefreshToken => "oauth2_refresh_token",
ProviderCredentialRefreshStrategy::Oauth2ClientCredentials => "oauth2_client_credentials",
ProviderCredentialRefreshStrategy::Oauth2TokenExchange => "oauth2_token_exchange",
ProviderCredentialRefreshStrategy::OktaXaa => "okta_xaa",
ProviderCredentialRefreshStrategy::GoogleServiceAccountJwt => "google_service_account_jwt",
ProviderCredentialRefreshStrategy::Unspecified => "unspecified",
}
Expand Down
10 changes: 9 additions & 1 deletion crates/openshell-cli/src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ pub struct TlsOptions {
/// OIDC bearer token — when set, injects `authorization: Bearer <token>`
/// on every gRPC request. Takes precedence over `edge_token`.
pub oidc_token: Option<String>,
/// OIDC ID token — when set, injects a gateway-private metadata header
/// so delegated/XAA flows can bind the signed-in user session.
pub oidc_id_token: Option<String>,
/// Skip TLS certificate verification for gateway connections.
pub gateway_insecure: bool,
}
Expand All @@ -53,6 +56,7 @@ impl TlsOptions {
gateway_name: None,
edge_token: None,
oidc_token: None,
oidc_id_token: None,
gateway_insecure: false,
}
}
Expand Down Expand Up @@ -441,7 +445,11 @@ pub async fn grpc_client(server: &str, tls: &TlsOptions) -> Result<GrpcClient> {
}

fn interceptor_from_tls(tls: &TlsOptions) -> Result<EdgeAuthInterceptor> {
EdgeAuthInterceptor::new(tls.oidc_token.as_deref(), tls.edge_token.as_deref())
EdgeAuthInterceptor::new(
tls.oidc_token.as_deref(),
tls.oidc_id_token.as_deref(),
tls.edge_token.as_deref(),
)
}

pub async fn grpc_inference_client(server: &str, tls: &TlsOptions) -> Result<GrpcInferenceClient> {
Expand Down
Loading
Loading