Skip to content

Commit 3956cca

Browse files
authored
Merge branch 'main' into update-cudaq-sha-1763775326
2 parents bbb8dbc + 0104722 commit 3956cca

File tree

7 files changed

+768
-519
lines changed

7 files changed

+768
-519
lines changed

libs/qec/lib/decoder.cpp

Lines changed: 91 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ INSTANTIATE_REGISTRY(cudaq::qec::decoder, const cudaqx::tensor<uint8_t> &)
2020
INSTANTIATE_REGISTRY(cudaq::qec::decoder, const cudaqx::tensor<uint8_t> &,
2121
const cudaqx::heterogeneous_map &)
2222

23+
// Include decoder implementations AFTER registry instantiation
24+
#include "decoders/sliding_window.h"
25+
2326
namespace cudaq::qec {
2427

2528
struct decoder::rt_impl {
@@ -50,6 +53,18 @@ struct decoder::rt_impl {
5053

5154
/// The id of the decoder (for instrumentation)
5255
uint32_t decoder_id = 0;
56+
57+
bool is_sliding_window = false;
58+
59+
/// The number of syndromes per round. Only used for sliding window decoder.
60+
size_t num_syndromes_per_round = 0;
61+
62+
/// Whether the first round detectors are included. Only used for sliding
63+
/// window decoder.
64+
bool has_first_round_detectors = false;
65+
66+
/// The current round. Only used for sliding window decoder.
67+
uint32_t current_round = 0;
5368
};
5469

5570
void decoder::rt_impl_deleter::operator()(rt_impl *p) const { delete p; }
@@ -175,20 +190,41 @@ void decoder::set_decoder_id(uint32_t decoder_id) {
175190

176191
uint32_t decoder::get_decoder_id() const { return pimpl->decoder_id; }
177192

178-
void decoder::set_D_sparse(const std::vector<std::vector<uint32_t>> &D_sparse) {
179-
this->D_sparse = D_sparse;
193+
template <typename PimplType>
194+
void set_D_sparse_common(decoder *decoder,
195+
const std::vector<std::vector<uint32_t>> &D_sparse,
196+
PimplType *pimpl) {
197+
auto *sw_decoder = dynamic_cast<sliding_window *>(decoder);
198+
199+
if (sw_decoder != nullptr) {
200+
pimpl->is_sliding_window = true;
201+
pimpl->num_syndromes_per_round = sw_decoder->get_num_syndromes_per_round();
202+
// Check if first row is a first-round detector (single syndrome index)
203+
pimpl->has_first_round_detectors =
204+
(D_sparse.size() > 0 && D_sparse[0].size() == 1);
205+
pimpl->current_round = 0;
206+
pimpl->persistent_detector_buffer.resize(pimpl->num_syndromes_per_round);
207+
pimpl->persistent_soft_detector_buffer.resize(
208+
pimpl->num_syndromes_per_round);
209+
210+
} else {
211+
pimpl->is_sliding_window = false;
212+
}
213+
180214
pimpl->num_msyn_per_decode = calculate_num_msyn_per_decode(D_sparse);
181215
pimpl->msyn_buffer.clear();
182216
pimpl->msyn_buffer.resize(pimpl->num_msyn_per_decode);
183217
pimpl->msyn_buffer_index = 0;
184218
}
185219

220+
void decoder::set_D_sparse(const std::vector<std::vector<uint32_t>> &D_sparse) {
221+
this->D_sparse = D_sparse;
222+
set_D_sparse_common(this, D_sparse, pimpl.get());
223+
}
224+
186225
void decoder::set_D_sparse(const std::vector<int64_t> &D_sparse_vec_in) {
187226
set_sparse_from_vec(D_sparse_vec_in, this->D_sparse);
188-
pimpl->num_msyn_per_decode = calculate_num_msyn_per_decode(D_sparse);
189-
pimpl->msyn_buffer.clear();
190-
pimpl->msyn_buffer.resize(pimpl->num_msyn_per_decode);
191-
pimpl->msyn_buffer_index = 0;
227+
set_D_sparse_common(this, this->D_sparse, pimpl.get());
192228
}
193229

194230
bool decoder::enqueue_syndrome(const uint8_t *syndrome,
@@ -198,12 +234,23 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome,
198234
printf("Syndrome buffer overflow. Syndrome will be ignored.\n");
199235
return false;
200236
}
237+
238+
pimpl->current_round++;
201239
bool did_decode = false;
202240
for (std::size_t i = 0; i < syndrome_length; i++) {
203241
pimpl->msyn_buffer[pimpl->msyn_buffer_index] = syndrome[i];
204242
pimpl->msyn_buffer_index++;
205243
}
206-
if (pimpl->msyn_buffer_index == pimpl->msyn_buffer.size()) {
244+
245+
bool should_decode = false;
246+
if (!pimpl->is_sliding_window) {
247+
should_decode = (pimpl->msyn_buffer_index == pimpl->msyn_buffer.size());
248+
} else {
249+
should_decode =
250+
(pimpl->current_round >= 2) ||
251+
(pimpl->current_round == 1 && pimpl->has_first_round_detectors);
252+
}
253+
if (should_decode) {
207254
// These are just for logging. They are initialized in such a way to avoid
208255
// dynamic memory allocation if logging is disabled.
209256
std::vector<uint32_t> log_msyn;
@@ -226,11 +273,34 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome,
226273
}
227274

228275
// Decode now.
229-
for (std::size_t i = 0; i < this->D_sparse.size(); i++) {
230-
pimpl->persistent_detector_buffer[i] = 0;
231-
for (auto col : this->D_sparse[i])
232-
pimpl->persistent_detector_buffer[i] ^= pimpl->msyn_buffer[col];
276+
if (!pimpl->is_sliding_window) {
277+
for (std::size_t i = 0; i < this->D_sparse.size(); i++) {
278+
pimpl->persistent_detector_buffer[i] = 0;
279+
for (auto col : this->D_sparse[i])
280+
pimpl->persistent_detector_buffer[i] ^= pimpl->msyn_buffer[col];
281+
}
282+
} else {
283+
// For sliding window decoder, syndrome_length must equal
284+
// num_syndromes_per_round
285+
assert(syndrome_length == pimpl->num_syndromes_per_round);
286+
if (pimpl->current_round == 1 && pimpl->has_first_round_detectors) {
287+
// First round: only compute first-round detectors (direct copy)
288+
for (std::size_t i = 0; i < pimpl->num_syndromes_per_round; i++) {
289+
pimpl->persistent_detector_buffer[i] = pimpl->msyn_buffer[i];
290+
}
291+
} else {
292+
// Buffer is full with 2 rounds: compute timelike detectors (XOR of two
293+
// rounds)
294+
std::size_t index =
295+
(pimpl->current_round - 2) * pimpl->num_syndromes_per_round;
296+
for (std::size_t i = 0; i < pimpl->num_syndromes_per_round; i++) {
297+
pimpl->persistent_detector_buffer[i] =
298+
pimpl->msyn_buffer[index + i] ^
299+
pimpl->msyn_buffer[index + i + pimpl->num_syndromes_per_round];
300+
}
301+
}
233302
}
303+
234304
if (should_log) {
235305
log_msyn.reserve(pimpl->msyn_buffer.size());
236306
for (std::size_t d = 0, D = pimpl->msyn_buffer.size(); d < D; d++) {
@@ -249,6 +319,14 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome,
249319
convert_vec_hard_to_soft(pimpl->persistent_detector_buffer,
250320
pimpl->persistent_soft_detector_buffer);
251321
auto decoded_result = decode(pimpl->persistent_soft_detector_buffer);
322+
323+
// If we didn't get a decoded result, just return
324+
if (pimpl->is_sliding_window) {
325+
if (decoded_result.result.size() == 0) {
326+
return false;
327+
}
328+
}
329+
252330
if (should_log) {
253331
log_t2 = std::chrono::high_resolution_clock::now();
254332
for (std::size_t e = 0, E = decoded_result.result.size(); e < E; e++)
@@ -300,6 +378,7 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome,
300378
did_decode = true;
301379
// Prepare for more data.
302380
pimpl->msyn_buffer_index = 0;
381+
pimpl->current_round = 0;
303382
}
304383
return did_decode;
305384
}
@@ -348,6 +427,7 @@ std::size_t decoder::get_num_observables() const { return O_sparse.size(); }
348427
void decoder::reset_decoder() {
349428
// Zero out all data that is considered "per-shot" memory.
350429
pimpl->msyn_buffer_index = 0;
430+
pimpl->current_round = 0;
351431
pimpl->msyn_buffer.clear();
352432
pimpl->msyn_buffer.resize(pimpl->num_msyn_per_decode);
353433
pimpl->corrections.clear();

0 commit comments

Comments
 (0)