diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java index e96403e48..39c6bcacf 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java @@ -6,7 +6,7 @@ import java.util.ArrayList; import java.util.List; -import java.util.Map; +import java.util.function.Function; import io.modelcontextprotocol.util.Assert; @@ -47,27 +47,18 @@ private DefaultServerTransportSecurityValidator(List allowedOrigins, Lis } @Override - public void validateHeaders(Map> headers) throws ServerTransportSecurityException { - boolean missingHost = true; - for (Map.Entry> entry : headers.entrySet()) { - if (ORIGIN_HEADER.equalsIgnoreCase(entry.getKey())) { - List values = entry.getValue(); - if (values == null || values.isEmpty()) { - throw new ServerTransportSecurityException(403, "Invalid Origin header"); - } - validateOrigin(values.get(0)); - } - else if (HOST_HEADER.equalsIgnoreCase(entry.getKey())) { - missingHost = false; - List values = entry.getValue(); - if (values == null || values.isEmpty()) { - throw new ServerTransportSecurityException(421, "Invalid Host header"); - } - validateHost(values.get(0)); - } + public void validateHeaders(Function> headerAccessor) throws ServerTransportSecurityException { + List originValues = headerAccessor.apply(ORIGIN_HEADER); + if (originValues != null && !originValues.isEmpty()) { + validateOrigin(originValues.get(0)); } - if (!allowedHosts.isEmpty() && missingHost) { - throw new ServerTransportSecurityException(421, "Invalid Host header"); + + if (!allowedHosts.isEmpty()) { + List hostValues = headerAccessor.apply(HOST_HEADER); + if (hostValues == null || hostValues.isEmpty()) { + throw new ServerTransportSecurityException(421, "Invalid Host header"); + } + validateHost(hostValues.get(0)); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java index ce805931f..7d86da480 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java @@ -6,31 +6,68 @@ import java.util.List; import java.util.Map; +import java.util.function.Function; /** * Interface for validating HTTP requests in server transports. Implementations can * validate Origin headers, Host headers, or any other security-related headers according * to the MCP specification. * + *

+ * New implementations should override {@link #validateHeaders(Function) + * validateHeaders(Function)} for more efficient, case-insensitive header access. The + * older {@link #validateHeaders(Map) validateHeaders(Map)} is deprecated and will be + * removed in a future major version. + * * @author Daniel Garnier-Moiroux * @see DefaultServerTransportSecurityValidator * @see ServerTransportSecurityException */ -@FunctionalInterface public interface ServerTransportSecurityValidator { /** * A no-op validator that accepts all requests without validation. */ - ServerTransportSecurityValidator NOOP = headers -> { + ServerTransportSecurityValidator NOOP = new ServerTransportSecurityValidator() { }; /** * Validates the HTTP headers from an incoming request. + * + *

+ * The default implementation converts the map into a case-insensitive header accessor + * and delegates to {@link #validateHeaders(Function)}. * @param headers A map of header names to their values (multi-valued headers * supported) * @throws ServerTransportSecurityException if validation fails + * @deprecated Use {@link #validateHeaders(Function)} instead for more efficient, + * case-insensitive header access. This method will be removed in a future major + * version. + */ + @Deprecated + default void validateHeaders(Map> headers) throws ServerTransportSecurityException { + validateHeaders(name -> headers.entrySet() + .stream() + .filter(e -> e.getKey().equalsIgnoreCase(name)) + .map(Map.Entry::getValue) + .findFirst() + .orElse(List.of())); + } + + /** + * Validates the HTTP headers from an incoming request using a header accessor + * function. + * + *

+ * New implementations should override this method. Header name lookup through the + * accessor should be case-insensitive (e.g., when backed by + * {@code HttpServletRequest.getHeaders}). + * @param headerAccessor A function that returns the list of values for a given header + * name, or an empty list if the header is not present. + * @throws ServerTransportSecurityException if validation fails */ - void validateHeaders(Map> headers) throws ServerTransportSecurityException; + default void validateHeaders(Function> headerAccessor) + throws ServerTransportSecurityException { + } } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java index d4cf8582d..a91cde162 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidatorTests.java @@ -7,6 +7,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Function; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -43,50 +44,50 @@ class OriginHeader { @Test void originHeaderMissing() { - assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(emptyAccessor())).doesNotThrowAnyException(); } @Test void originHeaderListEmpty() { - assertThatThrownBy(() -> validator.validateHeaders(Map.of("Origin", List.of()))).isEqualTo(INVALID_ORIGIN); + assertThatCode(() -> validator.validateHeaders(name -> List.of())).doesNotThrowAnyException(); } @Test void caseInsensitive() { - var headers = Map.of("origin", List.of("http://localhost:8080")); + var accessor = headerAccessor("Origin", "http://localhost:8080"); - assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void exactMatch() { - var headers = originHeader("http://localhost:8080"); + var accessor = originAccessor("http://localhost:8080"); - assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentPort() { - var headers = originHeader("http://localhost:3000"); + var accessor = originAccessor("http://localhost:3000"); - assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> validator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } @Test void differentHost() { - var headers = originHeader("http://example.com:8080"); + var accessor = originAccessor("http://example.com:8080"); - assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> validator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } @Test void differentScheme() { - var headers = originHeader("https://localhost:8080"); + var accessor = originAccessor("https://localhost:8080"); - assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> validator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } @Nested @@ -99,37 +100,37 @@ class WildcardPort { @Test void anyPortWithWildcard() { - var headers = originHeader("http://localhost:3000"); + var accessor = originAccessor("http://localhost:3000"); - assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> wildcardValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void noPortWithWildcard() { - var headers = originHeader("http://localhost"); + var accessor = originAccessor("http://localhost"); - assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> wildcardValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentPortWithWildcard() { - var headers = originHeader("http://localhost:8080"); + var accessor = originAccessor("http://localhost:8080"); - assertThatCode(() -> wildcardValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> wildcardValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentHostWithWildcard() { - var headers = originHeader("http://example.com:3000"); + var accessor = originAccessor("http://example.com:3000"); - assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> wildcardValidator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } @Test void differentSchemeWithWildcard() { - var headers = originHeader("https://localhost:3000"); + var accessor = originAccessor("https://localhost:3000"); - assertThatThrownBy(() -> wildcardValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> wildcardValidator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } } @@ -146,23 +147,23 @@ class MultipleOrigins { @Test void matchingOneOfMultiple() { - var headers = originHeader("http://example.com:3000"); + var accessor = originAccessor("http://example.com:3000"); - assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> multipleOriginsValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void matchingWildcardInMultiple() { - var headers = originHeader("http://myapp.example.com:9999"); + var accessor = originAccessor("http://myapp.example.com:9999"); - assertThatCode(() -> multipleOriginsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> multipleOriginsValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void notMatchingAny() { - var headers = originHeader("http://malicious.example.com:1234"); + var accessor = originAccessor("http://malicious.example.com:1234"); - assertThatThrownBy(() -> multipleOriginsValidator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> multipleOriginsValidator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } } @@ -176,9 +177,9 @@ void shouldAddMultipleOriginsWithAllowedOriginsMethod() { .allowedOrigins(List.of("http://localhost:8080", "http://example.com:*")) .build(); - var headers = originHeader("http://example.com:3000"); + var accessor = originAccessor("http://example.com:3000"); - assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test @@ -188,11 +189,11 @@ void shouldCombineAllowedOriginMethods() { .allowedOrigins(List.of("http://example.com:*", "http://test.com:3000")) .build(); - assertThatCode(() -> validator.validateHeaders(originHeader("http://localhost:8080"))) + assertThatCode(() -> validator.validateHeaders(originAccessor("http://localhost:8080"))) .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(originHeader("http://example.com:9999"))) + assertThatCode(() -> validator.validateHeaders(originAccessor("http://example.com:9999"))) .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(originHeader("http://test.com:3000"))) + assertThatCode(() -> validator.validateHeaders(originAccessor("http://test.com:3000"))) .doesNotThrowAnyException(); } @@ -210,45 +211,45 @@ class HostHeader { @Test void notConfigured() { - assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(emptyAccessor())).doesNotThrowAnyException(); } @Test void missing() { - assertThatThrownBy(() -> hostValidator.validateHeaders(new HashMap<>())).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> hostValidator.validateHeaders(emptyAccessor())).isEqualTo(INVALID_HOST); } @Test void listEmpty() { - assertThatThrownBy(() -> hostValidator.validateHeaders(Map.of("Host", List.of()))).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> hostValidator.validateHeaders(name -> List.of())).isEqualTo(INVALID_HOST); } @Test void caseInsensitive() { - var headers = Map.of("host", List.of("localhost:8080")); + var accessor = headerAccessor("Host", "localhost:8080"); - assertThatCode(() -> hostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> hostValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void exactMatch() { - var headers = hostHeader("localhost:8080"); + var accessor = hostAccessor("localhost:8080"); - assertThatCode(() -> hostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> hostValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentPort() { - var headers = hostHeader("localhost:3000"); + var accessor = hostAccessor("localhost:3000"); - assertThatThrownBy(() -> hostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> hostValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } @Test void differentHost() { - var headers = hostHeader("example.com:8080"); + var accessor = hostAccessor("example.com:8080"); - assertThatThrownBy(() -> hostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> hostValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } @Nested @@ -261,23 +262,23 @@ class HostWildcardPort { @Test void anyPort() { - var headers = hostHeader("localhost:3000"); + var accessor = hostAccessor("localhost:3000"); - assertThatCode(() -> wildcardHostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> wildcardHostValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void noPort() { - var headers = hostHeader("localhost"); + var accessor = hostAccessor("localhost"); - assertThatCode(() -> wildcardHostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> wildcardHostValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentHost() { - var headers = hostHeader("example.com:3000"); + var accessor = hostAccessor("example.com:3000"); - assertThatThrownBy(() -> wildcardHostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> wildcardHostValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } } @@ -293,30 +294,30 @@ class MultipleHosts { @Test void exactMatch() { - var headers = hostHeader("example.com:3000"); + var accessor = hostAccessor("example.com:3000"); - assertThatCode(() -> multipleHostsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> multipleHostsValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void wildcard() { - var headers = hostHeader("myapp.example.com:9999"); + var accessor = hostAccessor("myapp.example.com:9999"); - assertThatCode(() -> multipleHostsValidator.validateHeaders(headers)).doesNotThrowAnyException(); + assertThatCode(() -> multipleHostsValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void differentHost() { - var headers = hostHeader("malicious.example.com:3000"); + var accessor = hostAccessor("malicious.example.com:3000"); - assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } @Test void differentPort() { - var headers = hostHeader("localhost:8080"); + var accessor = hostAccessor("localhost:8080"); - assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> multipleHostsValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } } @@ -330,9 +331,9 @@ void multipleHosts() { .allowedHosts(List.of("localhost:8080", "example.com:*")) .build(); - assertThatCode(() -> validator.validateHeaders(hostHeader("example.com:3000"))) + assertThatCode(() -> validator.validateHeaders(hostAccessor("example.com:3000"))) .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(hostHeader("localhost:8080"))) + assertThatCode(() -> validator.validateHeaders(hostAccessor("localhost:8080"))) .doesNotThrowAnyException(); } @@ -343,11 +344,12 @@ void combined() { .allowedHosts(List.of("example.com:*", "test.com:3000")) .build(); - assertThatCode(() -> validator.validateHeaders(hostHeader("localhost:8080"))) + assertThatCode(() -> validator.validateHeaders(hostAccessor("localhost:8080"))) + .doesNotThrowAnyException(); + assertThatCode(() -> validator.validateHeaders(hostAccessor("example.com:9999"))) .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(hostHeader("example.com:9999"))) + assertThatCode(() -> validator.validateHeaders(hostAccessor("test.com:3000"))) .doesNotThrowAnyException(); - assertThatCode(() -> validator.validateHeaders(hostHeader("test.com:3000"))).doesNotThrowAnyException(); } } @@ -365,60 +367,180 @@ class CombinedOriginAndHostValidation { @Test void bothValid() { - var header = headers("http://localhost:8080", "localhost:8080"); + var accessor = combinedAccessor("http://localhost:8080", "localhost:8080"); - assertThatCode(() -> combinedValidator.validateHeaders(header)).doesNotThrowAnyException(); + assertThatCode(() -> combinedValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void originValidHostInvalid() { - var header = headers("http://localhost:8080", "malicious.example.com:8080"); + var accessor = combinedAccessor("http://localhost:8080", "malicious.example.com:8080"); - assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_HOST); + assertThatThrownBy(() -> combinedValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); } @Test void originInvalidHostValid() { - var header = headers("http://malicious.example.com:8080", "localhost:8080"); + var accessor = combinedAccessor("http://malicious.example.com:8080", "localhost:8080"); - assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_ORIGIN); + assertThatThrownBy(() -> combinedValidator.validateHeaders(accessor)).isEqualTo(INVALID_ORIGIN); } @Test void originMissingHostValid() { // Origin missing is OK (same-origin request) - var header = headers(null, "localhost:8080"); + var accessor = combinedAccessor(null, "localhost:8080"); - assertThatCode(() -> combinedValidator.validateHeaders(header)).doesNotThrowAnyException(); + assertThatCode(() -> combinedValidator.validateHeaders(accessor)).doesNotThrowAnyException(); } @Test void originValidHostMissing() { // Host missing is NOT OK when allowedHosts is configured - var header = headers("http://localhost:8080", null); + var accessor = combinedAccessor("http://localhost:8080", null); + + assertThatThrownBy(() -> combinedValidator.validateHeaders(accessor)).isEqualTo(INVALID_HOST); + } + + } + + @Nested + class DeprecatedMapBasedApi { + + @Test + void originValidation() { + Map> headers = new HashMap<>(); + headers.put("Origin", List.of("http://localhost:8080")); + + assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void originRejected() { + Map> headers = new HashMap<>(); + headers.put("Origin", List.of("http://malicious.example.com")); + + assertThatThrownBy(() -> validator.validateHeaders(headers)).isEqualTo(INVALID_ORIGIN); + } + + @Test + void caseInsensitiveHeaderLookup() { + Map> headers = new HashMap<>(); + headers.put("origin", List.of("http://localhost:8080")); + + assertThatCode(() -> validator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void hostValidation() { + DefaultServerTransportSecurityValidator hostValidator = DefaultServerTransportSecurityValidator.builder() + .allowedHost("localhost:8080") + .build(); - assertThatThrownBy(() -> combinedValidator.validateHeaders(header)).isEqualTo(INVALID_HOST); + Map> headers = new HashMap<>(); + headers.put("Host", List.of("localhost:8080")); + + assertThatCode(() -> hostValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + @Test + void hostRejected() { + DefaultServerTransportSecurityValidator hostValidator = DefaultServerTransportSecurityValidator.builder() + .allowedHost("localhost:8080") + .build(); + + Map> headers = new HashMap<>(); + headers.put("Host", List.of("malicious.com:8080")); + + assertThatThrownBy(() -> hostValidator.validateHeaders(headers)).isEqualTo(INVALID_HOST); + } + + @Test + void emptyHeaders() { + assertThatCode(() -> validator.validateHeaders(new HashMap<>())).doesNotThrowAnyException(); + } + + @Test + void combinedOriginAndHost() { + DefaultServerTransportSecurityValidator combinedValidator = DefaultServerTransportSecurityValidator + .builder() + .allowedOrigin("http://localhost:*") + .allowedHost("localhost:*") + .build(); + + Map> headers = new HashMap<>(); + headers.put("Origin", List.of("http://localhost:8080")); + headers.put("Host", List.of("localhost:8080")); + + assertThatCode(() -> combinedValidator.validateHeaders(headers)).doesNotThrowAnyException(); + } + + } + + @Nested + class InterfaceDefaultBridge { + + @Test + void noopAcceptsAll() { + assertThatCode(() -> ServerTransportSecurityValidator.NOOP.validateHeaders(emptyAccessor())) + .doesNotThrowAnyException(); + assertThatCode(() -> ServerTransportSecurityValidator.NOOP.validateHeaders(new HashMap<>())) + .doesNotThrowAnyException(); + } + + @Test + void mapDefaultBridgesToFunctionOverride() { + // A validator that only overrides the Function method should still work + // when called via the deprecated Map method + ServerTransportSecurityValidator functionOnlyValidator = new ServerTransportSecurityValidator() { + @Override + public void validateHeaders(Function> headerAccessor) + throws ServerTransportSecurityException { + List origins = headerAccessor.apply("Origin"); + if (origins != null && !origins.isEmpty() && origins.get(0).contains("evil")) { + throw new ServerTransportSecurityException(403, "Invalid Origin header"); + } + } + }; + + Map> goodHeaders = new HashMap<>(); + goodHeaders.put("Origin", List.of("http://good.example.com")); + assertThatCode(() -> functionOnlyValidator.validateHeaders(goodHeaders)).doesNotThrowAnyException(); + + Map> evilHeaders = new HashMap<>(); + evilHeaders.put("Origin", List.of("http://evil.example.com")); + assertThatThrownBy(() -> functionOnlyValidator.validateHeaders(evilHeaders)).isEqualTo(INVALID_ORIGIN); } } - private static Map> originHeader(String origin) { - return Map.of("Origin", List.of(origin)); + private static Function> emptyAccessor() { + return name -> List.of(); + } + + private static Function> headerAccessor(String headerName, String value) { + Map> headers = new HashMap<>(); + headers.put(headerName, List.of(value)); + return name -> headers.getOrDefault(name, List.of()); + } + + private static Function> originAccessor(String origin) { + return headerAccessor("Origin", origin); } - private static Map> hostHeader(String host) { - return Map.of("Host", List.of(host)); + private static Function> hostAccessor(String host) { + return headerAccessor("Host", host); } - private static Map> headers(String origin, String host) { - var map = new HashMap>(); + private static Function> combinedAccessor(String origin, String host) { + Map> headers = new HashMap<>(); if (origin != null) { - map.put("Origin", List.of(origin)); + headers.put("Origin", List.of(origin)); } if (host != null) { - map.put("Host", List.of(host)); + headers.put("Host", List.of(host)); } - return map; + return name -> headers.getOrDefault(name, List.of()); } }