Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion google/genai/_interactions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
"DefaultAsyncHttpxClient",
"DefaultAioHttpClient",
"AsyncGeminiNextGenAPIClientAdapter",
"GeminiNextGenAPIClientAdapter"
"GeminiNextGenAPIClientAdapter",
]

if not _t.TYPE_CHECKING:
Expand Down
4 changes: 2 additions & 2 deletions google/genai/_interactions/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,9 @@ def __init__(

def _build_maybe_vertex_path(self, *, api_version: str, path: str) -> str:
if not self._is_vertex or not self._vertex_location or not self._vertex_project:
return f'/{api_version}/{path}'
return f"/{api_version}/{path}"

return f'{api_version}/projects/{self._vertex_project}/locations/{self._vertex_location}/{path}'
return f"{api_version}/projects/{self._vertex_project}/locations/{self._vertex_location}/{path}"

def _enforce_trailing_slash(self, url: URL) -> URL:
if url.raw_path.endswith(b"/"):
Expand Down
52 changes: 36 additions & 16 deletions google/genai/_interactions/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
RequestOptions,
not_given,
)
from ._utils import is_given
from ._utils import is_given, is_mapping_t
from ._compat import cached_property
from ._models import FinalRequestOptions
from ._version import __version__
Expand Down Expand Up @@ -116,6 +116,15 @@ def __init__(

self.client_adapter = client_adapter

custom_headers_env = os.environ.get("GEMINI_NEXT_GEN_API_CUSTOM_HEADERS")
if custom_headers_env is not None:
parsed: dict[str, str] = {}
for line in custom_headers_env.split("\n"):
colon = line.find(":")
if colon >= 0:
parsed[line[:colon].strip()] = line[colon + 1 :].strip()
default_headers = {**parsed, **(default_headers if is_mapping_t(default_headers) else {})}

super().__init__(
version=__version__,
base_url=base_url,
Expand Down Expand Up @@ -179,7 +188,11 @@ def default_headers(self) -> dict[str, str | Omit]:

@override
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
if headers.get("Authorization") or custom_headers.get("Authorization") or isinstance(custom_headers.get("Authorization"), Omit):
if (
headers.get("Authorization")
or custom_headers.get("Authorization")
or isinstance(custom_headers.get("Authorization"), Omit)
):
return
if self.api_key and headers.get("x-goog-api-key"):
return
Expand All @@ -189,25 +202,22 @@ def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
raise TypeError(
'"Could not resolve authentication method. Expected the api_key to be set. Or for the `x-goog-api-key` headers to be explicitly omitted"'
)

@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
if not self.client_adapter or not self.client_adapter.is_vertex_ai():
return options

headers = options.headers or {}
has_auth = headers.get("Authorization") or headers.get("x-goog-api-key") # pytype: disable=attribute-error
has_auth = headers.get("Authorization") or headers.get("x-goog-api-key") # pytype: disable=attribute-error
if has_auth:
return options

adapted_headers = self.client_adapter.get_auth_headers()
if adapted_headers:
options.headers = {
**adapted_headers,
**headers
}
options.headers = {**adapted_headers, **headers}
return options

def copy(
self,
*,
Expand Down Expand Up @@ -350,6 +360,15 @@ def __init__(

self.client_adapter = client_adapter

custom_headers_env = os.environ.get("GEMINI_NEXT_GEN_API_CUSTOM_HEADERS")
if custom_headers_env is not None:
parsed: dict[str, str] = {}
for line in custom_headers_env.split("\n"):
colon = line.find(":")
if colon >= 0:
parsed[line[:colon].strip()] = line[colon + 1 :].strip()
default_headers = {**parsed, **(default_headers if is_mapping_t(default_headers) else {})}

super().__init__(
version=__version__,
base_url=base_url,
Expand Down Expand Up @@ -413,7 +432,11 @@ def default_headers(self) -> dict[str, str | Omit]:

@override
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
if headers.get("Authorization") or custom_headers.get("Authorization") or isinstance(custom_headers.get("Authorization"), Omit):
if (
headers.get("Authorization")
or custom_headers.get("Authorization")
or isinstance(custom_headers.get("Authorization"), Omit)
):
return
if self.api_key and headers.get("x-goog-api-key"):
return
Expand All @@ -423,23 +446,20 @@ def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
raise TypeError(
'"Could not resolve authentication method. Expected the api_key to be set. Or for the `x-goog-api-key` headers to be explicitly omitted"'
)

@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
if not self.client_adapter or not self.client_adapter.is_vertex_ai():
return options

headers = options.headers or {}
has_auth = headers.get("Authorization") or headers.get("x-goog-api-key") # pytype: disable=attribute-error
has_auth = headers.get("Authorization") or headers.get("x-goog-api-key") # pytype: disable=attribute-error
if has_auth:
return options

adapted_headers = await self.client_adapter.async_get_auth_headers()
if adapted_headers:
options.headers = {
**adapted_headers,
**headers
}
options.headers = {**adapted_headers, **headers}
return options

def copy(
Expand Down
21 changes: 7 additions & 14 deletions google/genai/_interactions/_client_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,25 @@

from abc import ABC, abstractmethod

__all__ = [
"GeminiNextGenAPIClientAdapter",
"AsyncGeminiNextGenAPIClientAdapter"
]
__all__ = ["GeminiNextGenAPIClientAdapter", "AsyncGeminiNextGenAPIClientAdapter"]


class BaseGeminiNextGenAPIClientAdapter(ABC):
@abstractmethod
def is_vertex_ai(self) -> bool:
...
def is_vertex_ai(self) -> bool: ...

@abstractmethod
def get_project(self) -> str | None:
...
def get_project(self) -> str | None: ...

@abstractmethod
def get_location(self) -> str | None:
...
def get_location(self) -> str | None: ...


class AsyncGeminiNextGenAPIClientAdapter(BaseGeminiNextGenAPIClientAdapter):
@abstractmethod
async def async_get_auth_headers(self) -> dict[str, str] | None:
...
async def async_get_auth_headers(self) -> dict[str, str] | None: ...


class GeminiNextGenAPIClientAdapter(BaseGeminiNextGenAPIClientAdapter):
@abstractmethod
def get_auth_headers(self) -> dict[str, str] | None:
...
def get_auth_headers(self) -> dict[str, str] | None: ...
6 changes: 5 additions & 1 deletion google/genai/_interactions/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ def __init__(self, *, message: str = "Connection error.", request: httpx.Request

class APITimeoutError(APIConnectionError):
def __init__(self, request: httpx.Request) -> None:
super().__init__(message="Request timed out. This is a client-side timeout. You can increase the timeout by setting the `timeout` argument on your request or in the client http options.", request=request)
super().__init__(
message="Request timed out. This is a client-side timeout. You can increase the timeout by setting the `timeout` argument on your request or in the client http options.",
request=request,
)


class BadRequestError(APIStatusError):
status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]
Expand Down
2 changes: 1 addition & 1 deletion google/genai/_interactions/_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ async def async_to_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles
elif is_sequence_t(files):
files = [(key, await _async_transform_file(file)) for key, file in files]
else:
raise TypeError("Unexpected file type input {type(files)}, expected mapping or sequence")
raise TypeError(f"Unexpected file type input {type(files)}, expected mapping or sequence")

return files

Expand Down
80 changes: 80 additions & 0 deletions google/genai/_interactions/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
ClassVar,
Protocol,
Required,
Annotated,
ParamSpec,
TypeAlias,
TypedDict,
TypeGuard,
final,
Expand Down Expand Up @@ -95,7 +97,15 @@
from ._constants import RAW_RESPONSE_HEADER

if TYPE_CHECKING:
from pydantic import GetCoreSchemaHandler, ValidatorFunctionWrapHandler
from pydantic_core import CoreSchema, core_schema
from pydantic_core.core_schema import ModelField, ModelSchema, LiteralSchema, ModelFieldsSchema
else:
try:
from pydantic_core import CoreSchema, core_schema
except ImportError:
CoreSchema = None
core_schema = None

__all__ = ["BaseModel", "GenericModel"]

Expand Down Expand Up @@ -412,6 +422,76 @@ def model_dump_json(
)


class _EagerIterable(list[_T], Generic[_T]):
"""
Accepts any Iterable[T] input (including generators), consumes it
eagerly, and validates all items upfront.

Validation preserves the original container type where possible
(e.g. a set[T] stays a set[T]). Serialization (model_dump / JSON)
always emits a list — round-tripping through model_dump() will not
restore the original container type.
"""

@classmethod
def __get_pydantic_core_schema__(
cls,
source_type: Any,
handler: GetCoreSchemaHandler,
) -> CoreSchema:
(item_type,) = get_args(source_type) or (Any,)
item_schema: CoreSchema = handler.generate_schema(item_type)
list_of_items_schema: CoreSchema = core_schema.list_schema(item_schema)

return core_schema.no_info_wrap_validator_function(
cls._validate,
list_of_items_schema,
serialization=core_schema.plain_serializer_function_ser_schema(
cls._serialize,
info_arg=False,
),
)

@staticmethod
def _validate(v: Iterable[_T], handler: "ValidatorFunctionWrapHandler") -> Any:
original_type: type[Any] = type(v)

# Normalize to list so list_schema can validate each item
if isinstance(v, list):
items: list[_T] = v
else:
try:
items = list(v)
except TypeError as e:
raise TypeError("Value is not iterable") from e

# Validate items against the inner schema
validated: list[_T] = handler(items)

# Reconstruct original container type
if original_type is list:
return validated
# str(list) produces the list's repr, not a string built from items,
# so skip reconstruction for str and its subclasses.
if issubclass(original_type, str):
return validated
try:
return original_type(validated)
except (TypeError, ValueError):
# If the type cannot be reconstructed, just return the validated list
return validated

@staticmethod
def _serialize(v: Iterable[_T]) -> list[_T]:
"""Always serialize as a list so Pydantic's JSON encoder is happy."""
if isinstance(v, list):
return v
return list(v)


EagerIterable: TypeAlias = Annotated[Iterable[_T], _EagerIterable]


def _construct_field(value: object, field: FieldInfo, key: str) -> object:
if value is None:
return field_get_default(field)
Expand Down
Loading
Loading