Skip to content

fix: ensure thread safety and proper resource cleanup #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import java.util.concurrent.ConcurrentLinkedQueue
public class MockStatsDServer(
host: String = DEFAULT_HOST,
port: Int = DEFAULT_PORT
) : StatsDServer(host = host, port = port) {
) : StatsDServer(initialHost = host, initialPort = port) {
private val calls = ConcurrentLinkedQueue<String>()

protected override fun onMessage(message: String) {
Expand Down
53 changes: 34 additions & 19 deletions src/main/kotlin/me/kpavlov/mocks/statsd/server/StatsDServer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import org.slf4j.LoggerFactory
import java.net.DatagramPacket
import java.net.DatagramSocket
import java.net.InetAddress
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executors

public const val RANDOM_PORT: Int = 0
Expand All @@ -18,23 +19,24 @@ private const val BUFFER_SIZE = 8932
*/
@Suppress("TooManyFunctions")
public open class StatsDServer(
host: String = DEFAULT_HOST,
port: Int = DEFAULT_PORT
) {
initialHost: String = DEFAULT_HOST,
private val initialPort: Int = DEFAULT_PORT
) : AutoCloseable {
private val logger = LoggerFactory.getLogger(javaClass)
private val socket = DatagramSocket(port, InetAddress.getByName(host))
private val repository = MetricRepository()
private val executorService = Executors.newSingleThreadExecutor()
private val executorService = Executors.newCachedThreadPool()

private var shouldRun = false

private val host = InetAddress.getByName(initialHost)
private var serverSocket: DatagramSocket? = null
/**
* Retrieves the port number on which the server is running.
*
* @return the port number
*/
public fun port(): Int = socket.localPort
public fun host(): String = socket.localAddress.hostAddress
public fun port(): Int = serverSocket?.localPort ?: throw IllegalStateException("Server is not started")
public fun host(): String = serverSocket?.localAddress?.hostAddress ?: throw IllegalStateException("Server is not started")

/**
* Reset collected metrics
Expand All @@ -49,24 +51,31 @@ public open class StatsDServer(
val buffer = ByteArray(BUFFER_SIZE)
val packet = DatagramPacket(buffer, buffer.size)

val latch = CountDownLatch(1)

executorService.submit {
shouldRun = true
serverSocket = DatagramSocket(initialPort, host)
logger.info("Starting StatsD server on ${host()}:${port()}")
while (shouldRun) {
socket.receive(packet)
val message = String(packet.data, 0, packet.length)
if (logger.isDebugEnabled) {
logger.debug("Received: {}", message)
}
@Suppress("TooGenericExceptionCaught")
try {
onMessage(message)
handleMessage(message)
} catch (e: Exception) {
logger.error("Can't handle message: $message", e)
serverSocket.use { socket ->
latch.countDown()
while (shouldRun) {
socket?.receive(packet)
val message = String(packet.data, 0, packet.length)
if (logger.isDebugEnabled) {
logger.debug("Received: {}", message)
}
@Suppress("TooGenericExceptionCaught")
try {
onMessage(message)
handleMessage(message)
} catch (e: Exception) {
logger.error("Can't handle message: $message", e)
}
}
}
}
latch.await()
}

private fun handleMessage(message: String) {
Expand Down Expand Up @@ -134,6 +143,12 @@ public open class StatsDServer(
public fun stop() {
logger.info("Stopping StatsD server on ${host()}:${port()}")
shouldRun = false
serverSocket?.close()
}

override fun close() {
stop()
executorService.shutdownNow()
}

private fun findMetric(metricName: String, tags: Map<String, String>? = null): Metric? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package me.kpavlov.mocks.statsd.server

import me.kpavlov.mocks.statsd.client.StatsDClient
import me.kpavlov.mocks.statsd.junit5.StatsDJUnit5Extension
import org.junit.jupiter.api.AfterAll
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.TestInstance
import org.junit.jupiter.api.extension.ExtendWith

Expand All @@ -15,4 +17,14 @@ internal class Junit5ExtensionTest : BaseStatsDServerTest() {
statsd = StatsDJUnit5Extension.statsDServer()
client = StatsDClient(port = statsd.port())
}

@BeforeEach
fun beforeEach() {
client = StatsDClient(port = statsd.port())
}

@AfterAll
fun afterAll() {
statsd.stop()
}
}