Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

Item

Notes

  • Matt will post “optimize-with-jax” notebook here

  • YTZ – Looks pretty good.

    • YTZ – So you’re giving me two things: ???methods to get the parameters, and function to call vjp on

    • MT – Yes

    • YTZ – Good

  • YTZ – I probably wont' vjp bonds.parameterize directly, but have it chained into a larger thing that I’ll vjp

  • YTZ – One change that I’ve made is that instead of just using bonds.parameterize, it also reutrns the bond indices, which are given to vjp with “has aux” kwarg, so it doesn’t attempt to differentiate with respect to those

    • YTZ – I “hide” those when needed using a functools.parameterize

    • Having these indices is useful when I’m doing alchemical transformations, so I can map from the particles/parameters in one molecule to another.

      • JW – More details? What if there are different numbers of atoms?

        • YTZ – It’s fine if they have different numbers of atoms. As the interpolation scales from one molecule to the other, the particiles that don’t exist will be assigned vdW terms of 0 (with some complicated hacks to handle soft core potentials)

  • YTZ – Handle things like improper folds?

    • MT – Yup, mapping matrix can handle floats and sums. This generalizes well for even hairier things like torsion interpolation

  • YTZ – Loss should return a scalar. Input should be FF parameters, output should be scalar.

    • YTZ – The toy loss function is returning per-parameter loss. Instead have it do jnp.linalg.norm

  • YTZ – If I could make stuff work with jax.grad, I wouldn’t need the system object to make things that are vjp-able

    • Ex.

      • Free energy = A(q(p)) – A is free energy, which is a funciton of system params, which are a function of FF params

      • Loss = A(q(p)) - A0 – A0 is the experimentally measured FE

      • So dL/dp = dL/dA . dA/dq . dq/dp

      • Or, in some cases, we have small moleucle params ff_p and protein params pro_p

      • So, combined_q = np.concatenate([ff_q(ff_p), pro_q])

      • Then A(combined_q) = A(combined_q(ff_q(ff_p), pro_q))

        • In the future where we modify protein params, the last term would be pro_q(pro_p)

      • So the thing I need is a vjp function that lets me do vjp_fn(dL/dq) and get dL/dp

      • MT – We should be in good shape to provide this

  • Dry run in January?

    • Will potentialhandlers be ready?

      • MT – Nonbonded will be the hardest

    • Do we meet timemachine input spec?

    • YTZ – Unit-wise, I want unitless floats

      • MT - we do this (once data is exported to array representations)

Action items

  •  

Decisions