Highlights
- Arctic Code Vault Contributor
Create your own GitHub profile
Sign up for your own profile on GitHub, the best place to host code, manage projects, and build software alongside 50 million developers.
Sign up
Pinned
502 contributions in the last year
Contribution activity
September 2020
Created a pull request in google/jax that received 11 comments
Add a prototype implementation of recursive checkpointing
Example: import jax import jax.numpy as jnp from jax.experimental.checkpoint import checkpoint_recursive @checkpoint_recursive def f(x): for i in r…
+101
−0
•
11
comments
- Support %indexRef with dependent references
- Improve index rematerialization capabilities in AD
- Make it possible to `make install` Dex.
- Improve defunctionalization of case statements
- Stop capturing loop indices when reifying the tangent function
- Differentiation through inactive cases + misc improvements in AD
- Try using GitHub Actions for CI instead of Travis
- Defunctionalize through case expressions
- Another bunch of AD improvements
- Don't linearize inactive expressions
- Flesh out transposition
- Flesh out linearization
- Expand the C API and Python bindings
- A batch of AD fixes and improvements
- A prototype of Python bindings
- Make it possible to use ptxas instead of the CUDA driver for SASS compilation
- Improve DCE and inliner
- Inline table builders (`for`s) with a single use site
- Add support for `all_to_all` in vmap
- Consider lists as groups of axis names too
- Fix the abstract eval and translation rule for all_to_all
- Fix axis_index inside nested pmaps
- Delete dead axis_index code
- Add back the batching rule for ppermute
- Extend axis env while translating the pmapped jaxpr to XLA
- Interrupt lu transformation generators whenever an exception occurs
- Add more context to the axis_frame error message.