-
Notifications
You must be signed in to change notification settings - Fork 88
LCORE-2308: LlamaStack Pydantic AI Provider #1806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
3def1bc
(feat) inital implementation
jrobertboos 2c38478
(fix) style
jrobertboos 9926b0c
(fix) pydocstyle
jrobertboos 5460b50
(fix) removed enviorment variables
jrobertboos 7d77ef4
(tests) inital unit tests
jrobertboos 45ba392
(fix) addressed coderabbit
jrobertboos File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Pydantic AI integrations/extensions for Lightspeed Core Stack.""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| """Pydantic AI provider for Llama Stack.""" | ||
|
|
||
| from pydantic_ai_lightspeed.llamastack._provider import LlamaStackProvider | ||
|
|
||
| __all__ = ["LlamaStackProvider"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| """Llama Stack provider implementation for Pydantic AI.""" | ||
|
|
||
| from __future__ import annotations as _annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import httpx | ||
| from openai import AsyncOpenAI | ||
| from pydantic_ai import ModelProfile | ||
| from pydantic_ai.models import create_async_http_client | ||
| from pydantic_ai.profiles.openai import openai_model_profile | ||
| from pydantic_ai.providers import Provider | ||
|
|
||
| from pydantic_ai_lightspeed.llamastack._transport import LlamaStackLibraryTransport | ||
|
|
||
| if TYPE_CHECKING: | ||
| from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient | ||
|
|
||
| DEFAULT_BASE_URL = "http://localhost:8321/v1" | ||
|
|
||
|
|
||
| class LlamaStackProvider(Provider[AsyncOpenAI]): | ||
| """Provider for Llama Stack — connects to a Llama Stack server's OpenAI-compatible API. | ||
|
|
||
| Supports two modes: | ||
|
|
||
| 1. **Server mode** — connect to a running Llama Stack server via HTTP | ||
| 2. **Library mode** — run Llama Stack in-process via ``AsyncLlamaStackAsLibraryClient`` | ||
| """ | ||
|
|
||
| @property | ||
| def name(self) -> str: | ||
| """The provider name.""" | ||
| return "llama-stack" | ||
|
|
||
| @property | ||
| def base_url(self) -> str: | ||
| """The base URL for the provider API.""" | ||
| return str(self._client.base_url) | ||
|
|
||
| @property | ||
| def client(self) -> AsyncOpenAI: | ||
| """The OpenAI-compatible client for the provider.""" | ||
| return self._client | ||
|
|
||
| @staticmethod | ||
| def model_profile(model_name: str) -> ModelProfile | None: | ||
| """Return the model profile for the named model, if available.""" | ||
| return openai_model_profile(model_name) | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| base_url: str | None = None, | ||
| api_key: str | None = None, | ||
| library_client: AsyncLlamaStackAsLibraryClient | None = None, | ||
| http_client: httpx.AsyncClient | None = None, | ||
| ) -> None: | ||
| """Create a new Llama Stack provider. | ||
|
|
||
| Args: | ||
| base_url: The base URL for the Llama Stack server (OpenAI-compatible endpoint). | ||
| Defaults to ``http://localhost:8321/v1``. | ||
| Must be ``None`` when ``library_client`` is provided. | ||
| api_key: The API key for authentication. Defaults to ``'not-needed'`` since | ||
| local Llama Stack servers typically don't require one. | ||
| Must be ``None`` when ``library_client`` is provided. | ||
| library_client: An initialized ``AsyncLlamaStackAsLibraryClient`` for library mode. | ||
| When provided, requests are dispatched in-process (no server needed). | ||
| Mutually exclusive with ``base_url``, ``api_key``, and ``http_client``. | ||
| http_client: An existing ``httpx.AsyncClient`` to use for making HTTP requests. | ||
| Must be ``None`` when ``library_client`` is provided. | ||
| """ | ||
| if library_client is not None: | ||
| if base_url is not None: | ||
| raise ValueError("Cannot provide both `library_client` and `base_url`") | ||
| if api_key is not None: | ||
| raise ValueError("Cannot provide both `library_client` and `api_key`") | ||
| if http_client is not None: | ||
| raise ValueError( | ||
| "Cannot provide both `library_client` and `http_client`" | ||
| ) | ||
|
|
||
| self._library_client = library_client | ||
| transport = LlamaStackLibraryTransport(library_client) | ||
| lib_http_client = httpx.AsyncClient( | ||
| transport=transport, | ||
Check warningCode scanning / Bandit Call to httpx without timeout Warning
Call to httpx without timeout
|
||
|
Comment on lines
+86
to
+87
|
||
| base_url="http://llama-stack-library", | ||
| timeout=httpx.Timeout(None), | ||
| ) | ||
| self._client = AsyncOpenAI( | ||
| http_client=lib_http_client, | ||
| base_url="http://llama-stack-library/v1", | ||
| api_key="not-needed", | ||
| ) | ||
| else: | ||
| base_url = base_url or DEFAULT_BASE_URL | ||
| api_key = api_key or "not-needed" | ||
|
|
||
| if http_client is not None: | ||
| self._client = AsyncOpenAI( | ||
| base_url=base_url, api_key=api_key, http_client=http_client | ||
| ) | ||
| else: | ||
| oai_http_client = create_async_http_client() | ||
| self._client = AsyncOpenAI( | ||
| base_url=base_url, api_key=api_key, http_client=oai_http_client | ||
| ) | ||
|
|
||
| def __repr__(self) -> str: | ||
| """Return a string representation of the provider.""" | ||
| return f"LlamaStackProvider(name={self.name!r}, base_url={self.base_url!r})" | ||
|
|
||
| def _set_http_client(self, http_client: httpx.AsyncClient) -> None: | ||
| """Inject an httpx.AsyncClient into the underlying OpenAI client. | ||
|
|
||
| Replaces the internal HTTP transport by assigning directly to the | ||
| protected ``self._client._client`` attribute of the AsyncOpenAI instance. | ||
|
|
||
| Args: | ||
| http_client: The async HTTP client to use for subsequent requests. | ||
| """ | ||
| self._client._client = http_client # pyright: ignore[reportPrivateUsage] # pylint: disable=protected-access | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,197 @@ | ||
| """httpx transport that routes OpenAI-compatible requests through a Llama Stack library client.""" | ||
|
|
||
| from __future__ import annotations as _annotations | ||
|
|
||
| import json | ||
| from collections.abc import AsyncGenerator, AsyncIterator | ||
| from typing import Any | ||
|
|
||
| import httpx | ||
| from llama_stack.core.library_client import ( | ||
| AsyncLlamaStackAsLibraryClient, | ||
| convert_pydantic_to_json_value, | ||
| ) | ||
| from llama_stack.core.request_headers import ( | ||
| PROVIDER_DATA_VAR, | ||
| request_provider_data_context, | ||
| ) | ||
| from llama_stack.core.server.routes import find_matching_route | ||
| from llama_stack.core.utils.context import preserve_contexts_async_generator | ||
|
|
||
|
|
||
| class _AsyncByteStream(httpx.AsyncByteStream): | ||
| """Wraps an async byte generator as an httpx AsyncByteStream.""" | ||
|
|
||
| def __init__(self, gen: AsyncGenerator[bytes, None]) -> None: | ||
| """Store an async generator that yields raw bytes for streaming. | ||
|
|
||
| Args: | ||
| gen: An async generator producing byte chunks to stream. | ||
| """ | ||
| self._gen = gen | ||
|
|
||
| async def __aiter__(self) -> AsyncIterator[bytes]: | ||
| """Yield bytes chunks from the wrapped generator. | ||
|
|
||
| Returns: | ||
| An async iterator of bytes fulfilling the httpx.AsyncByteStream contract. | ||
| """ | ||
| async for chunk in self._gen: | ||
| yield chunk | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
|
|
||
| class LlamaStackLibraryTransport(httpx.AsyncBaseTransport): | ||
| """Custom httpx transport that dispatches requests through a Llama Stack library client. | ||
|
|
||
| Instead of making real HTTP calls, this transport routes requests directly | ||
| to the Llama Stack's in-process route handlers via the library client's | ||
| route matching and body conversion logic. | ||
| """ | ||
|
|
||
| def __init__(self, client: AsyncLlamaStackAsLibraryClient) -> None: | ||
| """Initialize the transport with a Llama Stack library client. | ||
|
|
||
| Args: | ||
| client: An initialized ``AsyncLlamaStackAsLibraryClient`` whose route | ||
| handlers will receive dispatched requests. | ||
| """ | ||
| self._client = client | ||
|
|
||
| async def handle_async_request(self, request: httpx.Request) -> httpx.Response: | ||
| """Dispatch an httpx request to the in-process Llama Stack route handlers. | ||
|
|
||
| Args: | ||
| request: The outgoing httpx request to route. | ||
|
|
||
| Returns: | ||
| An httpx response built from the matched route handler result. | ||
|
|
||
| Raises: | ||
| RuntimeError: If the library client has not been initialized. | ||
| """ | ||
| if self._client.route_impls is None: | ||
| raise RuntimeError( | ||
| "Llama Stack library client not initialized. Call initialize() first." | ||
| ) | ||
|
|
||
| method = request.method | ||
| path = request.url.raw_path.decode("utf-8") | ||
|
|
||
| body = json.loads(request.content) if request.content else {} | ||
|
|
||
| headers: dict[str, str] = { | ||
| k.decode("utf-8") if isinstance(k, bytes) else k: ( | ||
| v.decode("utf-8") if isinstance(v, bytes) else v | ||
| ) | ||
| for k, v in request.headers.raw | ||
| } | ||
|
|
||
| if self._client.provider_data: | ||
| keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"] | ||
| if all(key not in headers for key in keys): | ||
| headers["X-LlamaStack-Provider-Data"] = json.dumps( | ||
| self._client.provider_data | ||
| ) | ||
|
|
||
| with request_provider_data_context(headers): | ||
| is_stream = body.get("stream", False) | ||
|
|
||
| if is_stream: | ||
| return await self._handle_streaming(request, method, path, body) | ||
| return await self._handle_non_streaming(request, method, path, body) | ||
|
|
||
| async def _handle_non_streaming( | ||
| self, | ||
| request: httpx.Request, | ||
| method: str, | ||
| path: str, | ||
| body: dict[str, Any], | ||
| ) -> httpx.Response: | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| """Dispatch a non-streaming request to the matched route handler. | ||
|
|
||
| Args: | ||
| request: The original httpx request (attached to the response). | ||
| method: The HTTP method (e.g. ``"POST"``). | ||
| path: The decoded URL path used for route matching. | ||
| body: The parsed JSON request body. | ||
|
|
||
| Returns: | ||
| An httpx.Response containing the JSON-serialized handler result. | ||
|
|
||
| Raises: | ||
| RuntimeError: If route_impls is not initialized. | ||
| """ | ||
| if self._client.route_impls is None: | ||
| raise RuntimeError("route_impls is not initialized") | ||
|
|
||
| matched_func, path_params, _, _ = find_matching_route( | ||
| method, path, self._client.route_impls | ||
| ) | ||
| merged_body = {**body, **path_params} | ||
| merged_body = self._client._convert_body( # pylint: disable=protected-access | ||
| matched_func, merged_body | ||
| ) | ||
|
|
||
| result = await matched_func(**merged_body) | ||
|
|
||
| json_content = json.dumps(convert_pydantic_to_json_value(result)) | ||
| status_code = httpx.codes.OK | ||
|
|
||
| if method.upper() == "DELETE" and result is None: | ||
| status_code = httpx.codes.NO_CONTENT | ||
| json_content = "" | ||
|
|
||
| return httpx.Response( | ||
| status_code=status_code, | ||
| content=json_content.encode("utf-8"), | ||
| headers={"Content-Type": "application/json"}, | ||
| request=request, | ||
| ) | ||
|
|
||
| async def _handle_streaming( | ||
| self, | ||
| request: httpx.Request, | ||
| method: str, | ||
| path: str, | ||
| body: dict[str, Any], | ||
| ) -> httpx.Response: | ||
| """Dispatch a streaming request and return an SSE event-stream response. | ||
|
|
||
| Args: | ||
| request: The original httpx request (attached to the response). | ||
| method: The HTTP method (e.g. ``"POST"``). | ||
| path: The decoded URL path used for route matching. | ||
| body: The parsed JSON request body (must contain ``stream: True``). | ||
|
|
||
| Returns: | ||
| An httpx.Response with a streaming body of SSE-formatted chunks. | ||
|
|
||
| Raises: | ||
| RuntimeError: If route_impls is not initialized. | ||
| """ | ||
| if self._client.route_impls is None: | ||
| raise RuntimeError("route_impls is not initialized") | ||
|
|
||
| func, path_params, _, _ = find_matching_route( | ||
| method, path, self._client.route_impls | ||
| ) | ||
| merged_body = {**body, **path_params} | ||
| merged_body = self._client._convert_body( # pylint: disable=protected-access | ||
| func, merged_body | ||
| ) | ||
|
|
||
| result = await func(**merged_body) | ||
|
|
||
| async def gen() -> AsyncGenerator[bytes, None]: | ||
| async for chunk in result: | ||
| data = json.dumps(convert_pydantic_to_json_value(chunk)) | ||
| yield f"data: {data}\n\n".encode("utf-8") | ||
|
|
||
| wrapped_gen = preserve_contexts_async_generator(gen(), [PROVIDER_DATA_VAR]) | ||
|
|
||
| return httpx.Response( | ||
| status_code=httpx.codes.OK, | ||
| stream=_AsyncByteStream(wrapped_gen), | ||
| headers={"Content-Type": "text/event-stream"}, | ||
| request=request, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Unit tests for the pydantic_ai_lightspeed package.""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Unit tests for pydantic_ai_lightspeed.llamastack sub-package.""" |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we prefer to use
Optional[str]style (for now)