PML-327: Backprop guard implemented for memristor#219
Conversation
ben9871
left a comment
There was a problem hiding this comment.
A stray debug print remains, some issues with detaching tensors and observability, can enhance tests too
| """ | ||
| ) | ||
|
|
||
| # Detach history that is too long ago |
There was a problem hiding this comment.
This detaches an old tensor stored in memristive_history, but it does not detach the recurrent state chain used by future forwards. By the time this line runs, newer self.memristive_state[i] tensors have already been computed from the original graph. Replacing an old list entry with old.detach() does not rewrite the graph inside later states. With num_backprop_steps=0, a loss on the fourth forward still produces gradients for the first three inputs. The fix should detach the active state that will be reused by the next timestep, not only the historical copy kept for inspection. Concretely, the recurrence needs a truncated-BPTT(Backprop through time) boundary such that self.memristive_state[i] is detached when it falls outside the requested sliding window.
| num_backprop_steps_a = metadata_a["num_backprop_steps"] | ||
| num_backprop_steps_b = metadata_b["num_backprop_steps"] | ||
|
|
||
| # Check memPS_a history: only the last (num_backprop_steps + 1) should be attached |
There was a problem hiding this comment.
This checks requires_grad on tensors stored in the history list, but that is not a sufficient test of the sliding-window contract. A detached old history entry can coexist with newer state tensors that still retain graph links to older forwards. Add a behavioural test that backpropagates a loss from the last timestep and verifies which previous timestep inputs/outputs receive gradients.
| self.memristive_history[i][position_to_detach] = ( | ||
| self.memristive_history[i][position_to_detach].detach() | ||
| ) | ||
| print(self.memristive_history[i]) |
There was a problem hiding this comment.
: ) This debug print() runs inside QuantumLayer.forward() for memristive layers. It will pollute training logs and notebooks once a sequence has multiple timesteps.
| update_rule: Callable, | ||
| initial_state: float, | ||
| name: str | None = None, | ||
| num_backprop_steps: int = 0, |
There was a problem hiding this comment.
num_backprop_steps is a window length, so it should be validated as a non-negative integer at the builder boundary. Right now num_backprop_steps=-1 is accepted and stored in metadata, then the forward-time indexing logic can compute an invalid position. Invalid API input should fail early with a clear ValueError.
| dtype=self.dtype, | ||
| ) | ||
| self.memristive_history[i] = [self.memristive_state[i]] | ||
| if self._memristive_metadata[i]["num_backprop_steps"] == 0: |
There was a problem hiding this comment.
This .detach() call has no effect as written because Tensor.detach() is not in-place. It returns a new detached tensor, and that return value is immediately discarded here. The reset tensor is newly created with torch.full(...) and normally has no graph anyway, so this branch can probably be removed. If the intent is to detach explicitly, assign the return value back: self.memristive_history[i][0] = self.memristive_history[i][0].detach().
Summary
Detaching history memristor states after num_backprop_steps forward passes to limit the memristive history depth accessed by the backwards.
Related Issue
PML-327
Type of change
Proposed changes
Adding a
num_backprop_stepsparameter to the add_memristve_phase_shifter method. This argument detaches any tensor in the memristive history that is associated with a forward pass at leastnum_backprop_steps+1iterations away.How to test / How to run
Documentation