diff --git a/src/opengradient/abi/TEERegistry.abi b/src/opengradient/abi/TEERegistry.abi index 51fe5b3b..5de0d5ef 100644 --- a/src/opengradient/abi/TEERegistry.abi +++ b/src/opengradient/abi/TEERegistry.abi @@ -1,80 +1,304 @@ [ - { - "inputs": [{"internalType": "uint8", "name": "teeType", "type": "uint8"}], - "name": "getActiveTEEs", - "outputs": [ - { - "components": [ - {"internalType": "address", "name": "owner", "type": "address"}, - {"internalType": "address", "name": "paymentAddress", "type": "address"}, - {"internalType": "string", "name": "endpoint", "type": "string"}, - {"internalType": "bytes", "name": "publicKey", "type": "bytes"}, - {"internalType": "bytes", "name": "tlsCertificate", "type": "bytes"}, - {"internalType": "bytes32", "name": "pcrHash", "type": "bytes32"}, - {"internalType": "uint8", "name": "teeType", "type": "uint8"}, - {"internalType": "bool", "name": "enabled", "type": "bool"}, - {"internalType": "uint256", "name": "registeredAt", "type": "uint256"}, - {"internalType": "uint256", "name": "lastHeartbeatAt", "type": "uint256"} + { + "inputs": [ + { + "internalType": "uint8", + "name": "teeType", + "type": "uint8" + } ], - "internalType": "struct TEERegistry.TEEInfo[]", - "name": "", - "type": "tuple[]" - } - ], - "stateMutability": "view", - "type": "function" - }, - { - "inputs": [{"internalType": "uint8", "name": "teeType", "type": "uint8"}], - "name": "getEnabledTEEs", - "outputs": [{"internalType": "bytes32[]", "name": "", "type": "bytes32[]"}], - "stateMutability": "view", - "type": "function" - }, - { - "inputs": [{"internalType": "uint8", "name": "teeType", "type": "uint8"}], - "name": "getTEEsByType", - "outputs": [{"internalType": "bytes32[]", "name": "", "type": "bytes32[]"}], - "stateMutability": "view", - "type": "function" - }, - { - "inputs": [{"internalType": "bytes32", "name": "teeId", "type": "bytes32"}], - "name": "getTEE", - "outputs": [ - { - "components": [ - {"internalType": "address", "name": "owner", "type": "address"}, - {"internalType": "address", "name": "paymentAddress", "type": "address"}, - {"internalType": "string", "name": "endpoint", "type": "string"}, - {"internalType": "bytes", "name": "publicKey", "type": "bytes"}, - {"internalType": "bytes", "name": "tlsCertificate", "type": "bytes"}, - {"internalType": "bytes32", "name": "pcrHash", "type": "bytes32"}, - {"internalType": "uint8", "name": "teeType", "type": "uint8"}, - {"internalType": "bool", "name": "enabled", "type": "bool"}, - {"internalType": "uint256", "name": "registeredAt", "type": "uint256"}, - {"internalType": "uint256", "name": "lastHeartbeatAt", "type": "uint256"} + "name": "getActiveTEEs", + "outputs": [ + { + "components": [ + { + "internalType": "address", + "name": "owner", + "type": "address" + }, + { + "internalType": "address", + "name": "paymentAddress", + "type": "address" + }, + { + "internalType": "string", + "name": "endpoint", + "type": "string" + }, + { + "internalType": "bytes", + "name": "publicKey", + "type": "bytes" + }, + { + "internalType": "bytes", + "name": "tlsCertificate", + "type": "bytes" + }, + { + "internalType": "bytes32", + "name": "pcrHash", + "type": "bytes32" + }, + { + "internalType": "uint8", + "name": "teeType", + "type": "uint8" + }, + { + "internalType": "bool", + "name": "enabled", + "type": "bool" + }, + { + "internalType": "uint256", + "name": "registeredAt", + "type": "uint256" + }, + { + "internalType": "uint256", + "name": "lastHeartbeatAt", + "type": "uint256" + }, + { + "components": [ + { + "internalType": "uint8", + "name": "keyId", + "type": "uint8" + }, + { + "internalType": "uint16", + "name": "kemId", + "type": "uint16" + }, + { + "internalType": "uint16", + "name": "kdfId", + "type": "uint16" + }, + { + "internalType": "uint16", + "name": "aeadId", + "type": "uint16" + }, + { + "internalType": "bytes", + "name": "publicKey", + "type": "bytes" + }, + { + "internalType": "bytes", + "name": "keyConfig", + "type": "bytes" + }, + { + "internalType": "uint256", + "name": "registeredAt", + "type": "uint256" + } + ], + "internalType": "struct TEERegistry.OhttpConfig", + "name": "ohttpConfig", + "type": "tuple" + } + ], + "internalType": "struct TEERegistry.TEEInfo[]", + "name": "", + "type": "tuple[]" + } ], - "internalType": "struct TEERegistry.TEEInfo", - "name": "", - "type": "tuple" - } - ], - "stateMutability": "view", - "type": "function" - }, - { - "inputs": [{"internalType": "bytes32", "name": "teeId", "type": "bytes32"}], - "name": "isTEEActive", - "outputs": [{"internalType": "bool", "name": "", "type": "bool"}], - "stateMutability": "view", - "type": "function" - }, - { - "inputs": [{"internalType": "bytes32", "name": "teeId", "type": "bytes32"}], - "name": "isTEEEnabled", - "outputs": [{"internalType": "bool", "name": "", "type": "bool"}], - "stateMutability": "view", - "type": "function" - } + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "uint8", + "name": "teeType", + "type": "uint8" + } + ], + "name": "getEnabledTEEs", + "outputs": [ + { + "internalType": "bytes32[]", + "name": "", + "type": "bytes32[]" + } + ], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "uint8", + "name": "teeType", + "type": "uint8" + } + ], + "name": "getTEEsByType", + "outputs": [ + { + "internalType": "bytes32[]", + "name": "", + "type": "bytes32[]" + } + ], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "bytes32", + "name": "teeId", + "type": "bytes32" + } + ], + "name": "getTEE", + "outputs": [ + { + "components": [ + { + "internalType": "address", + "name": "owner", + "type": "address" + }, + { + "internalType": "address", + "name": "paymentAddress", + "type": "address" + }, + { + "internalType": "string", + "name": "endpoint", + "type": "string" + }, + { + "internalType": "bytes", + "name": "publicKey", + "type": "bytes" + }, + { + "internalType": "bytes", + "name": "tlsCertificate", + "type": "bytes" + }, + { + "internalType": "bytes32", + "name": "pcrHash", + "type": "bytes32" + }, + { + "internalType": "uint8", + "name": "teeType", + "type": "uint8" + }, + { + "internalType": "bool", + "name": "enabled", + "type": "bool" + }, + { + "internalType": "uint256", + "name": "registeredAt", + "type": "uint256" + }, + { + "internalType": "uint256", + "name": "lastHeartbeatAt", + "type": "uint256" + }, + { + "components": [ + { + "internalType": "uint8", + "name": "keyId", + "type": "uint8" + }, + { + "internalType": "uint16", + "name": "kemId", + "type": "uint16" + }, + { + "internalType": "uint16", + "name": "kdfId", + "type": "uint16" + }, + { + "internalType": "uint16", + "name": "aeadId", + "type": "uint16" + }, + { + "internalType": "bytes", + "name": "publicKey", + "type": "bytes" + }, + { + "internalType": "bytes", + "name": "keyConfig", + "type": "bytes" + }, + { + "internalType": "uint256", + "name": "registeredAt", + "type": "uint256" + } + ], + "internalType": "struct TEERegistry.OhttpConfig", + "name": "ohttpConfig", + "type": "tuple" + } + ], + "internalType": "struct TEERegistry.TEEInfo", + "name": "", + "type": "tuple" + } + ], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "bytes32", + "name": "teeId", + "type": "bytes32" + } + ], + "name": "isTEEActive", + "outputs": [ + { + "internalType": "bool", + "name": "", + "type": "bool" + } + ], + "stateMutability": "view", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "bytes32", + "name": "teeId", + "type": "bytes32" + } + ], + "name": "isTEEEnabled", + "outputs": [ + { + "internalType": "bool", + "name": "", + "type": "bool" + } + ], + "stateMutability": "view", + "type": "function" + } ] diff --git a/src/opengradient/client/tee_registry.py b/src/opengradient/client/tee_registry.py index 9e6c8028..de7be8e8 100644 --- a/src/opengradient/client/tee_registry.py +++ b/src/opengradient/client/tee_registry.py @@ -4,7 +4,7 @@ import random import ssl from dataclasses import dataclass -from typing import List, NamedTuple, Optional +from typing import Any, List, NamedTuple, Optional, Sequence from web3 import Web3 @@ -17,8 +17,41 @@ TEE_TYPE_VALIDATOR = 1 +class OhttpConfig(NamedTuple): + """Mirrors the on-chain TEERegistry.OhttpConfig struct. + + The HPKE key material a client needs to encrypt an Oblivious HTTP request to + this TEE (the same configuration the chat-app browser client reads). + + Attributes: + key_id: OHTTP key configuration id. + kem_id: HPKE KEM id (0x0020 = DHKEM(X25519, HKDF-SHA256)). + kdf_id: HPKE KDF id (0x0001 = HKDF-SHA256). + aead_id: HPKE AEAD id (0x0003 = ChaCha20-Poly1305). + public_key: The TEE's HPKE (X25519) recipient public key. + key_config: The serialized OHTTP key config blob. + registered_at: Block timestamp the OHTTP config was registered. + """ + + key_id: int + kem_id: int + kdf_id: int + aead_id: int + public_key: bytes + key_config: bytes + registered_at: int + + class TEEInfo(NamedTuple): - """Mirrors the on-chain TEERegistry.TEEInfo struct.""" + """Mirrors the on-chain TEERegistry.TEEInfo struct (full record). + + This is a thin positional wrapper over the tuple web3 decodes from the + contract (built via ``TEEInfo(*raw)``), so every field holds the raw + decoded value. In particular ``ohttp_config`` is the *raw* decoded + sub-tuple, not a parsed `OhttpConfig` — use `_parse_ohttp_config` (as + `TEERegistry` does) to coerce it. The parsed, typed form is surfaced on + `TEEEndpoint.ohttp_config`. + """ owner: str payment_address: str @@ -30,16 +63,32 @@ class TEEInfo(NamedTuple): enabled: bool registered_at: int last_heartbeat_at: int + ohttp_config: Sequence[Any] @dataclass(frozen=True) class TEEEndpoint: - """A verified TEE with its endpoint URL and TLS certificate from the registry.""" + """A verified TEE resolved from the registry. + + Carries everything needed for both trust paths: the endpoint + pinned TLS + cert for a direct x402 connection, and the OHTTP/HPKE key material + + signing key for the oblivious-HTTP relay path. + + Attributes: + tee_id: keccak256 of the TEE's signing public key (0x-prefixed hex). + endpoint: The TEE gateway endpoint URL. + tls_cert_der: DER-encoded TLS certificate pinned at registration. + payment_address: x402 settlement address for this TEE. + signing_public_key_der: DER (SPKI) RSA public key the TEE signs with. + ohttp_config: The TEE's OHTTP/HPKE key configuration, if present. + """ tee_id: str endpoint: str tls_cert_der: bytes payment_address: str + signing_public_key_der: bytes = b"" + ohttp_config: Optional[OhttpConfig] = None class TEERegistry: @@ -103,6 +152,8 @@ def get_active_tees_by_type(self, tee_type: int) -> List[TEEEndpoint]: endpoint=tee.endpoint, tls_cert_der=bytes(tee.tls_certificate), payment_address=tee.payment_address, + signing_public_key_der=bytes(tee.public_key), + ohttp_config=_parse_ohttp_config(tee.ohttp_config), ) ) @@ -112,6 +163,10 @@ def get_llm_tee(self) -> Optional[TEEEndpoint]: """ Return a random active LLM proxy TEE from the registry. + The returned ``TEEEndpoint`` is the full record: endpoint + pinned TLS + cert for direct x402 connections, plus the OHTTP/HPKE ``ohttp_config`` + and ``signing_public_key_der`` for the oblivious-HTTP relay path. + Returns: TEEEndpoint for an active LLM proxy TEE, or None if none are available. """ @@ -122,6 +177,49 @@ def get_llm_tee(self) -> Optional[TEEEndpoint]: return random.choice(tees) + def get_llm_tee_ohttp_config(self) -> Optional[TEEEndpoint]: + """ + Return a random active LLM proxy TEE that advertises an OHTTP config. + + Like ``get_llm_tee`` but skips TEEs missing HPKE key material, so the + result is guaranteed usable for the Oblivious HTTP path. + + Returns: + A TEEEndpoint with a non-empty ``ohttp_config``, or None. + """ + candidates = [ + tee + for tee in self.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + if tee.ohttp_config is not None and len(tee.ohttp_config.public_key) == 32 + ] + if not candidates: + logger.warning("No active LLM proxy TEEs with an OHTTP config found in registry") + return None + + return random.choice(candidates) + + +def _parse_ohttp_config(raw: Sequence[Any]) -> Optional[OhttpConfig]: + """Coerce the decoded on-chain ohttpConfig tuple into an OhttpConfig. + + Returns None when the TEE has no OHTTP config registered (empty public key). + """ + try: + cfg = OhttpConfig( + key_id=int(raw[0]), + kem_id=int(raw[1]), + kdf_id=int(raw[2]), + aead_id=int(raw[3]), + public_key=bytes(raw[4]), + key_config=bytes(raw[5]), + registered_at=int(raw[6]), + ) + except (TypeError, IndexError, ValueError): + return None + if not cfg.public_key: + return None + return cfg + def build_ssl_context_from_der(der_cert: bytes) -> ssl.SSLContext: """ diff --git a/tests/tee_registry_test.py b/tests/tee_registry_test.py index fa37e2cf..252e0404 100644 --- a/tests/tee_registry_test.py +++ b/tests/tee_registry_test.py @@ -7,6 +7,7 @@ TEE_TYPE_LLM_PROXY, TEE_TYPE_VALIDATOR, TEERegistry, + _parse_ohttp_config, build_ssl_context_from_der, ) @@ -18,19 +19,33 @@ def _make_tee_info( payment_address="0xPayment", pub_key=b"pubkey", tls_cert_der=b"\x01\x02\x03", + ohttp_public_key=b"\x55" * 32, ): - """Build a tuple matching the TEEInfo struct order from the new contract.""" + """Build a tuple matching the full TEEInfo struct order from the contract. + + Includes the trailing ``ohttpConfig`` sub-tuple (keyId, kemId, kdfId, aeadId, + publicKey, keyConfig, registeredAt) that the full registry read parses. + """ return ( "0xOwner", # owner payment_address, # paymentAddress endpoint, # endpoint - pub_key, # publicKey + pub_key, # publicKey (RSA signing key, DER) tls_cert_der, # tlsCertificate b"\x00" * 32, # pcrHash 0, # teeType True, # enabled (always True from getActiveTEEs) 1000, # registeredAt 2000, # lastHeartbeatAt + ( # ohttpConfig + 1, # keyId + 0x0020, # kemId (X25519) + 0x0001, # kdfId (HKDF-SHA256) + 0x0003, # aeadId (ChaCha20-Poly1305) + ohttp_public_key, # publicKey (HPKE X25519) + b"keyconfig", # keyConfig + 3000, # registeredAt + ), ) @@ -97,6 +112,12 @@ def test_returns_active_tees(self, mock_contract): assert result[0].endpoint == "https://tee.example.com" assert result[0].payment_address == "0xPayment" assert result[0].tls_cert_der == b"\x01\x02\x03" + # Full registry read: the signing key + OHTTP config come back too. + assert result[0].signing_public_key_der == b"pubkey" + assert result[0].ohttp_config is not None + assert result[0].ohttp_config.kem_id == 0x0020 + assert result[0].ohttp_config.aead_id == 0x0003 + assert len(result[0].ohttp_config.public_key) == 32 contract.functions.getActiveTEEs.assert_called_once_with(TEE_TYPE_LLM_PROXY) def test_skips_tee_with_empty_endpoint(self, mock_contract): @@ -174,6 +195,84 @@ def test_queries_llm_proxy_type(self, mock_contract): contract.functions.getActiveTEEs.assert_called_once_with(TEE_TYPE_LLM_PROXY) +class TestParseOhttpConfig: + def test_parses_full_config(self): + # The raw decoded ohttpConfig sub-tuple from the TEEInfo fixture. + raw = _make_tee_info()[-1] + cfg = _parse_ohttp_config(raw) + + assert cfg is not None + assert cfg.key_id == 1 + assert cfg.kem_id == 0x0020 + assert cfg.kdf_id == 0x0001 + assert cfg.aead_id == 0x0003 + assert cfg.public_key == b"\x55" * 32 + assert cfg.key_config == b"keyconfig" + assert cfg.registered_at == 3000 + + def test_returns_none_on_empty_public_key(self): + raw = _make_tee_info(ohttp_public_key=b"")[-1] + assert _parse_ohttp_config(raw) is None + + def test_returns_none_on_malformed_tuple(self): + # Too few fields -> IndexError swallowed -> None. + assert _parse_ohttp_config((1, 2, 3)) is None + + +class TestGetActiveTeesOhttpConfig: + def test_endpoint_ohttp_config_is_none_when_no_hpke_key(self, mock_contract): + registry, contract = mock_contract + + contract.functions.getActiveTEEs.return_value.call.return_value = [_make_tee_info(ohttp_public_key=b"")] + + result = registry.get_active_tees_by_type(TEE_TYPE_LLM_PROXY) + assert len(result) == 1 + # No usable HPKE material -> parsed config drops to None (endpoint still returned). + assert result[0].ohttp_config is None + + +class TestGetLlmTeeOhttpConfig: + def test_filters_out_tees_without_hpke_key(self, mock_contract): + registry, contract = mock_contract + + # One TEE with a valid 32-byte HPKE key, one without any. + contract.functions.getActiveTEEs.return_value.call.return_value = [ + _make_tee_info(endpoint="https://no-ohttp.example.com", pub_key=b"pubkey0", ohttp_public_key=b""), + _make_tee_info(endpoint="https://has-ohttp.example.com", pub_key=b"pubkey1", ohttp_public_key=b"\x55" * 32), + ] + + result = registry.get_llm_tee_ohttp_config() + + assert result is not None + assert result.endpoint == "https://has-ohttp.example.com" + assert result.ohttp_config is not None + assert len(result.ohttp_config.public_key) == 32 + + def test_returns_none_when_no_usable_ohttp_config(self, mock_contract): + registry, contract = mock_contract + + contract.functions.getActiveTEEs.return_value.call.return_value = [ + _make_tee_info(ohttp_public_key=b""), + ] + + assert registry.get_llm_tee_ohttp_config() is None + + def test_returns_none_when_no_tees(self, mock_contract): + registry, contract = mock_contract + + contract.functions.getActiveTEEs.return_value.call.return_value = [] + + assert registry.get_llm_tee_ohttp_config() is None + + def test_queries_llm_proxy_type(self, mock_contract): + registry, contract = mock_contract + + contract.functions.getActiveTEEs.return_value.call.return_value = [] + registry.get_llm_tee_ohttp_config() + + contract.functions.getActiveTEEs.assert_called_once_with(TEE_TYPE_LLM_PROXY) + + # --- build_ssl_context_from_der Tests ---