| Matt will post “optimize-with-jax” notebook here YTZ – Looks pretty good. YTZ – I probably wont' vjp bonds.parameterize directly, but have it chained into a larger thing that I’ll vjp YTZ – One change that I’ve made is that instead of just using bonds.parameterize, it also reutrns the bond indices, which are given to vjp with “has aux” kwarg, so it doesn’t attempt to differentiate with respect to those YTZ – I “hide” those when needed using a functools.parameterize Having these indices is useful when I’m doing alchemical transformations, so I can map from the particles/parameters in one molecule to another.
YTZ – Handle things like improper folds? YTZ – Loss should return a scalar. Input should be FF parameters, output should be scalar. YTZ – If I could make stuff work with jax.grad , I wouldn’t need the system object to make things that are vjp -able Ex. Free energy = A(q(p)) – A is free energy, which is a funciton of system params, which are a function of FF params Loss = A(q(p)) - A0 – A0 is the experimentally measured FE So dL/dp = dL/dA . dA/dq . dq/dp Or, in some cases, we have small moleucle params ff_p and protein params pro_p So, combined_q = np.concatenate([ff_q(ff_p), pro_q]) Then A(combined_q) = A(combined_q(ff_q(ff_p), pro_q)) So the thing I need is a vjp function that lets me do vjp_fn(dL/dq) and get dL/dp MT – We should be in good shape to provide this
|