From 412fe5078730243efbbe620ab1c92078cee71c30 Mon Sep 17 00:00:00 2001 From: Matthias Kurz Date: Tue, 30 Jun 2026 14:12:38 +0200 Subject: [PATCH 1/2] Add server WebSocket compression #61 * Negotiate RFC 7692 permessage-deflate for server-side WebSockets. * Add compression settings, documentation, release notes, and MiMa filters. * Cover negotiation, compression, fragmentation, context takeover, max allocation, and low-level frame handling. --- .../paradox/client-side/websocket-support.md | 5 + .../paradox/server-side/websocket-support.md | 31 + .../websocket-compression.excludes | 19 + http-core/src/main/resources/reference.conf | 38 + .../engine/parsing/HttpHeaderParser.scala | 1 + .../pekko/http/impl/engine/ws/Handshake.scala | 25 +- .../impl/engine/ws/PerMessageDeflate.scala | 345 +++++ .../pekko/http/impl/engine/ws/WebSocket.scala | 2 + .../impl/settings/WebSocketSettingsImpl.scala | 52 +- .../engine/parsing/HttpHeaderParserSpec.scala | 3 + .../impl/engine/ws/WebSocketServerSpec.scala | 1206 ++++++++++++++++- 11 files changed, 1703 insertions(+), 24 deletions(-) create mode 100644 http-core/src/main/mima-filters/2.0.x.backwards.excludes/websocket-compression.excludes create mode 100644 http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/PerMessageDeflate.scala diff --git a/docs/src/main/paradox/client-side/websocket-support.md b/docs/src/main/paradox/client-side/websocket-support.md index ad6badff35..133cf1280b 100644 --- a/docs/src/main/paradox/client-side/websocket-support.md +++ b/docs/src/main/paradox/client-side/websocket-support.md @@ -9,6 +9,11 @@ the lifetime of the connection. Therefore a WebSocket connection is modelled as @apidoc[Flow[Message, Message, Mat]] to or a @apidoc[Flow[Message, Message, Mat]] that you connect a @apidoc[Source[Message, Mat]] and a @apidoc[Sink[Message, Mat]] to. +@@@ note +Client-side WebSocket compression is not currently implemented. Apache Pekko HTTP can negotiate +`permessage-deflate` compression for server-side WebSocket connections when a client requests it. +@@@ + A WebSocket request starts with a regular HTTP request which contains an `Upgrade` header (and possibly other regular HTTP request properties), so in addition to the flow of messages there also is an initial response from the server, this is modelled with @apidoc[WebSocketUpgradeResponse]. diff --git a/docs/src/main/paradox/server-side/websocket-support.md b/docs/src/main/paradox/server-side/websocket-support.md index 181de6320a..a6fc825e9d 100644 --- a/docs/src/main/paradox/server-side/websocket-support.md +++ b/docs/src/main/paradox/server-side/websocket-support.md @@ -151,6 +151,37 @@ In case you need to keep inactive connections alive, you can either tweak your i +## WebSocket compression + +Apache Pekko HTTP can negotiate the RFC 7692 `permessage-deflate` WebSocket extension for server-side WebSocket +connections. Compression is enabled by default, but is only used when the client requests it with the +`Sec-WebSocket-Extensions: permessage-deflate` header during the WebSocket handshake. + +You can disable WebSocket compression globally for the server: + +``` +pekko.http.server.websocket.compression.enabled = false +``` + +The server exposes additional settings for the negotiated extension under +`pekko.http.server.websocket.compression.permessage-deflate`, including `compression-level`, +`preferred-client-window-size`, `allow-server-no-context`, and `preferred-client-no-context`. See the +@ref[configuration reference](../configuration.md) for the complete list of settings and defaults. + +@@@ note +The `server_no_context_takeover` and `client_no_context_takeover` extension parameters affect whether compression +dictionaries are retained across messages. Retaining context generally improves compression ratio, while disabling +context takeover can reduce cross-message information retention and memory lifetime. The defaults keep context takeover +enabled unless explicitly configured otherwise. + +Pekko HTTP uses the JDK `Deflater` and `Inflater` implementation for `permessage-deflate`. The JDK API +does not expose zlib `windowBits` or `memLevel`, so Pekko HTTP does not accept `server_max_window_bits` +values below `15` and does not provide server window-size or memory-level settings. A client may still +request `client_max_window_bits`; when it does, Pekko HTTP can include the configured +`preferred-client-window-size` in the handshake response to ask the client to use that window size for +client-to-server messages. +@@@ + ## Automatic keep-alive Ping support For long running websocket connections it may be beneficial to enable automatic heartbeat using `Ping` frames. diff --git a/http-core/src/main/mima-filters/2.0.x.backwards.excludes/websocket-compression.excludes b/http-core/src/main/mima-filters/2.0.x.backwards.excludes/websocket-compression.excludes new file mode 100644 index 0000000000..c01eaba7b5 --- /dev/null +++ b/http-core/src/main/mima-filters/2.0.x.backwards.excludes/websocket-compression.excludes @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Add server-side WebSocket compression support. +ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.pekko.http.impl.engine.ws.Handshake#Server.buildResponse") diff --git a/http-core/src/main/resources/reference.conf b/http-core/src/main/resources/reference.conf index 81f1aa9036..1784bfd93a 100644 --- a/http-core/src/main/resources/reference.conf +++ b/http-core/src/main/resources/reference.conf @@ -348,6 +348,44 @@ pekko.http { # Enable verbose debug logging for all ingoing and outgoing frames log-frames = false + + compression { + # Whether the server should support WebSocket compression using the RFC 7692 + # permessage-deflate extension. Compression is negotiated during the + # WebSocket handshake and is only used when the client requests it. + enabled = true + + # Maximum size of a decompressed WebSocket message. If this value is + # exceeded while inflating a compressed message, the connection is closed + # with a WebSocket protocol error. + # Set to 0 to disable this limit. + max-allocation = 64k + + permessage-deflate { + # Pekko HTTP uses the JDK Deflater/Inflater implementation for + # permessage-deflate. The JDK API does not expose zlib windowBits or + # memLevel, so clients cannot negotiate a server_max_window_bits value + # below 15. client_max_window_bits can still be negotiated through + # preferred-client-window-size. + + # DEFLATE compression level used for server-to-client messages. Valid + # values are 0-9, where 0 uses no compression and 9 favors compression + # ratio over CPU usage. + compression-level = 6 + + # The client_max_window_bits value Pekko HTTP should request for + # client-to-server messages when the client sends client_max_window_bits + # without an explicit value. Valid values are 8-15. + preferred-client-window-size = 15 + + # Whether a client may request server_no_context_takeover. + allow-server-no-context = false + + # Whether Pekko HTTP should request client_no_context_takeover when the + # client indicates that it supports this parameter. + preferred-client-no-context = false + } + } } } diff --git a/http-core/src/main/scala/org/apache/pekko/http/impl/engine/parsing/HttpHeaderParser.scala b/http-core/src/main/scala/org/apache/pekko/http/impl/engine/parsing/HttpHeaderParser.scala index 63010f6e49..9b887c7431 100644 --- a/http-core/src/main/scala/org/apache/pekko/http/impl/engine/parsing/HttpHeaderParser.scala +++ b/http-core/src/main/scala/org/apache/pekko/http/impl/engine/parsing/HttpHeaderParser.scala @@ -487,6 +487,7 @@ private[http] object HttpHeaderParser { "content-type", "expect", "host", + "sec-websocket-extensions", "sec-websocket-key", "sec-websocket-protocol", "sec-websocket-version", diff --git a/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/Handshake.scala b/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/Handshake.scala index 9e0865cc7b..b3cde0a7ff 100644 --- a/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/Handshake.scala +++ b/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/Handshake.scala @@ -23,6 +23,7 @@ import scala.collection.immutable.Seq import pekko.event.LoggingAdapter import pekko.http.impl.util._ import pekko.http.impl.engine.server.UpgradeToOtherProtocolResponseHeader +import pekko.http.impl.settings.WebSocketSettingsImpl import pekko.http.scaladsl.model.headers._ import pekko.http.scaladsl.model.ws.Message import pekko.http.scaladsl.model._ @@ -122,6 +123,13 @@ private[http] object Handshake { case OptionVal.Some(p) => p.protocols case _ => Nil } + val clientRequestedExtensions = headers.collect { + case extensions: `Sec-WebSocket-Extensions` => extensions.extensions + }.flatten + val perMessageDeflate = + PerMessageDeflate.negotiate( + clientRequestedExtensions, + settings.asInstanceOf[WebSocketSettingsImpl].compression) val header = new UpgradeToWebSocketLowLevel { def requestedProtocols: Seq[String] = clientSupportedSubprotocols @@ -132,7 +140,7 @@ private[http] object Handshake { require( subprotocol.forall(chosen => clientSupportedSubprotocols.contains(chosen)), s"Tried to choose invalid subprotocol '$subprotocol' which wasn't offered by the client: [${requestedProtocols.mkString(", ")}]") - buildResponse(key.get, handler, subprotocol, settings, log) + buildResponse(key.get, handler, subprotocol, perMessageDeflate, settings, log) } def handleFrames( @@ -169,11 +177,16 @@ private[http] object Handshake { */ def buildResponse(key: `Sec-WebSocket-Key`, handler: Either[Graph[FlowShape[FrameEvent, FrameEvent], Any], Graph[FlowShape[Message, Message], Any]], - subprotocol: Option[String], settings: WebSocketSettings, log: LoggingAdapter): HttpResponse = { + subprotocol: Option[String], + perMessageDeflate: Option[PerMessageDeflate.Negotiated], + settings: WebSocketSettings, + log: LoggingAdapter): HttpResponse = { val frameHandler = handler match { - case Left(frameHandler) => frameHandler + case Left(frameHandler) => + perMessageDeflate.map(_.frameEventBidiFlow(settings.randomFactory).join(frameHandler)).getOrElse(frameHandler) case Right(messageHandler) => - WebSocket.stack(serverSide = true, settings, log = log).join(messageHandler) + WebSocket.stack(serverSide = true, settings, perMessageDeflate = perMessageDeflate, log = log) + .join(messageHandler) } HttpResponse( @@ -182,7 +195,9 @@ private[http] object Handshake { List( UpgradeHeader, ConnectionUpgradeHeader, - `Sec-WebSocket-Accept`.forKey(key), + `Sec-WebSocket-Accept`.forKey(key)) ::: + perMessageDeflate.map(p => `Sec-WebSocket-Extensions`(Seq(p.responseExtension))).toList ::: + List( UpgradeToOtherProtocolResponseHeader(WebSocket.framing.join(frameHandler)))) } } diff --git a/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/PerMessageDeflate.scala b/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/PerMessageDeflate.scala new file mode 100644 index 0000000000..eb71b21d03 --- /dev/null +++ b/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/PerMessageDeflate.scala @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.pekko.http.impl.engine.ws + +import java.io.ByteArrayOutputStream +import java.util.Random +import java.util.zip.Deflater +import java.util.zip.Inflater +import java.util.zip.DataFormatException + +import org.apache.pekko +import pekko.NotUsed +import pekko.annotation.InternalApi +import pekko.http.impl.settings.WebSocketCompressionSettingsImpl +import pekko.http.scaladsl.model.headers.WebSocketExtension +import pekko.stream.scaladsl.BidiFlow +import pekko.stream.scaladsl.Flow +import pekko.stream.stage.GraphStage +import pekko.stream.stage.GraphStageLogic +import pekko.stream.stage.InHandler +import pekko.stream.stage.OutHandler +import pekko.stream.{ Attributes, FlowShape, Inlet, Outlet } +import pekko.util.ByteString + +import scala.collection.immutable +import scala.collection.immutable.ListMap + +/** + * INTERNAL API + */ +@InternalApi +private[http] object PerMessageDeflate { + private val ExtensionName = "permessage-deflate" + private val ClientMaxWindowBits = "client_max_window_bits" + private val ServerMaxWindowBits = "server_max_window_bits" + private val ClientNoContextTakeover = "client_no_context_takeover" + private val ServerNoContextTakeover = "server_no_context_takeover" + private val EmptyStoredBlock = ByteString(0x00, 0x00, 0xFF.toByte, 0xFF.toByte) + + final case class Negotiated( + responseExtension: WebSocketExtension, + serverNoContextTakeover: Boolean, + clientNoContextTakeover: Boolean, + settings: WebSocketCompressionSettingsImpl) { + def bidiFlow: BidiFlow[FrameEventOrError, FrameEventOrError, FrameEvent, FrameEvent, NotUsed] = + BidiFlow.fromFlows(inflaterFlow, deflaterFlow) + + def frameEventBidiFlow( + maskRandom: () => Random): BidiFlow[FrameEvent, FrameEvent, FrameEvent, FrameEvent, NotUsed] = + BidiFlow.fromFlows( + Flow[FrameEvent] + .via(Masking.unmaskIf(condition = true)) + .via(inflaterFlow) + .map { + case frame: FrameEvent => frame + case FrameError(ex) => throw ex + } + .via(Masking.maskIf(condition = true, maskRandom)), + deflaterFlow) + + private def inflaterFlow: Flow[FrameEventOrError, FrameEventOrError, NotUsed] = + Flow.fromGraph(new LifecycleMapConcatStage( + "PerMessageDeflate.inflater", + () => new InflaterFlow(clientNoContextTakeover, settings))) + + private def deflaterFlow: Flow[FrameEvent, FrameEvent, NotUsed] = + Flow.fromGraph(new LifecycleMapConcatStage( + "PerMessageDeflate.deflater", + () => new DeflaterFlow(serverNoContextTakeover, settings))) + } + + def negotiate( + requested: immutable.Seq[WebSocketExtension], + settings: WebSocketCompressionSettingsImpl): Option[Negotiated] = { + if (!settings.enabled) None + else { + requested.collectFirst(Function.unlift { extension => + if (extension.name.equalsIgnoreCase(ExtensionName)) negotiate(extension, settings) else None + }) + } + } + + private def negotiate( + extension: WebSocketExtension, + settings: WebSocketCompressionSettingsImpl): Option[Negotiated] = { + var responseParams = ListMap.empty[String, String] + var clientNoContext = false + var serverNoContext = false + var accepted = true + + extension.params.foreach { + case (ClientMaxWindowBits, value) => + if (value.isEmpty) responseParams += ClientMaxWindowBits -> settings.preferredClientWindowSize.toString + else if (validWindowBits(value)) responseParams += ClientMaxWindowBits -> value + else accepted = false + case (ServerMaxWindowBits, value) => + if (value == "15") responseParams += ServerMaxWindowBits -> value + else accepted = false + case (ClientNoContextTakeover, "") => + clientNoContext = settings.preferredClientNoContext + if (clientNoContext) responseParams += ClientNoContextTakeover -> "" + case (ServerNoContextTakeover, "") => + if (settings.allowServerNoContext) { + serverNoContext = true + responseParams += ServerNoContextTakeover -> "" + } else accepted = false + case _ => + accepted = false + } + + if (accepted) { + Some(Negotiated(WebSocketExtension(ExtensionName, responseParams), serverNoContext, clientNoContext, settings)) + } else None + } + + private def validWindowBits(value: String): Boolean = + value.length <= 2 && value.forall(_.isDigit) && { + val parsed = value.toInt + parsed >= 8 && parsed <= 15 + } + + private final class InflaterFlow( + noContextTakeover: Boolean, + settings: WebSocketCompressionSettingsImpl) + extends LifecycleMapConcat[FrameEventOrError, FrameEventOrError] { + private var inflater = new Inflater(true) + private var compressedFrame: Option[CompressedFrame] = None + private var compressedMessageInProgress = false + private var decompressedMessageBytes = 0L + private var bypassFrameInProgress = false + + override def apply(event: FrameEventOrError): immutable.Iterable[FrameEventOrError] = event match { + case start @ FrameStart(header, data) + if header.rsv1 && + (header.opcode == Protocol.Opcode.Text || + header.opcode == Protocol.Opcode.Binary) => + if (compressedMessageInProgress || compressedFrame.isDefined) + throw new ProtocolException("Unexpected data frame while fragmented message is open") + if (header.rsv2 || header.rsv3) throw new ProtocolException("Unexpected reserved bit for compressed message") + compressedMessageInProgress = !header.fin + compressedFrame = Some(CompressedFrame(header.copy(rsv1 = false, length = 0), data, appendTail = header.fin)) + if (start.lastPart) finishFrame() else Nil + case start @ FrameStart(header, _) if bypassFrameInProgress => + throw new ProtocolException(s"Unexpected frame ${header.opcode} while frame data is open") + case start @ FrameStart(header, _) + if (compressedFrame.isDefined || compressedMessageInProgress) && header.opcode.isControl => + bypassFrameInProgress = !start.lastPart + start :: Nil + case start @ FrameStart(header, data) + if compressedMessageInProgress && header.opcode == Protocol.Opcode.Continuation => + if (header.rsv1 || header.rsv2 || header.rsv3) + throw new ProtocolException("Unexpected reserved bit for continuation frame") + compressedMessageInProgress = !header.fin + compressedFrame = Some(CompressedFrame(header.copy(length = 0), data, appendTail = header.fin)) + if (start.lastPart) finishFrame() else Nil + case start @ FrameStart(header, _) if compressedFrame.isDefined || compressedMessageInProgress => + throw new ProtocolException(s"Unexpected frame ${header.opcode} while fragmented message is open") + case data: FrameData if bypassFrameInProgress => + bypassFrameInProgress = !data.lastPart + data :: Nil + case data: FrameData if compressedFrame.isDefined => + compressedFrame = compressedFrame.map(_.append(data.data)) + if (data.lastPart) finishFrame() else Nil + case other => other :: Nil + } + + private def finishFrame(): immutable.Iterable[FrameEventOrError] = { + val frame = compressedFrame.get + compressedFrame = None + val inflated = inflate(frame.data, frame.appendTail) + if (frame.appendTail) decompressedMessageBytes = 0L + if (frame.appendTail && noContextTakeover) { + inflater.end() + inflater = new Inflater(true) + } + FrameStart(frame.header.copy(length = inflated.length), inflated) :: Nil + } + + private def inflate(data: ByteString, appendTail: Boolean): ByteString = { + try { + val input = if (appendTail) data ++ EmptyStoredBlock else data + inflater.setInput(input.toArray) + val output = new ByteArrayOutputStream() + val buffer = new Array[Byte](1024) + var count = inflater.inflate(buffer) + while (count > 0) { + decompressedMessageBytes += count + if (settings.maxAllocation > 0 && decompressedMessageBytes > settings.maxAllocation) + throw new ProtocolException("WebSocket decompressed message exceeds configured maximum allocation") + output.write(buffer, 0, count) + count = inflater.inflate(buffer) + } + ByteString.fromArray(output.toByteArray) + } catch { + case ex: DataFormatException => + throw new ProtocolException(s"Invalid WebSocket compressed message: ${ex.getMessage}") + } + } + + override def close(): Unit = + inflater.end() + } + + private final class DeflaterFlow( + noContextTakeover: Boolean, + settings: WebSocketCompressionSettingsImpl) + extends LifecycleMapConcat[FrameEvent, FrameEvent] { + private var deflater = new Deflater(settings.compressionLevel, true) + private var frame: Option[UncompressedFrame] = None + private var messageInProgress = false + private var bypassFrameInProgress = false + + override def apply(event: FrameEvent): immutable.Iterable[FrameEvent] = event match { + case FrameStart(header, _) + if (header.opcode == Protocol.Opcode.Text || + header.opcode == Protocol.Opcode.Binary) && + (header.rsv1 || header.rsv2 || header.rsv3) => + throw new ProtocolException("Unexpected reserved bit for outbound WebSocket message") + case FrameStart(header, _) + if header.opcode == Protocol.Opcode.Continuation && + (header.rsv1 || header.rsv2 || header.rsv3) => + throw new ProtocolException("Unexpected reserved bit for outbound WebSocket continuation frame") + case start @ FrameStart(header, data) + if header.opcode == Protocol.Opcode.Text || + header.opcode == Protocol.Opcode.Binary => + if (messageInProgress || frame.isDefined) + throw new ProtocolException("Unexpected data frame while fragmented message is open") + messageInProgress = !header.fin + frame = Some(UncompressedFrame(header.copy(length = 0, rsv1 = true), data, removeTail = header.fin)) + if (start.lastPart) finishFrame() else Nil + case start @ FrameStart(header, _) if bypassFrameInProgress => + throw new ProtocolException(s"Unexpected frame ${header.opcode} while frame data is open") + case start @ FrameStart(header, _) if (frame.isDefined || messageInProgress) && header.opcode.isControl => + bypassFrameInProgress = !start.lastPart + start :: Nil + case start @ FrameStart(header, data) if messageInProgress && header.opcode == Protocol.Opcode.Continuation => + messageInProgress = !header.fin + frame = Some(UncompressedFrame(header.copy(length = 0), data, removeTail = header.fin)) + if (start.lastPart) finishFrame() else Nil + case start @ FrameStart(header, _) if frame.isDefined || messageInProgress => + throw new ProtocolException(s"Unexpected frame ${header.opcode} while fragmented message is open") + case data: FrameData if bypassFrameInProgress => + bypassFrameInProgress = !data.lastPart + data :: Nil + case data: FrameData if frame.isDefined => + frame = frame.map(_.append(data.data)) + if (data.lastPart) finishFrame() else Nil + case other => other :: Nil + } + + private def finishFrame(): immutable.Iterable[FrameEvent] = { + val current = frame.get + frame = None + val compressed = deflate(current.data, current.removeTail) + if (current.removeTail && noContextTakeover) { + deflater.end() + deflater = new Deflater(settings.compressionLevel, true) + } + FrameStart(current.header.copy(length = compressed.length), compressed) :: Nil + } + + private def deflate(data: ByteString, removeTail: Boolean): ByteString = { + deflater.setInput(data.toArray) + val output = new ByteArrayOutputStream() + val buffer = new Array[Byte](1024) + var count = deflater.deflate(buffer, 0, buffer.length, Deflater.SYNC_FLUSH) + while (count > 0) { + output.write(buffer, 0, count) + count = deflater.deflate(buffer, 0, buffer.length, Deflater.SYNC_FLUSH) + } + val bytes = ByteString.fromArray(output.toByteArray) + if (removeTail && bytes.endsWith(EmptyStoredBlock)) bytes.dropRight(EmptyStoredBlock.length) else bytes + } + + override def close(): Unit = + deflater.end() + } + + private trait LifecycleMapConcat[-In, +Out] extends (In => immutable.Iterable[Out]) { + def close(): Unit + } + + private final class LifecycleMapConcatStage[In, Out]( + name: String, + create: () => LifecycleMapConcat[In, Out]) + extends GraphStage[FlowShape[In, Out]] { + private val in = Inlet[In](s"$name.in") + private val out = Outlet[Out](s"$name.out") + override val shape: FlowShape[In, Out] = FlowShape(in, out) + + override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = + new GraphStageLogic(shape) with InHandler with OutHandler { + private val handler = create() + private var pending = Iterator.empty[Out] + private var upstreamFinished = false + + override def onPush(): Unit = { + pending = handler(grab(in)).iterator + pushOrPull() + } + + override def onPull(): Unit = + pushOrPull() + + override def onUpstreamFinish(): Unit = { + upstreamFinished = true + if (!pending.hasNext) completeStage() + } + + override def postStop(): Unit = + handler.close() + + private def pushOrPull(): Unit = + if (pending.hasNext) push(out, pending.next()) + else if (upstreamFinished) completeStage() + else if (!hasBeenPulled(in)) pull(in) + + setHandler(in, this) + setHandler(out, this) + } + } + + private final case class CompressedFrame(header: FrameHeader, data: ByteString, appendTail: Boolean) { + def append(next: ByteString): CompressedFrame = copy(data = data ++ next) + } + + private final case class UncompressedFrame(header: FrameHeader, data: ByteString, removeTail: Boolean) { + def append(next: ByteString): UncompressedFrame = copy(data = data ++ next) + } +} diff --git a/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/WebSocket.scala b/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/WebSocket.scala index 0ee01d0c48..5d1d37656a 100644 --- a/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/WebSocket.scala +++ b/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/WebSocket.scala @@ -47,10 +47,12 @@ private[http] object WebSocket { def stack( serverSide: Boolean, websocketSettings: WebSocketSettings, + perMessageDeflate: Option[PerMessageDeflate.Negotiated] = None, closeTimeout: FiniteDuration = 3.seconds, // TODO put close timeout into the settings? log: LoggingAdapter): BidiFlow[FrameEvent, Message, Message, FrameEvent, NotUsed] = masking(serverSide, websocketSettings.randomFactory).atop( FrameLogger.logFramesIfEnabled(websocketSettings.logFrames)).atop( + perMessageDeflate.map(_.bidiFlow).getOrElse(BidiFlow.identity)).atop( frameHandling(serverSide, closeTimeout, log)).atop( periodicKeepAlive(websocketSettings)).atop( messageAPI(serverSide, closeTimeout)) diff --git a/http-core/src/main/scala/org/apache/pekko/http/impl/settings/WebSocketSettingsImpl.scala b/http-core/src/main/scala/org/apache/pekko/http/impl/settings/WebSocketSettingsImpl.scala index 7410699bda..831ca8c973 100644 --- a/http-core/src/main/scala/org/apache/pekko/http/impl/settings/WebSocketSettingsImpl.scala +++ b/http-core/src/main/scala/org/apache/pekko/http/impl/settings/WebSocketSettingsImpl.scala @@ -31,7 +31,8 @@ private[pekko] final case class WebSocketSettingsImpl( periodicKeepAliveMode: String, periodicKeepAliveMaxIdle: Duration, periodicKeepAliveData: () => ByteString, - logFrames: Boolean) + logFrames: Boolean, + compression: WebSocketCompressionSettingsImpl) extends pekko.http.scaladsl.settings.WebSocketSettings { require( @@ -58,21 +59,62 @@ private[pekko] object WebSocketSettingsImpl { // on purpose not extending Settin def serverFromRoot(root: Config): WebSocketSettingsImpl = server(root.getConfig("pekko.http.server.websocket")) def server(config: Config): WebSocketSettingsImpl = - fromConfig(config) + fromConfig(config, WebSocketCompressionSettingsImpl.fromConfig(config.getConfig("compression"))) def clientFromRoot(root: Config): WebSocketSettingsImpl = client(root.getConfig("pekko.http.client.websocket")) def client(config: Config): WebSocketSettingsImpl = - fromConfig(config) + fromConfig(config, WebSocketCompressionSettingsImpl.Disabled) - private def fromConfig(inner: Config): WebSocketSettingsImpl = { + private def fromConfig(inner: Config, compression: WebSocketCompressionSettingsImpl): WebSocketSettingsImpl = { val c = inner WebSocketSettingsImpl( Randoms.SecureRandomInstances, c.getString("periodic-keep-alive-mode"), // mode could be extended to be a factory of pings, if we'd need control over the data field c.getPotentiallyInfiniteDuration("periodic-keep-alive-max-idle"), NoPeriodicKeepAliveData, - c.getBoolean("log-frames")) + c.getBoolean("log-frames"), + compression) } } + +/** INTERNAL API */ +@InternalApi +private[pekko] final case class WebSocketCompressionSettingsImpl( + enabled: Boolean, + maxAllocation: Long, + compressionLevel: Int, + preferredClientWindowSize: Int, + allowServerNoContext: Boolean, + preferredClientNoContext: Boolean) { + require(maxAllocation >= 0, "websocket compression max-allocation must be >= 0") + require(compressionLevel >= 0 && compressionLevel <= 9, "websocket compression level must be between 0 and 9") + require( + preferredClientWindowSize >= 8 && preferredClientWindowSize <= 15, + "websocket compression preferred-client-window-size must be between 8 and 15") +} + +/** INTERNAL API */ +@InternalApi +private[pekko] object WebSocketCompressionSettingsImpl { + val Disabled: WebSocketCompressionSettingsImpl = + WebSocketCompressionSettingsImpl( + enabled = false, + maxAllocation = 0, + compressionLevel = 6, + preferredClientWindowSize = 15, + allowServerNoContext = false, + preferredClientNoContext = false) + + def fromConfig(c: Config): WebSocketCompressionSettingsImpl = { + val perMessageDeflate = c.getConfig("permessage-deflate") + WebSocketCompressionSettingsImpl( + c.getBoolean("enabled"), + c.getBytes("max-allocation"), + perMessageDeflate.getInt("compression-level"), + perMessageDeflate.getInt("preferred-client-window-size"), + perMessageDeflate.getBoolean("allow-server-no-context"), + perMessageDeflate.getBoolean("preferred-client-no-context")) + } +} diff --git a/http-core/src/test/scala/org/apache/pekko/http/impl/engine/parsing/HttpHeaderParserSpec.scala b/http-core/src/test/scala/org/apache/pekko/http/impl/engine/parsing/HttpHeaderParserSpec.scala index 09092a28e0..32b18dba72 100644 --- a/http-core/src/test/scala/org/apache/pekko/http/impl/engine/parsing/HttpHeaderParserSpec.scala +++ b/http-core/src/test/scala/org/apache/pekko/http/impl/engine/parsing/HttpHeaderParserSpec.scala @@ -312,6 +312,9 @@ abstract class HttpHeaderParserSpec(mode: String, newLine: String) extends Pekko parseAndCache(s"Connection: close${newLine}x")(s"CONNECTION: close${newLine}x") shouldEqual Connection("close") parseAndCache(s"Host: spray.io:123${newLine}x")(s"HOST: spray.io:123${newLine}x") shouldEqual Host("spray.io", 123) + parseAndCache(s"Sec-WebSocket-Extensions: permessage-deflate${newLine}x")( + s"SEC-WEBSOCKET-EXTENSIONS: permessage-deflate${newLine}x") shouldEqual `Sec-WebSocket-Extensions`( + Seq(WebSocketExtension("permessage-deflate"))) // don't parse other headers parseAndCache(s"User-Agent: hmpf${newLine}x")(s"USER-AGENT: hmpf${newLine}x") shouldEqual RawHeader("User-Agent", diff --git a/http-core/src/test/scala/org/apache/pekko/http/impl/engine/ws/WebSocketServerSpec.scala b/http-core/src/test/scala/org/apache/pekko/http/impl/engine/ws/WebSocketServerSpec.scala index 82a9c46271..e0a8d334a3 100644 --- a/http-core/src/test/scala/org/apache/pekko/http/impl/engine/ws/WebSocketServerSpec.scala +++ b/http-core/src/test/scala/org/apache/pekko/http/impl/engine/ws/WebSocketServerSpec.scala @@ -13,6 +13,10 @@ package org.apache.pekko.http.impl.engine.ws +import java.io.ByteArrayOutputStream +import java.util.zip.Deflater +import java.util.zip.Inflater + import org.apache.pekko import pekko.actor.ActorSystem import pekko.http.scaladsl.model.ws._ @@ -22,12 +26,15 @@ import pekko.stream.scaladsl.{ Flow, Keep, Sink, Source } import pekko.stream.testkit.Utils import pekko.util.ByteString import pekko.http.impl.engine.server.HttpServerTestSetupBase +import pekko.http.impl.settings.WebSocketSettingsImpl import pekko.http.impl.util.PekkoSpecWithMaterializer import scala.concurrent.duration._ class WebSocketServerSpec extends PekkoSpecWithMaterializer("pekko.http.server.websocket.log-frames = on") { spec => + private val EmptyDeflateBlock = ByteString(0x00) + "The server-side WebSocket integration should" should { "establish a websocket connection when the user requests it" should { "when user handler instantly tries to send messages" in Utils.assertAllStagesStopped { @@ -177,21 +184,1192 @@ class WebSocketServerSpec extends PekkoSpecWithMaterializer("pekko.http.server.w } } "prevent the selection of an unavailable subprotocol" in pending - "reject invalid WebSocket handshakes" should { - "missing `Upgrade: websocket` header" in pending - "missing `Connection: upgrade` header" in pending - "missing `Sec-WebSocket-Key header" in pending - "`Sec-WebSocket-Key` with wrong amount of base64 encoded data" in pending - "missing `Sec-WebSocket-Version` header" in pending - "unsupported `Sec-WebSocket-Version`" in pending - } - } + "support permessage-deflate compression" should { + "negotiate permessage-deflate when requested by the client" in Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") - class TestSetup extends HttpServerTestSetupBase with WSTestSetupBase { - implicit def system: ActorSystem = spec.system - implicit def materializer: Materializer = spec.materializer + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) - def expectBytes(length: Int): ByteString = netOut.expectBytes(length) - def expectBytes(bytes: ByteString): Unit = netOut.expectBytes(bytes) + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + + "negotiate permessage-deflate from repeated extension headers" in Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest( + "Sec-WebSocket-Extensions: unknown-extension\r\n" + + "Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + + "negotiate permessage-deflate from a fallback offer in the same extension header" in + Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest( + "Sec-WebSocket-Extensions: permessage-deflate; server_max_window_bits=14, permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + + "inflate inbound permessage-deflate messages before passing them to the application" in + Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + sendWSFrame(Protocol.Opcode.Text, deflatePerMessage(ByteString("compressed client message")), fin = true, + mask = true, rsv1 = true) + expectCompressedTextFrame("compressed client message") + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + + "inflate empty inbound permessage-deflate messages before passing them to the application" in + Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + sendWSFrame(Protocol.Opcode.Text, EmptyDeflateBlock, fin = true, mask = true, rsv1 = true) + expectCompressedTextFrame("") + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + + "pass uncompressed inbound messages through when permessage-deflate is negotiated" in + Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + sendWSFrame(Protocol.Opcode.Text, ByteString("plain client message"), fin = true, mask = true) + expectCompressedTextFrame("plain client message") + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + + "handle fragmented compressed messages with interleaved control frames" in Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + val (firstPart, secondPart) = + deflatePerMessageFrames(ByteString("fragmented compressed client message"), splitAt = 16) + + sendWSFrame(Protocol.Opcode.Text, firstPart, fin = false, mask = true, rsv1 = true) + sendWSFrame(Protocol.Opcode.Ping, ByteString("ping"), fin = true, mask = true) + val firstResponse = expectCompressedFrame(Protocol.Opcode.Text, fin = false, rsv1 = true) + expectWSFrame(Protocol.Opcode.Pong, ByteString("ping"), fin = true) + sendWSFrame(Protocol.Opcode.Continuation, secondPart, fin = true, mask = true) + + val secondResponse = expectCompressedFrame(Protocol.Opcode.Continuation, fin = false, rsv1 = false) + val finalResponse = expectCompressedFrame(Protocol.Opcode.Continuation, fin = true, rsv1 = false) + inflatePerMessageFrames(Seq(firstResponse, secondResponse, finalResponse)).utf8String shouldEqual + "fragmented compressed client message" + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + + "not finish fragmented compressed messages at split frame data boundaries" in Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + val (firstFrame, secondFrame) = + deflatePerMessageFrames(ByteString("split frame data compressed client message"), splitAt = 22) + + sendWSFrameInTwoNetworkChunks(Protocol.Opcode.Text, firstFrame, fin = false, mask = true, rsv1 = true) + sendWSFrame(Protocol.Opcode.Continuation, secondFrame, fin = true, mask = true) + + val firstResponse = expectCompressedFrame(Protocol.Opcode.Text, fin = false, rsv1 = true) + val secondResponse = expectCompressedFrame(Protocol.Opcode.Continuation, fin = false, rsv1 = false) + val finalResponse = expectCompressedFrame(Protocol.Opcode.Continuation, fin = true, rsv1 = false) + inflatePerMessageFrames(Seq(firstResponse, secondResponse, finalResponse)).utf8String shouldEqual + "split frame data compressed client message" + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + + "finish fragmented compressed messages when the last fragment only completes the inflater" in + Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + val payload = + ByteString.fromArrayUnsafe(hexToBytes( + "677170647a777a737574656b707a787a6f6a7561756578756f6b7868616371716c657a6d64697479766d726f6" + + "269746c6376777464776f6f72767a726f64667278676764687775786f6762766d776d706b76697773777a7072" + + "6a6a737279707a7078697a6c69616d7461656d646278626d786f66666e686e776a7a7461746d7a776668776b6" + + "f6f736e73746575637a6d727a7175707a6e74627578687871767771697a71766c64626d78726d6d7675756877" + + "62667963626b687a726d676e646263776e67797264706d6c6863626577616967706a78636a72697464756e627" + + "977616f79736475676f76736f7178746a7a7479626c64636b6b6778637768746c62")) + val frames = deflatePerMessageFrames(payload, compressionLevel = 9, fragmentCount = 4) + frames.last.length shouldEqual 1 + frames.last.head shouldEqual 1 + + sendWSFrame(Protocol.Opcode.Text, frames(0), fin = false, mask = true, rsv1 = true) + sendWSFrame(Protocol.Opcode.Continuation, frames(1), fin = false, mask = true) + sendWSFrame(Protocol.Opcode.Continuation, frames(2), fin = false, mask = true) + sendWSFrame(Protocol.Opcode.Continuation, frames(3), fin = true, mask = true) + + val firstResponse = expectCompressedFrame(Protocol.Opcode.Text, fin = false, rsv1 = true) + inflatePerMessageFrames(firstResponse +: expectCompressedContinuationFramesUntilFinal()) shouldEqual payload + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + + "deflate outbound messages sent by the application" in Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages( + Flow.fromSinkAndSource(Sink.ignore, Source.single(TextMessage.Strict("compressed server message")))) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + expectCompressedTextFrame("compressed server message") + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + closeNetworkInput() + expectNetworkClose() + } + } + + "deflate empty outbound messages sent by the application" in Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages( + Flow.fromSinkAndSource(Sink.ignore, Source.single(TextMessage.Strict("")))) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + val payload = expectCompressedFrame(Protocol.Opcode.Text, fin = true, rsv1 = true) + payload shouldEqual EmptyDeflateBlock + inflatePerMessage(payload) shouldEqual ByteString.empty + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + closeNetworkInput() + expectNetworkClose() + } + } + + "deflate streamed outbound messages without aggregating the full message" in Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages( + Flow.fromSinkAndSource( + Sink.ignore, + Source.single(TextMessage(Source(List("streamed ", "server ", "message")))))) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + val firstResponse = expectCompressedFrame(Protocol.Opcode.Text, fin = false, rsv1 = true) + val secondResponse = expectCompressedFrame(Protocol.Opcode.Continuation, fin = false, rsv1 = false) + val thirdResponse = expectCompressedFrame(Protocol.Opcode.Continuation, fin = false, rsv1 = false) + val finalResponse = expectCompressedFrame(Protocol.Opcode.Continuation, fin = true, rsv1 = false) + inflatePerMessageFrames(Seq(firstResponse, secondResponse, thirdResponse, + finalResponse)).utf8String shouldEqual + "streamed server message" + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + closeNetworkInput() + expectNetworkClose() + } + } + + "fail invalid compressed messages with a protocol error" in Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + sendWSFrame(Protocol.Opcode.Text, ByteString(0xFF, 0xFF, 0xFF), fin = true, mask = true, rsv1 = true) + expectWSCloseFrame(Protocol.CloseCodes.ProtocolError) + + closeNetworkInput() + expectNetworkClose() + } + } + + "fail non-continuation frames during fragmented compressed messages with a protocol error" in + Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + val (firstPart, _) = deflatePerMessageFrames(ByteString("fragmented compressed client message"), splitAt = 16) + + sendWSFrame(Protocol.Opcode.Text, firstPart, fin = false, mask = true, rsv1 = true) + sendWSFrame(Protocol.Opcode.Text, ByteString("not a continuation"), fin = true, mask = true) + expectCompressedFrame(Protocol.Opcode.Text, fin = false, rsv1 = true) + expectWSCloseFrame(Protocol.CloseCodes.ProtocolError) + + closeNetworkInput() + expectNetworkClose() + } + } + + "not negotiate permessage-deflate when websocket compression is disabled" in Utils.assertAllStagesStopped { + new TestSetup { + override def settings = { + val defaults = super.settings.websocketSettings.asInstanceOf[WebSocketSettingsImpl] + super.settings.withWebsocketSettings( + defaults.copy(compression = defaults.compression.copy(enabled = false))) + } + + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + + "negotiate configured permessage-deflate parameters" in Utils.assertAllStagesStopped { + new TestSetup { + override def settings = { + val defaults = super.settings.websocketSettings.asInstanceOf[WebSocketSettingsImpl] + super.settings.withWebsocketSettings( + defaults.copy(compression = defaults.compression.copy( + preferredClientWindowSize = 11, + allowServerNoContext = true, + preferredClientNoContext = true))) + } + + sendWebSocketRequest( + "Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits; server_max_window_bits=15; " + + "client_no_context_takeover; server_no_context_takeover\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits=11; server_max_window_bits=15; client_no_context_takeover; server_no_context_takeover + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + + "negotiate explicit client_max_window_bits values" in Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits=12\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits=12 + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + + "not negotiate invalid client_max_window_bits values" in Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits=7\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + + "not negotiate client_no_context_takeover when it was not requested by the client" in + Utils.assertAllStagesStopped { + new TestSetup { + override def settings = { + val defaults = super.settings.websocketSettings.asInstanceOf[WebSocketSettingsImpl] + super.settings.withWebsocketSettings( + defaults.copy(compression = defaults.compression.copy(preferredClientNoContext = true))) + } + + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + + "not negotiate permessage-deflate when server_no_context_takeover is not allowed" in + Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + + "reset the server compressor when server_no_context_takeover is negotiated" in Utils.assertAllStagesStopped { + new TestSetup { + override def settings = { + val defaults = super.settings.websocketSettings.asInstanceOf[WebSocketSettingsImpl] + super.settings.withWebsocketSettings( + defaults.copy(compression = defaults.compression.copy(allowServerNoContext = true))) + } + + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val message = "same server message same server message same server message" + val response = upgrade.get.handleMessages( + Flow.fromSinkAndSource( + Sink.ignore, + Source(List(TextMessage.Strict(message), TextMessage.Strict(message))))) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + val firstPayload = expectCompressedFrame(Protocol.Opcode.Text, fin = true, rsv1 = true) + val secondPayload = expectCompressedFrame(Protocol.Opcode.Text, fin = true, rsv1 = true) + firstPayload shouldEqual deflatePerMessage(ByteString(message)) + secondPayload shouldEqual firstPayload + inflatePerMessage(firstPayload).utf8String shouldEqual message + inflatePerMessage(secondPayload).utf8String shouldEqual message + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + closeNetworkInput() + expectNetworkClose() + } + } + + "reset the server decompressor when client_no_context_takeover is negotiated" in Utils.assertAllStagesStopped { + new TestSetup { + override def settings = { + val defaults = super.settings.websocketSettings.asInstanceOf[WebSocketSettingsImpl] + super.settings.withWebsocketSettings( + defaults.copy(compression = defaults.compression.copy(preferredClientNoContext = true))) + } + + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + val Seq(firstPayload, secondPayload) = deflatePerMessagesWithContext( + ByteString("same client message same client message same client message"), + ByteString("same client message same client message same client message")) + + sendWSFrame(Protocol.Opcode.Text, firstPayload, fin = true, mask = true, rsv1 = true) + expectCompressedTextFrame("same client message same client message same client message") + sendWSFrame(Protocol.Opcode.Text, secondPayload, fin = true, mask = true, rsv1 = true) + expectWSCloseFrame(Protocol.CloseCodes.ProtocolError) + + closeNetworkInput() + expectNetworkClose() + } + } + + "not compress server control frames" in Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.asInstanceOf[UpgradeToWebSocketLowLevel].handleFrames( + Flow.fromSinkAndSource( + Sink.ignore, + Source.single(FrameEvent.fullFrame(Protocol.Opcode.Ping, None, ByteString("server ping"), fin = true)))) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + expectWSFrame(Protocol.Opcode.Ping, ByteString("server ping"), fin = true) + netOut.expectComplete() + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + closeNetworkInput() + } + } + + "fail low-level outbound data frames with reserved bits when compression is negotiated" in + Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.asInstanceOf[UpgradeToWebSocketLowLevel].handleFrames( + Flow.fromSinkAndSource( + Sink.ignore, + Source.single(FrameEvent.fullFrame( + Protocol.Opcode.Text, + None, + ByteString("already compressed"), + fin = true, + rsv1 = true)))) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + netOut.expectError() + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + closeNetworkInput() + } + } + + "fail low-level outbound continuation frames with reserved bits when compression is negotiated" in + Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.asInstanceOf[UpgradeToWebSocketLowLevel].handleFrames( + Flow.fromSinkAndSource( + Sink.ignore, + Source(List( + FrameEvent.fullFrame(Protocol.Opcode.Text, None, ByteString("fragmented "), fin = false), + FrameEvent.fullFrame( + Protocol.Opcode.Continuation, + None, + ByteString("message"), + fin = true, + rsv1 = true))))) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + expectCompressedFrame(Protocol.Opcode.Text, fin = false, rsv1 = true) + netOut.expectError() + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + closeNetworkInput() + } + } + + "not negotiate unsupported server_max_window_bits values" in Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate; server_max_window_bits=14\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + + "fail compressed messages exceeding max-allocation" in Utils.assertAllStagesStopped { + new TestSetup { + override def settings = { + val defaults = super.settings.websocketSettings.asInstanceOf[WebSocketSettingsImpl] + super.settings.withWebsocketSettings( + defaults.copy(compression = defaults.compression.copy(maxAllocation = 4))) + } + + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message]) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + sendWSFrame(Protocol.Opcode.Text, deflatePerMessage(ByteString("too large")), fin = true, mask = true, + rsv1 = true) + expectWSCloseFrame(Protocol.CloseCodes.ProtocolError) + + closeNetworkInput() + expectNetworkClose() + } + } + + "fail fragmented compressed messages exceeding max-allocation cumulatively" in Utils.assertAllStagesStopped { + new TestSetup { + override def settings = { + val defaults = super.settings.websocketSettings.asInstanceOf[WebSocketSettingsImpl] + super.settings.withWebsocketSettings( + defaults.copy(compression = defaults.compression.copy(maxAllocation = 4))) + } + + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages(Flow[Message].mapConcat(_ => Nil: List[Message])) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + val (firstFrame, secondFrame) = deflatePerMessageFrames(ByteString("too large"), splitAt = 4) + + sendWSFrame(Protocol.Opcode.Text, firstFrame, fin = false, mask = true, rsv1 = true) + sendWSFrame(Protocol.Opcode.Continuation, secondFrame, fin = true, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.ProtocolError) + + closeNetworkInput() + expectNetworkClose() + } + } + + "apply compression to low-level frame handlers" in Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.asInstanceOf[UpgradeToWebSocketLowLevel].handleFrames(Flow[FrameEvent].map { + case start @ FrameStart(header, data) => + start.copy(header = header.copy(mask = None), data = FrameEventParser.mask(data, header.mask)) + case other => other + }) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Sec-WebSocket-Extensions: permessage-deflate + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + sendWSFrame(Protocol.Opcode.Text, deflatePerMessage(ByteString("compressed low-level message")), fin = true, + mask = true, rsv1 = true) + expectCompressedTextFrame("compressed low-level message") + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + closeNetworkInput() + expectNetworkClose() + } + } + } + + "reject invalid WebSocket handshakes" should { + "missing `Upgrade: websocket` header" in pending + "missing `Connection: upgrade` header" in pending + "missing `Sec-WebSocket-Key header" in pending + "`Sec-WebSocket-Key` with wrong amount of base64 encoded data" in pending + "missing `Sec-WebSocket-Version` header" in pending + "unsupported `Sec-WebSocket-Version`" in pending + } + } + + class TestSetup extends HttpServerTestSetupBase with WSTestSetupBase { + implicit def system: ActorSystem = spec.system + implicit def materializer: Materializer = spec.materializer + + def expectBytes(length: Int): ByteString = netOut.expectBytes(length) + def expectBytes(bytes: ByteString): Unit = netOut.expectBytes(bytes) + + def sendWebSocketRequest(extraHeader: String): Unit = + send( + s"""GET /chat HTTP/1.1 + |Host: server.example.com + |Upgrade: websocket + |Connection: Upgrade + |Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== + |Origin: http://example.com + |Sec-WebSocket-Version: 13 + |$extraHeader + |""".stripMargin) + + def sendWSFrameInTwoNetworkChunks( + opcode: Protocol.Opcode, + data: ByteString, + fin: Boolean, + mask: Boolean = false, + rsv1: Boolean = false): Unit = { + val maskValue = if (mask) Some(0x12345678) else None + val payload = maskValue match { + case Some(value) => WSTestUtils.maskedBytes(data, value)._1 + case None => data + } + val header = WSTestUtils.frameHeader(opcode, data.length, fin, maskValue, rsv1 = rsv1) + val splitAt = math.max(1, payload.length / 2) + send(header ++ payload.take(splitAt)) + send(payload.drop(splitAt)) + } + + def expectCompressedTextFrame(message: String): Unit = { + val payload = expectCompressedFrame(Protocol.Opcode.Text, fin = true, rsv1 = true) + payload should not be ByteString(message) + inflatePerMessage(payload).utf8String shouldEqual message + } + + def expectCompressedFrame(opcode: Protocol.Opcode, fin: Boolean, rsv1: Boolean): ByteString = { + val header = expectBytes(2) + ((header(0) & Protocol.FIN_MASK) != 0) shouldEqual fin + ((header(0) & Protocol.RSV1_MASK) != 0) shouldEqual rsv1 + (header(0) & Protocol.OP_MASK) shouldEqual opcode.code + (header(1) & Protocol.MASK_MASK) shouldEqual 0 + val length = header(1) & 0x7F + expectBytes(length) + } + + def expectCompressedContinuationFramesUntilFinal(): Seq[ByteString] = { + def read(acc: Vector[ByteString]): Vector[ByteString] = { + val header = expectBytes(2) + val fin = (header(0) & Protocol.FIN_MASK) != 0 + ((header(0) & Protocol.RSV1_MASK) != 0) shouldEqual false + (header(0) & Protocol.OP_MASK) shouldEqual Protocol.Opcode.Continuation.code + (header(1) & Protocol.MASK_MASK) shouldEqual 0 + val length = header(1) & 0x7F + val payload = expectBytes(length) + val next = acc :+ payload + if (fin) next else read(next) + } + + read(Vector.empty) + } + } + + private def deflatePerMessage(data: ByteString): ByteString = { + val deflater = new Deflater(6, true) + try { + deflater.setInput(data.toArray) + val output = new ByteArrayOutputStream() + val buffer = new Array[Byte](256) + var count = deflater.deflate(buffer, 0, buffer.length, Deflater.SYNC_FLUSH) + while (count > 0) { + output.write(buffer, 0, count) + count = deflater.deflate(buffer, 0, buffer.length, Deflater.SYNC_FLUSH) + } + val compressed = ByteString.fromArray(output.toByteArray) + compressed.dropRight(4) + } finally { + deflater.end() + } + } + + private def deflatePerMessageFrames(data: ByteString, splitAt: Int): (ByteString, ByteString) = { + val deflater = new Deflater(6, true) + try { + val (first, second) = data.splitAt(splitAt) + (deflateFrame(deflater, first, removeTail = false), deflateFrame(deflater, second, removeTail = true)) + } finally { + deflater.end() + } + } + + private def deflatePerMessageFrames( + data: ByteString, + compressionLevel: Int, + fragmentCount: Int): Seq[ByteString] = { + val compressed = deflatePerMessage(data, compressionLevel) + val fragmentLength = compressed.length / (fragmentCount - 1) + (0 until fragmentCount).map { index => + val offset = index * fragmentLength + val length = if (index == fragmentCount - 1) compressed.length - offset else fragmentLength + compressed.slice(offset, offset + length) + } + } + + private def deflatePerMessagesWithContext(messages: ByteString*): Seq[ByteString] = { + val deflater = new Deflater(6, true) + try messages.map(message => deflateFrame(deflater, message, removeTail = true)) + finally { + deflater.end() + } + } + + private def deflatePerMessage(data: ByteString, compressionLevel: Int): ByteString = { + val deflater = new Deflater(compressionLevel, true) + try deflateFrame(deflater, data, removeTail = true) + finally { + deflater.end() + } + } + + private def deflateFrame(deflater: Deflater, data: ByteString, removeTail: Boolean): ByteString = { + deflater.setInput(data.toArray) + val output = new ByteArrayOutputStream() + val buffer = new Array[Byte](256) + var count = deflater.deflate(buffer, 0, buffer.length, Deflater.SYNC_FLUSH) + while (count > 0) { + output.write(buffer, 0, count) + count = deflater.deflate(buffer, 0, buffer.length, Deflater.SYNC_FLUSH) + } + val compressed = ByteString.fromArray(output.toByteArray) + if (removeTail) compressed.dropRight(4) else compressed + } + + private def hexToBytes(hex: String): Array[Byte] = + hex.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray + + private def inflatePerMessageFrames(frames: Seq[ByteString]): ByteString = { + val inflater = new Inflater(true) + try { + frames.zipWithIndex.foldLeft(ByteString.empty) { + case (inflated, (frame, index)) => + val data = + if (index == frames.length - 1) frame ++ ByteString(0x00, 0x00, 0xFF.toByte, 0xFF.toByte) else frame + inflated ++ inflateFrame(inflater, data) + } + } finally { + inflater.end() + } + } + + private def inflateFrame(inflater: Inflater, data: ByteString): ByteString = { + inflater.setInput(data.toArray) + val output = new ByteArrayOutputStream() + val buffer = new Array[Byte](256) + var count = inflater.inflate(buffer) + while (count > 0) { + output.write(buffer, 0, count) + count = inflater.inflate(buffer) + } + ByteString.fromArray(output.toByteArray) + } + + private def inflatePerMessage(data: ByteString): ByteString = { + val inflater = new Inflater(true) + try { + inflater.setInput((data ++ ByteString(0x00, 0x00, 0xFF.toByte, 0xFF.toByte)).toArray) + val output = new ByteArrayOutputStream() + val buffer = new Array[Byte](256) + var count = inflater.inflate(buffer) + while (count > 0) { + output.write(buffer, 0, count) + count = inflater.inflate(buffer) + } + ByteString.fromArray(output.toByteArray) + } finally { + inflater.end() + } } } From d626ab3f75eb9ff95c8e836b41b1a0d8df2fb22b Mon Sep 17 00:00:00 2001 From: Matthias Kurz Date: Tue, 30 Jun 2026 15:11:23 +0200 Subject: [PATCH 2/2] Allow disabling WebSocket compression per upgrade #61 * Add WebSocketUpgrade overloads that accept a compressionEnabled flag. * Wire accepted WebSockets to decline negotiated compression per request. * Document and test per-upgrade compression disabling. --- .../paradox/server-side/websocket-support.md | 3 + .../websocket-compression.excludes | 3 +- .../pekko/http/impl/engine/ws/Handshake.scala | 22 ++++- .../ws/UpgradeToWebSocketLowLevel.scala | 7 ++ .../javadsl/model/ws/WebSocketUpgrade.scala | 48 +++++++++++ .../scaladsl/model/ws/WebSocketUpgrade.scala | 86 +++++++++++++++++++ .../impl/engine/ws/WebSocketServerSpec.scala | 32 +++++++ 7 files changed, 196 insertions(+), 5 deletions(-) diff --git a/docs/src/main/paradox/server-side/websocket-support.md b/docs/src/main/paradox/server-side/websocket-support.md index a6fc825e9d..f0f8326a6a 100644 --- a/docs/src/main/paradox/server-side/websocket-support.md +++ b/docs/src/main/paradox/server-side/websocket-support.md @@ -168,6 +168,9 @@ The server exposes additional settings for the negotiated extension under `preferred-client-window-size`, `allow-server-no-context`, and `preferred-client-no-context`. See the @ref[configuration reference](../configuration.md) for the complete list of settings and defaults. +If compression is enabled globally, a route can still decline compression for a single accepted WebSocket by using the +`handleMessages` or `handleMessagesWith` overload with `compressionEnabled = false`. + @@@ note The `server_no_context_takeover` and `client_no_context_takeover` extension parameters affect whether compression dictionaries are retained across messages. Retaining context generally improves compression ratio, while disabling diff --git a/http-core/src/main/mima-filters/2.0.x.backwards.excludes/websocket-compression.excludes b/http-core/src/main/mima-filters/2.0.x.backwards.excludes/websocket-compression.excludes index c01eaba7b5..e1742d6fab 100644 --- a/http-core/src/main/mima-filters/2.0.x.backwards.excludes/websocket-compression.excludes +++ b/http-core/src/main/mima-filters/2.0.x.backwards.excludes/websocket-compression.excludes @@ -15,5 +15,6 @@ # specific language governing permissions and limitations # under the License. -# Add server-side WebSocket compression support. +# Add server-side WebSocket compression support and per-upgrade compression control. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.pekko.http.impl.engine.ws.Handshake#Server.buildResponse") +ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.pekko.http.javadsl.model.ws.WebSocketUpgrade.handleMessagesWith") diff --git a/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/Handshake.scala b/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/Handshake.scala index b3cde0a7ff..e98a900c5f 100644 --- a/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/Handshake.scala +++ b/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/Handshake.scala @@ -136,20 +136,34 @@ private[http] object Handshake { def handle( handler: Either[Graph[FlowShape[FrameEvent, FrameEvent], Any], Graph[FlowShape[Message, Message], Any]], - subprotocol: Option[String]): HttpResponse = { + subprotocol: Option[String], + compressionEnabled: Boolean): HttpResponse = { require( subprotocol.forall(chosen => clientSupportedSubprotocols.contains(chosen)), s"Tried to choose invalid subprotocol '$subprotocol' which wasn't offered by the client: [${requestedProtocols.mkString(", ")}]") - buildResponse(key.get, handler, subprotocol, perMessageDeflate, settings, log) + val acceptedPerMessageDeflate = if (compressionEnabled) perMessageDeflate else None + buildResponse(key.get, handler, subprotocol, acceptedPerMessageDeflate, settings, log) } def handleFrames( handlerFlow: Graph[FlowShape[FrameEvent, FrameEvent], Any], subprotocol: Option[String]): HttpResponse = - handle(Left(handlerFlow), subprotocol) + handle(Left(handlerFlow), subprotocol, compressionEnabled = true) + + override private[http] def handleFrames( + handlerFlow: Graph[FlowShape[FrameEvent, FrameEvent], Any], + subprotocol: Option[String], + compressionEnabled: Boolean): HttpResponse = + handle(Left(handlerFlow), subprotocol, compressionEnabled) override def handleMessages(handlerFlow: Graph[FlowShape[Message, Message], Any], subprotocol: Option[String] = None): HttpResponse = - handle(Right(handlerFlow), subprotocol) + handle(Right(handlerFlow), subprotocol, compressionEnabled = true) + + override def handleMessages( + handlerFlow: Graph[FlowShape[Message, Message], Any], + subprotocol: Option[String], + compressionEnabled: Boolean): HttpResponse = + handle(Right(handlerFlow), subprotocol, compressionEnabled) } OptionVal.Some(header) } else OptionVal.None diff --git a/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/UpgradeToWebSocketLowLevel.scala b/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/UpgradeToWebSocketLowLevel.scala index 805dfeb561..ca5fbd1653 100644 --- a/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/UpgradeToWebSocketLowLevel.scala +++ b/http-core/src/main/scala/org/apache/pekko/http/impl/engine/ws/UpgradeToWebSocketLowLevel.scala @@ -42,4 +42,11 @@ private[http] abstract class UpgradeToWebSocketLowLevel extends InternalCustomHe @InternalApi private[http] def handleFrames(handlerFlow: Graph[FlowShape[FrameEvent, FrameEvent], Any], subprotocol: Option[String] = None): HttpResponse + + @InternalApi + private[http] def handleFrames( + handlerFlow: Graph[FlowShape[FrameEvent, FrameEvent], Any], + subprotocol: Option[String], + compressionEnabled: Boolean): HttpResponse = + handleFrames(handlerFlow, subprotocol) } diff --git a/http-core/src/main/scala/org/apache/pekko/http/javadsl/model/ws/WebSocketUpgrade.scala b/http-core/src/main/scala/org/apache/pekko/http/javadsl/model/ws/WebSocketUpgrade.scala index 93ba7eeae0..ddb24aef43 100644 --- a/http-core/src/main/scala/org/apache/pekko/http/javadsl/model/ws/WebSocketUpgrade.scala +++ b/http-core/src/main/scala/org/apache/pekko/http/javadsl/model/ws/WebSocketUpgrade.scala @@ -41,6 +41,17 @@ trait WebSocketUpgrade { */ def handleMessagesWith(handlerFlow: Graph[FlowShape[Message, Message], ? <: Any]): HttpResponse + /** + * Returns a response that can be used to answer a WebSocket handshake request. The connection will afterwards + * use the given handlerFlow to handle WebSocket messages from the client. + * + * The compressionEnabled flag allows declining negotiated WebSocket compression for this accepted WebSocket. + * + * @since 2.0.0 + */ + def handleMessagesWith( + handlerFlow: Graph[FlowShape[Message, Message], ? <: Any], compressionEnabled: Boolean): HttpResponse + /** * Returns a response that can be used to answer a WebSocket handshake request. The connection will afterwards * use the given handlerFlow to handle WebSocket messages from the client. The given subprotocol must be one @@ -48,6 +59,20 @@ trait WebSocketUpgrade { */ def handleMessagesWith(handlerFlow: Graph[FlowShape[Message, Message], ? <: Any], subprotocol: String): HttpResponse + /** + * Returns a response that can be used to answer a WebSocket handshake request. The connection will afterwards + * use the given handlerFlow to handle WebSocket messages from the client. The given subprotocol must be one + * of the ones offered by the client. + * + * The compressionEnabled flag allows declining negotiated WebSocket compression for this accepted WebSocket. + * + * @since 2.0.0 + */ + def handleMessagesWith( + handlerFlow: Graph[FlowShape[Message, Message], ? <: Any], + subprotocol: String, + compressionEnabled: Boolean): HttpResponse + /** * Returns a response that can be used to answer a WebSocket handshake request. The connection will afterwards * use the given inSink to handle WebSocket messages from the client and the given outSource to send messages to the client. @@ -55,6 +80,17 @@ trait WebSocketUpgrade { def handleMessagesWith( inSink: Graph[SinkShape[Message], ? <: Any], outSource: Graph[SourceShape[Message], ? <: Any]): HttpResponse + /** + * Returns a response that can be used to answer a WebSocket handshake request. The connection will afterwards + * use the given inSink to handle WebSocket messages from the client and the given outSource to send messages to the client. + * + * The compressionEnabled flag allows declining negotiated WebSocket compression for this accepted WebSocket. + * + * @since 2.0.0 + */ + def handleMessagesWith(inSink: Graph[SinkShape[Message], ? <: Any], outSource: Graph[SourceShape[Message], ? <: Any], + compressionEnabled: Boolean): HttpResponse + /** * Returns a response that can be used to answer a WebSocket handshake request. The connection will afterwards * use the given inSink to handle WebSocket messages from the client and the given outSource to send messages to the client. @@ -63,4 +99,16 @@ trait WebSocketUpgrade { */ def handleMessagesWith(inSink: Graph[SinkShape[Message], ? <: Any], outSource: Graph[SourceShape[Message], ? <: Any], subprotocol: String): HttpResponse + + /** + * Returns a response that can be used to answer a WebSocket handshake request. The connection will afterwards + * use the given inSink to handle WebSocket messages from the client and the given outSource to send messages to the client. + * + * The given subprotocol must be one of the ones offered by the client. + * The compressionEnabled flag allows declining negotiated WebSocket compression for this accepted WebSocket. + * + * @since 2.0.0 + */ + def handleMessagesWith(inSink: Graph[SinkShape[Message], ? <: Any], outSource: Graph[SourceShape[Message], ? <: Any], + subprotocol: String, compressionEnabled: Boolean): HttpResponse } diff --git a/http-core/src/main/scala/org/apache/pekko/http/scaladsl/model/ws/WebSocketUpgrade.scala b/http-core/src/main/scala/org/apache/pekko/http/scaladsl/model/ws/WebSocketUpgrade.scala index 4a26a9b712..34de701c24 100644 --- a/http-core/src/main/scala/org/apache/pekko/http/scaladsl/model/ws/WebSocketUpgrade.scala +++ b/http-core/src/main/scala/org/apache/pekko/http/scaladsl/model/ws/WebSocketUpgrade.scala @@ -52,6 +52,25 @@ trait WebSocketUpgrade extends jm.ws.WebSocketUpgrade { handlerFlow: Graph[FlowShape[Message, Message], Any], subprotocol: Option[String] = None): HttpResponse + /** + * The high-level interface to create a WebSocket server based on "messages". + * + * Returns a response to return in a request handler that will signal the + * low-level HTTP implementation to upgrade the connection to WebSocket and + * use the supplied handler to handle incoming WebSocket messages. + * + * Optionally, a subprotocol out of the ones requested by the client can be chosen. + * The `compressionEnabled` flag allows declining negotiated WebSocket compression + * for this accepted WebSocket. + * + * @since 2.0.0 + */ + def handleMessages( + handlerFlow: Graph[FlowShape[Message, Message], Any], + subprotocol: Option[String], + compressionEnabled: Boolean): HttpResponse = + handleMessages(handlerFlow, subprotocol) + /** * The high-level interface to create a WebSocket server based on "messages". * @@ -68,6 +87,27 @@ trait WebSocketUpgrade extends jm.ws.WebSocketUpgrade { subprotocol: Option[String] = None): HttpResponse = handleMessages(scaladsl.Flow.fromSinkAndSource(inSink, outSource), subprotocol) + /** + * The high-level interface to create a WebSocket server based on "messages". + * + * Returns a response to return in a request handler that will signal the + * low-level HTTP implementation to upgrade the connection to WebSocket and + * use the supplied inSink to consume messages received from the client and + * the supplied outSource to produce message to sent to the client. + * + * Optionally, a subprotocol out of the ones requested by the client can be chosen. + * The `compressionEnabled` flag allows declining negotiated WebSocket compression + * for this accepted WebSocket. + * + * @since 2.0.0 + */ + def handleMessagesWithSinkSource( + inSink: Graph[SinkShape[Message], Any], + outSource: Graph[SourceShape[Message], Any], + subprotocol: Option[String], + compressionEnabled: Boolean): HttpResponse = + handleMessages(scaladsl.Flow.fromSinkAndSource(inSink, outSource), subprotocol, compressionEnabled) + import scala.jdk.CollectionConverters._ /** @@ -81,6 +121,16 @@ trait WebSocketUpgrade extends jm.ws.WebSocketUpgrade { def handleMessagesWith(handlerFlow: Graph[FlowShape[jm.ws.Message, jm.ws.Message], ? <: Any]): HttpResponse = handleMessages(JavaMapping.toScala(handlerFlow)) + /** + * Java API + * + * @since 2.0.0 + */ + def handleMessagesWith( + handlerFlow: Graph[FlowShape[jm.ws.Message, jm.ws.Message], ? <: Any], + compressionEnabled: Boolean): HttpResponse = + handleMessages(JavaMapping.toScala(handlerFlow), None, compressionEnabled) + /** * Java API */ @@ -88,6 +138,17 @@ trait WebSocketUpgrade extends jm.ws.WebSocketUpgrade { handlerFlow: Graph[FlowShape[jm.ws.Message, jm.ws.Message], ? <: Any], subprotocol: String): HttpResponse = handleMessages(JavaMapping.toScala(handlerFlow), subprotocol = Some(subprotocol)) + /** + * Java API + * + * @since 2.0.0 + */ + def handleMessagesWith( + handlerFlow: Graph[FlowShape[jm.ws.Message, jm.ws.Message], ? <: Any], + subprotocol: String, + compressionEnabled: Boolean): HttpResponse = + handleMessages(JavaMapping.toScala(handlerFlow), subprotocol = Some(subprotocol), compressionEnabled) + /** * Java API */ @@ -97,6 +158,19 @@ trait WebSocketUpgrade extends jm.ws.WebSocketUpgrade { /** * Java API + * + * @since 2.0.0 + */ + def handleMessagesWith( + inSink: Graph[SinkShape[jm.ws.Message], ? <: Any], + outSource: Graph[SourceShape[jm.ws.Message], ? <: Any], + compressionEnabled: Boolean): HttpResponse = + handleMessages(createScalaFlow(inSink, outSource), None, compressionEnabled) + + /** + * Java API + * + * @since 2.0.0 */ def handleMessagesWith( inSink: Graph[SinkShape[jm.ws.Message], ? <: Any], @@ -104,6 +178,18 @@ trait WebSocketUpgrade extends jm.ws.WebSocketUpgrade { subprotocol: String): HttpResponse = handleMessages(createScalaFlow(inSink, outSource), subprotocol = Some(subprotocol)) + /** + * Java API + * + * @since 2.0.0 + */ + def handleMessagesWith( + inSink: Graph[SinkShape[jm.ws.Message], ? <: Any], + outSource: Graph[SourceShape[jm.ws.Message], ? <: Any], + subprotocol: String, + compressionEnabled: Boolean): HttpResponse = + handleMessages(createScalaFlow(inSink, outSource), subprotocol = Some(subprotocol), compressionEnabled) + private def createScalaFlow(inSink: Graph[SinkShape[jm.ws.Message], ? <: Any], outSource: Graph[SourceShape[jm.ws.Message], ? <: Any]): Graph[FlowShape[Message, Message], NotUsed] = JavaMapping.toScala(scaladsl.Flow.fromSinkAndSourceMat(inSink, outSource)(scaladsl.Keep.none): Graph[FlowShape[ diff --git a/http-core/src/test/scala/org/apache/pekko/http/impl/engine/ws/WebSocketServerSpec.scala b/http-core/src/test/scala/org/apache/pekko/http/impl/engine/ws/WebSocketServerSpec.scala index e0a8d334a3..0146cc9488 100644 --- a/http-core/src/test/scala/org/apache/pekko/http/impl/engine/ws/WebSocketServerSpec.scala +++ b/http-core/src/test/scala/org/apache/pekko/http/impl/engine/ws/WebSocketServerSpec.scala @@ -695,6 +695,38 @@ class WebSocketServerSpec extends PekkoSpecWithMaterializer("pekko.http.server.w } } + "not negotiate permessage-deflate when compression is disabled for the accepted WebSocket" in + Utils.assertAllStagesStopped { + new TestSetup { + sendWebSocketRequest("Sec-WebSocket-Extensions: permessage-deflate\r\n") + + val request = expectRequest() + val upgrade = request.attribute(webSocketUpgrade) + val response = upgrade.get.handleMessages( + Flow.fromSinkAndSource(Sink.ignore, Source.single(TextMessage.Strict("plain server message"))), + None, + compressionEnabled = false) + responses.sendNext(response) + + expectResponseWithWipedDate( + """HTTP/1.1 101 Switching Protocols + |Upgrade: websocket + |Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= + |Server: pekko-http/test + |Date: XXXX + |Connection: upgrade + | + |""") + + expectWSFrame(Protocol.Opcode.Text, ByteString("plain server message"), fin = true) + expectWSCloseFrame(Protocol.CloseCodes.Regular) + + sendWSCloseFrame(Protocol.CloseCodes.Regular, mask = true) + closeNetworkInput() + expectNetworkClose() + } + } + "negotiate configured permessage-deflate parameters" in Utils.assertAllStagesStopped { new TestSetup { override def settings = {