2020-04-09 Meeting notes

Date

Apr 8, 2020

Participants

  • @Matt Thompson

  • @John Chodera

  • @Jeffrey Wagner

  • Yutong Zhao

Goals

  • Scope out things JC and YTZ need this object to do

  • in time issue #310, what is “p” in dx/dp derivative?

Discussion topics

Item

Notes

Item

Notes

Yutong/timemachine background

  • We’re working on a program (timemachine) that can compute derivatives of observables wrt to FF parameters

  • We use the chain rule to take simulation snapshots and find their energies wrt FF params, and then we can tweak the FF params and see how the observable would be changed

  • One challenge is knowing which FF parameters are “broadcasted” multiple times in a system

  • In principle, jax could be used to compute these gradients. But there’s one part of the calc that we do, which is way too slow to compute gradients for. For those, we have a set of hand-written GPU kernels.

  • These kernels have been highly optimized to compute hessian vector products and jacobian products. These

  • JDC – Advantage of jax is that other people take care of the backend – It can plug into CPUs, GPUs, and perform any needed performance enhancements without out intervention.

  • We use jax for a lot of things, not just autodiff, but also optimizers, which contain all routines needed for optimizing parameters (in this case, FF params)

  • “How do we turn FF parameters into System parameters, and be able to compute jacobian for them?”

  • JDC – We basically need to track which FF PARAMETER corresponds to which INSTANCE of a parameter in the System. If we do it right, we should be able to compute the gradient of the SYSTEM ENERGY with respect to a FF PARAMETER.

  • dU/dFFParam = dU/dSysParam (dot) dSysParam/dFFParam can be achieved using the jacobian. One method of computing dU/dSysParam is using jax, which offers reverse-mode automatic differentiation. When you compute dSysParam/dSSParam, that will be enormously expensive, since jacobian cas easily have 1 million rows and tens of thousands of columns. So you NEVER form the entire jacobian. Instead you make a function that creates the jacobian left-multiplied by some vector.

  • I’d hoped to have support in OFF for vector jacobian product (VJP) – In our case, it’s dU/dSysParam. It would be great for OFFTK to have a function that provides this. It would be a function that transforms a function f into its jacobian. It would also know how to left-multiple that jacobian by some vector.

  • JDC – Specific implementation?

  • It’s up to you to implement the VJP however you want. The time-consuming part is computing dU/dSysParams.

  • JDC – OpenFF evaluator computes dU/dFFParams by finite difference currently, which is really inefficient. When we can do this analytically, we’ll unlock a lot more sicnce

  • The vector that we talk about in the VJP is the “adjoint”, which is defined as the gradient of the loss, where the loss is the objective function (basically potential) wrt system parameters.

  • Timemachine does not do dU/dSystem. It provides one form of a vector that should be right-multiplied by a vector that we care about. You just want to do dU/dParameters, whereas timemachine does a more complex mixed partial, since it deals with trajectories.

  • JDC – So, what output can we provide to be useful to YTZ

  • JDC – Our goals are to:

    • Provide a lightweight python container for systems

    • Have abstractions that match concept of modular parameter types from FF

    • Provide interoperability with other representations of parameterized systems, and maybe eventually provide a library for other simulations engines to interact with our System object

    • Good, friendly, portable serialized representations

    • Easily render object model into several performant representations (such as OMM system), or a differentiable representation

    • Provide a flattened view into the FF parameters and their multiple instances of application in the system

  • If OFF provides a parameterize(off_top, ff) method, it would be nice to have it output a jax-compatible object, since then we can call jax.vjp on it. One difficulty with jax is that the body of a requirement has to be traceable.

  • JDC – The mapping from FF parameters into parameterized system “parameter slots” will most likely be a large, sparse, (mostly) boolean matrix (or simple combinations). We can provide a function that will build that representation from the contents of the system.

  • It’s not all 1’s and 0’s, because of things like idivfs

  • JDC – Instead of having a parameterize function make the jax object directly, it will create an OFF system object, which itself has a OFFSystem.to_jax method

  • There may be implementation issues for to_jax – what kind of object will be returned? Jax representation has to be a function comprised of primitives

  • JDC – This should be possible. We can also have a to_tensorflow and to_pytorch that operate similarly.

  • Issue is that packages like jax, pytorch tf, hide away the compute graph as much as they can. In the to_jax method, it would need to return a jax function. For to_tensorflow, it will need to need to be a tensorflow node, which may need to be pinned to a device (modern infra may need a trace to the function itself).

  • For OFF, you want to fit parameters using OFF Evaluator. That requires doing lots of finite-difference calcs.

  • MT – Clarify “backwards mode”?

  • Suppose you have a funciton that’s f(g(h(x))), compute f'(x)

  • df/dg dot dg/dh dot dh/dx = a.b.c

  • a(b(c)) - forwardmode

  • ((ab)c) - backwardsmode

  • YTZ – Basically, if you have a bunch of matrices that you’re going to multiply in sequence, the first being AxB, the second BxC, the third CxD, then you should start with the SMALLEST dimensions first, and associate them w/ parentheses (multiply them first), and this way you can reduce the total number of operations you need to perform to get all this done.

  • YTZ – It’d be useful if Matt ran some parameter fitting experiments himself, since that would provide insight into the process. It would show common patterns of expensive/tricky steps in the process. I’ve learned a lot by doing this, and by following what Lee Ping found.

  • YTZ – We should get a better idea of what would be a useful output for to_jax by just trying it out ourselves.

  • YTZ – I’ll be at virtual May meeting.

  • MT – Could we meet before that? Maybe in two weeks?

  • (General) yes – Matt will reach out to schedule





Action items

Decisions