Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 86 additions & 89 deletions qek/data/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from typing import Any, Callable, Generator, Generic, Sequence, TypeVar, cast
from numpy.typing import NDArray
from pasqal_cloud import SDK
from pasqal_cloud.batch import Batch
from pasqal_cloud.device import BaseConfig, EmuTNConfig, EmulatorType
from pasqal_cloud.job import Job
from pasqal_cloud.utils.filters import BatchFilters
from pasqal_cloud.utils.filters import JobFilters
from pathlib import Path
import numpy as np
import os
Expand Down Expand Up @@ -561,7 +560,7 @@ class PasqalCloudExtracted(BaseExtracted):
def __init__(
self,
compiled: list[Compiled],
batch_ids: list[str],
job_ids: list[str],
sdk: SDK,
state_extractor: Callable[[Job, pl.Sequence], dict[str, int] | None],
path: Path | None = None,
Expand All @@ -571,13 +570,13 @@ def __init__(

Arguments:
compiled: The result of compiling a set of graphs.
batch_ids: The ids of the batches on the cloud API, in the same order as `compiled`.
job_ids: The ids of the jobs on the cloud API, in the same order as `compiled`.
state_extractor: A callback used to extract the counter from a job.
Used as various cloud back-ends return different formats.
path: If provided, a path at which to save the results once they're available.
"""
self._compiled = compiled
self._batch_ids = batch_ids
self._job_ids = job_ids
self._results: SyncExtracted | None = None
self._path = path
self._sdk = sdk
Expand All @@ -592,28 +591,28 @@ def _wait(self) -> None:
if self._results is not None:
# Results are already available.
return
pending_batch_ids: set[str] = set(self._batch_ids)
completed_batches: dict[str, Batch] = {}
while len(pending_batch_ids) > 0:
pending_job_ids: set[str] = set(self._job_ids)
completed_jobs: dict[str, Job] = {}
while len(pending_job_ids) > 0:
time.sleep(SLEEP_DELAY_S)

# Fetch up to 100 pending batches (upstream limits).
MAX_BATCH_LEN = 100
check_ids: list[str | UUID] = [cast(str | UUID, id) for id in pending_batch_ids][
:MAX_BATCH_LEN
# Fetch up to 100 pending jobs (upstream limits).
MAX_JOB_LEN = 100
check_ids: list[str | UUID] = [cast(str | UUID, id) for id in pending_job_ids][
:MAX_JOB_LEN
]

# Update their status.
check_batches = self._sdk.get_batches(filters=BatchFilters(id=check_ids))
for batch in check_batches.results:
assert isinstance(batch, Batch)
if batch.status not in {"PENDING", "RUNNING"}:
logger.debug("Job %s is now complete", batch.id)
pending_batch_ids.discard(batch.id)
completed_batches[batch.id] = batch
check_jobs = self._sdk.get_jobs(filters=JobFilters(id=check_ids))
for job in check_jobs.results:
assert isinstance(job, Job)
if job.status not in {"PENDING", "RUNNING"}:
logger.debug("Job %s is now complete", job.id)
pending_job_ids.discard(job.id)
completed_jobs[job.id] = job

# At this point, all batches are complete.
self._ingest(completed_batches)
# At this point, all jobs are complete.
self._ingest(completed_jobs)

def __await__(self) -> Generator[Any, Any, None]:
"""
Expand All @@ -628,72 +627,69 @@ def __await__(self) -> Generator[Any, Any, None]:
if self._results is not None:
# Results are already available.
return
pending_batch_ids: set[str] = set(self._batch_ids)
completed_batches: dict[str, Batch] = {}
while len(pending_batch_ids) > 0:
pending_job_ids: set[str] = set(self._job_ids)
completed_jobs: dict[str, Job] = {}
while len(pending_job_ids) > 0:
yield from asyncio.sleep(SLEEP_DELAY_S).__await__()

# Fetch up to 100 pending batches (upstream limits).
MAX_BATCH_LEN = 100
check_ids: list[str | UUID] = [cast(str | UUID, id) for id in pending_batch_ids][
:MAX_BATCH_LEN
# Fetch up to 100 pending jobs (upstream limits).
MAX_JOB_LEN = 100
check_ids: list[str | UUID] = [cast(str | UUID, id) for id in pending_job_ids][
:MAX_JOB_LEN
]

# Update their status.
check_batches = self._sdk.get_batches(
filters=BatchFilters(id=check_ids)
check_jobs = self._sdk.get_jobs(
filters=JobFilters(id=check_ids)
) # Ideally, this should be async, see https://github.com/pasqal-io/pasqal-cloud/issues/162.
for batch in check_batches.results:
assert isinstance(batch, Batch)
if batch.status not in {"PENDING", "RUNNING"}:
logger.debug("Job %s is now complete", batch.id)
pending_batch_ids.discard(batch.id)
completed_batches[batch.id] = batch
for job in check_jobs.results:
assert isinstance(job, Job)
if job.status not in {"PENDING", "RUNNING"}:
logger.debug("Job %s is now complete", job.id)
pending_job_ids.discard(job.id)
completed_jobs[job.id] = job

# At this point, all batches are complete.
self._ingest(completed_batches)
# At this point, all jobs are complete.
self._ingest(completed_jobs)

def _ingest(self, batches: dict[str, Batch]) -> None:
def _ingest(self, jobs: dict[str, Job]) -> None:
"""
Ingest data received from the remote server.

No I/O.
"""
assert len(batches) == len(self._batch_ids)
assert len(jobs) == len(self._job_ids)

raw_data = []
targets: list[int] = []
sequences = []
states = []
for i, id in enumerate(self._batch_ids):
batch = batches[id]
for i, id in enumerate(self._job_ids):
job = jobs[id]
compiled = self._compiled[i]
# Note: There's only one job per batch.
assert len(batch.jobs) == 1
for job in batch.jobs.values():
if job.status == "DONE":
state_dict = self._state_extractor(job, compiled.sequence)
if state_dict is None:
logger.warning(
"Batch %s (graph %s) did not return a usable state, skipping",
i,
compiled.graph.id,
)
continue
raw_data.append(compiled.graph)
if compiled.graph.target is not None:
targets.append(compiled.graph.target)
sequences.append(compiled.sequence)
states.append(state_dict)
else:
# If some sequences failed, let's skip them and proceed as well as we can.
if job.status == "DONE":
state_dict = self._state_extractor(job, compiled.sequence)
if state_dict is None:
logger.warning(
"Batch %s (graph %s) failed with errors %s, skipping",
"Job %s (graph %s) did not return a usable state, skipping",
i,
compiled.graph.id,
job.status,
job.errors,
)
continue
raw_data.append(compiled.graph)
if compiled.graph.target is not None:
targets.append(compiled.graph.target)
sequences.append(compiled.sequence)
states.append(state_dict)
else:
# If some sequences failed, let's skip them and proceed as well as we can.
logger.warning(
"Job %s (graph %s) failed with status %s and errors %s, skipping",
i,
compiled.graph.id,
job.status,
job.errors,
)
self._results = SyncExtracted(
raw_data=raw_data, targets=targets, sequences=sequences, states=states
)
Expand Down Expand Up @@ -754,9 +750,9 @@ class BaseRemoteExtractor(BaseExtractor[GraphType], Generic[GraphType]):
device_name: The name of the device to use. As of this writing,
the default value of "FRESNEL" represents the latest QPU
available through the Pasqal Cloud API.
batch_id: Use this to resume a workflow e.g. after turning off
job_id: Use this to resume a workflow e.g. after turning off
your computer while the QPU was executing your sequences.
Warning: A batch started with one executor MUST NOT be resumed
Warning: A job started with one executor MUST NOT be resumed
with a different executor.
"""

Expand All @@ -767,7 +763,7 @@ def __init__(
username: str,
device_name: str,
password: str | None = None,
batch_ids: list[str] | None = None,
job_ids: list[str] | None = None,
path: Path | None = None,
):
sdk = SDK(username=username, project_id=project_id, password=password)
Expand All @@ -778,11 +774,11 @@ def __init__(

super().__init__(device=device, compiler=compiler, path=path)
self._sdk = sdk
self._batch_ids: list[str] | None = batch_ids
self._job_ids: list[str] | None = job_ids

@property
def batch_ids(self) -> list[str] | None:
return self._batch_ids
def job_ids(self) -> list[str] | None:
return self._job_ids

@abc.abstractmethod
def run(
Expand All @@ -803,7 +799,7 @@ def _run(
logger.warning("No sequences to run, did you forget to call compile()?")
return PasqalCloudExtracted(
compiled=[],
batch_ids=[],
job_ids=[],
sdk=self._sdk,
path=self.path,
state_extractor=state_extractor,
Expand All @@ -814,34 +810,36 @@ def _run(
# If we want to add more runs, we'll need to split them across several jobs.
max_runs = device.max_runs if isinstance(device.max_runs, int) else 500

if self._batch_ids is None:
if self._job_ids is None:
# Enqueue jobs.
self._batch_ids = []
self._job_ids = []
for compiled in self.sequences:
logger.debug("Enqueuing execution of compiled graph #%s", compiled.graph.id)
batch = self._sdk.create_batch(
job = self._sdk.create_batch(
compiled.sequence.to_abstract_repr(),
jobs=[{"runs": max_runs}],
wait=False,
emulator=emulator,
configuration=config,
)
assert len(job.ordered_jobs) == 1
job_id = job.ordered_jobs[0].id
logger.info(
"Remote execution of compiled graph #%s starting, batched with id %s",
"Remote execution of compiled graph #%s starting, job with id %s",
compiled.graph.id,
batch.id,
job_id,
)
self._batch_ids.append(batch.id)
self._job_ids.append(job_id)
logger.info(
"All %s jobs enqueued for remote execution, with ids %s",
len(self._batch_ids),
self._batch_ids,
len(self._job_ids),
self._job_ids,
)
assert len(self._batch_ids) == len(self.sequences)
assert len(self._job_ids) == len(self.sequences)

return PasqalCloudExtracted(
compiled=self.sequences,
batch_ids=self._batch_ids,
job_ids=self._job_ids,
sdk=self._sdk,
path=self.path,
state_extractor=state_extractor,
Expand Down Expand Up @@ -876,7 +874,7 @@ class RemoteQPUExtractor(BaseRemoteExtractor[GraphType]):
device_name: The name of the device to use. As of this writing,
the default value of "FRESNEL" represents the latest QPU
available through the Pasqal Cloud API.
batch_id: Use this to resume a workflow e.g. after turning off
job_id: Use this to resume a workflow e.g. after turning off
your computer while the QPU was executing your sequences.
"""

Expand All @@ -887,7 +885,7 @@ def __init__(
username: str,
device_name: str = "FRESNEL",
password: str | None = None,
batch_ids: list[str] | None = None,
job_ids: list[str] | None = None,
path: Path | None = None,
):
super().__init__(
Expand All @@ -896,7 +894,7 @@ def __init__(
username=username,
device_name=device_name,
password=password,
batch_ids=batch_ids,
job_ids=job_ids,
path=path,
)

Expand Down Expand Up @@ -927,7 +925,7 @@ class RemoteEmuMPSExtractor(BaseRemoteExtractor[GraphType]):
device_name: The name of the device to use. As of this writing,
the default value of "FRESNEL" represents the latest QPU
available through the Pasqal Cloud API.
batch_id: Use this to resume a workflow e.g. after turning off
job_id: Use this to resume a workflow e.g. after turning off
your computer while the QPU was executing your sequences.
"""

Expand All @@ -938,7 +936,7 @@ def __init__(
username: str,
device_name: str = "FRESNEL",
password: str | None = None,
batch_ids: list[str] | None = None,
job_ids: list[str] | None = None,
path: Path | None = None,
):
super().__init__(
Expand All @@ -947,17 +945,16 @@ def __init__(
username=username,
device_name=device_name,
password=password,
batch_ids=batch_ids,
job_ids=job_ids,
path=path,
)

def run(self, dt: int = 10) -> PasqalCloudExtracted:
def extractor(job: Job, sequence: pl.Sequence) -> dict[str, int] | None:
cutoff_duration = int(ceil(sequence.get_duration() / dt) * dt)
full_result = job.full_result
if full_result is None:
return None
result = full_result["bitstring"][cutoff_duration]
result = full_result["counter"]
if result is None:
return None
assert isinstance(result, dict)
Expand Down
3 changes: 1 addition & 2 deletions qek/target/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,7 @@ async def run(
bag = cast(dict[str, dict[int, Counter[str]]], job.result)

assert self._sequence is not None
cutoff_duration = int(ceil(self._sequence.get_duration() / dt) * dt)
return bag["bitstring"][cutoff_duration]
return bag["counter"]


if os.name == "posix":
Expand Down
Empty file added tests/__init__.py
Empty file.
7 changes: 7 additions & 0 deletions tests/cloud_fixtures/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Fixtures

This folder contains some fixtures data that can be used to mock the cloud responses for tests.

## Files

- `device_specs.json`: Extract from the response of the GET /devices/public-specs endpoint. It can be used to retrieve the Fresnel specs.
3 changes: 3 additions & 0 deletions tests/cloud_fixtures/device_specs.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"FRESNEL": "{\"name\": \"Fresnel\", \"dimensions\": 2, \"rydberg_level\": 60, \"min_atom_distance\": 5, \"max_atom_num\": 80, \"max_radial_distance\": 38, \"interaction_coeff_xy\": null, \"supports_slm_mask\": false, \"max_layout_filling\": 0.5, \"optimal_layout_filling\": 0.45, \"min_layout_traps\": 10, \"max_layout_traps\": 200, \"max_sequence_duration\": 6000, \"max_runs\": 500, \"reusable_channels\": false, \"pre_calibrated_layouts\": [], \"version\": \"1\", \"channels\": [{\"id\": \"rydberg_global\", \"basis\": \"ground-rydberg\", \"addressing\": \"Global\", \"max_abs_detuning\": 48.69468613064179, \"max_amp\": 12.566370614359172, \"min_retarget_interval\": null, \"fixed_retarget_t\": null, \"max_targets\": null, \"clock_period\": 4, \"min_duration\": 16, \"max_duration\": 6000, \"min_avg_amp\": 0.5654866776461628, \"mod_bandwidth\": 8, \"eom_config\": {\"limiting_beam\": \"RED\", \"max_limiting_amp\": 138.23007675795088, \"intermediate_detuning\": 2513.2741228718346, \"controlled_beams\": [\"BLUE\"], \"mod_bandwidth\": 40, \"custom_buffer_time\": 240, \"multiple_beam_control\": false, \"red_shift_coeff\": 1.656}}], \"is_virtual\": false}"
}
Loading
Loading