|
| 1 | + |
| 2 | +## Optimizing Molecular Dynamics Weights with Machine Learning Tools |
| 3 | + |
| 4 | +**By James Holton, with contributions from Karson Chrispens, Steve, and Marcus Collins** |
| 5 | + |
| 6 | +In our latest round of diffuse scattering experiments, we ran into an intriguing optimization problem that feels a lot like training a neural network. |
| 7 | + |
| 8 | +### The Scientific Setup |
| 9 | + |
| 10 | +For each 3D pixel in reciprocal space (indexed by **h**), we have: |
| 11 | + |
| 12 | +* **Observed data**, ( y(h) ), from experiment |
| 13 | +* **Predicted data**, ( x(h) ), computed from molecular dynamics (MD) trajectories |
| 14 | + |
| 15 | +We evaluate agreement using the **Pearson correlation coefficient**: |
| 16 | + |
| 17 | +CC=⟨xy⟩−⟨x⟩⟨y⟩⟨x2⟩−⟨x⟩2⟨y2⟩−⟨y⟩2 |
| 18 | +CC= |
| 19 | +⟨x |
| 20 | +2 |
| 21 | +⟩−⟨x⟩ |
| 22 | +2 |
| 23 | + |
| 24 | +Each prediction ( x(h) ) is derived from **structure factors** ( F(h, t) ) across time points in the MD simulation: |
| 25 | + |
| 26 | +[ |
| 27 | +x(h) = \langle F(h)^2 \rangle_t - \langle F(h) \rangle_t^2 |
| 28 | +] |
| 29 | + |
| 30 | +The goal is to assign **weights** ( w(t) ) to each time point to maximize ( CC ): |
| 31 | + |
| 32 | +[ |
| 33 | +x'(h) = \sum_t w_t F(h,t)^2 - \left( \sum_t w_t F(h,t) \right)^2 |
| 34 | +] |
| 35 | + |
| 36 | +If we can find optimal weights, we can identify which regions of the trajectory best match experimental reality — potentially distinguishing “good” frames from those that detract from agreement. |
| 37 | + |
| 38 | +### Community Brainstorming |
| 39 | + |
| 40 | +**Steve** suggested asking whether CC is the right target — perhaps a likelihood might better capture the physics. |
| 41 | + |
| 42 | +**Karson Chrispens** proposed leveraging machine learning frameworks like **JAX** or **PyTorch** to treat the weights as trainable parameters. By backpropagating through the Pearson correlation, an optimizer like Adam could efficiently learn the optimal weights. |
| 43 | + |
| 44 | +**James Holton** suspected this approach could outperform traditional non-linear least-squares optimization and shared example MTZ datasets for testing. |
| 45 | + |
| 46 | +**Steve** also mentioned using a **genetic algorithm** if the weights were binary (0 or 1), though acknowledged the continuous formulation might not have a unique minimum. |
| 47 | + |
| 48 | +### Prototyping the Optimizer |
| 49 | + |
| 50 | +Karson quickly implemented a JAX-based prototype using **reciprocalspaceship** for MTZ I/O and **optax** for optimization. |
| 51 | +The loss function was simply **–CC**, and weights were constrained to (0, 1) via a sigmoid. |
| 52 | + |
| 53 | +When tested on toy datasets and real MTZ files, the optimizer: |
| 54 | + |
| 55 | +* Successfully recovered **50:50** weights for mixtures of two “ground-truth” structures. |
| 56 | +* Produced sensible intermediate values when one or both inputs were “wrong.” |
| 57 | +* Converged robustly from different initializations. |
| 58 | + |
| 59 | +Example output for a ground-truth mixture: |
| 60 | + |
| 61 | +``` |
| 62 | +Final weights: [0.46, 0.54] |
| 63 | +Final CC: 1.0000 |
| 64 | +``` |
| 65 | + |
| 66 | +And for mismatched data: |
| 67 | + |
| 68 | +``` |
| 69 | +Final weights: [0.76, 0.24] |
| 70 | +Final CC: 0.78 |
| 71 | +``` |
| 72 | + |
| 73 | +### Discussion |
| 74 | + |
| 75 | +**Marcus Collins** noted that this approach resembles computing **Boltzmann-like factors** for each configuration and suggested PyTorch could be an equally good (and more common) platform. He also cautioned that Pearson CC may not be the optimal objective function. |
| 76 | + |
| 77 | +Karson confirmed that JAX runs efficiently on GPUs and planned to scale the approach to larger datasets by stacking multiple MTZ files. |
| 78 | + |
| 79 | +### Where This Might Go Next |
| 80 | + |
| 81 | +This prototype demonstrates that **gradient-based optimization** can efficiently identify the contribution of different MD frames to observed diffuse scattering patterns. Future directions include: |
| 82 | + |
| 83 | +* Expanding to full MD trajectories with thousands of frames. |
| 84 | +* Experimenting with alternate objectives (e.g., likelihood, cross-entropy). |
| 85 | +* Incorporating **crystal symmetry** and **resolution weighting**. |
| 86 | +* Exploring physical interpretations of the resulting weights. |
| 87 | + |
| 88 | +### Code and Data |
| 89 | + |
| 90 | +Karson’s implementation, `pearson_target.py`, is available [here](https://github.com/k-chrispens/simulation_timeseries_optim), and the test MTZ data can be downloaded from |
| 91 | +[here](http://bl831.als.lbl.gov/~jamesh/pickup/diffUSE_CC_opt_test.tgz) |
| 92 | + |
| 93 | +--- |
| 94 | + |
| 95 | +**TL;DR:** |
| 96 | +By treating MD frame weights as trainable parameters in a differentiable Pearson correlation objective, we can use ML optimizers like Adam to rapidly identify which parts of a trajectory best explain experimental diffuse scattering — turning a brute-force search into a smooth, data-driven optimization problem. |
| 97 | + |
0 commit comments