Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/app/endpoints/rlsapi_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
extract_provider_and_model_from_model_id,
handle_known_apistatus_errors,
is_context_length_error,
normalize_vertex_ai_model_id,
)
from utils.quota import check_tokens_available
from utils.responses import (
Expand Down Expand Up @@ -343,9 +344,12 @@ async def _call_llm(

logger.debug("Using model %s for rlsapi v1 inference", resolved_model_id)

# Normalize Vertex AI model IDs to work around llama-stack 0.6.x bug
normalized_model = normalize_vertex_ai_model_id(resolved_model_id)

response = await client.responses.create(
input=question,
model=resolved_model_id,
model=normalized_model,
instructions=instructions,
tools=tools or [],
stream=False,
Expand Down
13 changes: 11 additions & 2 deletions src/utils/compaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from log import get_logger
from models.compaction import ConversationSummary
from utils.query import normalize_vertex_ai_model_id
from utils.token_estimator import (
estimate_conversation_tokens,
estimate_tokens,
Expand Down Expand Up @@ -266,10 +267,14 @@ async def summarize_chunk(
# by utils.responses.get_topic_summary and protects the directives from
# prompt-injection via user message content that ends up in the
# transcript.

# Normalize Vertex AI model IDs to work around llama-stack 0.6.x bug
normalized_model = normalize_vertex_ai_model_id(model)

response = await client.responses.create(
input=f"Conversation:\n{transcript}",
instructions=SUMMARIZATION_PROMPT,
model=model,
model=normalized_model,
stream=False,
store=False,
)
Expand Down Expand Up @@ -374,10 +379,14 @@ async def recursively_resummarize(
model,
)
# Same instructions/input split as summarize_chunk — see comment there.

# Normalize Vertex AI model IDs to work around llama-stack 0.6.x bug
normalized_model = normalize_vertex_ai_model_id(model)

response = await client.responses.create(
input=transcript,
instructions=RECURSIVE_RESUMMARIZATION_PROMPT,
model=model,
model=normalized_model,
stream=False,
store=False,
)
Expand Down
28 changes: 28 additions & 0 deletions src/utils/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,34 @@ def extract_provider_and_model_from_model_id(model_id: str) -> tuple[str, str]:
return "", model_id


def normalize_vertex_ai_model_id(model_id: str) -> str:
"""Normalize Vertex AI model ID to work around llama-stack 0.6.x bug.

llama-stack 0.6.x has a bug in the inline::meta-reference responses provider
where it normalizes model IDs before checking against allowed_models, but doesn't
normalize the allowed_models list itself. This causes Vertex AI models to fail
validation because:
- Model is registered as: publishers/google/models/gemini-2.5-flash
- llama-stack strips to: google/gemini-2.5-flash internally
- Checks against allowed list: ['publishers/google/models/gemini-2.5-flash']
- Mismatch → 500 error

This workaround strips the publishers/google/models/ prefix to match what
llama-stack expects internally.

Fixed in llama-stack 0.7.0 via https://github.com/ogx-ai/ogx/pull/5169

Args:
model_id: The model ID, possibly in Vertex AI format

Returns:
Normalized model ID with Vertex AI prefix stripped if present
"""
if model_id.startswith("publishers/google/models/"):
return model_id.replace("publishers/google/models/", "google/", 1)
Copy link
Copy Markdown
Contributor

@asimurka asimurka Jun 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anik120 is "publishers/google/models" some consistent prefix? I thought that the fix will be more variable (namely replacing any model prefix with google/ not hard-coded)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a consistent prefix for Vertex AI models - it's GCloud's standard resource naming convention for Vertex AI model resources. Ref ogx-ai/ogx#5169 that handles this specific format.

This fix is intentionally narrow and targeted because:

  • It's a temporary workaround for llama-stack 0.6.x
  • The bug is specific to Vertex AI models with this exact prefix
  • We preserve other formats unchanged (e.g., models/gemini-2.5-flash for Gemini API)

return model_id


def handle_known_apistatus_errors(
error: LLSApiStatusError | OpenAIAPIStatusError, model_id: str
) -> AbstractErrorResponse:
Expand Down
12 changes: 10 additions & 2 deletions src/utils/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
from utils.query import (
extract_provider_and_model_from_model_id,
handle_known_apistatus_errors,
normalize_vertex_ai_model_id,
prepare_input,
)
from utils.suid import to_llama_stack_conversation_id
Expand Down Expand Up @@ -178,11 +179,14 @@ async def get_topic_summary( # pylint: disable=too-many-nested-blocks
The topic summary for the question
"""
try:
# Normalize Vertex AI model IDs to work around llama-stack 0.6.x bug
normalized_model = normalize_vertex_ai_model_id(model_id)

response = cast(
ResponseObject,
await client.responses.create(
input=question,
model=model_id,
model=normalized_model,
instructions=get_topic_summary_system_prompt(),
stream=False,
store=False, # Don't store topic summary requests
Expand Down Expand Up @@ -389,9 +393,13 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma

# Build x-llamastack-provider-data header from MCP tool headers
extra_headers = _build_provider_data_headers(tools)

# Normalize Vertex AI model IDs to work around llama-stack 0.6.x bug
normalized_model = normalize_vertex_ai_model_id(model)

return ResponsesApiParams(
input=input_text,
model=model,
model=normalized_model,
instructions=system_prompt,
tools=tools,
conversation=llama_stack_conv_id,
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/utils/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from models.api.requests import QueryRequest
from models.common.responses.types import InputTool, InputToolMCP
from models.config import ApprovalFilter, ByokRag, ModelContextProtocolServer
from utils.query import normalize_vertex_ai_model_id
from utils.responses import (
_build_chunk_attributes,
_merge_tools,
Expand Down Expand Up @@ -3577,3 +3578,32 @@ async def test_merge_header_no_server_tools_returns_client_only(
)
assert tools is not None
assert len(tools) == 1


class TestNormalizeVertexAIModelId:
"""Tests for normalize_vertex_ai_model_id function."""

def test_normalizes_vertex_ai_model_id(self) -> None:
"""Test that Vertex AI model IDs are normalized correctly."""
input_model = "publishers/google/models/gemini-2.5-flash"
expected = "google/gemini-2.5-flash"
assert normalize_vertex_ai_model_id(input_model) == expected

def test_normalizes_vertex_ai_model_id_with_version(self) -> None:
"""Test normalization with versioned Vertex AI model ID."""
input_model = "publishers/google/models/gemini-1.5-pro-001"
expected = "google/gemini-1.5-pro-001"
assert normalize_vertex_ai_model_id(input_model) == expected

def test_preserves_non_vertex_ai_model_ids(self) -> None:
"""Test that non-Vertex AI model IDs are returned unchanged."""
# Regular model IDs should pass through
assert normalize_vertex_ai_model_id("gpt-4") == "gpt-4"
assert normalize_vertex_ai_model_id("openai/gpt-4") == "openai/gpt-4"
assert normalize_vertex_ai_model_id("watsonx/model") == "watsonx/model"

def test_preserves_gemini_api_format(self) -> None:
"""Test that Gemini API format (models/...) is preserved."""
# Gemini API format doesn't have the publishers prefix
gemini_api_format = "models/gemini-2.5-flash"
assert normalize_vertex_ai_model_id(gemini_api_format) == gemini_api_format
Loading