Skip to content

Commit c84ccd3

Browse files
authored
callStateMap was accessed from multiple threads without synchronization. changing callStateMap to concurrent hashmap (#36886)
1 parent 94336fa commit c84ccd3

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoLifecycleTest.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
import java.util.ArrayList;
3333
import java.util.Arrays;
3434
import java.util.Collections;
35-
import java.util.HashMap;
3635
import java.util.List;
3736
import java.util.Map;
37+
import java.util.concurrent.ConcurrentHashMap;
3838
import java.util.concurrent.CountDownLatch;
3939
import java.util.concurrent.TimeUnit;
4040
import java.util.concurrent.atomic.AtomicBoolean;
@@ -293,7 +293,7 @@ public void testTeardownCalledAfterExceptionInFinishBundleStateful() {
293293

294294
@Before
295295
public void setup() {
296-
ExceptionThrowingFn.callStateMap = new HashMap<>();
296+
ExceptionThrowingFn.callStateMap.clear();
297297
ExceptionThrowingFn.exceptionWasThrown.set(false);
298298
}
299299

@@ -356,7 +356,7 @@ CallState finalState() {
356356
}
357357

358358
private static class ExceptionThrowingFn<T> extends DoFn<T, T> {
359-
static HashMap<Integer, DelayedCallStateTracker> callStateMap = new HashMap<>();
359+
static Map<Integer, DelayedCallStateTracker> callStateMap = new ConcurrentHashMap<>();
360360
// exception is not necessarily thrown on every instance. But we expect at least
361361
// one during tests
362362
static AtomicBoolean exceptionWasThrown = new AtomicBoolean(false);
@@ -373,7 +373,10 @@ private static void validate(CallState... requiredCallStates) {
373373
Map<Integer, DelayedCallStateTracker> callStates;
374374
synchronized (ExceptionThrowingFn.class) {
375375
callStates =
376-
(Map<Integer, DelayedCallStateTracker>) ExceptionThrowingFn.callStateMap.clone();
376+
(Map<Integer, DelayedCallStateTracker>)
377+
Collections.synchronizedMap(
378+
ExceptionThrowingFn.callStateMap.entrySet().stream()
379+
.collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())));
377380
}
378381
assertThat(callStates, is(not(anEmptyMap())));
379382
// assert that callStateMap contains only TEARDOWN as a value. Note: We do not expect

0 commit comments

Comments
 (0)