Suitability of current infrastructure for jax optimization | MT shows notebook JC – Do calls like get_forcefield_parameters always return jax-numpy arrays? or can it return other types? YTZ – The needed formats for return types depend on where chain rule should be applied. You might want to do a survey with a broader community and ask them what tools they’re using. In general, I need to be able to get dL/dQ using jax. JC – If we plug in espaloma as a different model, we’ll have other forms of output. YTZ – Your current FF fitting must somehow compute dL/dQ. JC – Yes, but forcebalance isn’t ideal. We’d prefer to have a continually-differentiable representation, so in the future we’ll replace forceblaance with this. YTZ – I’ve shown you our pipeline. We use modern tooling. What you need to do is figure out how to replace forcebalance. JC – Question is “which methods do we need to expose to support our future fitting efforts?” YTZ – Let’s stop using the term “JIT” – I think what you really want is automatic differentiation, where you can chain data structures together. So, if we were all agreeing to use one package (jax, or numpy, or something else) then we wouldn’t have this problem. But if you have different frontends to compute dL/dQ with different ml packages, then you have to use a different set of chain rules. What OpenFF should avoid is writing the same thing three times for three different packages. JC – We can’t commit to one particular ML package, but instead need to find a package-independent API that can easily be extended to implementations in different packages YTZ – What you actually want to get at is the underlying VJP representation in these different libraries. JC – I’m not sure we want to provide that for several different libraries. What YTZ needs is a way to take a snapshot of coordinates and a FF parameter vector and gives potential energies. Is this fundamentally the thing you need (Given positions and fundamental parameters, return PE) YTZ – It’s not just the dU/dQ that I need – There’s more stuff upstream. JC – I think we’re looking on the wrong level YTZ – It’s not the responsibility of the openFF consortium to think through the details everyone’s downstream implementation. What’s the scope of the conversation – Is it to enumerate use cases, or to understand how I’m doing differentiable FF fitting. JC – Thee use cases: Change parameters, manually see how energy changes Output to various formats Expose flat vector of FF parameters in a manner in which people can optimize using automatic differentiation.
JC – We’re talking about the third. YTZ – So you’re talking about implementing the energy functions themselves? JC – We could provide a factory for this. But I wonder if there’s a simple way to approach the API that will be more easily extensible. YTZ – Let’s talk about the details of writing the jax-based version of the energy function. It’s easy for valence terms. The difficulty comes from the nonbonded part. I still can’t write a economical nonbonded kernel in jax when there are more than 1k particles. This is because of complications regarding neighbor list and PME computations. JC – None of our training set is larger than 100 atoms. YTZ – We need to compute derivatives for fully solvated systems. So we have some bespoke CUDA code to do the nonbonded part. JC – If we provided a factory “to_jax_potential”, could that be used as part of your calculation? One option for us is to use timemachine under the hood for that. Then we could put that whole thing in the public API. YTZ – That would restrict you to only functional forms supported by timemachine. JC – This would be sufficient for your and our current needs. Though we understand that this would encounter trouble if we tried to use condensed-phase. YTZ – This could make timemachine a dependency of that code JC – It could be an optional dependency; we already handle those. YTZ – Our code may be too unstable. This would also put you on the hook for implementing all the functional forms for different bonds/angles/torsions. That seems like a large scope increase. JW – I’m curious as to whether it is sufficient to provide the arrays presented in the earlier notebook in jax format, and a way to ingest an array of the same shape and turn in back into a FF at the end of the optimization. YTZ – Could you clarify use of “factory”? JC – This would be a funciton that returns an object or jax-based method/function, which could be called as myfunc(postions, ff_params) and returns energies. YTZ – I was thinking that we were tring to determine how to go from FF parameters and have some machinery to backpropagate those arbitrarily.
Scope: OpenFF provides flat arrays and methods that can plug in cleanly to timamachine evaluations OR openFF provides a jax-based methodenergy(postions, ff_params) and returns energies, solving the nonbonded problem internally or using timemachine on the backend.
MT – Are there existing implementations of the second option? JC – JaxMD and timemachine. MT – This raises the question of whether we should take ownership of that feature. Thinking ahead, as new popular ML libraries come up, we could host implementations of this in different packages. JC – We’ll eventually need automatically differentiable energies to replace forcebalance regardless. YTZ – We have a pipeline that works, it’s not perfect. My question is “at what point do we go back to the newest version of the openFF toolkits?”. If OpenFF can provide a differentiable version of the parameters, that lets us largely merge back in. The bigger question of how to replace forcebalance is a bigger question and will require more people. JC – Agreed, we’re currently ensuring that we implement this in a future-compatible way. So while we don’t have the manpower now, we want their jobs to be easier. MT – Let’s focus on “will there need to be huge changes to the core object model to satisfy use of ML libraries?” JC – I think the cirtiical quesiton on that is “are parameter assignments differentiable?” YTZ – there’s some duck-typing that happens if parameter assignments are passed in via jax-numpy arrays. YTZ – To MT’s questions – When a future person tries to implement U, they’ll be implementing it as a fxn of system parameters. So they’ll parameterize the system and… MT – Good. I interpret that to mean “It won’t require changes to core object model”.
JC – Let’s focus on “what do we need to provide via the API?” YTZ – The main thing we need to merge back with OpenFF toolkit is to have the .paramterize methods that MT has been showing are differentiable through VJPs. JC – In that workflow, which part needs to provide jax arrays? YTZ – Currently parameters aren’t a jax array (or even a numpy array) JW – Would it be sufficient to provide a JC – Should Matt, Jeff, and Josh Fass meet on this, since Josh has a foot in each world? YTZ – JF isn’t currently too experienced with the VJP code. In a month he may be more up to speed. JC – Ok, it’ll be hard to move forward on this until we
YTZ – Big requirement is that timemachine needs to be able to see
|
| To ensure that YTZ can stop maintaining a fork of the OpenFF toolkit using the functionality that will soon be available, OpenFF will test their implementation of the System object by implementing it in the place of the handers in
, and once that is complete, will schedule another meeting with YTZ. |