diff --git a/.fern/metadata.json b/.fern/metadata.json index 3f69c9919..d8a581559 100644 --- a/.fern/metadata.json +++ b/.fern/metadata.json @@ -1,7 +1,7 @@ { "cliVersion": "5.37.7", "generatorName": "fernapi/fern-python-sdk", - "generatorVersion": "5.14.4", + "generatorVersion": "5.14.6", "generatorConfig": { "inline_request_params": false, "extras": { @@ -94,10 +94,10 @@ } ] }, - "originGitCommit": "6b200e3df80ccea85f9d69dc3c5572410454d8b7", + "originGitCommit": "2fca5765e1ee013584ac0e3caafee618cec00d7f", "originGitCommitIsDirty": true, "invokedBy": "ci", - "requestedVersion": "7.0.1", + "requestedVersion": "7.0.2", "ciProvider": "github", - "sdkVersion": "7.0.1" + "sdkVersion": "7.0.2" } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 28a78da98..a8812d780 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ dynamic = ["version"] [tool.poetry] name = "cohere" -version = "7.0.1" +version = "7.0.2" description = "" readme = "README.md" authors = [] diff --git a/src/cohere/core/client_wrapper.py b/src/cohere/core/client_wrapper.py index d7d1ef431..2aa50a0b6 100644 --- a/src/cohere/core/client_wrapper.py +++ b/src/cohere/core/client_wrapper.py @@ -31,12 +31,12 @@ def get_headers(self) -> typing.Dict[str, str]: import platform headers: typing.Dict[str, str] = { - "User-Agent": "cohere/7.0.1", + "User-Agent": "cohere/7.0.2", "X-Fern-Language": "Python", "X-Fern-Runtime": f"python/{platform.python_version()}", "X-Fern-Platform": f"{platform.system().lower()}/{platform.release()}", "X-Fern-SDK-Name": "cohere", - "X-Fern-SDK-Version": "7.0.1", + "X-Fern-SDK-Version": "7.0.2", **(self.get_custom_headers() or {}), } if self._client_name is not None: diff --git a/src/cohere/core/serialization.py b/src/cohere/core/serialization.py index c36e865cc..1d753e26f 100644 --- a/src/cohere/core/serialization.py +++ b/src/cohere/core/serialization.py @@ -26,6 +26,75 @@ def __init__(self, *, alias: str) -> None: self.alias = alias +# Resolving type hints (typing.get_type_hints) is expensive because it eval/compiles +# forward-reference annotations. The result is constant for a given type, so we cache it. +# This is critical for hot paths like SSE event parsing, where the same (often large +# discriminated-union) type is converted on every single event. +_type_hints_cache: typing.Dict[typing.Any, typing.Dict[str, typing.Any]] = {} + + +def _get_cached_type_hints(expected_type: typing.Any) -> typing.Dict[str, typing.Any]: + try: + cached = _type_hints_cache.get(expected_type) + except TypeError: + # Unhashable type; resolve without caching. + return _resolve_type_hints(expected_type) + if cached is None: + cached = _resolve_type_hints(expected_type) + _type_hints_cache[expected_type] = cached + return cached + + +def _resolve_type_hints(expected_type: typing.Any) -> typing.Dict[str, typing.Any]: + try: + return typing_extensions.get_type_hints(expected_type, include_extras=True) + except NameError: + # The type contains a circular reference, so we use the __annotations__ attribute directly. + return getattr(expected_type, "__annotations__", {}) + + +# Whether convert_and_respect_annotation_metadata can possibly rewrite anything for a given +# annotation, i.e. whether any reachable model/TypedDict field carries a FieldMetadata alias. +# This is constant per type, so we cache it and use it to short-circuit the recursive walk. +_requires_conversion_cache: typing.Dict[typing.Any, bool] = {} + + +def _requires_conversion(type_: typing.Any) -> bool: + try: + cached = _requires_conversion_cache.get(type_) + except TypeError: + # Unhashable annotation; compute without caching. + return _compute_requires_conversion(type_, set()) + if cached is None: + cached = _compute_requires_conversion(type_, set()) + _requires_conversion_cache[type_] = cached + return cached + + +def _compute_requires_conversion(type_: typing.Any, seen: typing.Set[typing.Any]) -> bool: + clean_type = _remove_annotations(type_) + + try: + if clean_type in seen: + return False + seen = seen | {clean_type} + except TypeError: + # Unhashable type; skip cycle tracking (the type graph is finite in practice). + pass + + # Models / TypedDicts: a field alias here means we must dealias; otherwise recurse into fields. + if (inspect.isclass(clean_type) and issubclass(clean_type, pydantic.BaseModel)) or typing_extensions.is_typeddict( + clean_type + ): + annotations = _get_cached_type_hints(clean_type) + if _get_alias_to_field_name(annotations): + return True + return any(_compute_requires_conversion(hint, seen) for hint in annotations.values()) + + # Containers / unions: recurse into the type arguments (List/Set/Sequence/Dict/Union/etc.). + return any(_compute_requires_conversion(arg, seen) for arg in typing_extensions.get_args(clean_type)) + + def convert_and_respect_annotation_metadata( *, object_: typing.Any, @@ -57,6 +126,13 @@ def convert_and_respect_annotation_metadata( return None if inner_type is None: inner_type = annotation + # The only thing this function ever rewrites is keys that carry a FieldMetadata + # alias. If nothing in the (cached) type graph has such an alias, the conversion is + # a content-identity transform, so we can skip the entire recursive walk. This is + # the hot path for SSE streaming, where a large discriminated union would otherwise + # be traversed on every single event. + if not _requires_conversion(annotation): + return object_ clean_type = _remove_annotations(inner_type) # Pydantic models @@ -160,12 +236,7 @@ def _convert_mapping( direction: typing.Literal["read", "write"], ) -> typing.Mapping[str, object]: converted_object: typing.Dict[str, object] = {} - try: - annotations = typing_extensions.get_type_hints(expected_type, include_extras=True) - except NameError: - # The TypedDict contains a circular reference, so - # we use the __annotations__ attribute directly. - annotations = getattr(expected_type, "__annotations__", {}) + annotations = _get_cached_type_hints(expected_type) aliases_to_field_names = _get_alias_to_field_name(annotations) for key, value in object_.items(): if direction == "read" and key in aliases_to_field_names: @@ -221,12 +292,12 @@ def _remove_annotations(type_: typing.Any) -> typing.Any: def get_alias_to_field_mapping(type_: typing.Any) -> typing.Dict[str, str]: - annotations = typing_extensions.get_type_hints(type_, include_extras=True) + annotations = _get_cached_type_hints(type_) return _get_alias_to_field_name(annotations) def get_field_to_alias_mapping(type_: typing.Any) -> typing.Dict[str, str]: - annotations = typing_extensions.get_type_hints(type_, include_extras=True) + annotations = _get_cached_type_hints(type_) return _get_field_to_alias_name(annotations)