Skip to content

Commit e441e48

Browse files
committed
Reworked approach in place - sliding window test passes
Signed-off-by: Chuck Ketcham <[email protected]>
1 parent 58b68f2 commit e441e48

File tree

3 files changed

+676
-492
lines changed

3 files changed

+676
-492
lines changed

libs/qec/lib/decoder.cpp

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ INSTANTIATE_REGISTRY(cudaq::qec::decoder, const cudaqx::tensor<uint8_t> &)
2020
INSTANTIATE_REGISTRY(cudaq::qec::decoder, const cudaqx::tensor<uint8_t> &,
2121
const cudaqx::heterogeneous_map &)
2222

23+
// Include decoder implementations AFTER registry instantiation
24+
#include "decoders/sliding_window.h"
25+
2326
namespace cudaq::qec {
2427

2528
struct decoder::rt_impl {
@@ -50,6 +53,18 @@ struct decoder::rt_impl {
5053

5154
/// The id of the decoder (for instrumentation)
5255
uint32_t decoder_id = 0;
56+
57+
bool is_sliding_window = false;
58+
59+
/// The number of syndromes per round. Only used for sliding window decoder.
60+
size_t num_syndromes_per_round = 0;
61+
62+
/// Whether the first round detectors are included. Only used for sliding
63+
/// window decoder.
64+
bool has_first_round_detectors = false;
65+
66+
/// The current round. Only used for sliding window decoder.
67+
uint32_t current_round = 0;
5368
};
5469

5570
void decoder::rt_impl_deleter::operator()(rt_impl *p) const { delete p; }
@@ -174,6 +189,23 @@ uint32_t decoder::get_decoder_id() const { return pimpl->decoder_id; }
174189

175190
void decoder::set_D_sparse(const std::vector<std::vector<uint32_t>> &D_sparse) {
176191
this->D_sparse = D_sparse;
192+
auto *sw_decoder = dynamic_cast<sliding_window *>(this);
193+
194+
if (sw_decoder != nullptr) {
195+
pimpl->is_sliding_window = true;
196+
pimpl->num_syndromes_per_round = sw_decoder->get_num_syndromes_per_round();
197+
// Check if first row is a first-round detector (single syndrome index)
198+
pimpl->has_first_round_detectors =
199+
(D_sparse.size() > 0 && D_sparse[0].size() == 1);
200+
pimpl->current_round = 0;
201+
pimpl->persistent_detector_buffer.resize(pimpl->num_syndromes_per_round);
202+
pimpl->persistent_soft_detector_buffer.resize(
203+
pimpl->num_syndromes_per_round);
204+
205+
} else {
206+
pimpl->is_sliding_window = false;
207+
}
208+
177209
pimpl->num_msyn_per_decode = calculate_num_msyn_per_decode(D_sparse);
178210
pimpl->msyn_buffer.clear();
179211
pimpl->msyn_buffer.resize(pimpl->num_msyn_per_decode);
@@ -182,7 +214,23 @@ void decoder::set_D_sparse(const std::vector<std::vector<uint32_t>> &D_sparse) {
182214

183215
void decoder::set_D_sparse(const std::vector<int64_t> &D_sparse_vec_in) {
184216
set_sparse_from_vec(D_sparse_vec_in, this->D_sparse);
185-
pimpl->num_msyn_per_decode = calculate_num_msyn_per_decode(D_sparse);
217+
auto *sw_decoder = dynamic_cast<sliding_window *>(this);
218+
219+
if (sw_decoder != nullptr) {
220+
pimpl->is_sliding_window = true;
221+
pimpl->num_syndromes_per_round = sw_decoder->get_num_syndromes_per_round();
222+
// Check if first row is a first-round detector (single syndrome index)
223+
pimpl->has_first_round_detectors =
224+
(this->D_sparse.size() > 0 && this->D_sparse[0].size() == 1);
225+
pimpl->current_round = 0;
226+
pimpl->persistent_detector_buffer.resize(pimpl->num_syndromes_per_round);
227+
pimpl->persistent_soft_detector_buffer.resize(
228+
pimpl->num_syndromes_per_round);
229+
} else {
230+
pimpl->is_sliding_window = false;
231+
}
232+
233+
pimpl->num_msyn_per_decode = calculate_num_msyn_per_decode(this->D_sparse);
186234
pimpl->msyn_buffer.clear();
187235
pimpl->msyn_buffer.resize(pimpl->num_msyn_per_decode);
188236
pimpl->msyn_buffer_index = 0;
@@ -195,12 +243,23 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome,
195243
printf("Syndrome buffer overflow. Syndrome will be ignored.\n");
196244
return false;
197245
}
246+
247+
pimpl->current_round++;
198248
bool did_decode = false;
199249
for (std::size_t i = 0; i < syndrome_length; i++) {
200250
pimpl->msyn_buffer[pimpl->msyn_buffer_index] = syndrome[i];
201251
pimpl->msyn_buffer_index++;
202252
}
203-
if (pimpl->msyn_buffer_index == pimpl->msyn_buffer.size()) {
253+
254+
bool should_decode = false;
255+
if (!pimpl->is_sliding_window) {
256+
should_decode = (pimpl->msyn_buffer_index == pimpl->msyn_buffer.size());
257+
} else {
258+
should_decode =
259+
(pimpl->current_round >= 2) ||
260+
(pimpl->current_round == 1 && pimpl->has_first_round_detectors);
261+
}
262+
if (should_decode) {
204263
// These are just for logging. They are initialized in such a way to avoid
205264
// dynamic memory allocation if logging is disabled.
206265
std::vector<uint32_t> log_msyn;
@@ -223,11 +282,34 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome,
223282
}
224283

225284
// Decode now.
226-
for (std::size_t i = 0; i < this->D_sparse.size(); i++) {
227-
pimpl->persistent_detector_buffer[i] = 0;
228-
for (auto col : this->D_sparse[i])
229-
pimpl->persistent_detector_buffer[i] ^= pimpl->msyn_buffer[col];
285+
if (!pimpl->is_sliding_window) {
286+
for (std::size_t i = 0; i < this->D_sparse.size(); i++) {
287+
pimpl->persistent_detector_buffer[i] = 0;
288+
for (auto col : this->D_sparse[i])
289+
pimpl->persistent_detector_buffer[i] ^= pimpl->msyn_buffer[col];
290+
}
291+
} else {
292+
// For sliding window decoder, syndrome_length must equal
293+
// num_syndromes_per_round
294+
assert(syndrome_length == pimpl->num_syndromes_per_round);
295+
if (pimpl->current_round == 1 && pimpl->has_first_round_detectors) {
296+
// First round: only compute first-round detectors (direct copy)
297+
for (std::size_t i = 0; i < pimpl->num_syndromes_per_round; i++) {
298+
pimpl->persistent_detector_buffer[i] = pimpl->msyn_buffer[i];
299+
}
300+
} else {
301+
// Buffer is full with 2 rounds: compute timelike detectors (XOR of two
302+
// rounds)
303+
for (std::size_t i = 0; i < pimpl->num_syndromes_per_round; i++) {
304+
std::size_t index =
305+
(pimpl->current_round - 2) * pimpl->num_syndromes_per_round;
306+
pimpl->persistent_detector_buffer[i] =
307+
pimpl->msyn_buffer[index + i] ^
308+
pimpl->msyn_buffer[index + i + pimpl->num_syndromes_per_round];
309+
}
310+
}
230311
}
312+
231313
if (should_log) {
232314
log_msyn.reserve(pimpl->msyn_buffer.size());
233315
for (std::size_t d = 0, D = pimpl->msyn_buffer.size(); d < D; d++) {
@@ -246,6 +328,14 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome,
246328
convert_vec_hard_to_soft(pimpl->persistent_detector_buffer,
247329
pimpl->persistent_soft_detector_buffer);
248330
auto decoded_result = decode(pimpl->persistent_soft_detector_buffer);
331+
332+
// If we didn't get a decoded result, just return
333+
if (pimpl->is_sliding_window) {
334+
if (decoded_result.result.size() == 0) {
335+
return false;
336+
}
337+
}
338+
249339
if (should_log) {
250340
log_t2 = std::chrono::high_resolution_clock::now();
251341
for (std::size_t e = 0, E = decoded_result.result.size(); e < E; e++)
@@ -297,6 +387,7 @@ bool decoder::enqueue_syndrome(const uint8_t *syndrome,
297387
did_decode = true;
298388
// Prepare for more data.
299389
pimpl->msyn_buffer_index = 0;
390+
pimpl->current_round = 0;
300391
}
301392
return did_decode;
302393
}
@@ -345,6 +436,7 @@ std::size_t decoder::get_num_observables() const { return O_sparse.size(); }
345436
void decoder::reset_decoder() {
346437
// Zero out all data that is considered "per-shot" memory.
347438
pimpl->msyn_buffer_index = 0;
439+
pimpl->current_round = 0;
348440
pimpl->msyn_buffer.clear();
349441
pimpl->msyn_buffer.resize(pimpl->num_msyn_per_decode);
350442
pimpl->corrections.clear();

0 commit comments

Comments
 (0)