YTZ – So you’re giving me two things: ???, and function to call vjp on
MT – Yes
YTZ – Good
YTZ – I probably wont' vjp bonds.parameterize directly, but have it chained into a larger thing that I’ll vjp
YTZ – One change that I’ve made is that instead of just using bonds.parameterize, it also reutrns the bond indices, which are given to vjp with “has aux” kwarg, so it doesn’t attempt to differentiate with respect to those
YTZ – I “hide” those when needed using a functools.parameterize
Having these indices is useful when I’m doing alchemical transformations, so I can map from the particles/parameters in one molecule to another.
JW – More details? What if there are different numbers of atoms?
YTZ – It’s fine if they have different numbers of atoms. As the interpolation scales from one molecule to the other, the particiles that don’t exist will be assigned vdW terms of 0 (with some complicated hacks to handle soft core potentials)
YTZ – Handle things like improper folds?
MT – Yup, mapping matrix can handle floats and sums. This generalizes well for even hairier things like torsion interpolation
YTZ – Loss should return a scalar. Input should be FF parameters, output should be scalar.
YTZ – The toy loss function is returning per-parameter loss. Instead have it do jnp.linalg.norm
YTZ – If I could make stuff work with jax.grad, I wouldn’t need the system object to make things that are vjp-able
Ex.
Free energy = A(q(p)) – A is free energy, which is a funciton of system params, which are a function of FF params
Loss = A(q(p)) - A0 – A0 is the experimentally measured FE
So dL/dp = dL/dA . dA/dq . dq/dp
Or, in some cases, we have small moleucle params ff_p and protein params pro_p
So, combined_q = np.concatenate([ff_q(ff_p), pro_q])
Then A(combined_q) = A(combined_q(ff_q(ff_p), pro_q))
In the future where we modify protein params, the last term would be pro_q(pro_p)
So the thing I need is a vjp function that lets me do vjp_fn(dL/dq) and get dL/dp
Add Comment