We’re working on a program (timemachine) that can compute derivatives of observables wrt to FF parameters
We use the chain rule to take simulation snapshots and find their energies wrt FF params, and then we can tweak the FF params and see how the observable would be changed
One challenge is knowing which FF parameters are “broadcasted” multiple times in a system
In principle, jax could be used to compute these gradients. But there’s one part of the calc that we do, which is way too slow to compute gradients for. For those, we have a set of hand-written GPU kernels.
These kernels have been highly optimized to compute hessian vector products and jacobian products. These
JDC – Advantage of jax is that other people take care of the backend – It can plug into CPUs, GPUs, and perform any needed performance enhancements without out intervention.
We use jax for a lot of things, not just autodiff, but also optimizers, which contain all routines needed for optimizing parameters (in this case, FF params)
“How do we turn FF parameters into System parameters, and be able to compute jacobian for them?”
JDC – We basically need to track which FF PARAMETER corresponds to which INSTANCE of a parameter in the System. If we do it right, we should be able to compute the gradient of the SYSTEM ENERGY with respect to a FF PARAMETER.
dU/dFFParam = dU/dSysParam (dot) dSysParam/dFFParam can be achieved using the jacobian. One method of computing dU/dSysParam is using jax, which offers reverse-mode automatic differentiation. When you compute dSysParam/dSSParam, that will be enormously expensive, since jacobian cas easily have 1 million rows and tens of thousands of columns. So you NEVER form the entire jacobian. Instead you make a function that creates the jacobian left-multiplied by some vector.
I’d hoped to have support in OFF for vector jacobian product (VJP) – In our case, it’s dU/dSysParam. It would be great for OFFTK to have a function that provides this. It would be a function that transforms a function f into its jacobian. It would also know how to left-multiple that jacobian by some vector.
JDC – Specific implementation?
It’s up to you to implement the VJP however you want. The time-consuming part is computing dU/dSysParams.
JDC – OpenFF evaluator computes dU/dFFParams by finite difference currently, which is really inefficient. When we can do this analytically, we’ll unlock a lot more sicnce
The vector that we talk about in the VJP is the “adjoint”, which is defined as the gradient of the loss, where the loss is the objective function (basically potential) wrt system parameters.
Timemachine does not do dU/dSystem. It provides one form of a vector that should be right-multiplied by a vector that we care about. You just want to do dU/dParameters, whereas timemachine does a more complex mixed partial, since it deals with trajectories.
JDC – So, what output can we provide to be useful to YTZ
JDC – Our goals are to:
Provide a lightweight python container for systems
Have abstractions that match concept of modular parameter types from FF
Provide interoperability with other representations of parameterized systems, and maybe eventually provide a library for other simulations engines to interact with our System object
Good, friendly, portable serialized representations
Easily render object model into several performant representations (such as OMM system), or a differentiable representation
Provide a flattened view into the FF parameters and their multiple instances of application in the system
If OFF provides a parameterize(off_top, ff) method, it would be nice to have it output a jax-compatible object, since then we can call jax.vjp on it. One difficulty with jax is that the body of a requirement has to be traceable.
JDC – The mapping from FF parameters into parameterized system “parameter slots” will most likely be a large, sparse, (mostly) boolean matrix (or simple combinations). We can provide a function that will build that representation from the contents of the system.
It’s not all 1’s and 0’s, because of things like idivfs
JDC – Instead of having a parameterize
function make the jax object directly, it will create an OFF system object, which itself has a OFFSystem.to_jax
method
There may be implementation issues for to_jax – what kind of object will be returned? Jax representation has to be a function comprised of primitives
JDC – This should be possible. We can also have a to_tensorflow
and to_pytorch
that operate similarly.
Issue is that packages like jax, pytorch tf, hide away the compute graph as much as they can. In the to_jax
method, it would need to return a jax function. For to_tensorflow
, it will need to need to be a tensorflow node, which may need to be pinned to a device (modern infra may need a trace to the function itself).
For OFF, you want to fit parameters using OFF Evaluator. That requires doing lots of finite-difference calcs.
MT – Clarify “backwards mode”?
Suppose you have a funciton that’s f(g(h(x))), compute f'(x)
df/dg dot dg/dh dot dh/dx = a.b.c
a(b(c)) - forwardmode
((ab)c) - backwardsmode
YTZ – Basically, if you have a bunch of matrices that you’re going to multiply in sequence, the first being AxB, the second BxC, the third CxD, then you should start with the SMALLEST dimensions first, and associate them w/ parentheses (multiply them first), and this way you can reduce the total number of operations you need to perform to get all this done.
YTZ – It’d be useful if Matt ran some parameter fitting experiments himself, since that would provide insight into the process. It would show common patterns of expensive/tricky steps in the process. I’ve learned a lot by doing this, and by following what Lee Ping found.
YTZ – We should get a better idea of what would be a useful output for to_jax
by just trying it out ourselves.
YTZ – I’ll be at virtual May meeting.
MT – Could we meet before that? Maybe in two weeks?
(General) yes – Matt will reach out to schedule