Skip to content

Make scorer JUnit extension Quarkus specific #1329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,31 +1,21 @@
package io.quarkiverse.langchain4j.testing.scorer;

import java.io.Closeable;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Function;

import org.jboss.logging.Logger;

public class Scorer implements Closeable {
public class Scorer {

private static final Logger LOG = Logger.getLogger(Scorer.class);
private final ExecutorService executor;

public Scorer(int concurrency) {
if (concurrency > 1) {
executor = Executors.newFixedThreadPool(concurrency);
} else {
executor = Executors.newSingleThreadExecutor();
}
}

public Scorer() {
this(1);
public Scorer(ExecutorService executor) {
this.executor = executor;
}

@SuppressWarnings({ "unchecked" })
Expand Down Expand Up @@ -74,10 +64,6 @@ public <T> EvaluationReport<T> evaluate(
return new EvaluationReport<>(orderedEvalutions);
}

public void close() {
executor.shutdown();
}

public record EvaluationResult<T>(
EvaluationSample<T> sample, T result, Throwable thrown, boolean passed) {
public static <T> EvaluationResult<T> fromCompletedEvaluation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Function;
import java.util.stream.Stream;

Expand All @@ -12,18 +14,20 @@
class ScorerTest {

private Scorer scorer;
private ExecutorService executor;

@AfterEach
void tearDown() {
if (scorer != null) {
scorer.close();
if (executor != null) {
executor.shutdown();
}
}

@SuppressWarnings("unchecked")
@Test
void evaluateShouldReturnCorrectReport() {
scorer = new Scorer(2);
executor = Executors.newFixedThreadPool(2);
scorer = new Scorer(executor);

EvaluationSample<String> sample1 = new EvaluationSample<>(
"Sample1",
Expand Down Expand Up @@ -62,7 +66,8 @@ void evaluateShouldReturnCorrectReport() {
@SuppressWarnings("unchecked")
@Test
void evaluateShouldReturnCorrectlyOrderedReport() {
scorer = new Scorer(2);
executor = Executors.newFixedThreadPool(2);
scorer = new Scorer(executor);
var sleeps = Stream.of(25l, 0l);
var samples = new Samples<>(
sleeps
Expand Down Expand Up @@ -93,7 +98,8 @@ private String sleep(Parameters params) {
@Test
@SuppressWarnings("unchecked")
void evaluateShouldHandleExceptionsInFunction() {
scorer = new Scorer();
executor = Executors.newSingleThreadExecutor();
scorer = new Scorer(executor);
EvaluationSample<String> sample = new EvaluationSample<>(
"Sample1",
new Parameters().add(new Parameter.UnnamedParameter("param1")),
Expand All @@ -118,7 +124,8 @@ void evaluateShouldHandleExceptionsInFunction() {
@Test
@SuppressWarnings("unchecked")
void evaluateShouldHandleMultipleStrategies() {
scorer = new Scorer();
executor = Executors.newSingleThreadExecutor();
scorer = new Scorer(executor);

EvaluationSample<String> sample = new EvaluationSample<>(
"Sample1",
Expand Down
12 changes: 10 additions & 2 deletions testing/scorer/scorer-junit5/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-testing-scorer-core</artifactId>
<version>999-SNAPSHOT</version>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-junit5</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-smallrye-context-propagation</artifactId>
</dependency>

<dependency>
Expand All @@ -48,4 +56,4 @@
</dependency>
</dependencies>

</project>
</project>

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,85 +1,28 @@
package io.quarkiverse.langchain4j.scorer.junit5;

import java.lang.reflect.Field;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CopyOnWriteArrayList;

import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.jupiter.api.extension.ParameterResolutionException;
import org.junit.jupiter.api.extension.ParameterResolver;
import org.junit.platform.commons.support.HierarchyTraversalMode;
import org.junit.platform.commons.support.ReflectionSupport;

import io.quarkiverse.langchain4j.testing.scorer.Samples;
import io.quarkiverse.langchain4j.testing.scorer.Scorer;
import io.quarkiverse.langchain4j.testing.scorer.YamlLoader;

public class ScorerExtension implements BeforeEachCallback, AfterEachCallback, ParameterResolver {
private final List<Scorer> scorers = new CopyOnWriteArrayList<>();

@Override
public void beforeEach(ExtensionContext extensionContext) {
Optional<Class<?>> maybeClass = extensionContext.getTestClass();
if (maybeClass.isPresent()) {
List<Field> fields = ReflectionSupport.findFields(maybeClass.get(),
field -> field.getType().isAssignableFrom(Scorer.class), HierarchyTraversalMode.TOP_DOWN);
for (Field field : fields) {
Scorer sc;
if (field.isAnnotationPresent(ScorerConfiguration.class)) {
ScorerConfiguration annotation = field.getAnnotation(ScorerConfiguration.class);
sc = new Scorer(annotation.concurrency());
} else {
sc = new Scorer();
}
scorers.add(sc);
inject(sc, extensionContext.getRequiredTestInstance(), field);
}
}
}

private void inject(Scorer sc, Object instance, Field field) {
try {
field.setAccessible(true);
field.set(instance, sc);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}

@Override
public void afterEach(ExtensionContext extensionContext) {
for (Scorer scorer : scorers) {
scorer.close();
}
}
public class ScorerExtension implements ParameterResolver {

@Override
public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext)
throws ParameterResolutionException {
return (parameterContext.findAnnotation(SampleLocation.class).isPresent()
&& parameterContext.getParameter().getType().isAssignableFrom(Samples.class))
|| parameterContext.getParameter().getType().isAssignableFrom(Scorer.class);
&& parameterContext.getParameter().getType().isAssignableFrom(Samples.class));
}

@Override
public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext)
throws ParameterResolutionException {
if (parameterContext.getParameter().getType().isAssignableFrom(Scorer.class)) {
if (parameterContext.getParameter().isAnnotationPresent(ScorerConfiguration.class)) {
ScorerConfiguration annotation = parameterContext.getParameter().getAnnotation(ScorerConfiguration.class);
return new Scorer(annotation.concurrency());
} else {
return new Scorer();
}
} else {
// List of data samples
String path = parameterContext.findAnnotation(SampleLocation.class).orElseThrow().value();
return YamlLoader.load(path);
}
// List of data samples
String path = parameterContext.findAnnotation(SampleLocation.class).orElseThrow().value();
return YamlLoader.load(path);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package io.quarkiverse.langchain4j.scorer.junit5;

import jakarta.enterprise.context.ApplicationScoped;

import org.eclipse.microprofile.context.ManagedExecutor;

import io.quarkiverse.langchain4j.testing.scorer.Scorer;

public class ScorerProducer {

@ApplicationScoped
public Scorer scorer(ManagedExecutor executor) {
return new Scorer(executor);
}
}
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,18 @@

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mockito;

import io.quarkiverse.langchain4j.scorer.junit5.SampleLocation;
import io.quarkiverse.langchain4j.scorer.junit5.ScorerConfiguration;
import io.quarkiverse.langchain4j.scorer.junit5.ScorerExtension;
import io.quarkiverse.langchain4j.testing.scorer.Samples;
import io.quarkiverse.langchain4j.testing.scorer.Scorer;

@ExtendWith(ScorerExtension.class)
class ScorerExtensionTest {

@ScorerConfiguration(concurrency = 3)
private Scorer scorerWithConcurrency;

private Scorer defaultScorer;

@Test
void scorerFieldInjectionShouldWork() {
assertThat(scorerWithConcurrency).isNotNull();
assertThat(scorerWithConcurrency).extracting("executor").isNotNull();
assertThat(defaultScorer).isNotNull();
assertThat(defaultScorer).extracting("executor").isNotNull();
}

@Test
void scorerParameterShouldBeResolved(@ScorerConfiguration(concurrency = 2) Scorer scorer) {
assertThat(scorer).isNotNull();
assertThat(scorer).extracting("executor").isNotNull();
}

@Test
void samplesParameterShouldBeResolved(@SampleLocation("src/test/resources/test-samples.yaml") Samples<String> samples) {
assertThat(samples).isNotNull();
assertThat(samples).hasSizeGreaterThan(0);
assertThat(samples.get(0).name()).isEqualTo("Sample1"); // Assuming the YAML has this entry.
}

@Test
void scorerShouldBeClosedAfterTest() {
Scorer mockScorer = Mockito.mock(Scorer.class);
mockScorer.close();
Mockito.verify(mockScorer).close();
}
}
Loading