diff --git a/google/genai/_extra_utils.py b/google/genai/_extra_utils.py index d0bbada72..53c45cc3c 100644 --- a/google/genai/_extra_utils.py +++ b/google/genai/_extra_utils.py @@ -690,3 +690,31 @@ def has_agent_platform_mcp_servers( if getattr(tool, 'mcp_servers', None): return True return False + + +def get_usage_header( + config: Optional[types.GenerateContentConfigOrDict] = None, + usage: str = 'afc', +) -> types.GenerateContentConfig: + """Sets the afc version label.""" + usage_header = f'google-genai-sdk/{usage}' + if not config: + config_model = types.GenerateContentConfig() + elif isinstance(config, dict): + config_model = types.GenerateContentConfig(**config) + else: + config_model = config + + if not config_model.http_options: + config_model.http_options = types.HttpOptions() + existing_headers = config_model.http_options.headers or {} + if 'user-agent' in existing_headers: + existing_headers['user-agent'] += usage_header + else: + existing_headers['user-agent'] = usage_header + if 'x-goog-api-client' in existing_headers: + existing_headers['x-goog-api-client'] += usage_header + else: + existing_headers['x-goog-api-client'] = usage_header + config_model.http_options.headers = existing_headers + return config_model \ No newline at end of file diff --git a/google/genai/_replay_api_client.py b/google/genai/_replay_api_client.py index d36931307..f7b1fc30c 100644 --- a/google/genai/_replay_api_client.py +++ b/google/genai/_replay_api_client.py @@ -100,6 +100,12 @@ def _redact_language_label(language_label: str) -> str: return re.sub(r'gl-python/', '{LANGUAGE_LABEL}/', language_label) +def _redact_sdk_usage_label(header_value: str) -> str: + return header_value.replace('google-genai-sdk/afc', '').replace( + 'google-genai-sdk/chat', '' + ).strip() + + def _redact_request_headers(headers: dict[str, str]) -> dict[str, str]: """Redacts headers that should not be recorded.""" redacted_headers = {} @@ -107,13 +113,19 @@ def _redact_request_headers(headers: dict[str, str]) -> dict[str, str]: if header_name.lower() == 'x-goog-api-key': redacted_headers[header_name] = '{REDACTED}' elif header_name.lower() == 'user-agent': - redacted_headers[header_name] = _redact_language_label( - _redact_version_numbers(header_value) + redacted_headers[header_name] = _redact_sdk_usage_label( + _redact_language_label( + _redact_version_numbers(header_value) + ) ).replace('agentplatform-genai-modules', 'vertex-genai-modules') elif header_name.lower() == 'x-goog-api-client': - redacted_headers[header_name] = _redact_language_label( - _redact_version_numbers(header_value) - ).replace('agentplatform-genai-modules', 'vertex-genai-modules') + redacted_headers[header_name] = _redact_sdk_usage_label( + _redact_language_label( + _redact_version_numbers(header_value) + ) + ).replace( + 'agentplatform-genai-modules', 'vertex-genai-modules' + ) elif header_name.lower() == 'x-goog-user-project': continue elif header_name.lower() == 'authorization': diff --git a/google/genai/chats.py b/google/genai/chats.py index 3d5e181e9..5a70a6039 100644 --- a/google/genai/chats.py +++ b/google/genai/chats.py @@ -17,513 +17,533 @@ import sys from typing import AsyncIterator, Awaitable, Optional, Union, get_args +from . import _extra_utils from . import _transformers as t from . import types from .models import AsyncModels, Models -from .types import Content, ContentOrDict, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict +from .types import ( + Content, + ContentOrDict, + GenerateContentConfigOrDict, + GenerateContentResponse, + Part, + PartUnionDict, +) if sys.version_info >= (3, 10): - from typing import TypeGuard + from typing import TypeGuard else: - from typing_extensions import TypeGuard + from typing_extensions import TypeGuard def _validate_content(content: Content) -> bool: - if not content.parts: - return False - for part in content.parts: - if part == Part(): - return False - return True + if not content.parts: + return False + for part in content.parts: + if part == Part(): + return False + return True def _validate_contents(contents: list[Content]) -> bool: - if not contents: - return False - for content in contents: - if not _validate_content(content): - return False - return True + if not contents: + return False + for content in contents: + if not _validate_content(content): + return False + return True def _validate_response(response: GenerateContentResponse) -> bool: - if not response.candidates: - return False - if not response.candidates[0].content: - return False - return _validate_content(response.candidates[0].content) + if not response.candidates: + return False + if not response.candidates[0].content: + return False + return _validate_content(response.candidates[0].content) def _extract_curated_history( comprehensive_history: list[Content], ) -> list[Content]: - """Extracts the curated (valid) history from a comprehensive history. - - The comprehensive history contains all turns (user input and model responses), - including any invalid or rejected model outputs. This function filters that - history to return only the valid turns. - - Args: - comprehensive_history: A list representing the complete chat history. - Including invalid turns. - - Returns: - curated history, which is a list of valid turns. - """ - if not comprehensive_history: - return [] - curated_history = [] - length = len(comprehensive_history) - i = 0 - current_input = comprehensive_history[i] - while i < length: - if comprehensive_history[i].role not in ["user", "model"]: - raise ValueError( - f"Role must be user or model, but got {comprehensive_history[i].role}" - ) - - if comprehensive_history[i].role == "user": - current_input = comprehensive_history[i] - curated_history.append(current_input) - i += 1 - else: - current_output = [] - is_valid = True - while i < length and comprehensive_history[i].role == "model": - current_output.append(comprehensive_history[i]) - if is_valid and not _validate_content(comprehensive_history[i]): - is_valid = False - i += 1 - if is_valid: - curated_history.extend(current_output) - elif curated_history: - curated_history.pop() - return curated_history - - -class _BaseChat: - """Base chat session.""" - - def __init__( - self, - *, - model: str, - config: Optional[GenerateContentConfigOrDict] = None, - history: list[ContentOrDict], - ): - self._model = model - self._config = config - content_models = [] - for content in history: - if not isinstance(content, Content): - content_model = Content.model_validate(content) - else: - content_model = content - content_models.append(content_model) - self._comprehensive_history = content_models - """Comprehensive history is the full history of the chat, including turns of the invalid contents from the model and their associated inputs. - """ - self._curated_history = _extract_curated_history(content_models) - """Curated history is the set of valid turns that will be used in the subsequent send requests. - """ + """Extracts the curated (valid) history from a comprehensive history. - def record_history( - self, - user_input: Content, - model_output: list[Content], - automatic_function_calling_history: list[Content], - is_valid: bool, - ) -> None: - """Records the chat history. - - Maintaining both comprehensive and curated histories. + The comprehensive history contains all turns (user input and model responses), + including any invalid or rejected model outputs. This function filters that + history to return only the valid turns. Args: - user_input: The user's input content. - model_output: A list of `Content` from the model's response. This can be - an empty list if the model produced no output. - automatic_function_calling_history: A list of `Content` representing the - history of automatic function calls, including the user input as the - first entry. - is_valid: A boolean flag indicating whether the current model output is - considered valid. + comprehensive_history: A list representing the complete chat history. + Including invalid turns. + + Returns: + curated history, which is a list of valid turns. """ - input_contents = ( - # Because the AFC input contains the entire curated chat history in - # addition to the new user input, we need to truncate the AFC history - # to deduplicate the existing chat history. - automatic_function_calling_history[len(self._curated_history) :] - if automatic_function_calling_history - else [user_input] - ) - # Appends an empty content when model returns empty response, so that the - # history is always alternating between user and model. - output_contents = ( - model_output if model_output else [Content(role="model", parts=[])] - ) - self._comprehensive_history.extend(input_contents) - self._comprehensive_history.extend(output_contents) - if is_valid: - self._curated_history.extend(input_contents) - self._curated_history.extend(output_contents) - - def get_history(self, curated: bool = False) -> list[Content]: - """Returns the chat history. + if not comprehensive_history: + return [] + curated_history = [] + length = len(comprehensive_history) + i = 0 + current_input = comprehensive_history[i] + while i < length: + if comprehensive_history[i].role not in ["user", "model"]: + raise ValueError( + f"Role must be user or model, but got {comprehensive_history[i].role}" + ) + + if comprehensive_history[i].role == "user": + current_input = comprehensive_history[i] + curated_history.append(current_input) + i += 1 + else: + current_output = [] + is_valid = True + while i < length and comprehensive_history[i].role == "model": + current_output.append(comprehensive_history[i]) + if is_valid and not _validate_content(comprehensive_history[i]): + is_valid = False + i += 1 + if is_valid: + curated_history.extend(current_output) + elif curated_history: + curated_history.pop() + return curated_history - Args: - curated: A boolean flag indicating whether to return the curated (valid) - history or the comprehensive (all turns) history. Defaults to False - (returns the comprehensive history). - Returns: - A list of `Content` objects representing the chat history. +class _BaseChat: + """Base chat session.""" + + def __init__( + self, + *, + model: str, + config: Optional[GenerateContentConfigOrDict] = None, + history: list[ContentOrDict], + ): + self._model = model + self._config = _extra_utils.get_usage_header(config, usage="chat") + content_models = [] + for content in history: + if not isinstance(content, Content): + content_model = Content.model_validate(content) + else: + content_model = content + content_models.append(content_model) + self._comprehensive_history = content_models + """Comprehensive history is the full history of the chat, including turns of the invalid contents from the model and their associated inputs. + """ + self._curated_history = _extract_curated_history(content_models) + """Curated history is the set of valid turns that will be used in the subsequent send requests. """ - if curated: - return self._curated_history - else: - return self._comprehensive_history + + def record_history( + self, + user_input: Content, + model_output: list[Content], + automatic_function_calling_history: list[Content], + is_valid: bool, + ) -> None: + """Records the chat history. + + Maintaining both comprehensive and curated histories. + + Args: + user_input: The user's input content. + model_output: A list of `Content` from the model's response. This can be + an empty list if the model produced no output. + automatic_function_calling_history: A list of `Content` representing the + history of automatic function calls, including the user input as the + first entry. + is_valid: A boolean flag indicating whether the current model output is + considered valid. + """ + input_contents = ( + # Because the AFC input contains the entire curated chat history in + # addition to the new user input, we need to truncate the AFC history + # to deduplicate the existing chat history. + automatic_function_calling_history[len(self._curated_history) :] + if automatic_function_calling_history + else [user_input] + ) + # Appends an empty content when model returns empty response, so that the + # history is always alternating between user and model. + output_contents = ( + model_output if model_output else [Content(role="model", parts=[])] + ) + self._comprehensive_history.extend(input_contents) + self._comprehensive_history.extend(output_contents) + if is_valid: + self._curated_history.extend(input_contents) + self._curated_history.extend(output_contents) + + def get_history(self, curated: bool = False) -> list[Content]: + """Returns the chat history. + + Args: + curated: A boolean flag indicating whether to return the curated (valid) + history or the comprehensive (all turns) history. Defaults to False + (returns the comprehensive history). + + Returns: + A list of `Content` objects representing the chat history. + """ + if curated: + return self._curated_history + else: + return self._comprehensive_history def _is_part_type( contents: Union[list[PartUnionDict], PartUnionDict], ) -> TypeGuard[t.ContentType]: - if isinstance(contents, list): - return all(_is_part_type(part) for part in contents) - else: - allowed_part_types = get_args(types.PartUnion) - if type(contents) in allowed_part_types: - return True + if isinstance(contents, list): + return all(_is_part_type(part) for part in contents) else: - # Some images don't pass isinstance(item, PIL.Image.Image) - # For example - if types.PIL_Image is not None and isinstance(contents, types.PIL_Image): - return True - return False + allowed_part_types = get_args(types.PartUnion) + if type(contents) in allowed_part_types: + return True + else: + # Some images don't pass isinstance(item, PIL.Image.Image) + # For example + if types.PIL_Image is not None and isinstance(contents, types.PIL_Image): + return True + return False class Chat(_BaseChat): - """Chat session.""" - - def __init__( - self, - *, - modules: Models, - model: str, - config: Optional[GenerateContentConfigOrDict] = None, - history: list[ContentOrDict], - ): - self._modules = modules - super().__init__( - model=model, - config=config, - history=history, - ) - - def send_message( - self, - message: Union[list[PartUnionDict], PartUnionDict], - config: Optional[GenerateContentConfigOrDict] = None, - ) -> GenerateContentResponse: - """Sends the conversation history with the additional message and returns the model's response. - - Args: - message: The message to send to the model. - config: Optional config to override the default Chat config for this - request. - - Returns: - The model's response. - - Usage: - - .. code-block:: python - - chat = client.chats.create(model='gemini-2.0-flash') - response = chat.send_message('tell me a story') - """ - - if not _is_part_type(message): - raise ValueError( - f"Message must be a valid part type: {types.PartUnion} or" - f" {types.PartUnionDict}, got {type(message)}" - ) - input_content = t.t_content(message) - response = self._modules.generate_content( - model=self._model, - contents=self._curated_history + [input_content], # type: ignore[arg-type] - config=config if config else self._config, - ) - model_output = ( - [response.candidates[0].content] - if response.candidates and response.candidates[0].content - else [] - ) - automatic_function_calling_history = ( - response.automatic_function_calling_history - if response.automatic_function_calling_history - else [] - ) - self.record_history( - user_input=input_content, - model_output=model_output, - automatic_function_calling_history=automatic_function_calling_history, - is_valid=_validate_response(response), - ) - return response - - def send_message_stream( - self, - message: Union[list[PartUnionDict], PartUnionDict], - config: Optional[GenerateContentConfigOrDict] = None, - ) -> Iterator[GenerateContentResponse]: - """Sends the conversation history with the additional message and yields the model's response in chunks. - - Args: - message: The message to send to the model. - config: Optional config to override the default Chat config for this - request. - - Yields: - The model's response in chunks. - - Usage: - - .. code-block:: python - - chat = client.chats.create(model='gemini-2.0-flash') - for chunk in chat.send_message_stream('tell me a story'): - print(chunk.text) - """ - - if not _is_part_type(message): - raise ValueError( - f"Message must be a valid part type: {types.PartUnion} or" - f" {types.PartUnionDict}, got {type(message)}" - ) - input_content = t.t_content(message) - output_contents = [] - finish_reason = None - is_valid = True - chunk = None - if isinstance(self._modules, Models): - for chunk in self._modules.generate_content_stream( - model=self._model, - contents=self._curated_history + [input_content], # type: ignore[arg-type] - config=config if config else self._config, - ): - if not _validate_response(chunk): - is_valid = False - if chunk.candidates and chunk.candidates[0].content: - output_contents.append(chunk.candidates[0].content) - if chunk.candidates and chunk.candidates[0].finish_reason: - finish_reason = chunk.candidates[0].finish_reason - yield chunk - automatic_function_calling_history = ( - chunk.automatic_function_calling_history - if chunk is not None and chunk.automatic_function_calling_history - else [] - ) - self.record_history( - user_input=input_content, - model_output=output_contents, - automatic_function_calling_history=automatic_function_calling_history, - is_valid=is_valid - and output_contents is not None - and finish_reason is not None, - ) + """Chat session.""" + + def __init__( + self, + *, + modules: Models, + model: str, + config: Optional[GenerateContentConfigOrDict] = None, + history: list[ContentOrDict], + ): + self._modules = modules + super().__init__( + model=model, + config=config, + history=history, + ) + + def send_message( + self, + message: Union[list[PartUnionDict], PartUnionDict], + config: Optional[GenerateContentConfigOrDict] = None, + ) -> GenerateContentResponse: + """Sends the conversation history with the additional message and returns the model's response. + + Args: + message: The message to send to the model. + config: Optional config to override the default Chat config for this + request. + + Returns: + The model's response. + + Usage: + + .. code-block:: python + + chat = client.chats.create(model='gemini-2.0-flash') + response = chat.send_message('tell me a story') + """ + + if not _is_part_type(message): + raise ValueError( + f"Message must be a valid part type: {types.PartUnion} or" + f" {types.PartUnionDict}, got {type(message)}" + ) + input_content = t.t_content(message) + method_config = config if config else self._config + method_config = _extra_utils.get_usage_header(method_config, usage="chat") + response = self._modules.generate_content( + model=self._model, + contents=self._curated_history + [input_content], # type: ignore[arg-type] + config=method_config, + ) + model_output = ( + [response.candidates[0].content] + if response.candidates and response.candidates[0].content + else [] + ) + automatic_function_calling_history = ( + response.automatic_function_calling_history + if response.automatic_function_calling_history + else [] + ) + self.record_history( + user_input=input_content, + model_output=model_output, + automatic_function_calling_history=automatic_function_calling_history, + is_valid=_validate_response(response), + ) + return response + + def send_message_stream( + self, + message: Union[list[PartUnionDict], PartUnionDict], + config: Optional[GenerateContentConfigOrDict] = None, + ) -> Iterator[GenerateContentResponse]: + """Sends the conversation history with the additional message and yields the model's response in chunks. + + Args: + message: The message to send to the model. + config: Optional config to override the default Chat config for this + request. + + Yields: + The model's response in chunks. + + Usage: + + .. code-block:: python + + chat = client.chats.create(model='gemini-2.0-flash') + for chunk in chat.send_message_stream('tell me a story'): + print(chunk.text) + """ + + if not _is_part_type(message): + raise ValueError( + f"Message must be a valid part type: {types.PartUnion} or" + f" {types.PartUnionDict}, got {type(message)}" + ) + input_content = t.t_content(message) + output_contents = [] + finish_reason = None + is_valid = True + chunk = None + method_config = config if config else self._config + method_config = _extra_utils.get_usage_header(method_config, usage="chat") + if isinstance(self._modules, Models): + for chunk in self._modules.generate_content_stream( + model=self._model, + contents=self._curated_history + [input_content], # type: ignore[arg-type] + config=method_config, + ): + if not _validate_response(chunk): + is_valid = False + if chunk.candidates and chunk.candidates[0].content: + output_contents.append(chunk.candidates[0].content) + if chunk.candidates and chunk.candidates[0].finish_reason: + finish_reason = chunk.candidates[0].finish_reason + yield chunk + automatic_function_calling_history = ( + chunk.automatic_function_calling_history + if chunk is not None and chunk.automatic_function_calling_history + else [] + ) + self.record_history( + user_input=input_content, + model_output=output_contents, + automatic_function_calling_history=automatic_function_calling_history, + is_valid=is_valid + and output_contents is not None + and finish_reason is not None, + ) class Chats: - """A util class to create chat sessions.""" - - def __init__(self, modules: Models): - self._modules = modules - - def create( - self, - *, - model: str, - config: Optional[GenerateContentConfigOrDict] = None, - history: Optional[list[ContentOrDict]] = None, - ) -> Chat: - """Creates a new chat session. - - Args: - model: The model to use for the chat. - config: The configuration to use for the generate content request. - history: The history to use for the chat. - - Returns: - A new chat session. - """ - return Chat( - modules=self._modules, - model=model, - config=config, - history=history if history else [], - ) + """A util class to create chat sessions.""" + + def __init__(self, modules: Models): + self._modules = modules + + def create( + self, + *, + model: str, + config: Optional[GenerateContentConfigOrDict] = None, + history: Optional[list[ContentOrDict]] = None, + ) -> Chat: + """Creates a new chat session. + + Args: + model: The model to use for the chat. + config: The configuration to use for the generate content request. + history: The history to use for the chat. + + Returns: + A new chat session. + """ + return Chat( + modules=self._modules, + model=model, + config=config, + history=history if history else [], + ) class AsyncChat(_BaseChat): - """Async chat session.""" - - def __init__( - self, - *, - modules: AsyncModels, - model: str, - config: Optional[GenerateContentConfigOrDict] = None, - history: list[ContentOrDict], - ): - self._modules = modules - super().__init__( - model=model, - config=config, - history=history, - ) - - async def send_message( - self, - message: Union[list[PartUnionDict], PartUnionDict], - config: Optional[GenerateContentConfigOrDict] = None, - ) -> GenerateContentResponse: - """Sends the conversation history with the additional message and returns model's response. - - Args: - message: The message to send to the model. - config: Optional config to override the default Chat config for this - request. - - Returns: - The model's response. - - Usage: - - .. code-block:: python - - chat = client.aio.chats.create(model='gemini-2.0-flash') - response = await chat.send_message('tell me a story') - """ - if not _is_part_type(message): - raise ValueError( - f"Message must be a valid part type: {types.PartUnion} or" - f" {types.PartUnionDict}, got {type(message)}" - ) - input_content = t.t_content(message) - response = await self._modules.generate_content( - model=self._model, - contents=self._curated_history + [input_content], # type: ignore[arg-type] - config=config if config else self._config, - ) - model_output = ( - [response.candidates[0].content] - if response.candidates and response.candidates[0].content - else [] - ) - automatic_function_calling_history = ( - response.automatic_function_calling_history - if response.automatic_function_calling_history - else [] - ) - self.record_history( - user_input=input_content, - model_output=model_output, - automatic_function_calling_history=automatic_function_calling_history, - is_valid=_validate_response(response), - ) - return response - - async def send_message_stream( - self, - message: Union[list[PartUnionDict], PartUnionDict], - config: Optional[GenerateContentConfigOrDict] = None, - ) -> AsyncIterator[GenerateContentResponse]: - """Sends the conversation history with the additional message and yields the model's response in chunks. - - Args: - message: The message to send to the model. - config: Optional config to override the default Chat config for this - request. - - Yields: - The model's response in chunks. - - Usage: - - .. code-block:: python - - chat = client.aio.chats.create(model='gemini-2.0-flash') - async for chunk in await chat.send_message_stream('tell me a story'): - print(chunk.text) - """ - - if not _is_part_type(message): - raise ValueError( - f"Message must be a valid part type: {types.PartUnion} or" - f" {types.PartUnionDict}, got {type(message)}" - ) - input_content = t.t_content(message) - - async def async_generator(): # type: ignore[no-untyped-def] - output_contents = [] - finish_reason = None - is_valid = True - chunk = None - async for chunk in await self._modules.generate_content_stream( # type: ignore[attr-defined] - model=self._model, - contents=self._curated_history + [input_content], # type: ignore[arg-type] - config=config if config else self._config, - ): - if not _validate_response(chunk): - is_valid = False - if chunk.candidates and chunk.candidates[0].content: - output_contents.append(chunk.candidates[0].content) - if chunk.candidates and chunk.candidates[0].finish_reason: - finish_reason = chunk.candidates[0].finish_reason - yield chunk - - if not output_contents or finish_reason is None: - is_valid = False - - self.record_history( - user_input=input_content, - model_output=output_contents, - automatic_function_calling_history=chunk.automatic_function_calling_history - if chunk is not None and chunk.automatic_function_calling_history - else [], - is_valid=is_valid, - ) - - return async_generator() # type: ignore[no-untyped-call, no-any-return] + """Async chat session.""" + + def __init__( + self, + *, + modules: AsyncModels, + model: str, + config: Optional[GenerateContentConfigOrDict] = None, + history: list[ContentOrDict], + ): + self._modules = modules + super().__init__( + model=model, + config=config, + history=history, + ) + + async def send_message( + self, + message: Union[list[PartUnionDict], PartUnionDict], + config: Optional[GenerateContentConfigOrDict] = None, + ) -> GenerateContentResponse: + """Sends the conversation history with the additional message and returns model's response. + + Args: + message: The message to send to the model. + config: Optional config to override the default Chat config for this + request. + + Returns: + The model's response. + + Usage: + + .. code-block:: python + + chat = client.aio.chats.create(model='gemini-2.0-flash') + response = await chat.send_message('tell me a story') + """ + if not _is_part_type(message): + raise ValueError( + f"Message must be a valid part type: {types.PartUnion} or" + f" {types.PartUnionDict}, got {type(message)}" + ) + input_content = t.t_content(message) + method_config = config if config else self._config + method_config = _extra_utils.get_usage_header(method_config, usage="chat") + response = await self._modules.generate_content( + model=self._model, + contents=self._curated_history + [input_content], # type: ignore[arg-type] + config=method_config, + ) + model_output = ( + [response.candidates[0].content] + if response.candidates and response.candidates[0].content + else [] + ) + automatic_function_calling_history = ( + response.automatic_function_calling_history + if response.automatic_function_calling_history + else [] + ) + self.record_history( + user_input=input_content, + model_output=model_output, + automatic_function_calling_history=automatic_function_calling_history, + is_valid=_validate_response(response), + ) + return response + + async def send_message_stream( + self, + message: Union[list[PartUnionDict], PartUnionDict], + config: Optional[GenerateContentConfigOrDict] = None, + ) -> AsyncIterator[GenerateContentResponse]: + """Sends the conversation history with the additional message and yields the model's response in chunks. + + Args: + message: The message to send to the model. + config: Optional config to override the default Chat config for this + request. + + Yields: + The model's response in chunks. + + Usage: + + .. code-block:: python + + chat = client.aio.chats.create(model='gemini-2.0-flash') + async for chunk in await chat.send_message_stream('tell me a story'): + print(chunk.text) + """ + + if not _is_part_type(message): + raise ValueError( + f"Message must be a valid part type: {types.PartUnion} or" + f" {types.PartUnionDict}, got {type(message)}" + ) + input_content = t.t_content(message) + method_config = config if config else self._config + method_config = _extra_utils.get_usage_header( + method_config, usage="chat" + ) + + async def async_generator(): # type: ignore[no-untyped-def] + output_contents = [] + finish_reason = None + is_valid = True + chunk = None + async for chunk in await self._modules.generate_content_stream( # type: ignore[attr-defined] + model=self._model, + contents=self._curated_history + [input_content], # type: ignore[arg-type] + config=method_config, + ): + if not _validate_response(chunk): + is_valid = False + if chunk.candidates and chunk.candidates[0].content: + output_contents.append(chunk.candidates[0].content) + if chunk.candidates and chunk.candidates[0].finish_reason: + finish_reason = chunk.candidates[0].finish_reason + yield chunk + + if not output_contents or finish_reason is None: + is_valid = False + + self.record_history( + user_input=input_content, + model_output=output_contents, + automatic_function_calling_history=( + chunk.automatic_function_calling_history + if chunk is not None and chunk.automatic_function_calling_history + else [] + ), + is_valid=is_valid, + ) + + return async_generator() # type: ignore[no-untyped-call, no-any-return] class AsyncChats: - """A util class to create async chat sessions.""" - - def __init__(self, modules: AsyncModels): - self._modules = modules - - def create( - self, - *, - model: str, - config: Optional[GenerateContentConfigOrDict] = None, - history: Optional[list[ContentOrDict]] = None, - ) -> AsyncChat: - """Creates a new chat session. - - Args: - model: The model to use for the chat. - config: The configuration to use for the generate content request. - history: The history to use for the chat. - - Returns: - A new chat session. - """ - return AsyncChat( - modules=self._modules, - model=model, - config=config, - history=history if history else [], - ) + """A util class to create async chat sessions.""" + + def __init__(self, modules: AsyncModels): + self._modules = modules + + def create( + self, + *, + model: str, + config: Optional[GenerateContentConfigOrDict] = None, + history: Optional[list[ContentOrDict]] = None, + ) -> AsyncChat: + """Creates a new chat session. + + Args: + model: The model to use for the chat. + config: The configuration to use for the generate content request. + history: The history to use for the chat. + + Returns: + A new chat session. + """ + return AsyncChat( + modules=self._modules, + model=model, + config=config, + history=history if history else [], + ) diff --git a/google/genai/models.py b/google/genai/models.py index 547132546..b4f9efde6 100644 --- a/google/genai/models.py +++ b/google/genai/models.py @@ -6476,6 +6476,7 @@ def generate_content( response = types.GenerateContentResponse() i = 0 while remaining_remote_calls_afc > 0: + parsed_config = _extra_utils.get_usage_header(parsed_config) i += 1 response = self._generate_content( model=model, contents=contents, config=parsed_config @@ -6644,6 +6645,7 @@ def generate_content_stream( func_response_parts = None i = 0 while remaining_remote_calls_afc > 0: + parsed_config = _extra_utils.get_usage_header(parsed_config) i += 1 response = self._generate_content_stream( model=model, contents=contents, config=parsed_config @@ -8639,6 +8641,7 @@ async def generate_content( response = types.GenerateContentResponse() while remaining_remote_calls_afc > 0: + final_parsed_config = _extra_utils.get_usage_header(final_parsed_config) response = await self._generate_content( model=model, contents=contents, config=final_parsed_config ) @@ -8834,6 +8837,7 @@ async def async_generator(model, contents, config): # type: ignore[no-untyped-d chunk = None i = 0 while remaining_remote_calls_afc > 0: + config = _extra_utils.get_usage_header(config) i += 1 response = await self._generate_content_stream( model=model, contents=contents, config=config