Skip to content

PML-327: Backprop guard implemented for memristor#219

Open
LF-Vigneux wants to merge 28 commits into
merlinquantum:release/0.4from
LF-Vigneux:PML-327-backpropagation-guard-memristor
Open

PML-327: Backprop guard implemented for memristor#219
LF-Vigneux wants to merge 28 commits into
merlinquantum:release/0.4from
LF-Vigneux:PML-327-backpropagation-guard-memristor

Conversation

@LF-Vigneux
Copy link
Copy Markdown
Contributor

Summary

Detaching history memristor states after num_backprop_steps forward passes to limit the memristive history depth accessed by the backwards.

Related Issue

PML-327

Type of change

  • Bug fix
  • New feature
  • Documentation update
  • Refactor / Cleanup
  • Performance improvement
  • CI / Build / Tooling
  • Breaking change (requires migration notes)

Proposed changes

Adding a num_backprop_stepsparameter to the add_memristve_phase_shifter method. This argument detaches any tensor in the memristive history that is associated with a forward pass at least num_backprop_steps+1 iterations away.

How to test / How to run

pytest -q

Documentation

  • User docs updated (Sphinx)
  • Examples / notebooks updated
  • Docstrings updated
  • Updated the API

@ben9871 ben9871 self-assigned this May 21, 2026
Copy link
Copy Markdown
Contributor

@ben9871 ben9871 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A stray debug print remains, some issues with detaching tensors and observability, can enhance tests too

Comment thread merlin/algorithms/layer.py Outdated
"""
)

# Detach history that is too long ago
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This detaches an old tensor stored in memristive_history, but it does not detach the recurrent state chain used by future forwards. By the time this line runs, newer self.memristive_state[i] tensors have already been computed from the original graph. Replacing an old list entry with old.detach() does not rewrite the graph inside later states. With num_backprop_steps=0, a loss on the fourth forward still produces gradients for the first three inputs. The fix should detach the active state that will be reused by the next timestep, not only the historical copy kept for inspection. Concretely, the recurrence needs a truncated-BPTT(Backprop through time) boundary such that self.memristive_state[i] is detached when it falls outside the requested sliding window.

Comment thread tests/algorithms/test_layer.py Outdated
num_backprop_steps_a = metadata_a["num_backprop_steps"]
num_backprop_steps_b = metadata_b["num_backprop_steps"]

# Check memPS_a history: only the last (num_backprop_steps + 1) should be attached
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This checks requires_grad on tensors stored in the history list, but that is not a sufficient test of the sliding-window contract. A detached old history entry can coexist with newer state tensors that still retain graph links to older forwards. Add a behavioural test that backpropagates a loss from the last timestep and verifies which previous timestep inputs/outputs receive gradients.

Comment thread merlin/algorithms/layer.py Outdated
self.memristive_history[i][position_to_detach] = (
self.memristive_history[i][position_to_detach].detach()
)
print(self.memristive_history[i])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

: ) This debug print() runs inside QuantumLayer.forward() for memristive layers. It will pollute training logs and notebooks once a sequence has multiple timesteps.

Comment thread merlin/builder/circuit_builder.py Outdated
update_rule: Callable,
initial_state: float,
name: str | None = None,
num_backprop_steps: int = 0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_backprop_steps is a window length, so it should be validated as a non-negative integer at the builder boundary. Right now num_backprop_steps=-1 is accepted and stored in metadata, then the forward-time indexing logic can compute an invalid position. Invalid API input should fail early with a clear ValueError.

Comment thread merlin/algorithms/layer.py Outdated
dtype=self.dtype,
)
self.memristive_history[i] = [self.memristive_state[i]]
if self._memristive_metadata[i]["num_backprop_steps"] == 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This .detach() call has no effect as written because Tensor.detach() is not in-place. It returns a new detached tensor, and that return value is immediately discarded here. The reset tensor is newly created with torch.full(...) and normally has no graph anyway, so this branch can probably be removed. If the intent is to detach explicitly, assign the return value back: self.memristive_history[i][0] = self.memristive_history[i][0].detach().

@CassNot CassNot added this to the v0.4 milestone May 28, 2026
@CassNot CassNot added the enhancement New feature or request label May 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants