diff --git a/.scalafmt.conf b/.scalafmt.conf new file mode 100644 index 0000000..20029ea --- /dev/null +++ b/.scalafmt.conf @@ -0,0 +1,27 @@ +version="2.7.5" +align = none +align.openParenCallSite = false +align.openParenDefnSite = false +align.tokens = [] +assumeStandardLibraryStripMargin = false +binPack.parentConstructors = false +continuationIndent.callSite = 2 +continuationIndent.defnSite = 2 +danglingParentheses = true +docstrings = ScalaDoc +docstrings.blankFirstLine = yes +encoding = UTF-8 +importSelectors = singleLine +includeCurlyBraceInSelectChains = true +indentOperator = spray +lineEndings = unix +maxColumn = 80 +newlines.alwaysBeforeTopLevelStatements = true +newlines.sometimesBeforeColonInMethodReturnType = false +optIn.breakChainOnFirstMethodDot = true +rewrite.rules = [ + PreferCurlyFors +] +spaces { + inImportCurlyBraces = false +} diff --git a/benchmarks/src/main/scala/fs2/netty/benchmarks/echo/Fs2Netty.scala b/benchmarks/src/main/scala/fs2/netty/benchmarks/echo/Fs2Netty.scala index 4ac553b..d1ba27b 100644 --- a/benchmarks/src/main/scala/fs2/netty/benchmarks/echo/Fs2Netty.scala +++ b/benchmarks/src/main/scala/fs2/netty/benchmarks/echo/Fs2Netty.scala @@ -29,7 +29,7 @@ object Fs2Netty extends IOApp { val port = Port(args(1).toInt).get val rsrc = Network[IO] flatMap { net => - val handlers = net.server(host, port) map { client => + val handlers = net.server(host, port, options = Nil) map { client => client.reads.through(client.writes).attempt.void } diff --git a/build.sbt b/build.sbt index c96fe33..cc6635e 100644 --- a/build.sbt +++ b/build.sbt @@ -25,7 +25,7 @@ ThisBuild / organizationName := "Typelevel" ThisBuild / startYear := Some(2021) -ThisBuild / crossScalaVersions := Seq("2.12.12", "2.13.4", "3.0.0-M3") +ThisBuild / crossScalaVersions := Seq("2.12.12", "2.13.4") // "3.0.0-M3" temporarily removed to easier/speedier ThisBuild / githubWorkflowOSes ++= Seq("macos-latest", "windows-latest") diff --git a/core/src/main/scala/fs2/netty/Socket.scala b/core/src/main/scala/fs2/netty/NettyChannelInitializer.scala similarity index 56% rename from core/src/main/scala/fs2/netty/Socket.scala rename to core/src/main/scala/fs2/netty/NettyChannelInitializer.scala index 018e949..f080980 100644 --- a/core/src/main/scala/fs2/netty/Socket.scala +++ b/core/src/main/scala/fs2/netty/NettyChannelInitializer.scala @@ -14,18 +14,20 @@ * limitations under the License. */ -package fs2 -package netty +package fs2.netty -import com.comcast.ip4s.{IpAddress, SocketAddress} +import fs2.netty.pipeline.socket.Socket +import io.netty.channel.socket.SocketChannel +import io.netty.channel.{Channel, ChannelInitializer} -trait Socket[F[_]] { +trait NettyChannelInitializer[F[_], O, I] { - def localAddress: F[SocketAddress[IpAddress]] - def remoteAddress: F[SocketAddress[IpAddress]] + def toSocketChannelInitializer( + cb: Socket[F, O, I] => F[Unit] + ): F[ChannelInitializer[SocketChannel]] = + toChannelInitializer[SocketChannel](cb) - def reads: Stream[F, Byte] - - def write(bytes: Chunk[Byte]): F[Unit] - def writes: Pipe[F, Byte, INothing] + def toChannelInitializer[C <: Channel]( + cb: Socket[F, O, I] => F[Unit] + ): F[ChannelInitializer[C]] } diff --git a/core/src/main/scala/fs2/netty/Network.scala b/core/src/main/scala/fs2/netty/Network.scala index 43e1364..8d7c635 100644 --- a/core/src/main/scala/fs2/netty/Network.scala +++ b/core/src/main/scala/fs2/netty/Network.scala @@ -17,117 +17,183 @@ package fs2 package netty -import cats.effect.{Async, Concurrent, Resource, Sync} import cats.effect.std.{Dispatcher, Queue} +import cats.effect.{Async, Concurrent, Resource, Sync} import cats.syntax.all._ - import com.comcast.ip4s.{Host, IpAddress, Port, SocketAddress} - +import fs2.netty.pipeline.NettyPipeline +import fs2.netty.pipeline.socket.Socket import io.netty.bootstrap.{Bootstrap, ServerBootstrap} -import io.netty.channel.{Channel, ChannelInitializer, ChannelOption => JChannelOption, EventLoopGroup, ServerChannel} -import io.netty.channel.socket.SocketChannel +import io.netty.buffer.ByteBuf +import io.netty.channel.{Channel, EventLoopGroup, ServerChannel, ChannelOption => JChannelOption} import java.net.InetSocketAddress import java.util.concurrent.ThreadFactory import java.util.concurrent.atomic.AtomicInteger +// TODO: Do we need to distinguish between TCP (connection based network) and UDP (connection-less network)? final class Network[F[_]: Async] private ( - parent: EventLoopGroup, - child: EventLoopGroup, - clientChannelClazz: Class[_ <: Channel], - serverChannelClazz: Class[_ <: ServerChannel]) { + parent: EventLoopGroup, // TODO: custom value class? + child: EventLoopGroup, + clientChannelClazz: Class[_ <: Channel], + serverChannelClazz: Class[_ <: ServerChannel] +) { def client( - addr: SocketAddress[Host], - options: List[ChannelOption] = Nil) - : Resource[F, Socket[F]] = - Dispatcher[F] flatMap { disp => - Resource suspend { - Concurrent[F].deferred[Socket[F]] flatMap { d => - addr.host.resolve[F] flatMap { resolved => - Sync[F] delay { - val bootstrap = new Bootstrap - bootstrap.group(child) - .channel(clientChannelClazz) - .option(JChannelOption.AUTO_READ.asInstanceOf[JChannelOption[Any]], false) // backpressure - .handler(initializer(disp)(d.complete(_).void)) - - options.foreach(opt => bootstrap.option(opt.key, opt.value)) - - val connectChannel = Sync[F] defer { - val cf = bootstrap.connect(resolved.toInetAddress, addr.port.value) - fromNettyFuture[F](cf.pure[F]).as(cf.channel()) - } - - Resource.make(connectChannel <* d.get)(ch => fromNettyFuture(Sync[F].delay(ch.close())).void).evalMap(_ => d.get) - } - } + addr: SocketAddress[Host], + options: List[ChannelOption] + ): Resource[F, Socket[F, ByteBuf, ByteBuf]] = + for { + disp <- Dispatcher[F] + pipeline <- Resource.eval(NettyPipeline(disp)) + c <- client(addr, pipeline, options) + } yield c + + def client[O, I]( + addr: SocketAddress[Host], + pipelineInitializer: NettyChannelInitializer[F, O, I], + options: List[ChannelOption] + ): Resource[F, Socket[F, O, I]] = + Resource.suspend { + for { + futureSocket <- Concurrent[F].deferred[Socket[F, O, I]] + + initializer <- pipelineInitializer.toSocketChannelInitializer( + futureSocket.complete(_).void + ) + + resolvedHost <- addr.host.resolve[F] + + bootstrap <- Sync[F].delay { + val bootstrap = new Bootstrap + bootstrap + .group(child) + .channel(clientChannelClazz) + .option( + JChannelOption.AUTO_READ.asInstanceOf[JChannelOption[Any]], + false + ) // backpressure TODO: backpressure creating the connection or is this reads? + .handler(initializer) + + options.foreach(opt => bootstrap.option(opt.key, opt.value)) + bootstrap } - } + + // TODO: Log properly as info, debug, or trace. Or send as an event to another stream. Maybe the whole network could have an event stream. + _ <- Sync[F].delay(println(bootstrap.config())) + + connectChannel = Sync[F] defer { + val cf = + bootstrap.connect(resolvedHost.toInetAddress, addr.port.value) + fromNettyFuture[F](cf.pure[F]).as(cf.channel()) + } + } yield Resource + .make(connectChannel <* futureSocket.get)(ch => + fromNettyFuture(Sync[F].delay(ch.close())).void + ) + .evalMap(_ => futureSocket.get) } + //TODO: Add back default args for opts, removed to fix compilation error for overloaded method def server( - host: Option[Host], - port: Port, - options: List[ChannelOption] = Nil) - : Stream[F, Socket[F]] = + host: Option[Host], + port: Port, + options: List[ChannelOption] + ): Stream[F, Socket[F, ByteBuf, ByteBuf]] = Stream.resource(serverResource(host, Some(port), options)).flatMap(_._2) + // TODO: maybe here it's nicer to have the I first then O?, or will that be confusing if Socket has reversed order? + def server[O, I: Socket.Decoder]( + host: Option[Host], + port: Port, + pipelineInitializer: NettyChannelInitializer[F, O, I], + options: List[ChannelOption] + ): Stream[F, Socket[F, O, I]] = + Stream + .resource( + serverResource[O, I](host, Some(port), pipelineInitializer, options) + ) + .flatMap(_._2) + def serverResource( - host: Option[Host], - port: Option[Port], - options: List[ChannelOption] = Nil) - : Resource[F, (SocketAddress[IpAddress], Stream[F, Socket[F]])] = - Dispatcher[F] flatMap { disp => - Resource suspend { - Queue.unbounded[F, Socket[F]] flatMap { sockets => - host.traverse(_.resolve[F]) flatMap { resolved => - Sync[F] delay { - val bootstrap = new ServerBootstrap - bootstrap.group(parent, child) - .option(JChannelOption.AUTO_READ.asInstanceOf[JChannelOption[Any]], false) // backpressure - .channel(serverChannelClazz) - .childHandler(initializer(disp)(sockets.offer)) - - options.foreach(opt => bootstrap.option(opt.key, opt.value)) - - val connectChannel = Sync[F] defer { - val cf = bootstrap.bind( - resolved.map(_.toInetAddress).orNull, - port.map(_.value).getOrElse(0)) - fromNettyFuture[F](cf.pure[F]).as(cf.channel()) - } - - val connection = Resource.make(connectChannel) { ch => - fromNettyFuture[F](Sync[F].delay(ch.close())).void - } - - connection evalMap { ch => - Sync[F].delay(SocketAddress.fromInetSocketAddress(ch.localAddress().asInstanceOf[InetSocketAddress])).tupleRight( - Stream.repeatEval(Sync[F].delay(ch.read()) *> sockets.take)) - } - } - } + host: Option[Host], + port: Option[Port], + options: List[ChannelOption] + ): Resource[ + F, + (SocketAddress[IpAddress], Stream[F, Socket[F, ByteBuf, ByteBuf]]) + ] = + for { + dispatcher <- Dispatcher[F] + pipeline <- Resource.eval(NettyPipeline[F](dispatcher)) + sr <- serverResource(host, port, pipeline, options) + } yield sr + + def serverResource[O, I: Socket.Decoder]( + host: Option[Host], + port: Option[Port], + pipelineInitializer: NettyChannelInitializer[F, O, I], + options: List[ChannelOption] + ): Resource[F, (SocketAddress[IpAddress], Stream[F, Socket[F, O, I]])] = + Resource suspend { + for { + clientConnections <- Queue.unbounded[F, Socket[F, O, I]] + + resolvedHost <- host.traverse(_.resolve[F]) + + socketInitializer <- pipelineInitializer.toSocketChannelInitializer( + clientConnections.offer + ) + + bootstrap <- Sync[F] delay { + val bootstrap = new ServerBootstrap + bootstrap + .group(parent, child) + .option( + JChannelOption.AUTO_READ.asInstanceOf[JChannelOption[Any]], + false + ) // backpressure for accepting connections, not reads on any individual connection + //.childOption() TODO: Any useful ones? + .channel(serverChannelClazz) + .childHandler(socketInitializer) + + options.foreach(opt => bootstrap.option(opt.key, opt.value)) + bootstrap } - } - } - private[this] def initializer( - disp: Dispatcher[F])( - result: Socket[F] => F[Unit]) - : ChannelInitializer[SocketChannel] = - new ChannelInitializer[SocketChannel] { - def initChannel(ch: SocketChannel) = { - val p = ch.pipeline() - ch.config().setAutoRead(false) - - disp unsafeRunAndForget { - SocketHandler[F](disp, ch) flatMap { s => - Sync[F].delay(p.addLast(s)) *> result(s) - } + // TODO: Log properly as info, debug, or trace. Also can print localAddress + _ <- Sync[F].delay(println(bootstrap.config())) + + // TODO: is the right name? Bind uses the parent ELG that calla TCP accept which yields a connection to child ELG? + tcpAcceptChannel = Sync[F] defer { + val cf = bootstrap.bind( + resolvedHost.map(_.toInetAddress).orNull, + port.map(_.value).getOrElse(0) + ) + fromNettyFuture[F](cf.pure[F]).as(cf.channel()) + } + } yield Resource + .make(tcpAcceptChannel) { ch => + fromNettyFuture[F](Sync[F].delay(ch.close())).void + } + .evalMap { ch => + Sync[F] + .delay( + SocketAddress.fromInetSocketAddress( + ch.localAddress().asInstanceOf[InetSocketAddress] + ) + ) + .tupleRight( + Stream.repeatEval( + Sync[F].delay(ch.read()) *> clientConnections.take + ) + ) } - } } + + implicit val decoder: Socket.Decoder[Byte] = new Socket.Decoder[Byte] { + override def decode(x: AnyRef): Either[String, Byte] = ??? + } } object Network { @@ -135,23 +201,33 @@ object Network { private[this] val (eventLoopClazz, serverChannelClazz, clientChannelClazz) = { val (e, s, c) = uring().orElse(epoll()).orElse(kqueue()).getOrElse(nio()) - (e, s.asInstanceOf[Class[_ <: ServerChannel]], c.asInstanceOf[Class[_ <: Channel]]) + ( + e, + s.asInstanceOf[Class[_ <: ServerChannel]], + c.asInstanceOf[Class[_ <: Channel]] + ) } def apply[F[_]: Async]: Resource[F, Network[F]] = { // TODO configure threads def instantiate(name: String) = Sync[F] delay { - val constr = eventLoopClazz.getDeclaredConstructor(classOf[Int], classOf[ThreadFactory]) - val result = constr.newInstance(new Integer(1), new ThreadFactory { - private val ctr = new AtomicInteger(0) - def newThread(r: Runnable): Thread = { - val t = new Thread(r) - t.setDaemon(true) - t.setName(s"fs2-netty-$name-io-worker-${ctr.getAndIncrement()}") - t.setPriority(Thread.MAX_PRIORITY) - t + val constr = eventLoopClazz.getDeclaredConstructor( + classOf[Int], + classOf[ThreadFactory] + ) + val result = constr.newInstance( + new Integer(1), + new ThreadFactory { + private val ctr = new AtomicInteger(0) + def newThread(r: Runnable): Thread = { + val t = new Thread(r) + t.setDaemon(true) + t.setName(s"fs2-netty-$name-io-worker-${ctr.getAndIncrement()}") + t.setPriority(Thread.MAX_PRIORITY) + t + } } - }) + ) result.asInstanceOf[EventLoopGroup] } @@ -164,7 +240,10 @@ object Network { (instantiateR("server"), instantiateR("client")) mapN { (server, client) => try { val meth = eventLoopClazz.getDeclaredMethod("setIoRatio", classOf[Int]) - meth.invoke(server, new Integer(90)) // TODO tweak this a bit more; 100 was worse than 50 and 90 was a dramatic step up from both + meth.invoke( + server, + new Integer(90) + ) // TODO tweak this a bit more; 100 was worse than 50 and 90 was a dramatic step up from both meth.invoke(client, new Integer(90)) } catch { case _: Exception => () @@ -176,13 +255,27 @@ object Network { private[this] def uring() = try { - if (sys.props.get("fs2.netty.use.io_uring").map(_.toBoolean).getOrElse(false)) { + if ( + sys.props + .get("fs2.netty.use.io_uring") + .map(_.toBoolean) + .getOrElse(false) + ) { Class.forName("io.netty.incubator.channel.uring.IOUringEventLoop") - Some(( - Class.forName("io.netty.incubator.channel.uring.IOUringEventLoopGroup"), - Class.forName("io.netty.incubator.channel.uring.IOUringServerSocketChannel"), - Class.forName("io.netty.incubator.channel.uring.IOUringSocketChannel"))) + Some( + ( + Class.forName( + "io.netty.incubator.channel.uring.IOUringEventLoopGroup" + ), + Class.forName( + "io.netty.incubator.channel.uring.IOUringServerSocketChannel" + ), + Class.forName( + "io.netty.incubator.channel.uring.IOUringSocketChannel" + ) + ) + ) } else { None } @@ -194,10 +287,13 @@ object Network { try { Class.forName("io.netty.channel.epoll.EpollEventLoop") - Some(( - Class.forName("io.netty.channel.epoll.EpollEventLoopGroup"), - Class.forName("io.netty.channel.epoll.EpollServerSocketChannel"), - Class.forName("io.netty.channel.epoll.EpollSocketChannel"))) + Some( + ( + Class.forName("io.netty.channel.epoll.EpollEventLoopGroup"), + Class.forName("io.netty.channel.epoll.EpollServerSocketChannel"), + Class.forName("io.netty.channel.epoll.EpollSocketChannel") + ) + ) } catch { case _: Throwable => None } @@ -206,10 +302,13 @@ object Network { try { Class.forName("io.netty.channel.kqueue.KQueueEventLoop") - Some(( - Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup"), - Class.forName("io.netty.channel.kqueue.KQueueServerSocketChannel"), - Class.forName("io.netty.channel.kqueue.KQueueSocketChannel"))) + Some( + ( + Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup"), + Class.forName("io.netty.channel.kqueue.KQueueServerSocketChannel"), + Class.forName("io.netty.channel.kqueue.KQueueSocketChannel") + ) + ) } catch { case _: Throwable => None } @@ -218,5 +317,6 @@ object Network { ( Class.forName("io.netty.channel.nio.NioEventLoopGroup"), Class.forName("io.netty.channel.socket.nio.NioServerSocketChannel"), - Class.forName("io.netty.channel.socket.nio.NioSocketChannel")) + Class.forName("io.netty.channel.socket.nio.NioSocketChannel") + ) } diff --git a/core/src/main/scala/fs2/netty/SocketHandler.scala b/core/src/main/scala/fs2/netty/SocketHandler.scala deleted file mode 100644 index 1e053cc..0000000 --- a/core/src/main/scala/fs2/netty/SocketHandler.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright 2021 Typelevel - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package fs2 -package netty - -import cats.{Applicative, Functor} -import cats.effect.{Async, Poll, Sync} -import cats.effect.std.{Dispatcher, Queue} -import cats.syntax.all._ - -import com.comcast.ip4s.{IpAddress, SocketAddress} - -import io.netty.buffer.{ByteBuf, Unpooled} -import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} -import io.netty.channel.socket.SocketChannel - -private final class SocketHandler[F[_]: Async] ( - disp: Dispatcher[F], - channel: SocketChannel, - bufs: Queue[F, AnyRef]) // ByteBuf | Throwable | Null - extends ChannelInboundHandlerAdapter - with Socket[F] { - - val localAddress: F[SocketAddress[IpAddress]] = - Sync[F].delay(SocketAddress.fromInetSocketAddress(channel.localAddress())) - - val remoteAddress: F[SocketAddress[IpAddress]] = - Sync[F].delay(SocketAddress.fromInetSocketAddress(channel.remoteAddress())) - - private[this] def take(poll: Poll[F]): F[ByteBuf] = - poll(bufs.take) flatMap { - case null => Applicative[F].pure(null) // EOF marker - case buf: ByteBuf => buf.pure[F] - case t: Throwable => t.raiseError[F, ByteBuf] - } - - private[this] val fetch: Stream[F, ByteBuf] = - Stream.bracketFull[F, ByteBuf](poll => Sync[F].delay(channel.read()) *> take(poll)) { (b, _) => - if (b != null) - Sync[F].delay(b.release()).void - else - Applicative[F].unit - } - - lazy val reads: Stream[F, Byte] = - Stream force { - Functor[F].ifF(isOpen)( - fetch.flatMap(b => if (b == null) Stream.empty else Stream.chunk(toChunk(b))) ++ reads, - Stream.empty) - } - - def write(bytes: Chunk[Byte]): F[Unit] = - fromNettyFuture[F](Sync[F].delay(channel.writeAndFlush(toByteBuf(bytes)))).void - - val writes: Pipe[F, Byte, INothing] = - _.chunks.evalMap(c => write(c) *> isOpen).takeWhile(b => b).drain - - private[this] val isOpen: F[Boolean] = - Sync[F].delay(channel.isOpen()) - - override def channelRead(ctx: ChannelHandlerContext, msg: AnyRef) = - disp.unsafeRunAndForget(bufs.offer(msg)) - - override def exceptionCaught(ctx: ChannelHandlerContext, t: Throwable) = - disp.unsafeRunAndForget(bufs.offer(t)) - - override def channelInactive(ctx: ChannelHandlerContext) = - try { - disp.unsafeRunAndForget(bufs.offer(null)) - } catch { - case _: IllegalStateException => () // sometimes we can see this due to race conditions in shutdown - } - - private[this] def toByteBuf(chunk: Chunk[Byte]): ByteBuf = - chunk match { - case Chunk.ArraySlice(arr, off, len) => - Unpooled.wrappedBuffer(arr, off, len) - - case c: Chunk.ByteBuffer => - Unpooled.wrappedBuffer(c.toByteBuffer) - - case c => - Unpooled.wrappedBuffer(c.toArray) - } - - private[this] def toChunk(buf: ByteBuf): Chunk[Byte] = - if (buf.hasArray()) - Chunk.array(buf.array()) - else if (buf.nioBufferCount() > 0) - Chunk.byteBuffer(buf.nioBuffer()) - else - ??? -} - -private object SocketHandler { - def apply[F[_]: Async](disp: Dispatcher[F], channel: SocketChannel): F[SocketHandler[F]] = - Queue.unbounded[F, AnyRef] map { bufs => - new SocketHandler(disp, channel, bufs) - } -} diff --git a/core/src/main/scala/fs2/netty/embedded/EmbeddedChannelWithAutoRead.scala b/core/src/main/scala/fs2/netty/embedded/EmbeddedChannelWithAutoRead.scala new file mode 100644 index 0000000..fa67f68 --- /dev/null +++ b/core/src/main/scala/fs2/netty/embedded/EmbeddedChannelWithAutoRead.scala @@ -0,0 +1,142 @@ +/* + * Copyright 2021 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2.netty.embedded + +import io.netty.buffer.ByteBuf +import io.netty.channel.embedded.EmbeddedChannel +import io.netty.channel.{ChannelFuture, ChannelPromise} +import io.netty.util.ReferenceCountUtil + +import java.nio.channels.ClosedChannelException +import java.util + +// Based off of https://github.com/netty/netty/pull/9935/files - WARNING: Java code below +// Should be committed back upstream +class EmbeddedChannelWithAutoRead extends EmbeddedChannel { + + /** + * Used to simulate socket buffers. When autoRead is false, all inbound information will be temporarily stored here. + */ + private lazy val tempInboundMessages = + new util.ArrayDeque[util.AbstractMap.SimpleEntry[Any, ChannelPromise]]() + + def areInboundMessagesBuffered: Boolean = !tempInboundMessages.isEmpty + + def writeInboundFixed(msgs: Any*): Boolean = { + ensureOpen() + if (msgs.isEmpty) + return !inboundMessages().isEmpty + + if (!config().isAutoRead) { + msgs.foreach(msg => + tempInboundMessages.add( + new util.AbstractMap.SimpleEntry[Any, ChannelPromise](msg, null) + ) + ) + return false + } + + val p = pipeline + for (m <- msgs) { + p.fireChannelRead(m) + } + + flushInbound() + !inboundMessages().isEmpty + } + + override def writeOneInbound( + msg: Any, + promise: ChannelPromise + ): ChannelFuture = { + val (isChannelOpen, exception) = + if (isOpen) + (true, null) + else (false, new ClosedChannelException) + + if (isChannelOpen) { + if (!config().isAutoRead) { + tempInboundMessages.add( + new util.AbstractMap.SimpleEntry[Any, ChannelPromise](msg, promise) + ) + return promise + } else + pipeline().fireChannelRead(msg) + } + + if (exception == null) + promise.setSuccess() + else + promise.setFailure(exception) + } + + override def doClose(): Unit = { + super.doClose() + if (!tempInboundMessages.isEmpty) { + var exception: ClosedChannelException = null; + while (true) { + val entry = tempInboundMessages.poll() + if (entry == null) { + return + } + val value = entry.getKey; + if (value != null) { + ReferenceCountUtil.release(value); + } + val promise: ChannelPromise = entry.getValue; + if (promise != null) { + if (exception == null) { + exception = new ClosedChannelException(); + } + promise.tryFailure(exception); + } + } + } + } + + override def doBeginRead(): Unit = { + if (!tempInboundMessages.isEmpty) { + while (true) { + val pair = tempInboundMessages.poll(); + if (pair == null) { + return + } + + val msg = pair.getKey; + if (msg != null) { +// println(s"Firing read ${debug(msg)}") + pipeline().fireChannelRead(msg) + } + + val promise = pair.getValue + if (promise != null) { + try { + checkException() + promise.setSuccess() + } catch { + case e: Throwable => + promise.setFailure(e) + } + } + } + + // fire channelReadComplete. + val _ = flushInbound() + } + + } +} diff --git a/core/src/main/scala/fs2/netty/embedded/Fs2NettyEmbeddedChannel.scala b/core/src/main/scala/fs2/netty/embedded/Fs2NettyEmbeddedChannel.scala new file mode 100644 index 0000000..314e25f --- /dev/null +++ b/core/src/main/scala/fs2/netty/embedded/Fs2NettyEmbeddedChannel.scala @@ -0,0 +1,137 @@ +/* + * Copyright 2021 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2 +package netty.embedded + +import cats.effect.{Async, Sync} +import cats.implicits._ +import fs2.netty.embedded.Fs2NettyEmbeddedChannel.Encoder +import fs2.netty.NettyChannelInitializer +import fs2.netty.pipeline.socket.Socket +import io.netty.buffer.{ByteBuf, Unpooled} +import io.netty.channel.embedded.EmbeddedChannel + +import java.util +import java.util.Queue + +/** + * Better, safer, and clearer api for testing channels + * For use in tests only. + * @param underlying + * @param F + * @tparam F + */ +final case class Fs2NettyEmbeddedChannel[F[_]] private ( + underlying: EmbeddedChannelWithAutoRead +)(implicit + F: Sync[F] +) { + + // TODO: write examples (a spec?) for these + + def writeAllInboundWithoutFlush[A]( + a: A* + )(implicit encoder: Encoder[A]): F[Unit] = + for { + encodedObjects <- F.delay(a.map(encoder.encode)) + _ <- encodedObjects.traverse(bb => + F.delay(underlying.writeOneInbound(bb)) + ) // returns channelFutures + } yield () + + /** + * @param a + * @param encoder + * @tparam A + * @return `true` if the write operation did add something to the inbound buffer + */ + def writeAllInboundThenFlushThenRunAllPendingTasks[A](a: A*)(implicit + encoder: Encoder[A] + ): F[Boolean] = for { + encodedObjects <- F.delay(a.map(encoder.encode)) + areMsgsAdded <- F.delay( + underlying.writeInboundFixed(encodedObjects: _*) + ) // areByteBufsAddedToUnhandledBuffer? onUnhandledInboundMessage + } yield areMsgsAdded + + def flushInbound(): F[Unit] = F.delay(underlying.flushInbound()).void + + def flushOutbound(): F[Unit] = F.delay(underlying.flushOutbound()).void + + def inboundMessages: F[util.Queue[AnyRef]] = + F.delay(underlying.inboundMessages()) + + def outboundMessages: F[util.Queue[AnyRef]] = + F.delay(underlying.outboundMessages()) + + def runScheduledPendingTasks: F[Long] = F.delay { + underlying.runScheduledPendingTasks() + } + + def isOpen: F[Boolean] = F.pure(underlying.isOpen) + + def isClosed: F[Boolean] = F.pure(!underlying.isOpen) + + def close(): F[Unit] = F.delay(underlying.close()).void +} + +object Fs2NettyEmbeddedChannel { + + val NoTasksToRun: Long = -1L + + def apply[F[_], O, I]( + initializer: NettyChannelInitializer[F, O, I] + )(implicit F: Async[F]): F[(Fs2NettyEmbeddedChannel[F], Socket[F, O, I])] = + for { + channel <- F.delay( + new EmbeddedChannelWithAutoRead() + ) // With FlowControl/Dispatcher fixes EmbeddedChannelWithAutoRead might not be needed after all. + socket <- F.async[Socket[F, O, I]] { cb => + initializer + .toChannelInitializer[EmbeddedChannel] { socket => + F.delay(cb(socket.asRight[Throwable])) + } + .flatMap { initializer => + F.delay(channel.pipeline().addFirst(initializer)) *> F.delay( + channel.runPendingTasks() + ) + } + .as[Option[F[Unit]]](None) + } + } yield (new Fs2NettyEmbeddedChannel[F](channel), socket) + + // TODO: Functor and contramap + trait Encoder[A] { + def encode(a: A): ByteBuf + } + + object CommonEncoders { + implicit val byteBufEncoder: Encoder[ByteBuf] = identity + + implicit val byteArrayEncoder: Encoder[Array[Byte]] = (a: Array[Byte]) => + Unpooled.wrappedBuffer(a) + + implicit val byteEncoder: Encoder[Byte] = (a: Byte) => + Unpooled.buffer(1, 1).writeByte(a.toInt) + + implicit val stringEncoder: Encoder[String] = (str: String) => + byteArrayEncoder.encode(str.getBytes) + +// implicit def listEncoder[A](implicit decoder: Decoder[A]): Encoder[List[A]] = (list: List[A]) => +// list.map() + } +} diff --git a/core/src/main/scala/fs2/netty/incudator/http/ExampleHttpServer.scala b/core/src/main/scala/fs2/netty/incudator/http/ExampleHttpServer.scala new file mode 100644 index 0000000..64ca2f5 --- /dev/null +++ b/core/src/main/scala/fs2/netty/incudator/http/ExampleHttpServer.scala @@ -0,0 +1,180 @@ +/* + * Copyright 2021 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2.netty.incudator.http + +import cats.data.Kleisli +import cats.effect.{ExitCode, IO, IOApp} +import fs2.netty.incudator.http.HttpClientConnection.WebSocketResponse +import io.netty.handler.codec.http._ +import io.netty.handler.codec.http.websocketx._ + +import scala.concurrent.duration._ + +object ExampleHttpServer extends IOApp { + + private[this] val HttpRouter = + Kleisli[IO, FullHttpRequest, FullHttpResponse] { request => + if (request.uri() == "/health_check") + IO { + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.OK + ) + } + else if (request.uri() == "/echo") + IO { + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.OK, + request.content() // echo back body + ) + } + else + IO { + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.NOT_FOUND + ) + } + } + + private[this] val ChatRooms = + scala.collection.mutable.Map.empty[String, List[WebSocket[IO]]] + + private[this] val GenericWebSocketConfig = WebSocketConfig( + maxFramePayloadLength = 65536, + allowExtensions = false, + subProtocols = List.empty[String], + utf8FrameValidation = true + ) + + private[this] val WebSocketRouter = + Kleisli[IO, FullHttpRequest, WebSocketResponse[IO]] { request => + if (request.uri() == "/dev/null") + IO { + WebSocketResponse.SwitchToWebSocketProtocol[IO]( + GenericWebSocketConfig, + { + case Left(handshakeError: WebSocketHandshakeException) => + IO.unit + + case Left(error) => + IO.unit + + case Right((handshakeComplete, wsConn)) => + wsConn.reads.compile.drain + } + ) + } + else if (request.uri() == "/echo") + IO { + WebSocketResponse.SwitchToWebSocketProtocol[IO]( + GenericWebSocketConfig, + { + case Left(handshakeError: WebSocketHandshakeException) => + IO.unit + + case Left(error) => + IO.unit + + case Right((handshakeComplete, wsConn)) => + wsConn.reads + .evalMap { // TODO: is switchMap cleaner? + case frame: PingWebSocketFrame => + wsConn.write(new PongWebSocketFrame(frame.content())) + + case _: PongWebSocketFrame => + IO.unit + + case frame: TextWebSocketFrame => + wsConn.write(frame) + + case frame: CloseWebSocketFrame => + wsConn.write(frame) + + case frame: BinaryWebSocketFrame => + wsConn.write(frame) + + case _: ContinuationWebSocketFrame => + IO.unit + } + .attempt + .compile + .drain + } + ) + } + else if (request.uri() == "/chat") + IO { + WebSocketResponse.SwitchToWebSocketProtocol[IO]( + GenericWebSocketConfig, + { + case Left(handshakeError: WebSocketHandshakeException) => + IO.unit + + case Left(error) => + IO.unit + + case Right((handshakeComplete, webSocket)) => + for { + roomId <- IO( + handshakeComplete + .requestUri() + .split("//?") + .last + .split("=") + .last + ) // e.g. /chat?roomId=123abc + + _ <- IO(ChatRooms.updateWith(roomId) { + case Some(connections) => + Some(webSocket :: connections) + case None => + Some(List(webSocket)) + }) + + // TODO: broadcast reads to all connections in a chat room + } yield () + } + ) + } + else + IO( + WebSocketResponse + .`4xx`[IO](404, body = None, EmptyHttpHeaders.INSTANCE) + ) + } + + override def run(args: List[String]): IO[ExitCode] = + HttpServer + .start[IO]( + HttpServer.HttpConfigs( + requestTimeoutPeriod = 500.milliseconds, + HttpServer.HttpConfigs.Parsing.default + ) + ) + .evalMap { httpClientConnections => + httpClientConnections + .map(_.successfullyDecodedReads(HttpRouter, WebSocketRouter)) + .parJoin(65536) + .compile + .drain + } + .useForever + .as(ExitCode.Success) + +} diff --git a/core/src/main/scala/fs2/netty/incudator/http/HttpClientConnection.scala b/core/src/main/scala/fs2/netty/incudator/http/HttpClientConnection.scala new file mode 100644 index 0000000..08d3b3a --- /dev/null +++ b/core/src/main/scala/fs2/netty/incudator/http/HttpClientConnection.scala @@ -0,0 +1,250 @@ +/* + * Copyright 2021 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2.netty.incudator.http + +import cats.Applicative +import cats.data.Kleisli +import cats.effect.Sync +import cats.syntax.all._ +import fs2.Stream +import fs2.netty.incudator.http.HttpClientConnection._ +import fs2.netty.pipeline.socket.Socket +import io.netty.buffer.Unpooled +import io.netty.channel.{ChannelHandlerContext, ChannelPipeline} +import io.netty.handler.codec.TooLongFrameException +import io.netty.handler.codec.http._ +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.HandshakeComplete +import io.netty.handler.codec.http.websocketx.{WebSocketFrame, WebSocketServerProtocolHandler} + +// TODO: this is just a fancy function over Socket, so maybe just make this an object and a function? +class HttpClientConnection[F[_]: Sync]( + clientSocket: Socket[ + F, + FullHttpResponse, + FullHttpRequest + ] +) { + + // TODO: Add idle state handler and handle io.netty.handler.timeout.IdleStateEvent from read side with + // HttpResponseStatus.REQUEST_TIMEOUT for a clean close in case of race condition where client is in the process of + // sending a request. However, need to track weather a request is inFlight, so as not to close connection just + // because server is taking long to respond, triggering a Idle Event from read side. Client is just waiting so it + // cannot send another request. + + def successfullyDecodedReads( + httpRouter: Kleisli[F, FullHttpRequest, FullHttpResponse], + webSocketRouter: Kleisli[F, FullHttpRequest, WebSocketResponse[F]] + ): Stream[F, Unit] = + clientSocket.reads + .evalMap { request => + if (request.decoderResult().isFailure) + createResponseForDecodeError(request.decoderResult().cause()) + .flatMap(clientSocket.write) + else if (isWebSocketRequest(request)) + transitionToWebSocketsOrRespond( + webSocketRouter, + request + ) + else + httpRouter(request).flatMap(clientSocket.write) + } + + private def createResponseForDecodeError( + cause: Throwable + ): F[DefaultFullHttpResponse] = + Sync[F].delay { + cause match { + case ex: TooLongFrameException if isTooLongHeaderException(ex) => + val resp = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.REQUEST_HEADER_FIELDS_TOO_LARGE + ) + HttpUtil.setKeepAlive(resp, true) + resp + + case ex: TooLongFrameException if isTooLongInitialLineException(ex) => + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.REQUEST_URI_TOO_LONG + ) + // Netty will close connection here + + // TODO: HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE + case _ => + val resp = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.INTERNAL_SERVER_ERROR + ) + HttpUtil.setKeepAlive(resp, false) + resp + } + } + + implicit val decoder = new Socket.Decoder[WebSocketFrame] { + override def decode(x: AnyRef): Either[String, WebSocketFrame] = ??? + } + + private def transitionToWebSocketsOrRespond( + webSocketRouter: Kleisli[F, FullHttpRequest, WebSocketResponse[F]], + request: FullHttpRequest + ): F[Unit] = + webSocketRouter(request).flatMap { + case WebSocketResponse.SwitchToWebSocketProtocol( + wsConfigs, + cb + ) => + clientSocket + .mutatePipeline[WebSocketFrame, WebSocketFrame]( + installWebSocketHandlersAndContinueWebSocketUpgrade( + request, + wsConfigs + ) + ) + .flatMap { connection => + connection.events + // only take 1st event since Netty will only first once + .collectFirst { case hc: HandshakeComplete => hc } + .evalTap(handshakeComplete => + connection + // TODO: maybe like a covary method? + .mutatePipeline[WebSocketFrame, WebSocketFrame](_ => + Applicative[F].unit + ) + .map(wsConn => + cb( + ( + handshakeComplete, + new WebSocket[F](underlying = wsConn) + ).asRight[Throwable] + ) + ) + ) + .compile + .drain + } + .onError { case e => + cb(e.asLeft[(HandshakeComplete, WebSocket[F])]) + } + .void + + case WebSocketResponse.`3xx`(code, body, headers) => + wsResponse(code, body, headers).flatMap(clientSocket.write) + + case WebSocketResponse.`4xx`(code, body, headers) => + wsResponse(code, body, headers).flatMap(clientSocket.write) + + case WebSocketResponse.`5xx`(code, body, headers) => + wsResponse(code, body, headers).flatMap(clientSocket.write) + } + + private def installWebSocketHandlersAndContinueWebSocketUpgrade( + request: FullHttpRequest, + wsConfigs: WebSocketConfig + )(pipeline: ChannelPipeline): F[Unit] = + for { + // TODO: FS2-Netty should re-add itself back as last handler, perhaps it 1st removes itself then re-adds. + // We'll also remove this handler after handshake, so might be better to manually add + // WebSocketServerProtocolHandshakeHandler and Utf8FrameValidator since almost none of the other logic from + // WebSocketServerProtocolHandler will be needed. Maybe just the logic around close frame should be ported over. + handler <- Applicative[F].pure( + new WebSocketServerProtocolHandler(wsConfigs.toNetty) { + + /* + Default `exceptionCaught` of `WebSocketServerProtocolHandler` returns a 400 w/o any headers like `Content-length`. + Let higher layer handler this. Catch WebSocketHandshakeException + */ + override def exceptionCaught( + ctx: ChannelHandlerContext, + cause: Throwable + ): Unit = ctx.fireExceptionCaught(cause) + } + ) + + _ <- Sync[F].delay(pipeline.addLast(handler)) + + _ <- Sync[F].delay( + handler.channelRead(pipeline.context(handler), request) + ) + } yield () + + private def wsResponse( + code: Int, + body: Option[String], + headers: HttpHeaders + ): F[FullHttpResponse] = + Sync[F].delay( + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.valueOf(code), + body.fold(Unpooled.EMPTY_BUFFER)(s => + Unpooled.wrappedBuffer(s.getBytes()) + ), + headers, + EmptyHttpHeaders.INSTANCE + ) + ) +} + +object HttpClientConnection { + + private def isWebSocketRequest(request: FullHttpRequest): Boolean = { + // this is the minimum that Netty checks + request.method() == HttpMethod.GET && request + .headers() + .contains(HttpHeaderNames.SEC_WEBSOCKET_KEY) + } + + private def isTooLongHeaderException(cause: TooLongFrameException) = + cause.getMessage.contains("header") + + private def isTooLongInitialLineException(cause: TooLongFrameException) = + cause.getMessage.contains("line") + + sealed abstract class WebSocketResponse[F[_]] + + object WebSocketResponse { + + // One of throwable could be WebSocketHandshakeException + final case class SwitchToWebSocketProtocol[F[_]]( + wsConfigs: WebSocketConfig, + cb: Either[Throwable, (HandshakeComplete, WebSocket[F])] => F[ + Unit + ] + ) extends WebSocketResponse[F] + + // TODO: refined types for code would be nice + final case class `3xx`[F[_]]( + code: Int, + body: Option[String], + headers: HttpHeaders + ) extends WebSocketResponse[F] + + final case class `4xx`[F[_]]( + code: Int, + body: Option[String], + headers: HttpHeaders + ) extends WebSocketResponse[F] + + final case class `5xx`[F[_]]( + code: Int, + body: Option[String], + headers: HttpHeaders + ) extends WebSocketResponse[F] + + } + +} diff --git a/core/src/main/scala/fs2/netty/incudator/http/HttpPipeliningBlockerHandler.scala b/core/src/main/scala/fs2/netty/incudator/http/HttpPipeliningBlockerHandler.scala new file mode 100644 index 0000000..e18b5f2 --- /dev/null +++ b/core/src/main/scala/fs2/netty/incudator/http/HttpPipeliningBlockerHandler.scala @@ -0,0 +1,72 @@ +/* + * Copyright 2021 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2.netty.incudator.http + +import io.netty.channel.{ChannelDuplexHandler, ChannelHandlerContext, ChannelPromise} +import io.netty.handler.codec.http.{DefaultFullHttpResponse, FullHttpRequest, FullHttpResponse, HttpResponseStatus, HttpUtil, HttpVersion} +import io.netty.util.ReferenceCountUtil + +class HttpPipeliningBlockerHandler extends ChannelDuplexHandler { + + private var clientAttemptingHttpPipelining = false + private var isHttpRequestInFlight = false + + override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = + msg match { + case request: FullHttpRequest => + if (!isHttpRequestInFlight) { + isHttpRequestInFlight = true + super.channelRead(ctx, msg) + } else { + /* + Stop reading since we're going to close channel + */ + ctx.channel().config().setAutoRead(false) // TODO: remove this now? + ReferenceCountUtil.release(request) + clientAttemptingHttpPipelining = true + } + + case _ => + super.channelRead(ctx, msg) + } + + override def write( + ctx: ChannelHandlerContext, + msg: Any, + promise: ChannelPromise + ): Unit = { + msg match { + case _: FullHttpResponse => + super.write(ctx, msg, promise) + isHttpRequestInFlight = false + if (clientAttemptingHttpPipelining) { + // TODO: at some point, this can be made more robust to check if 1st response was sent. + // Perhaps channel is closed. In which case, don't need to send. + val response = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.TOO_MANY_REQUESTS + ) + HttpUtil.setKeepAlive(response, false) + HttpUtil.setContentLength(response, 0) + ctx.writeAndFlush(response) + } + + case _ => + super.write(ctx, msg, promise) + } + } +} diff --git a/core/src/main/scala/fs2/netty/incudator/http/HttpServer.scala b/core/src/main/scala/fs2/netty/incudator/http/HttpServer.scala new file mode 100644 index 0000000..31886ee --- /dev/null +++ b/core/src/main/scala/fs2/netty/incudator/http/HttpServer.scala @@ -0,0 +1,136 @@ +/* + * Copyright 2021 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2.netty.incudator.http + +import cats.Eval +import cats.effect.std.Dispatcher +import cats.effect.{Async, Resource} +import cats.syntax.all._ +import fs2.Stream +import fs2.netty.Network +import fs2.netty.pipeline.NettyPipeline +import fs2.netty.pipeline.socket.Socket +import io.netty.handler.codec.http._ +import io.netty.handler.timeout.ReadTimeoutHandler + +import scala.concurrent.duration.FiniteDuration + +object HttpServer { + + implicit val decoder = new Socket.Decoder[FullHttpRequest] { + + override def decode(x: AnyRef): Either[String, FullHttpRequest] = x match { + case req: FullHttpRequest => req.asRight[String] + case _ => "non http message, pipeline error".asLeft[FullHttpRequest] + } + } + + def start[F[_]: Async]( + httpConfigs: HttpConfigs + ): Resource[F, Stream[F, HttpClientConnection[F]]] = + for { + network <- Network[F] + + dispatcher <- Dispatcher[F] + + pipeline <- Resource.eval( + NettyPipeline[F, FullHttpResponse, FullHttpRequest]( + dispatcher, + List( + Eval.always( + new HttpServerCodec( + httpConfigs.parsing.maxInitialLineLength, + httpConfigs.parsing.maxHeaderSize, + httpConfigs.parsing.maxChunkSize + ) + ), + Eval.always(new HttpServerKeepAliveHandler), + Eval.always( + new HttpObjectAggregator( + httpConfigs.parsing.maxHttpContentLength + ) + ), + // TODO: this also closes channel when exception is fired, should HttpClientConnection just handle that Idle Events? + Eval.always( + new ReadTimeoutHandler( + httpConfigs.requestTimeoutPeriod.length, + httpConfigs.requestTimeoutPeriod.unit + ) + ) + // new HttpPipeliningBlockerHandler + ) + ) + ) + + rawHttpClientConnection <- network + .serverResource( + host = None, + port = None, + pipeline, + options = Nil + ) + .map(_._2) + + } yield rawHttpClientConnection.map(new HttpClientConnection[F](_)) + + /** + * @param requestTimeoutPeriod - limit on how long connection can remain open w/o any requests + */ + final case class HttpConfigs( + requestTimeoutPeriod: FiniteDuration, + parsing: HttpConfigs.Parsing + ) + + // TODO: what about `Int Refined NonNegative` or validated or custom value types? + object HttpConfigs { + + /** + * @param maxHttpContentLength - limit on body/entity size + * @param maxInitialLineLength - limit on how long url can be, along with HTTP preamble, i.e. "GET HTTP 1.1 ..." + * @param maxHeaderSize - limit on size of single header + */ + final case class Parsing( + maxHttpContentLength: Int, + maxInitialLineLength: Int, + maxHeaderSize: Int + ) { + def maxChunkSize: Int = Parsing.DefaultMaxChunkSize + } + + object Parsing { + + private val DefaultMaxChunkSize: Int = + 8192 // Netty default + + val DefaultMaxHttpContentLength: Int = + 65536 // Netty default + + val DefaultMaxInitialLineLength: Int = + 4096 // Netty default + + val DefaultMaxHeaderSize: Int = 8192 // Netty default + + val default: Parsing = Parsing( + DefaultMaxHttpContentLength, + DefaultMaxInitialLineLength, + DefaultMaxHeaderSize + ) + } + + } + +} diff --git a/core/src/main/scala/fs2/netty/incudator/http/WebSocket.scala b/core/src/main/scala/fs2/netty/incudator/http/WebSocket.scala new file mode 100644 index 0000000..50385f3 --- /dev/null +++ b/core/src/main/scala/fs2/netty/incudator/http/WebSocket.scala @@ -0,0 +1,58 @@ +/* + * Copyright 2021 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2.netty.incudator.http + +import fs2.netty.pipeline.socket.Socket +import fs2.{INothing, Pipe, Stream} +import io.netty.channel.ChannelPipeline +import io.netty.handler.codec.http.websocketx.WebSocketFrame + +class WebSocket[F[_]]( + underlying: Socket[ + F, + WebSocketFrame, + WebSocketFrame + ] +) extends Socket[F, WebSocketFrame, WebSocketFrame] { + + // override def localAddress: F[SocketAddress[IpAddress]] = underlying.localAddress +// +// override def remoteAddress: F[SocketAddress[IpAddress]] = underlying.remoteAddress + + override def reads: Stream[F, WebSocketFrame] = underlying.reads + + // TODO: this will be aware of close frames + override def write(output: WebSocketFrame): F[Unit] = + underlying.write(output) + + override def writes: Pipe[F, WebSocketFrame, INothing] = underlying.writes + + override def events: Stream[F, AnyRef] = underlying.events + + override def isOpen: F[Boolean] = underlying.isOpen + + override def isClosed: F[Boolean] = underlying.isClosed + + override def isDetached: F[Boolean] = underlying.isDetached + + override def close(): F[Unit] = underlying.close() + + override def mutatePipeline[O2, I2: Socket.Decoder]( + mutator: ChannelPipeline => F[Unit] + ): F[Socket[F, O2, I2]] = + underlying.mutatePipeline(mutator) +} diff --git a/core/src/main/scala/fs2/netty/incudator/http/WebSocketConfig.scala b/core/src/main/scala/fs2/netty/incudator/http/WebSocketConfig.scala new file mode 100644 index 0000000..f8d0548 --- /dev/null +++ b/core/src/main/scala/fs2/netty/incudator/http/WebSocketConfig.scala @@ -0,0 +1,92 @@ +/* + * Copyright 2021 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2.netty.incudator.http + +import fs2.netty.incudator.http.WebSocketConfig.DisableTimeout +import io.netty.handler.codec.http.websocketx.{WebSocketCloseStatus, WebSocketDecoderConfig, WebSocketServerProtocolConfig} + + +/** + * + * @param maxFramePayloadLength - limit on payload length from Text and Binary Frames + * @param allowExtensions - WS extensions like those for compression + * @param subProtocols - optional subprotocols to negotiate + * @param utf8FrameValidation - optionally validate text frames' payloads are utf8 + */ +final case class WebSocketConfig( + maxFramePayloadLength: Int, + allowExtensions: Boolean, + subProtocols: List[String], + utf8FrameValidation: Boolean + ) { + + def toNetty: WebSocketServerProtocolConfig = + WebSocketServerProtocolConfig + .newBuilder() + + // Match all paths, let application filter requests. + .websocketPath("/") + .checkStartsWith(true) + .subprotocols(subProtocolsCsv) + + // Application will handle timeouts for WS Handshake request. Set this far into the future b/c Netty doesn't + // not allow non-positive values in configs. + .handshakeTimeoutMillis(200000L) // 200 sec + + // Application will handle all inbound close frames, this flag tells Netty to handle them + .handleCloseFrames(false) + + // Application will handle all Control Frames + .dropPongFrames(false) + + // Netty's WebSocketCloseFrameHandler ensures Close Frames are sent on close (if they weren't sent before) + // and closes always send a Close Frame. + // It also checks that no new messages are sent after Close Frame is sent, throwing a ClosedChannelException. + // It would be nice to set it as INTERNAL_SERVER_ERROR since applications should handle closes, but b/c + // of a weird bug in Netty, this is triggered when UTF8 validation fails, so setting it to INVALID_PAYLOAD_DATA. + .sendCloseFrame(WebSocketCloseStatus.INVALID_PAYLOAD_DATA) + + // Netty can check that Close Frame has been sent in some time period, but we don't need this option because + // application should close channel immediately after each close + .forceCloseTimeoutMillis(DisableTimeout) + + .decoderConfig( + WebSocketDecoderConfig + .newBuilder() + .maxFramePayloadLength(maxFramePayloadLength) + + // Server's must set this to true + .expectMaskedFrames(true) + + // Allows to loosen the masking requirement on received frames. Should NOT be set. + .allowMaskMismatch(false) + .allowExtensions(allowExtensions) + .closeOnProtocolViolation(true) + .withUTF8Validator(utf8FrameValidation) + .build() + ) + .build() + + private def subProtocolsCsv = subProtocols match { + case list => list.mkString(", ") + case Nil => null + } +} + +object WebSocketConfig { + private val DisableTimeout = 0L +} \ No newline at end of file diff --git a/core/src/main/scala/fs2/netty/pipeline/NettyPipeline.scala b/core/src/main/scala/fs2/netty/pipeline/NettyPipeline.scala new file mode 100644 index 0000000..ab2ad83 --- /dev/null +++ b/core/src/main/scala/fs2/netty/pipeline/NettyPipeline.scala @@ -0,0 +1,103 @@ +/* + * Copyright 2021 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2.netty.pipeline + +import cats.Eval +import cats.effect.std.Dispatcher +import cats.effect.{Async, Sync} +import cats.syntax.all._ +import fs2.netty.NettyChannelInitializer +import fs2.netty.pipeline.socket.{Socket, SocketHandler} +import io.netty.buffer.ByteBuf +import io.netty.channel.{Channel, ChannelHandler, ChannelHandlerAdapter, ChannelInitializer} +import io.netty.handler.flow.FlowControlHandler + +class NettyPipeline[F[_]: Async, O, I: Socket.Decoder] private ( + handlers: List[Eval[ChannelHandler]] +)( + dispatcher: Dispatcher[F] +) extends NettyChannelInitializer[F, O, I] { + + // TODO: there are other interesting type of channels + // TODO: Remember ChannelInitializer is Sharable! + override def toChannelInitializer[C <: Channel]( + cb: Socket[F, O, I] => F[Unit] + ): F[ChannelInitializer[C]] = Sync[F].delay { (ch: C) => + { + val p = ch.pipeline() + ch.config().setAutoRead(false) + + handlers + .map(_.value) + .foldLeft(p)((pipeline, handler) => pipeline.addLast(handler)) + /* `channelRead` on ChannelInboundHandler's may get invoked more than once despite autoRead being turned off + and handler calling read to control read rate, i.e. backpressure. Netty's solution is to use `FlowControlHandler`. + Below from https://stackoverflow.com/questions/45887006/how-to-ensure-channelread-called-once-after-each-read-in-netty: + | Also note that most decoders will automatically perform a read even if you have AUTO_READ=false since they + | need to read enough data in order to yield at least one message to subsequent (i.e. your) handlers... but + | after they yield a message, they won't auto-read from the socket again. + */ + .addLast(new FlowControlHandler(false)) + + dispatcher.unsafeRunAndForget { + // TODO: read up on CE3 Dispatcher, how is it different than Context Switch? Is this taking place async? Also is cats.effect.Effect removed in CE3? + SocketHandler[F, O, I](dispatcher, ch) + .flatTap(h => + Sync[F].delay(p.addLast(h)) + ) // TODO: pass EventExecutorGroup + .flatMap(cb) + //TODO: Wonder if cb should be invoked on handlerAdded in SocketHandler? Technically, Socket isn't + // "fully active"; SocketHandler is in the pipeline but is marked with ADD_PENDING status (whatever that means, + // maybe it's ok). Need to work out expectation of callers. And if this is addLast is called from a different + // thread, then handlerAdded will be scheduled by Netty to execute in the future. + } + } + } +} + +object NettyPipeline { + + def apply[F[_]: Async]( + dispatcher: Dispatcher[F] + ): F[NettyPipeline[F, ByteBuf, ByteBuf]] = + apply(dispatcher, handlers = Nil) + + def apply[F[_]: Async, O, I: Socket.Decoder]( + dispatcher: Dispatcher[F], + handlers: List[Eval[ChannelHandler]] + ): F[NettyPipeline[F, O, I]] = + Sync[F].delay( + new NettyPipeline[F, O, I]( + memoizeSharableHandlers(handlers) + )(dispatcher) + ) + + /* + Netty will throw an exception if Sharable handler is added to more than one channel. + */ + private[this] def memoizeSharableHandlers[E, O, I: Socket.Decoder, F[ + _ + ]: Async](handlers: List[Eval[ChannelHandler]]) = + handlers.map(eval => + eval.flatMap { + case adapter: ChannelHandlerAdapter if adapter.isSharable => + eval.memoize + case _ => + eval + } + ) +} diff --git a/core/src/main/scala/fs2/netty/pipeline/prebuilt/AlternativeBytePipeline.scala b/core/src/main/scala/fs2/netty/pipeline/prebuilt/AlternativeBytePipeline.scala new file mode 100644 index 0000000..a127cbb --- /dev/null +++ b/core/src/main/scala/fs2/netty/pipeline/prebuilt/AlternativeBytePipeline.scala @@ -0,0 +1,99 @@ +/* + * Copyright 2021 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2.netty.pipeline.prebuilt + +import cats.effect.std.Dispatcher +import cats.effect.{Async, Sync} +import cats.syntax.all._ +import fs2.netty.pipeline.NettyPipeline +import fs2.netty.pipeline.prebuilt.AlternativeBytePipeline._ +import fs2.netty.NettyChannelInitializer +import fs2.netty.pipeline.socket.Socket +import fs2.{Chunk, INothing, Pipe, Stream} +import io.netty.buffer.{ByteBuf, ByteBufUtil, Unpooled} +import io.netty.channel.{Channel, ChannelInitializer, ChannelPipeline} + +// This class and BytePipeline highlight the different way to create +// sockets, i.e. rely on Netty handlers or encode transforms in fs2. +class AlternativeBytePipeline[F[_]: Async]( + byteBufPipeline: NettyPipeline[F, ByteBuf, ByteBuf] +) extends NettyChannelInitializer[F, Chunk[Byte], Byte] { + + override def toChannelInitializer[C <: Channel]( + cb: Socket[F, Chunk[Byte], Byte] => F[Unit] + ): F[ChannelInitializer[C]] = + byteBufPipeline + .toChannelInitializer { byteBufSocket => + Sync[F] + .delay(new ByteBufToByteChunkSocket[F](byteBufSocket)) + .flatMap(cb) + } +} + +object AlternativeBytePipeline { + + def apply[F[_]: Async]( + dispatcher: Dispatcher[F] + ): F[AlternativeBytePipeline[F]] = + for { + byteBufPipeline <- NettyPipeline.apply[F](dispatcher) + } yield new AlternativeBytePipeline(byteBufPipeline) + + private class ByteBufToByteChunkSocket[F[_]: Async]( + socket: Socket[F, ByteBuf, ByteBuf] + ) extends Socket[F, Chunk[Byte], Byte] { + + override lazy val reads: Stream[F, Byte] = + socket.reads + .evalMap(bb => + Sync[F].delay(ByteBufUtil.getBytes(bb)).map(Chunk.array(_)) + ) + .flatMap(Stream.chunk) + + override lazy val events: Stream[F, AnyRef] = socket.events + + override def write(output: Chunk[Byte]): F[Unit] = + socket.write(toByteBuf(output)) + + override lazy val writes: Pipe[F, Chunk[Byte], INothing] = + _.map(toByteBuf).through(socket.writes) + + override val isOpen: F[Boolean] = socket.isOpen + + override val isClosed: F[Boolean] = socket.isClosed + + override val isDetached: F[Boolean] = socket.isDetached + + override def close(): F[Unit] = socket.close() + + override def mutatePipeline[O2, I2: Socket.Decoder]( + mutator: ChannelPipeline => F[Unit] + ): F[Socket[F, O2, I2]] = socket.mutatePipeline[O2, I2](mutator) + + private[this] def toByteBuf(chunk: Chunk[Byte]): ByteBuf = + chunk match { + case Chunk.ArraySlice(arr, off, len) => + Unpooled.wrappedBuffer(arr, off, len) + + case c: Chunk.ByteBuffer => + Unpooled.wrappedBuffer(c.toByteBuffer) + + case c => + Unpooled.wrappedBuffer(c.toArray) + } + } +} diff --git a/core/src/main/scala/fs2/netty/pipeline/prebuilt/BytePipeline.scala b/core/src/main/scala/fs2/netty/pipeline/prebuilt/BytePipeline.scala new file mode 100644 index 0000000..c4ea3d3 --- /dev/null +++ b/core/src/main/scala/fs2/netty/pipeline/prebuilt/BytePipeline.scala @@ -0,0 +1,110 @@ +/* + * Copyright 2021 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2.netty.pipeline.prebuilt + +import cats.Eval +import cats.effect.std.Dispatcher +import cats.effect.{Async, Sync} +import cats.syntax.all._ +import fs2.netty.pipeline.NettyPipeline +import fs2.netty.pipeline.prebuilt.BytePipeline._ +import fs2.netty.NettyChannelInitializer +import fs2.netty.pipeline.socket.Socket +import fs2.{Chunk, INothing, Pipe, Stream} +import io.netty.buffer.{ByteBuf, Unpooled} +import io.netty.channel.{Channel, ChannelInitializer, ChannelPipeline} +import io.netty.handler.codec.bytes.ByteArrayDecoder + +class BytePipeline[F[_]: Async]( + byteArrayPipeline: NettyPipeline[F, ByteBuf, Array[Byte]] +) extends NettyChannelInitializer[F, Chunk[Byte], Byte] { + + override def toChannelInitializer[C <: Channel]( + cb: Socket[F, Chunk[Byte], Byte] => F[Unit] + ): F[ChannelInitializer[C]] = + byteArrayPipeline + .toChannelInitializer { byteArraySocket => + /* + TODO: Can't do this b/c ProFunctor isn't Chunk aware + Sync[F].delay(byteArraySocket.dimap[Chunk[Byte], Byte](toByteBuf)(Chunk.array(_))) + */ + Sync[F].delay(new ChunkingByteSocket[F](byteArraySocket)).flatMap(cb) + } + +} + +object BytePipeline { + + def apply[F[_]: Async](dispatcher: Dispatcher[F]): F[BytePipeline[F]] = + for { + pipeline <- NettyPipeline[F, ByteBuf, Array[Byte]]( + dispatcher, + handlers = List( + Eval.always(new ByteArrayDecoder) + ) + ) + } yield new BytePipeline(pipeline) + + implicit val byteArraySocketDecoder: Socket.Decoder[Array[Byte]] = { + case array: Array[Byte] => array.asRight[String] + case _ => + "pipeline is misconfigured".asLeft[Array[Byte]] + } + + private class ChunkingByteSocket[F[_]: Async]( + socket: Socket[F, ByteBuf, Array[Byte]] + ) extends Socket[F, Chunk[Byte], Byte] { + + override lazy val reads: Stream[F, Byte] = + socket.reads.map(Chunk.array(_)).flatMap(Stream.chunk) + + override lazy val events: Stream[F, AnyRef] = socket.events + + override def write(output: Chunk[Byte]): F[Unit] = + socket.write(toByteBuf(output)) + + override lazy val writes: Pipe[F, Chunk[Byte], INothing] = + _.map(toByteBuf).through(socket.writes) + + override val isOpen: F[Boolean] = socket.isOpen + + override val isClosed: F[Boolean] = socket.isClosed + + override val isDetached: F[Boolean] = socket.isDetached + + override def close(): F[Unit] = socket.close() + + override def mutatePipeline[O2, I2: Socket.Decoder]( + mutator: ChannelPipeline => F[Unit] + ): F[Socket[F, O2, I2]] = socket.mutatePipeline[O2, I2](mutator) + + } + + // TODO: alloc over unpooled? + private def toByteBuf(chunk: Chunk[Byte]): ByteBuf = + chunk match { + case Chunk.ArraySlice(arr, off, len) => + Unpooled.wrappedBuffer(arr, off, len) + + case c: Chunk.ByteBuffer => + Unpooled.wrappedBuffer(c.toByteBuffer) + + case c => + Unpooled.wrappedBuffer(c.toArray) + } + +} diff --git a/core/src/main/scala/fs2/netty/pipeline/socket/NoopChannel.scala b/core/src/main/scala/fs2/netty/pipeline/socket/NoopChannel.scala new file mode 100644 index 0000000..ae2ba7e --- /dev/null +++ b/core/src/main/scala/fs2/netty/pipeline/socket/NoopChannel.scala @@ -0,0 +1,150 @@ +/* + * Copyright 2021 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2.netty.pipeline.socket + +import io.netty.buffer.ByteBufAllocator +import io.netty.channel._ +import io.netty.util.{Attribute, AttributeKey} + +import java.net.SocketAddress + +/** + * Void Channel for SocketHandler to prevent writes and further channel effects. + * Reading state of parent channel is still allowed as it is safe, i.e. no side-effects. + * @param parent Channel to reference for reading state + */ +class NoopChannel(parent: Channel) extends Channel { + + override def id(): ChannelId = parent.id() + + override def eventLoop(): EventLoop = parent.eventLoop() + + override def parent(): Channel = parent + + override def config(): ChannelConfig = parent.config() + + override def isOpen: Boolean = parent.isOpen + + override def isRegistered: Boolean = parent.isRegistered + + override def isActive: Boolean = parent.isActive + + override def metadata(): ChannelMetadata = parent.metadata + + override def localAddress(): SocketAddress = parent.localAddress + + override def remoteAddress(): SocketAddress = parent.remoteAddress + + override def closeFuture(): ChannelFuture = parent.voidPromise() + + override def isWritable: Boolean = false + + override def bytesBeforeUnwritable(): Long = parent.bytesBeforeUnwritable + + override def bytesBeforeWritable(): Long = parent.bytesBeforeWritable + + override def unsafe(): Channel.Unsafe = parent.unsafe() + + override def pipeline(): ChannelPipeline = parent.pipeline + + override def alloc(): ByteBufAllocator = parent.alloc + + override def read(): Channel = parent.read + + override def flush(): Channel = parent.flush + + override def compareTo(o: Channel): Int = parent.compareTo(o) + + override def attr[T](key: AttributeKey[T]): Attribute[T] = parent.attr(key) + + override def hasAttr[T](key: AttributeKey[T]): Boolean = parent.hasAttr(key) + + override def bind(localAddress: SocketAddress): ChannelFuture = + parent.voidPromise() + + override def connect(remoteAddress: SocketAddress): ChannelFuture = + parent.voidPromise() + + override def connect( + remoteAddress: SocketAddress, + localAddress: SocketAddress + ): ChannelFuture = parent.voidPromise() + + override def disconnect(): ChannelFuture = parent.voidPromise() + + override def close(): ChannelFuture = parent.voidPromise() + + override def deregister(): ChannelFuture = parent.voidPromise() + + override def bind( + localAddress: SocketAddress, + promise: ChannelPromise + ): ChannelFuture = parent.voidPromise() + + override def connect( + remoteAddress: SocketAddress, + promise: ChannelPromise + ): ChannelFuture = parent.voidPromise() + + override def connect( + remoteAddress: SocketAddress, + localAddress: SocketAddress, + promise: ChannelPromise + ): ChannelFuture = parent.voidPromise() + + override def disconnect(promise: ChannelPromise): ChannelFuture = + parent.voidPromise() + + override def close(promise: ChannelPromise): ChannelFuture = + parent.voidPromise() + + override def deregister(promise: ChannelPromise): ChannelFuture = + parent.voidPromise() + + /* + Below are the key methods we want to overwrite to stop writes + */ + + override def write(msg: Any): ChannelFuture = + parent.newPromise().setFailure(new NoopChannel.NoopFailure) + + override def write(msg: Any, promise: ChannelPromise): ChannelFuture = + parent.newPromise().setFailure(new NoopChannel.NoopFailure) + + override def writeAndFlush(msg: Any, promise: ChannelPromise): ChannelFuture = + parent.newPromise().setFailure(new NoopChannel.NoopFailure) + + override def writeAndFlush(msg: Any): ChannelFuture = + parent.newPromise().setFailure(new NoopChannel.NoopFailure) + + override def newPromise(): ChannelPromise = parent.newPromise + + override def newProgressivePromise(): ChannelProgressivePromise = + parent.newProgressivePromise + + override def newSucceededFuture(): ChannelFuture = + parent.newPromise().setFailure(new NoopChannel.NoopFailure) + + override def newFailedFuture(cause: Throwable): ChannelFuture = + parent.newPromise().setFailure(new NoopChannel.NoopFailure) + + override def voidPromise(): ChannelPromise = parent.voidPromise +} + +object NoopChannel { + private class NoopFailure extends Throwable("Noop channel") +} diff --git a/core/src/main/scala/fs2/netty/pipeline/socket/Socket.scala b/core/src/main/scala/fs2/netty/pipeline/socket/Socket.scala new file mode 100644 index 0000000..caee684 --- /dev/null +++ b/core/src/main/scala/fs2/netty/pipeline/socket/Socket.scala @@ -0,0 +1,118 @@ +/* + * Copyright 2021 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2 +package netty.pipeline.socket + +import cats.arrow.Profunctor +import cats.syntax.all._ +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelPipeline + +// TODO: `I <: ReferenceCounted` to avoid type erasure. This is a very big constraint on the Netty channel, although for HTTP +// and WS use cases this is completely ok. One alternative is scala reflections api, but will overhead be acceptable +// along the critical code path (assuming high volume servers/clients)? +// Think through variance of types. +trait Socket[F[_], O, I] { + + // TODO: Temporarily disabling while making Socket generic enough to test with EmbeddedChannel. Furthermore, these + // methods restrict Socket to be a InetChannel which isn't compatible with EmbeddedChannel. Netty also works with + // DomainSocketChannel and LocalChannel which have DomainSocketAddress and LocalAddress respectively(both for IPC?), + // not IpAddresses. + // Can these be provided on the server or client network resource construction rather than on the Socket? +// def localAddress: F[SocketAddress[IpAddress]] +// def remoteAddress: F[SocketAddress[IpAddress]] + + def reads: Stream[F, I] + + /** + * Handlers may optionally generate events to communicate with downstream handlers. These include but not limited to + * signals about handshake complete, timeouts, and errors. + * + * Some examples from Netty: + * - ChannelInputShutdownReadComplete + * - ChannelInputShutdownEvent + * - SslCompletionEvent + * - ProxyConnectionEvent + * - HandshakeComplete + * - Http2FrameStreamEvent + * - IdleStateEvent + * @return + */ + def events: Stream[F, AnyRef] + + def write(output: O): F[Unit] + def writes: Pipe[F, O, INothing] + + def isOpen: F[Boolean] + def isClosed: F[Boolean] + def isDetached: F[Boolean] + + def close(): F[Unit] + + def mutatePipeline[O2, I2: Socket.Decoder]( + mutator: ChannelPipeline => F[Unit] + ): F[Socket[F, O2, I2]] +} + +object Socket { + + trait Decoder[A] { + def decode(x: AnyRef): Either[String, A] + } + + private[this] val ByteBufClassName = classOf[ByteBuf].getName + + implicit val ByteBufDecoder: Decoder[ByteBuf] = { + case bb: ByteBuf => bb.asRight[String] + case x => + s"pipeline error, expected $ByteBufClassName, but got ${x.getClass.getName}" + .asLeft[ByteBuf] + } + + //todo Do we then define an IO instance of this? + // Maybe we need to have a custom typeclass that also accounts for pipeline handling type C? Although contravariance + // should handle that? + implicit def ProfunctorInstance[F[_]]: Profunctor[Socket[F, *, *]] = + new Profunctor[Socket[F, *, *]] { + + override def dimap[A, B, C, D]( + fab: Socket[F, A, B] + )(f: C => A)(g: B => D): Socket[F, C, D] = + new Socket[F, C, D] { + override def reads: Stream[F, D] = fab.reads.map(g) + + override def events: Stream[F, AnyRef] = fab.events + + override def write(output: C): F[Unit] = fab.write(f(output)) + + override def writes: Pipe[F, C, INothing] = + _.map(f).through(fab.writes) + + override def isOpen: F[Boolean] = fab.isOpen + + override def isClosed: F[Boolean] = fab.isClosed + + override def isDetached: F[Boolean] = fab.isDetached + + override def close(): F[Unit] = fab.close() + + override def mutatePipeline[O2, I2: Decoder]( + mutator: ChannelPipeline => F[Unit] + ): F[Socket[F, O2, I2]] = fab.mutatePipeline(mutator) + } + } +} diff --git a/core/src/main/scala/fs2/netty/pipeline/socket/SocketHandler.scala b/core/src/main/scala/fs2/netty/pipeline/socket/SocketHandler.scala new file mode 100644 index 0000000..e03afb2 --- /dev/null +++ b/core/src/main/scala/fs2/netty/pipeline/socket/SocketHandler.scala @@ -0,0 +1,217 @@ +/* + * Copyright 2021 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fs2 +package netty.pipeline.socket + +import cats.effect._ +import cats.effect.std.{Dispatcher, Queue} +import cats.syntax.all._ +import cats.{Applicative, Functor} +import fs2.netty.fromNettyFuture +import io.netty.buffer.ByteBuf +import io.netty.channel._ +import io.netty.handler.flow.FlowControlHandler +import io.netty.util.ReferenceCountUtil + +final class SocketHandler[F[_]: Async: Concurrent, O, I] private ( + disp: Dispatcher[F], + private var channel: Channel, + readsQueue: Queue[F, Option[Either[Throwable, I]]], + eventsQueue: Queue[F, AnyRef], + pipelineMutationSwitch: Deferred[F, Unit] +)(implicit inboundDecoder: Socket.Decoder[I]) + extends ChannelInboundHandlerAdapter + with Socket[F, O, I] { + +// override val localAddress: F[SocketAddress[IpAddress]] = +// Sync[F].delay(SocketAddress.fromInetSocketAddress(channel.localAddress())) +// +// override val remoteAddress: F[SocketAddress[IpAddress]] = +// Sync[F].delay(SocketAddress.fromInetSocketAddress(channel.remoteAddress())) + + // TODO: we can avoid Option boxing if I <: Null + private[this] def take(poll: Poll[F]): F[Option[I]] = + poll(readsQueue.take) flatMap { + case None => + Applicative[F].pure(none[I]) // EOF marker + + case Some(Right(i)) => + Applicative[F].pure(i.some) + + case Some(Left(t)) => + t.raiseError[F, Option[I]] + } + + private[this] val fetch: Stream[F, I] = + Stream + .bracketFull[F, Option[I]](poll => + Sync[F].delay(channel.read()) *> take(poll) + ) { (opt, _) => + opt.fold(Applicative[F].unit)(i => + Sync[F] + .delay(ReferenceCountUtil.safeRelease(i)) + .void // TODO: check ref count before release? + ) + } + .unNoneTerminate + .interruptWhen(pipelineMutationSwitch.get.attempt) + + override lazy val reads: Stream[F, I] = + Stream force { + Functor[F].ifF(isOpen)( + fetch.flatMap(i => + if (i == null) Stream.empty else Stream.emit(i) + ) ++ reads, + Stream.empty + ) + } + + override lazy val events: Stream[F, AnyRef] = + Stream + .fromQueueUnterminated(eventsQueue) + .interruptWhen(pipelineMutationSwitch.get.attempt) + + override def write(output: O): F[Unit] = + fromNettyFuture[F]( + /*Sync[F].delay(println(s"Write ${debug(output)}")) *>*/ Sync[F].delay( + channel.writeAndFlush(output) + ) + ).void + + override val writes: Pipe[F, O, INothing] = + _.evalMap(o => write(o) *> isOpen).takeWhile(bool => bool).drain + + override val isOpen: F[Boolean] = + Sync[F].delay(channel.isOpen) + + override val isClosed: F[Boolean] = isOpen.map(bool => !bool) + + override val isDetached: F[Boolean] = + Sync[F].delay(channel.isInstanceOf[NoopChannel]) + + override def close(): F[Unit] = fromNettyFuture[F]( + Sync[F].delay(channel.close()) + ).void + + override def channelRead(ctx: ChannelHandlerContext, msg: AnyRef) = + inboundDecoder.decode( + ReferenceCountUtil.touch( + msg, + s"Last touch point in FS2-Netty for ${msg.getClass.getSimpleName}" + ) + ) match { + case Left(errorMsg) => + // TODO: Netty logs if release fails, but perhaps we want to catch error and do custom logging/reporting/handling + ReferenceCountUtil.safeRelease(msg) + + case Right(i) => + // TODO: what's the perf impact of unsafeRunSync-only vs. unsafeRunAndForget-&-FlowControlHandler? +// println(s"READ ${debug(msg)}") + disp.unsafeRunAndForget(readsQueue.offer(i.asRight[Exception].some)) + } + + private def debug(x: Any) = x match { + case bb: ByteBuf => + val b = bb.readByte() + bb.resetReaderIndex() + val arr = Array[Byte](1) + arr(0) = b + new String(arr) + + case _ => + "blah" + } + + override def exceptionCaught(ctx: ChannelHandlerContext, t: Throwable) = + disp.unsafeRunAndForget(readsQueue.offer(t.asLeft[I].some)) + + override def channelInactive(ctx: ChannelHandlerContext) = + try { + //TODO: Is ordering preserved? + disp.unsafeRunAndForget(readsQueue.offer(None)) + } catch { + case _: IllegalStateException => + () // sometimes we can see this due to race conditions in shutdown + } + + override def userEventTriggered( + ctx: ChannelHandlerContext, + evt: AnyRef + ): Unit = + //TODO: Is ordering preserved? Might indeed be best to not run this handler in a separate thread pool (unless + // netty manages ordering...which isn't likely as it should just hand off to ec) and call dispatcher manually + // where needed. This way we can keep a thread-unsafe mutable queue. + disp.unsafeRunAndForget(eventsQueue.offer(evt)) + + override def mutatePipeline[O2, I2: Socket.Decoder]( + mutator: ChannelPipeline => F[Unit] + ): F[Socket[F, O2, I2]] = + for { + // TODO: Edge cases aren't fully tested + _ <- pipelineMutationSwitch.complete( + () + ) // shutdown the events and reads streams + oldChannel = channel // Save reference, as we first stop socket processing + _ <- Sync[F].delay { + channel = new NoopChannel(channel) + } // shutdown writes + _ <- Sync[F].delay( + oldChannel.pipeline().removeLast() + ) //remove SocketHandler + _ <- Sync[F].delay( + oldChannel.pipeline().removeLast() + ) //remove FlowControlHandler + /* + TODO: Above may dump remaining messages into fireChannelRead, do we care about those messages? Should we + signal up that this happened? Probably should as certain apps may care about a peer not behaving according to + the expected protocol. In this case, we add a custom handler to capture those messages, then either: + - raiseError on the new reads stream, or + - set a Signal + Also need to think through edge case where Netty is concurrently calling channel read vs. this manipulating + pipeline. Maybe protocols need to inform this layer about when exactly to transition. + */ + _ <- mutator(oldChannel.pipeline()) + sh <- SocketHandler[F, O2, I2](disp, oldChannel) + // TODO: pass a name for debugging purposes? + _ <- Sync[F].delay( + oldChannel.pipeline().addLast(new FlowControlHandler(false)) + ) + _ <- Sync[F].delay(oldChannel.pipeline().addLast(sh)) + } yield sh + + // not to self: if we want to schedule an action to be done when channel is closed, can also do `ctx.channel.closeFuture.addListener` +} + +object SocketHandler { + + def apply[F[_]: Async: Concurrent, O, I: Socket.Decoder]( + disp: Dispatcher[F], + channel: Channel + ): F[SocketHandler[F, O, I]] = + for { + readsQueue <- Queue.unbounded[F, Option[Either[Throwable, I]]] + eventsQueue <- Queue.unbounded[F, AnyRef] + pipelineMutationSwitch <- Deferred[F, Unit] + } yield new SocketHandler( + disp, + channel, + readsQueue, + eventsQueue, + pipelineMutationSwitch + ) + +} diff --git a/core/src/test/scala/fs2/netty/NetworkSpec.scala b/core/src/test/scala/fs2/netty/NetworkSpec.scala index 5aa3b22..6079a3a 100644 --- a/core/src/test/scala/fs2/netty/NetworkSpec.scala +++ b/core/src/test/scala/fs2/netty/NetworkSpec.scala @@ -17,14 +17,17 @@ package fs2 package netty -import cats.effect.IO import cats.effect.testing.specs2.CatsResource - +import cats.effect.{IO, Resource} +import io.netty.buffer.{ByteBuf, Unpooled} import org.specs2.mutable.SpecificationLike +import java.nio.charset.Charset +import scala.collection.mutable.ListBuffer + class NetworkSpec extends CatsResource[IO, Network[IO]] with SpecificationLike { - val resource = Network[IO] + override val resource: Resource[IO, Network[IO]] = Network[IO] "network tcp sockets" should { "create a network instance" in { @@ -32,32 +35,63 @@ class NetworkSpec extends CatsResource[IO, Network[IO]] with SpecificationLike { } "support a simple echo use-case" in withResource { net => - val data = List[Byte](1, 2, 3, 4, 5, 6, 7) + val msg = "Echo me" - val rsrc = net.serverResource(None, None) flatMap { - case (isa, incoming) => + val rsrc = net.serverResource(None, None, Nil) flatMap { + case (ip, incoming) => val handler = incoming flatMap { socket => - socket.reads.through(socket.writes) + socket.reads + .evalTap(bb => IO(bb.retain())) + .through(socket.writes) } for { _ <- handler.compile.drain.background - results <- net.client(isa) flatMap { socket => - Stream.emits(data) - .through(socket.writes) - .merge(socket.reads) - .take(data.length.toLong) - .compile.resource.toList + results <- net.client(ip, options = Nil) flatMap { cSocket => + Stream + // Send individual bytes as the simplest use case + .emits(msg.getBytes) + .evalMap(byteToByteBuf) + .through(cSocket.writes) + .merge(cSocket.reads) + .flatMap(byteBufToStream) + .take(msg.length.toLong) + .map(byteToString) + .compile + .resource + .toList + .map(_.mkString) } } yield results } - rsrc.use(IO.pure(_)) flatMap { results => + rsrc.use(IO.pure) flatMap { results => IO { - results mustEqual data + results mustEqual msg } } } } + + private def byteToByteBuf(byte: Byte): IO[ByteBuf] = IO { + val arr = new Array[Byte](1) + arr(0) = byte + Unpooled.wrappedBuffer(new String(arr).getBytes()) + } + + private def byteBufToStream(bb: ByteBuf): Stream[IO, Byte] = { + val buffer = new ListBuffer[Byte] + bb.forEachByte((value: Byte) => { + val _ = buffer.addOne(value) + true + }) + Stream.fromIterator[IO].apply[Byte](buffer.iterator, 1) + } + + private def byteToString(b: Byte): String = { + val arr = new Array[Byte](1) + arr(0) = b + new String(arr, Charset.defaultCharset()) + } } diff --git a/core/src/test/scala/fs2/netty/pipeline/NettyPipelineSpec.scala b/core/src/test/scala/fs2/netty/pipeline/NettyPipelineSpec.scala new file mode 100644 index 0000000..38267e7 --- /dev/null +++ b/core/src/test/scala/fs2/netty/pipeline/NettyPipelineSpec.scala @@ -0,0 +1,660 @@ +package fs2.netty.pipeline + +import cats.Eval +import cats.effect.std.{Dispatcher, Queue} +import cats.effect.testing.specs2.CatsResource +import cats.effect.{IO, Resource} +import cats.syntax.all._ +import fs2.Stream +import fs2.netty.embedded.Fs2NettyEmbeddedChannel +import fs2.netty.embedded.Fs2NettyEmbeddedChannel.CommonEncoders._ +import fs2.netty.embedded.Fs2NettyEmbeddedChannel.Encoder +import fs2.netty.pipeline.NettyPipelineSpec._ +import fs2.netty.pipeline.socket.Socket +import io.netty.buffer.{ByteBuf, Unpooled} +import io.netty.channel.ChannelHandler.Sharable +import io.netty.channel.socket.ChannelInputShutdownReadComplete +import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandler} +import io.netty.handler.codec.MessageToMessageDecoder +import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} +import io.netty.handler.codec.string.StringDecoder +import io.netty.util.ReferenceCountUtil +import org.specs2.mutable.SpecificationLike + +import java.nio.channels.ClosedChannelException +import java.util +import scala.collection.mutable.ListBuffer +import scala.concurrent.duration._ + +class NettyPipelineSpec + extends CatsResource[IO, Dispatcher[IO]] + with SpecificationLike { + + override val resource: Resource[IO, Dispatcher[IO]] = Dispatcher[IO] + + "default pipeline, i.e. no extra Channel handlers and reads and writes are on ByteBuf's" should { + "no activity on Netty channel should correspond to no activity on socket and vice-versa" in withResource { + implicit dispatcher => + for { + // Given a socket and embedded channel from the default Netty Pipeline + x <- NettyEmbeddedChannelWithByteBufPipeline + (channel, socket) = x + + // Then configs should be setup for backpressure + _ <- IO(channel.underlying.config().isAutoRead should beFalse) + _ <- IO(channel.underlying.config().isAutoClose should beTrue) + _ <- IO( + channel.underlying + .config() + .getWriteBufferLowWaterMark shouldEqual 32 * 1024 + ) + _ <- IO( + channel.underlying + .config() + .getWriteBufferHighWaterMark shouldEqual 64 * 1024 + ) + + // When flushing inbound events, i.e. calling read complete, on an empty channel + _ <- channel.flushInbound() + + // Then there are no reads on socket reads stream + // TODO: what's the canonical way to check for empty stream? + reads <- socket.reads + .interruptAfter(Duration.Zero) + .compile + .toList + _ <- IO(reads should beEmpty) + + // When trigger Netty to run events when there aren't any to run + nextTaskTime <- channel.runScheduledPendingTasks + _ <- IO(nextTaskTime shouldEqual Fs2NettyEmbeddedChannel.NoTasksToRun) + + // Then there should be no events on the socket events stream + events: List[AnyRef] <- socket.events + .interruptAfter(Duration.Zero) + .compile + .toList + _ <- IO(events should beEmpty) + + // When there's no activity on Netty channel, but channel is still active + _ <- IO(channel.isOpen) + + // Then there should not be any exceptions + isOpen <- socket.isOpen + _ <- IO(isOpen shouldEqual true) + isClosed <- socket.isClosed + _ <- IO(isClosed shouldEqual false) + + // When there's no activity on socket writes + _ <- channel.flushOutbound() + + // Then there's no message on Netty channel outbound queue + writes <- channel.outboundMessages + _ <- IO(writes.isEmpty shouldEqual true) + + // And finally, no exceptions on the Netty channel + _ <- IO(channel.underlying.checkException()) + } yield ok + } + + "reading on socket without backpressure results in from Netty reading onto its channel" in withResource { + implicit dispatcher => + for { + // Given a socket and embedded channel from the default Netty Pipeline + x <- NettyEmbeddedChannelWithByteBufPipeline + (channel, socket) = x + + // And list of ByteBuf's + byteBufs <- IO( + List( + Unpooled.wrappedBuffer("hello".getBytes), + Unpooled.buffer(1, 1).writeByte(' '), + Unpooled.copiedBuffer("world".getBytes) + ) + ) + + // And a socket that doesn't backpressure reads, i.e. always accepts elements from stream + queue <- Queue.unbounded[IO, String] + _ <- socket.reads + .flatMap(byteBufToByteStream) + .map(byteToString) + .evalMap(queue.offer) + .compile + .drain + .background + .use { _ => + for { + // When writing each ByteBuf individually to the channel + _ <- channel + .writeAllInboundThenFlushThenRunAllPendingTasks(byteBufs: _*) + + // Then messages should be consumed from Netty + _ <- IO.sleep(200.millis) + _ <- IO( + channel.underlying.areInboundMessagesBuffered shouldEqual false + ) + + // And reads on socket yield the original message sent on channel + str <- (0 until "hello world".length).toList + .traverseFilter(_ => queue.tryTake) + .map(_.mkString) + _ <- IO(str shouldEqual "hello world") + + // And ByteBuf's should be released + _ <- IO(byteBufs.map(_.refCnt()) shouldEqual List.fill(3)(0)) + } yield () + } + } yield ok + } + + "backpressure on socket reads results in Netty NOT reading onto its channel" in withResource { + implicit dispatcher => + for { + // Given a socket and embedded channel from the default Netty Pipeline + x <- NettyEmbeddedChannelWithByteBufPipeline + (channel, socket) = x + + // And list of ByteBuf's + byteBufs <- IO( + List( + Unpooled.wrappedBuffer("hello".getBytes), + Unpooled.buffer(1, 1).writeByte(' '), + Unpooled.copiedBuffer("world".getBytes) + ) + ) + + // And a socket with backpressure, i.e. socket reads aren't being consumed + + // When writing each ByteBuf to the channel + areMsgsAdded <- channel + .writeAllInboundThenFlushThenRunAllPendingTasks(byteBufs: _*) + + // Then messages are NOT added onto the Netty channel + _ <- IO.sleep( + 200.millis + ) // give fs2-Netty chance to read like in non-backpressure test + _ <- IO(areMsgsAdded should beFalse) + _ <- IO( + channel.underlying.areInboundMessagesBuffered shouldEqual true + ) + + // And reads on socket yield the original message sent on channel + str <- socket.reads + .flatMap(byteBufToByteStream) + .take(11) + .map(byteToString) + .compile + .toList + .map(_.mkString) + _ <- IO(str shouldEqual "hello world") + + // And ByteBuf's should be released + _ <- IO(byteBufs.map(_.refCnt()) shouldEqual List.fill(3)(0)) + } yield ok + } + + "writing onto fs2-netty socket appear on Netty's channel" in withResource { + implicit dispatcher => + for { + // Given a socket and embedded channel from the default Netty Pipeline + x <- NettyEmbeddedChannelWithByteBufPipeline + (channel, socket) = x + + // And list of ByteBuf's + byteBufs <- IO( + List( + Unpooled.wrappedBuffer("hello".getBytes), + Unpooled.buffer(1, 1).writeByte(' '), + Unpooled.copiedBuffer("world".getBytes) + ) + ) + + // When writing each ByteBuf to the socket + _ <- byteBufs.traverse(socket.write) + + // Then Netty channel has messages in its outbound queue + str <- Stream + .fromIterator[IO]((0 until 3).iterator, chunkSize = 100) + .evalMap(_ => IO(channel.underlying.readOutbound[ByteBuf]())) + .flatMap(byteBufToByteStream) + .take(11) + .map(byteToString) + .compile + .toList + .map(_.mkString) + _ <- IO(str shouldEqual "hello world") + + // And ByteBuf's are not released. Embedded channel doesn't release, but real channel should. + _ <- IO(byteBufs.map(_.refCnt()) shouldEqual List.fill(3)(1)) + _ <- IO.unit.guarantee( + IO(byteBufs.foreach(ReferenceCountUtil.release)) + ) + } yield ok + } + + "closed connection in Netty appears as closed streams in fs2-netty" in withResource { + implicit dispatcher => + for { + x <- NettyEmbeddedChannelWithByteBufPipeline + (channel, socket) = x + + // Netty sanity check + _ <- channel.isOpen.flatMap(isOpen => IO(isOpen should beTrue)) + _ <- socket.isOpen.flatMap(isOpen => IO(isOpen should beTrue)) + + _ <- channel.close() + + // Netty sanity check + _ <- channel.isClosed.flatMap(isClosed => IO(isClosed should beTrue)) + _ <- socket.isOpen.flatMap(isOpen => IO(isOpen should beFalse)) + _ <- socket.isClosed.flatMap(isClosed => IO(isClosed should beTrue)) + } yield ok + } + + "closing connection in fs2-netty closes underlying Netty channel" in withResource { + implicit dispatcher => + for { + x <- NettyEmbeddedChannelWithByteBufPipeline + (channel, socket) = x + + _ <- socket.close() + + _ <- channel.isClosed.flatMap(isClosed => IO(isClosed should beTrue)) + _ <- socket.isOpen.flatMap(isOpen => IO(isOpen should beFalse)) + _ <- socket.isClosed.flatMap(isClosed => IO(isClosed should beTrue)) + } yield ok + } + + "writing onto a closed socket is a no-op and throws an exception" in withResource { + implicit dispatcher => + for { + x <- NettyEmbeddedChannelWithByteBufPipeline + (channel, socket) = x + _ <- channel.close() + + byteBuf <- IO(Unpooled.wrappedBuffer("hi".getBytes)) + caughtClosedChannelException <- socket + .write(byteBuf) + .as(false) + .handleErrorWith { + case _: ClosedChannelException => true.pure[IO] + case _ => false.pure[IO] + } + + _ <- IO(caughtClosedChannelException shouldEqual true) + + _ <- channel.outboundMessages.flatMap(out => + IO(out.isEmpty shouldEqual true) + ) + } yield ok + } + + "exceptions in Netty pipeline raises an exception on the reads stream" in withResource { + implicit dispatcher => + for { + x <- NettyEmbeddedChannelWithByteBufPipeline + (channel, socket) = x + + _ <- IO( + channel.underlying + .pipeline() + .fireExceptionCaught(new Throwable("unit test error")) + ) + + errMsg <- socket.reads + .map(_ => "") + .handleErrorWith(t => Stream.emit(t.getMessage)) + .compile + .last + _ <- IO(errMsg shouldEqual "unit test error".some) + + _ <- channel.isOpen.flatMap(isOpen => IO(isOpen shouldEqual true)) + _ <- socket.isOpen.flatMap(isOpen => IO(isOpen shouldEqual true)) + } yield ok + } + + "pipeline events appear in fs2-netty as events stream" in withResource { + implicit dispatcher => + for { + x <- NettyEmbeddedChannelWithByteBufPipeline + (channel, socket) = x + + _ <- IO( + channel.underlying + .pipeline() + .fireUserEventTriggered(ChannelInputShutdownReadComplete.INSTANCE) + ) + + event <- socket.events.take(1).compile.last + } yield event should_=== Some(ChannelInputShutdownReadComplete.INSTANCE) + } + + "mutations" should { + "no-op mutation creates a Socket with same behavior as original, while original Socket is unregistered from pipeline and channel" in withResource { + dispatcher => + for { + // Given a channel and socket for the default pipeline + pipeline <- NettyPipeline[IO](dispatcher) + x <- Fs2NettyEmbeddedChannel[IO, ByteBuf, ByteBuf]( + pipeline + ) + (channel, socket) = x + + // Then socket is attached to a pipeline + _ <- socket.isDetached.map(_ should beFalse) + + // When performing a no-op socket pipeline mutation + newSocket <- socket.mutatePipeline[ByteBuf, ByteBuf](_ => IO.unit) + + // Then new socket should be able to receive and write ByteBuf's + encoder = implicitly[Encoder[Byte]] + byteBufs = "hello world".getBytes().map(encoder.encode).toList + _ <- channel + .writeAllInboundThenFlushThenRunAllPendingTasks(byteBufs: _*) + _ <- newSocket.reads + // fs2-netty automatically releases + .evalMap(bb => IO(bb.retain())) + .take(11) + .through(newSocket.writes) + .compile + .drain + str <- (0 until 11).toList + .traverse { _ => + IO(channel.underlying.readOutbound[ByteBuf]()) + .flatMap(bb => IO(bb.readByte())) + } + .map(_.toArray) + .map(new String(_)) + _ <- IO(str shouldEqual "hello world") + + // And new socket is attached to a pipeline + _ <- newSocket.isDetached.map(_ should beFalse) + + // And old socket is no longer attached to a pipeline + _ <- socket.isDetached.map(_ should beTrue) + + // And old socket should not receive any of the ByteBuf's + oldSocketReads <- socket.reads + .interruptAfter(1.second) + .compile + .toList + _ <- IO(oldSocketReads should beEmpty) + + // Nor should old socket be able to write. + oldSocketWrite <- socket.write(Unpooled.EMPTY_BUFFER).attempt + _ <- IO(oldSocketWrite should beLeft[Throwable].like { case t => + t.getMessage should_=== ("Noop channel") + }) + _ <- IO(channel.underlying.outboundMessages().isEmpty should beTrue) + } yield ok + } + + // varies I/O types and along with adding a handler that changes byteBufs to constant strings, affects reads stream and socket writes + "vary the Socket types" in withResource { dispatcher => + for { + // Given a channel and socket for the default pipeline + pipeline <- NettyPipeline[IO](dispatcher) + x <- Fs2NettyEmbeddedChannel[IO, ByteBuf, ByteBuf]( + pipeline + ) + (channel, socket) = x + + pipelineDecoder = new Socket.Decoder[Array[Byte]] { + override def decode(x: AnyRef): Either[String, Array[Byte]] = + x match { + case array: Array[Byte] => array.asRight[String] + case _ => + "whoops, pipeline is misconfigured".asLeft[Array[Byte]] + } + } + byteSocket <- socket + .mutatePipeline[Array[Byte], Array[Byte]] { pipeline => + for { + _ <- IO(pipeline.addLast(new ByteArrayDecoder)) + _ <- IO(pipeline.addLast(new ByteArrayEncoder)) + } yield () + }(pipelineDecoder) + + byteBuf = implicitly[Encoder[Array[Byte]]] + .encode("hello world".getBytes()) + _ <- channel + .writeAllInboundThenFlushThenRunAllPendingTasks(byteBuf) + _ <- byteSocket.reads + .take(1) + .through(byteSocket.writes) + .compile + .drain + + str <- IO(channel.underlying.readOutbound[ByteBuf]()) + .flatTap(bb => IO(bb.readableBytes() shouldEqual 11)) + .tupleRight(new Array[Byte](11)) + .flatMap { case (buf, bytes) => IO(buf.readBytes(bytes)).as(bytes) } + .map(new String(_)) + _ <- IO(str shouldEqual "hello world") + } yield ok + } + + // pipeline mutation error + + // socket decode error + + // test reads, writes, events, and exceptions in combination to ensure order of events makes sense + } + + // test pipeline with ByteArrayEncoder/Decoder passed into pipeline, not mutation + } + + "custom pipelines" should { + implicit val stringSocketDecoder: Socket.Decoder[String] = { + case str: String => str.asRight[String] + case _ => "pipeline misconfigured".asLeft[String] + } + + "custom handlers can change the types of reads and writes " in withResource { + dispatcher => + for { + pipeline <- NettyPipeline[IO, String, String]( + dispatcher, + handlers = List(Eval.now(new StringDecoder)) + ) + x <- Fs2NettyEmbeddedChannel[IO, String, String](pipeline) + (channel, socket) = x + + _ <- channel.writeAllInboundThenFlushThenRunAllPendingTasks( + "hello", + " ", + "world" + ) + + strings <- socket.reads.take(3).compile.toList + + _ <- IO(strings.mkString("") should_=== "hello world") + + _ <- socket.write("output message") + + msg <- IO(channel.underlying.readOutbound[String]()) + _ <- IO(msg should_=== "output message") + } yield ok + } + + // tests should enforce that ByteBuf is read off embedded channel ^^ + + "non sharable handlers must be always evaluated per channel" in withResource { + dispatcher => + for { + pipeline <- NettyPipeline[IO, String, String]( + dispatcher, + handlers = + List(Eval.always(new StatefulMessageToReadCountChannelHandler)) + ) + x <- Fs2NettyEmbeddedChannel[IO, String, String](pipeline) + (channelOne, socketOne) = x + y <- Fs2NettyEmbeddedChannel[IO, String, String](pipeline) + (channelTwo, socketTwo) = y + + inputs = List("a", "b", "c") + + // for same input to each channel we expect the same output, i.e. same scan of counts + _ <- channelOne.writeAllInboundThenFlushThenRunAllPendingTasks( + inputs: _* + ) + countsOne <- socketOne.reads.take(3).map(_.toInt).compile.toList + _ <- IO(countsOne should_=== List(1, 2, 3)) + + _ <- channelTwo.writeAllInboundThenFlushThenRunAllPendingTasks( + inputs: _* + ) + countsTwo <- socketTwo.reads.take(3).map(_.toInt).compile.toList + _ <- IO(countsTwo should_=== List(1, 2, 3)) + } yield ok + } + + "sharable handlers are memoized per channel regardless of the eval policy" in withResource { + dispatcher => + for { + pipeline <- NettyPipeline[IO, String, String]( + dispatcher, + handlers = List( + Eval.always( + new SharableStatefulByteBufToReadCountChannelHandler + ), + Eval.now( + new SharableStatefulStringToReadCountChannelHandler + ), + Eval.later( + new SharableStatefulStringToReadCountChannelHandler + ) + ) + ) + x <- Fs2NettyEmbeddedChannel[IO, String, String](pipeline) + (channelOne, socketOne) = x + y <- Fs2NettyEmbeddedChannel[IO, String, String](pipeline) + (channelTwo, socketTwo) = y + + inputs = List("a", "b", "c") + + _ <- channelOne.writeAllInboundThenFlushThenRunAllPendingTasks( + inputs: _* + ) + countsOne <- socketOne.reads.take(3).map(_.toInt).compile.toList + _ <- IO(countsOne should_=== List(1, 2, 3)) + + _ <- channelTwo.writeAllInboundThenFlushThenRunAllPendingTasks( + inputs: _* + ) + countsTwo <- socketTwo.reads.take(3).map(_.toInt).compile.toList + _ <- IO(countsTwo should_=== List(4, 5, 6)) + } yield ok + } + } + + private def byteToString(byte: Byte): String = { + val bytes = new Array[Byte](1) + bytes(0) = byte + new String(bytes) + } +} + +object NettyPipelineSpec { + + private def NettyEmbeddedChannelWithByteBufPipeline(implicit + dispatcher: Dispatcher[IO] + ) = + for { + pipeline <- NettyPipeline[IO](dispatcher) + x <- Fs2NettyEmbeddedChannel[IO, ByteBuf, ByteBuf](pipeline) + } yield x + + private def byteBufToByteStream(bb: ByteBuf): Stream[IO, Byte] = { + val buffer = new ListBuffer[Byte] + bb.forEachByte((value: Byte) => { + val _ = buffer.addOne(value) + true + }) + Stream.fromIterator[IO](buffer.iterator, 1) + } + + /** + * Does not use MessageToMessageDecoder, SimpleChannelInboundHandler, or anything that extends ChannelHandlerAdapter. + * Netty tacks if a ChannelHandlerAdapter annotated with @Sharable is added. Netty will throw an exception if such a + * handler would be reused, e.g. + * io.netty.channel.ChannelInitializer exceptionCaught + * WARNING: Failed to initialize a channel. Closing: [id: 0xembedded, L:embedded - R:embedded] + * io.netty.channel.ChannelPipelineException: fs2.netty.NettyPipelineSpec$StatefulMessageToReadCountChannelHandler is not a @Sharable handler, so can't be added or removed multiple tim + */ + private class StatefulMessageToReadCountChannelHandler + extends ChannelInboundHandler { + private var readCounter = 0 + + override def channelRegistered(ctx: ChannelHandlerContext): Unit = + ctx.fireChannelRegistered() + + override def channelUnregistered(ctx: ChannelHandlerContext): Unit = + ctx.fireChannelUnregistered() + + override def channelActive(ctx: ChannelHandlerContext): Unit = + ctx.fireChannelActive() + + override def channelInactive(ctx: ChannelHandlerContext): Unit = + ctx.fireChannelInactive() + + override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = { + ReferenceCountUtil.safeRelease(msg) + readCounter += 1 + ctx.fireChannelRead(readCounter.toString) + } + + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = + ctx.fireChannelReadComplete() + + override def userEventTriggered( + ctx: ChannelHandlerContext, + evt: Any + ): Unit = + ctx.fireUserEventTriggered() + + override def channelWritabilityChanged(ctx: ChannelHandlerContext): Unit = + ctx.fireChannelWritabilityChanged() + + override def exceptionCaught( + ctx: ChannelHandlerContext, + cause: Throwable + ): Unit = + ctx.fireExceptionCaught(cause) + + override def handlerAdded(ctx: ChannelHandlerContext): Unit = () + + override def handlerRemoved(ctx: ChannelHandlerContext): Unit = () + } + + @Sharable + private class SharableStatefulStringToReadCountChannelHandler + extends MessageToMessageDecoder[String] { + private var readCounter = 0 + + override def decode( + ctx: ChannelHandlerContext, + msg: String, + out: util.List[AnyRef] + ): Unit = { + readCounter += 1 + out.add(readCounter.toString) + } + } + + @Sharable + private class SharableStatefulByteBufToReadCountChannelHandler + extends MessageToMessageDecoder[ByteBuf] { + private var readCounter = 0 + + override def decode( + ctx: ChannelHandlerContext, + msg: ByteBuf, + out: util.List[AnyRef] + ): Unit = { + readCounter += 1 + out.add(readCounter.toString) + } + } + +} diff --git a/core/src/test/scala/fs2/netty/pipeline/prebuilt/BytePipelineSpec.scala b/core/src/test/scala/fs2/netty/pipeline/prebuilt/BytePipelineSpec.scala new file mode 100644 index 0000000..970a29f --- /dev/null +++ b/core/src/test/scala/fs2/netty/pipeline/prebuilt/BytePipelineSpec.scala @@ -0,0 +1,67 @@ +package fs2.netty.pipeline.prebuilt + +import cats.effect.std.Dispatcher +import cats.effect.testing.specs2.CatsResource +import cats.effect.{IO, Resource} +import cats.syntax.all._ +import fs2.Chunk +import fs2.netty.embedded.Fs2NettyEmbeddedChannel +import fs2.netty.embedded.Fs2NettyEmbeddedChannel.CommonEncoders._ +import io.netty.buffer.ByteBuf +import org.specs2.mutable.SpecificationLike + +class BytePipelineSpec + extends CatsResource[IO, Dispatcher[IO]] + with SpecificationLike { + + override val resource: Resource[IO, Dispatcher[IO]] = Dispatcher[IO] + + "can echo back what is written" in withResource { dispatcher => + for { + pipeline <- BytePipeline(dispatcher) + x <- Fs2NettyEmbeddedChannel[IO, Chunk[Byte], Byte](pipeline) + (channel, socket) = x + + _ <- channel.writeAllInboundThenFlushThenRunAllPendingTasks("hello world") + _ <- socket.reads + .take(5) + .chunks + .through(socket.writes) + .compile + .drain + + str <- IO(channel.underlying.readOutbound[ByteBuf]()) + .flatTap(bb => IO(bb.readableBytes() shouldEqual 5)) + .tupleRight(new Array[Byte](5)) + .flatMap { case (buf, bytes) => IO(buf.readBytes(bytes)).as(bytes) } + .map(new String(_)) + + _ <- IO(str shouldEqual "hello") + } yield ok + } + + "alternative can echo back what is written" in withResource { dispatcher => + for { + pipeline <- AlternativeBytePipeline(dispatcher) + x <- Fs2NettyEmbeddedChannel[IO, Chunk[Byte], Byte](pipeline) + (channel, socket) = x + + _ <- channel.writeAllInboundThenFlushThenRunAllPendingTasks("hello world") + _ <- socket.reads + .take(5) + .chunks + .through(socket.writes) + .compile + .drain + + str <- IO(channel.underlying.readOutbound[ByteBuf]()) + .flatTap(bb => IO(bb.readableBytes() shouldEqual 5)) + .tupleRight(new Array[Byte](5)) + .flatMap { case (buf, bytes) => IO(buf.readBytes(bytes)).as(bytes) } + .map(new String(_)) + + _ <- IO(str shouldEqual "hello") + } yield ok + } + +}