diff --git a/doc/devices/braket_remote.rst b/doc/devices/braket_remote.rst index 14e50e80..4db2845a 100644 --- a/doc/devices/braket_remote.rst +++ b/doc/devices/braket_remote.rst @@ -66,6 +66,10 @@ You can set a timeout by using the ``poll_timeout_seconds`` argument; the device will retry circuits that do not complete within the timeout. A timeout of 30 to 60 seconds is recommended for circuits with fewer than 25 qubits. +Each of the submitted circuit can be visualised using the attribute ``circuits`` on the device + +>> print(remote_device.circuits[0]) + Device options ~~~~~~~~~~~~~~ diff --git a/src/braket/pennylane_plugin/braket_device.py b/src/braket/pennylane_plugin/braket_device.py index c07f8659..4d6ef09b 100644 --- a/src/braket/pennylane_plugin/braket_device.py +++ b/src/braket/pennylane_plugin/braket_device.py @@ -160,7 +160,9 @@ def __init__( self._parallel = parallel self._max_parallel = max_parallel self._circuit = None + self._circuits = [] self._task = None + self._tasks = [] self._noise_model = noise_model self._parametrize_differentiable = parametrize_differentiable self._run_kwargs = run_kwargs @@ -179,7 +181,9 @@ def __init__( def reset(self): super().reset() self._circuit = None + self._circuits = [] self._task = None + self._tasks = [] @property def operations(self) -> frozenset[str]: @@ -195,11 +199,21 @@ def circuit(self) -> Circuit: """Circuit: The last circuit run on this device.""" return self._circuit + @property + def circuits(self) -> list[Circuit]: + """Circuit: The circuits run on this device.""" + return self._circuits + @property def task(self) -> QuantumTask: """QuantumTask: The task corresponding to the last run circuit.""" return self._task + @property + def tasks(self) -> list[QuantumTask]: + """The tasks corresponding to the circuits run on this device.""" + return self._tasks + @property def parallel(self) -> bool: """bool: Whether the device supports parallel execution of batches.""" @@ -686,6 +700,8 @@ def __init__( self._poll_interval_seconds = poll_interval_seconds self._max_connections = max_connections self._max_retries = max_retries + self._circuits = [] + self._tasks = [] @property def use_grouping(self) -> bool: @@ -698,6 +714,8 @@ def use_grouping(self) -> bool: return not ("provides_jacobian" in caps and caps["provides_jacobian"]) def _run_task_batch(self, braket_circuits, pl_circuits, batch_shots: int, inputs): + self._circuits = braket_circuits + batch_shots = 0 if self.analytic else self.shots if self._supports_program_sets: program_set = ( ProgramSet.zip(braket_circuits, input_sets=inputs) @@ -712,6 +730,7 @@ def _run_task_batch(self, braket_circuits, pl_circuits, batch_shots: int, inputs poll_interval_seconds=self._poll_interval_seconds, **self._run_kwargs, ) + self._tasks = [task] return self._braket_program_set_to_pl_result(task.result(), pl_circuits) task_batch = self._device.run_batch( braket_circuits, @@ -724,7 +743,7 @@ def _run_task_batch(self, braket_circuits, pl_circuits, batch_shots: int, inputs inputs=inputs, **self._run_kwargs, ) - + self._tasks = task_batch.tasks # Call results() to retrieve the Braket results in parallel. try: braket_results_batch = task_batch.results( diff --git a/test/unit_tests/test_braket_device.py b/test/unit_tests/test_braket_device.py index f40c17c0..91046b0d 100644 --- a/test/unit_tests/test_braket_device.py +++ b/test/unit_tests/test_braket_device.py @@ -203,11 +203,15 @@ def test_reset(): """Tests that the members of the device are cleared on reset.""" dev = _aws_device(wires=2) dev._circuit = CIRCUIT + dev._circuits = [CIRCUIT, CIRCUIT] dev._task = TASK + dev._tasks = [TASK, TASK] dev.reset() assert dev.circuit is None + assert dev.circuits == [] assert dev.task is None + assert dev.tasks == [] def test_apply(): @@ -1115,6 +1119,24 @@ def test_batch_execute_program_set_noncommuting(): @patch.object(AwsDevice, "properties", new_callable=mock.PropertyMock) @patch.object(AwsDevice, "run_batch") +def test_aws_device_batch_execute_parallel_circuits_persistance(mock_run_batch): + mock_run_batch.return_value = TASK_BATCH + dev = _aws_device(wires=4, foo="bar", parallel=True) + assert dev.parallel is True + + with QuantumTape() as circuit: + qml.Hadamard(wires=0) + qml.CNOT(wires=[0, 1]) + qml.probs(wires=[0]) + qml.expval(qml.PauliX(1)) + qml.var(qml.PauliY(2)) + qml.sample(qml.PauliZ(3)) + + circuits = [circuit, circuit] + dev.batch_execute(circuits) + assert dev.circuits[1] + + def test_aws_device_batch_execute_parallel(mock_run_batch, mock_properties): """Test batch_execute(parallel=True) correctly calls batch execution methods for AwsDevices in Braket SDK""" @@ -1135,6 +1157,8 @@ def test_aws_device_batch_execute_parallel(mock_run_batch, mock_properties): circuits = [circuit, circuit] batch_results = dev.batch_execute(circuits) + + assert dev.tasks[0] for results in batch_results: assert np.allclose( results[0],