15
15
*/
16
16
package org .springframework .batch .core .repository .dao .mongodb ;
17
17
18
- import java .util .ArrayList ;
19
18
import java .util .Collection ;
20
- import java .util .Comparator ;
21
19
import java .util .List ;
22
- import java .util .Optional ;
23
20
24
21
import org .springframework .batch .core .job .JobExecution ;
25
22
import org .springframework .batch .core .job .JobInstance ;
26
23
import org .springframework .batch .core .step .StepExecution ;
27
24
import org .springframework .batch .core .repository .dao .StepExecutionDao ;
28
25
import org .springframework .batch .core .repository .persistence .converter .JobExecutionConverter ;
29
26
import org .springframework .batch .core .repository .persistence .converter .StepExecutionConverter ;
27
+ import org .springframework .data .domain .Sort ;
30
28
import org .springframework .data .mongodb .core .MongoOperations ;
31
29
import org .springframework .data .mongodb .core .query .Query ;
32
30
import org .springframework .jdbc .support .incrementer .DataFieldMaxValueIncrementer ;
36
34
37
35
/**
38
36
* @author Mahmoud Ben Hassine
37
+ * @author Jinwoo Bae
39
38
* @since 5.2.0
40
39
*/
41
40
public class MongoStepExecutionDao implements StepExecutionDao {
@@ -100,34 +99,42 @@ public StepExecution getStepExecution(JobExecution jobExecution, Long stepExecut
100
99
@ Override
101
100
public StepExecution getLastStepExecution (JobInstance jobInstance , String stepName ) {
102
101
// TODO optimize the query
103
- // get all step executions
104
- List <org .springframework .batch .core .repository .persistence .StepExecution > stepExecutions = new ArrayList <>();
105
- Query query = query (where ("jobInstanceId" ).is (jobInstance .getId ()));
102
+ Query jobExecutionQuery = query (where ("jobInstanceId" ).is (jobInstance .getId ()));
106
103
List <org .springframework .batch .core .repository .persistence .JobExecution > jobExecutions = this .mongoOperations
107
- .find (query , org .springframework .batch .core .repository .persistence .JobExecution .class ,
104
+ .find (jobExecutionQuery , org .springframework .batch .core .repository .persistence .JobExecution .class ,
108
105
JOB_EXECUTIONS_COLLECTION_NAME );
109
- for (org .springframework .batch .core .repository .persistence .JobExecution jobExecution : jobExecutions ) {
110
- stepExecutions .addAll (jobExecution .getStepExecutions ());
111
- }
112
- // sort step executions by creation date then id (see contract) and return the
113
- // first one
114
- Optional <org .springframework .batch .core .repository .persistence .StepExecution > lastStepExecution = stepExecutions
115
- .stream ()
116
- .filter (stepExecution -> stepExecution .getName ().equals (stepName ))
117
- .min (Comparator
118
- .comparing (org .springframework .batch .core .repository .persistence .StepExecution ::getCreateTime )
119
- .thenComparing (org .springframework .batch .core .repository .persistence .StepExecution ::getId ));
120
- if (lastStepExecution .isPresent ()) {
121
- org .springframework .batch .core .repository .persistence .StepExecution stepExecution = lastStepExecution .get ();
122
- JobExecution jobExecution = this .jobExecutionConverter .toJobExecution (jobExecutions .stream ()
123
- .filter (execution -> execution .getJobExecutionId ().equals (stepExecution .getJobExecutionId ()))
124
- .findFirst ()
125
- .get (), jobInstance );
126
- return this .stepExecutionConverter .toStepExecution (stepExecution , jobExecution );
106
+
107
+ if (jobExecutions .isEmpty ()) {
108
+ return null ;
127
109
}
128
- else {
110
+
111
+ List <Long > jobExecutionIds = jobExecutions .stream ()
112
+ .map (org .springframework .batch .core .repository .persistence .JobExecution ::getJobExecutionId )
113
+ .toList ();
114
+
115
+ Query stepExecutionQuery = query (where ("name" ).is (stepName ).and ("jobExecutionId" ).in (jobExecutionIds ))
116
+ .with (Sort .by (Sort .Direction .DESC , "createTime" , "stepExecutionId" ))
117
+ .limit (1 );
118
+
119
+ org .springframework .batch .core .repository .persistence .StepExecution stepExecution = this .mongoOperations
120
+ .findOne (stepExecutionQuery , org .springframework .batch .core .repository .persistence .StepExecution .class ,
121
+ STEP_EXECUTIONS_COLLECTION_NAME );
122
+
123
+ if (stepExecution == null ) {
129
124
return null ;
130
125
}
126
+
127
+ org .springframework .batch .core .repository .persistence .JobExecution jobExecution = jobExecutions .stream ()
128
+ .filter (execution -> execution .getJobExecutionId ().equals (stepExecution .getJobExecutionId ()))
129
+ .findFirst ()
130
+ .orElse (null );
131
+
132
+ if (jobExecution != null ) {
133
+ JobExecution jobExecutionDomain = this .jobExecutionConverter .toJobExecution (jobExecution , jobInstance );
134
+ return this .stepExecutionConverter .toStepExecution (stepExecution , jobExecutionDomain );
135
+ }
136
+
137
+ return null ;
131
138
}
132
139
133
140
@ Override
@@ -144,22 +151,23 @@ public void addStepExecutions(JobExecution jobExecution) {
144
151
145
152
@ Override
146
153
public long countStepExecutions (JobInstance jobInstance , String stepName ) {
147
- long count = 0 ;
148
- // TODO optimize the count query
149
- Query query = query (where ("jobInstanceId" ).is (jobInstance .getId ()));
150
- List <org .springframework .batch .core .repository .persistence .JobExecution > jobExecutions = this .mongoOperations
151
- .find (query , org .springframework .batch .core .repository .persistence .JobExecution .class ,
152
- JOB_EXECUTIONS_COLLECTION_NAME );
153
- for (org .springframework .batch .core .repository .persistence .JobExecution jobExecution : jobExecutions ) {
154
- List <org .springframework .batch .core .repository .persistence .StepExecution > stepExecutions = jobExecution
155
- .getStepExecutions ();
156
- for (org .springframework .batch .core .repository .persistence .StepExecution stepExecution : stepExecutions ) {
157
- if (stepExecution .getName ().equals (stepName )) {
158
- count ++;
159
- }
160
- }
154
+ Query jobExecutionQuery = query (where ("jobInstanceId" ).is (jobInstance .getId ()));
155
+ List <Long > jobExecutionIds = this .mongoOperations
156
+ .find (jobExecutionQuery , org .springframework .batch .core .repository .persistence .JobExecution .class ,
157
+ JOB_EXECUTIONS_COLLECTION_NAME )
158
+ .stream ()
159
+ .map (org .springframework .batch .core .repository .persistence .JobExecution ::getJobExecutionId )
160
+ .toList ();
161
+
162
+ if (jobExecutionIds .isEmpty ()) {
163
+ return 0 ;
161
164
}
162
- return count ;
165
+
166
+ // Count step executions directly from BATCH_STEP_EXECUTION collection
167
+ Query stepQuery = query (where ("name" ).is (stepName ).and ("jobExecutionId" ).in (jobExecutionIds ));
168
+ return this .mongoOperations .count (stepQuery ,
169
+ org .springframework .batch .core .repository .persistence .StepExecution .class ,
170
+ STEP_EXECUTIONS_COLLECTION_NAME );
163
171
}
164
172
165
173
}
0 commit comments