diff --git a/modules/junit-jupiter/src/main/java/org/testcontainers/junit/jupiter/TestcontainersExtension.java b/modules/junit-jupiter/src/main/java/org/testcontainers/junit/jupiter/TestcontainersExtension.java index ca99a2d2382..ee7d9b1f663 100644 --- a/modules/junit-jupiter/src/main/java/org/testcontainers/junit/jupiter/TestcontainersExtension.java +++ b/modules/junit-jupiter/src/main/java/org/testcontainers/junit/jupiter/TestcontainersExtension.java @@ -1,5 +1,7 @@ package org.testcontainers.junit.jupiter; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import lombok.Getter; import org.junit.jupiter.api.extension.AfterAllCallback; import org.junit.jupiter.api.extension.AfterEachCallback; @@ -41,6 +43,7 @@ public class TestcontainersExtension private static final String SHARED_LIFECYCLE_AWARE_CONTAINERS = "sharedLifecycleAwareContainers"; private static final String LOCAL_LIFECYCLE_AWARE_CONTAINERS = "localLifecycleAwareContainers"; + private static final ConcurrentHashMap<String, StoreAdapterThread> STORE_ADAPTER_THREADS = new ConcurrentHashMap<>(); @Override public void beforeAll(ExtensionContext context) { @@ -74,7 +77,7 @@ private void startContainers(List<StoreAdapter> storeAdapters, Store store, Exte Stream<Startable> startables = storeAdapters .stream() .map(storeAdapter -> { - store.getOrComputeIfAbsent(storeAdapter.getKey(), k -> storeAdapter); + store.getOrComputeIfAbsent(storeAdapter.getKey(), k -> storeAdapterStart(k, storeAdapter)); return storeAdapter.container; }); Startables.deepStart(startables).join(); @@ -130,6 +133,20 @@ public void afterEach(ExtensionContext context) { signalAfterTestToContainersFor(LOCAL_LIFECYCLE_AWARE_CONTAINERS, context); } + private static synchronized StoreAdapter storeAdapterStart(String key, StoreAdapter adapter) { + boolean isInitialized = STORE_ADAPTER_THREADS.containsKey(key); + + if (!isInitialized) { + StoreAdapter storeAdapter = adapter.start(); + STORE_ADAPTER_THREADS.put(key, new StoreAdapterThread(storeAdapter)); + return storeAdapter; + } + + StoreAdapterThread storeAdapter = STORE_ADAPTER_THREADS.get(key); + storeAdapter.threads.incrementAndGet(); + return storeAdapter.storeAdapter; + } + private void signalBeforeTestToContainers( List<TestLifecycleAware> lifecycleAwareContainers, TestDescription testDescription @@ -267,9 +284,9 @@ private static StoreAdapter getContainerInstance(final Object testInstance, fina private static class StoreAdapter implements CloseableResource { @Getter - private String key; + private final String key; - private Startable container; + private final Startable container; private StoreAdapter(Class<?> declaringClass, String fieldName, Startable container) { this.key = declaringClass.getName() + "." + fieldName; @@ -283,7 +300,25 @@ private StoreAdapter start() { @Override public void close() { - container.stop(); + int total = STORE_ADAPTER_THREADS.getOrDefault(key, StoreAdapterThread.NULL).threads.decrementAndGet(); + if (total < 1) { + container.stop(); + STORE_ADAPTER_THREADS.remove(key); + } + } + + } + + private static class StoreAdapterThread { + + public static final StoreAdapterThread NULL = new StoreAdapterThread(null); + public final StoreAdapter storeAdapter; + public final AtomicInteger threads = new AtomicInteger(1); + + private StoreAdapterThread(StoreAdapter storeAdapter) { + this.storeAdapter = storeAdapter; } + } + } diff --git a/modules/junit-jupiter/src/test/java/org/testcontainers/junit/jupiter/ParallelContainerTests.java b/modules/junit-jupiter/src/test/java/org/testcontainers/junit/jupiter/ParallelContainerTests.java new file mode 100644 index 00000000000..c6f806005a4 --- /dev/null +++ b/modules/junit-jupiter/src/test/java/org/testcontainers/junit/jupiter/ParallelContainerTests.java @@ -0,0 +1,88 @@ +package org.testcontainers.junit.jupiter; + +import static org.junit.Assert.assertEquals; +import static org.testcontainers.junit.jupiter.JUnitJupiterTestImages.POSTGRES_IMAGE; + +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.api.parallel.ExecutionMode; +import org.testcontainers.containers.PostgreSQLContainer; + +@Testcontainers +@Execution(ExecutionMode.CONCURRENT) +public class ParallelContainerTests { + + @Container + protected static final PostgreSQLContainer<?> POSTGRE_SQL_CONTAINER = new PostgreSQLContainer<>(POSTGRES_IMAGE) + .withDatabaseName("foo") + .withUsername("foo") + .withPassword("secret"); + + @Test + void container_should_be_running_first_test() throws SQLException { + assertContainerIsRunning(POSTGRE_SQL_CONTAINER); + } + + @Test + void container_should_be_running_second_test() throws SQLException { + assertContainerIsRunning(POSTGRE_SQL_CONTAINER); + } + + @Nested + @Execution(ExecutionMode.CONCURRENT) + class FirstParallelTest extends BaseContainerTests { + + @Test + void container_should_be_running() throws SQLException { + assertContainerIsRunning(POSTGRE_SQL_BASE_CONTAINER); + } + + } + + @Nested + @Execution(ExecutionMode.CONCURRENT) + class SecondParallelTest extends BaseContainerTests { + + @Test + void container_should_be_running() throws SQLException { + assertContainerIsRunning(POSTGRE_SQL_BASE_CONTAINER); + } + + } + + @Testcontainers + private static class BaseContainerTests { + + @Container + protected static final PostgreSQLContainer<?> POSTGRE_SQL_BASE_CONTAINER = new PostgreSQLContainer<>(POSTGRES_IMAGE) + .withDatabaseName("foo") + .withUsername("foo") + .withPassword("secret"); + + } + + @SuppressWarnings({"SqlDialectInspection", "SqlNoDataSourceInspection"}) + private static void assertContainerIsRunning(PostgreSQLContainer<?> container) throws SQLException { + HikariConfig hikariConfig = new HikariConfig(); + hikariConfig.setJdbcUrl(container.getJdbcUrl()); + hikariConfig.setUsername("foo"); + hikariConfig.setPassword("secret"); + + try (HikariDataSource ds = new HikariDataSource(hikariConfig)) { + Statement statement = ds.getConnection().createStatement(); + statement.execute("SELECT 1"); + ResultSet resultSet = statement.getResultSet(); + resultSet.next(); + + int resultSetInt = resultSet.getInt(1); + assertEquals(1, resultSetInt); + } + } + +} diff --git a/modules/junit-jupiter/src/test/resources/junit-platform.properties b/modules/junit-jupiter/src/test/resources/junit-platform.properties new file mode 100644 index 00000000000..c8a668e1816 --- /dev/null +++ b/modules/junit-jupiter/src/test/resources/junit-platform.properties @@ -0,0 +1,3 @@ +junit.jupiter.execution.parallel.enabled=true +junit.jupiter.execution.parallel.mode.default=same_thread +junit.jupiter.execution.parallel.mode.classes.default=same_thread