Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .fern/metadata.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"cliVersion": "5.37.7",
"generatorName": "fernapi/fern-python-sdk",
"generatorVersion": "5.14.4",
"generatorVersion": "5.14.6",
"generatorConfig": {
"inline_request_params": false,
"extras": {
Expand Down Expand Up @@ -94,10 +94,10 @@
}
]
},
"originGitCommit": "6b200e3df80ccea85f9d69dc3c5572410454d8b7",
"originGitCommit": "2fca5765e1ee013584ac0e3caafee618cec00d7f",
"originGitCommitIsDirty": true,
"invokedBy": "ci",
"requestedVersion": "7.0.1",
"requestedVersion": "7.0.2",
"ciProvider": "github",
"sdkVersion": "7.0.1"
"sdkVersion": "7.0.2"
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ dynamic = ["version"]

[tool.poetry]
name = "cohere"
version = "7.0.1"
version = "7.0.2"
description = ""
readme = "README.md"
authors = []
Expand Down
4 changes: 2 additions & 2 deletions src/cohere/core/client_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def get_headers(self) -> typing.Dict[str, str]:
import platform

headers: typing.Dict[str, str] = {
"User-Agent": "cohere/7.0.1",
"User-Agent": "cohere/7.0.2",
"X-Fern-Language": "Python",
"X-Fern-Runtime": f"python/{platform.python_version()}",
"X-Fern-Platform": f"{platform.system().lower()}/{platform.release()}",
"X-Fern-SDK-Name": "cohere",
"X-Fern-SDK-Version": "7.0.1",
"X-Fern-SDK-Version": "7.0.2",
**(self.get_custom_headers() or {}),
}
if self._client_name is not None:
Expand Down
87 changes: 79 additions & 8 deletions src/cohere/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,75 @@ def __init__(self, *, alias: str) -> None:
self.alias = alias


# Resolving type hints (typing.get_type_hints) is expensive because it eval/compiles
# forward-reference annotations. The result is constant for a given type, so we cache it.
# This is critical for hot paths like SSE event parsing, where the same (often large
# discriminated-union) type is converted on every single event.
_type_hints_cache: typing.Dict[typing.Any, typing.Dict[str, typing.Any]] = {}


def _get_cached_type_hints(expected_type: typing.Any) -> typing.Dict[str, typing.Any]:
try:
cached = _type_hints_cache.get(expected_type)
except TypeError:
# Unhashable type; resolve without caching.
return _resolve_type_hints(expected_type)
if cached is None:
cached = _resolve_type_hints(expected_type)
_type_hints_cache[expected_type] = cached
return cached


def _resolve_type_hints(expected_type: typing.Any) -> typing.Dict[str, typing.Any]:
try:
return typing_extensions.get_type_hints(expected_type, include_extras=True)
except NameError:
# The type contains a circular reference, so we use the __annotations__ attribute directly.
return getattr(expected_type, "__annotations__", {})


# Whether convert_and_respect_annotation_metadata can possibly rewrite anything for a given
# annotation, i.e. whether any reachable model/TypedDict field carries a FieldMetadata alias.
# This is constant per type, so we cache it and use it to short-circuit the recursive walk.
_requires_conversion_cache: typing.Dict[typing.Any, bool] = {}


def _requires_conversion(type_: typing.Any) -> bool:
try:
cached = _requires_conversion_cache.get(type_)
except TypeError:
# Unhashable annotation; compute without caching.
return _compute_requires_conversion(type_, set())
if cached is None:
cached = _compute_requires_conversion(type_, set())
_requires_conversion_cache[type_] = cached
return cached


def _compute_requires_conversion(type_: typing.Any, seen: typing.Set[typing.Any]) -> bool:
clean_type = _remove_annotations(type_)

try:
if clean_type in seen:
return False
seen = seen | {clean_type}
except TypeError:
# Unhashable type; skip cycle tracking (the type graph is finite in practice).
pass

# Models / TypedDicts: a field alias here means we must dealias; otherwise recurse into fields.
if (inspect.isclass(clean_type) and issubclass(clean_type, pydantic.BaseModel)) or typing_extensions.is_typeddict(
clean_type
):
annotations = _get_cached_type_hints(clean_type)
if _get_alias_to_field_name(annotations):
return True
return any(_compute_requires_conversion(hint, seen) for hint in annotations.values())

# Containers / unions: recurse into the type arguments (List/Set/Sequence/Dict/Union/etc.).
return any(_compute_requires_conversion(arg, seen) for arg in typing_extensions.get_args(clean_type))


def convert_and_respect_annotation_metadata(
*,
object_: typing.Any,
Expand Down Expand Up @@ -57,6 +126,13 @@ def convert_and_respect_annotation_metadata(
return None
if inner_type is None:
inner_type = annotation
# The only thing this function ever rewrites is keys that carry a FieldMetadata
# alias. If nothing in the (cached) type graph has such an alias, the conversion is
# a content-identity transform, so we can skip the entire recursive walk. This is
# the hot path for SSE streaming, where a large discriminated union would otherwise
# be traversed on every single event.
if not _requires_conversion(annotation):
return object_

clean_type = _remove_annotations(inner_type)
# Pydantic models
Expand Down Expand Up @@ -160,12 +236,7 @@ def _convert_mapping(
direction: typing.Literal["read", "write"],
) -> typing.Mapping[str, object]:
converted_object: typing.Dict[str, object] = {}
try:
annotations = typing_extensions.get_type_hints(expected_type, include_extras=True)
except NameError:
# The TypedDict contains a circular reference, so
# we use the __annotations__ attribute directly.
annotations = getattr(expected_type, "__annotations__", {})
annotations = _get_cached_type_hints(expected_type)
aliases_to_field_names = _get_alias_to_field_name(annotations)
for key, value in object_.items():
if direction == "read" and key in aliases_to_field_names:
Expand Down Expand Up @@ -221,12 +292,12 @@ def _remove_annotations(type_: typing.Any) -> typing.Any:


def get_alias_to_field_mapping(type_: typing.Any) -> typing.Dict[str, str]:
annotations = typing_extensions.get_type_hints(type_, include_extras=True)
annotations = _get_cached_type_hints(type_)
return _get_alias_to_field_name(annotations)


def get_field_to_alias_mapping(type_: typing.Any) -> typing.Dict[str, str]:
annotations = typing_extensions.get_type_hints(type_, include_extras=True)
annotations = _get_cached_type_hints(type_)
return _get_field_to_alias_name(annotations)


Expand Down
Loading