Create your own GitHub profile
Sign up for your own profile on GitHub, the best place to host code, manage projects, and build software alongside 50 million developers.
Sign upHighlights
- Arctic Code Vault Contributor
Pinned
1,590 contributions in the last year
Contribution activity
July 2020
Created a pull request in google/jax that received 8 comments
Properly set X64 flag in github actions
Trying to confirm my suspicions that our github CI is not actually running any x64 tests... Edit: turns out it isn't (see discussion below). The fi…
+4
−4
•
8
comments
- WIP: exploring weakly-dtyped device arrays
- Improve dtype & test coverage for jnp.fmod
- Cleanup: pass function name rather than function object
- Add jnp.modf(); improve dtype & test coverage for related functions
- Cleanup: canonicalize several dtypes to prevent noisy warnings
- Improve searchsorted implementation
- jnp.linspace & friends: more carefully handle dtypes
- Cleanup test names in random_test.py
- jax.random: use correct x32/x64 default dtypes.
- Tweak Dockerfile to prevent build failure and add TODO
- Improve test coverage for jax.numpy sorting algorithms
- update README for new jaxlib version
- update jaxlib version and changelog for pypi
- update version and changelog for pypi
- Cleanup: fix type issues in lax_numpy.py
- implement jax.numpy.lexsort
- Fix type mismatch in jet rule for abs
- [x64 deprecation] Create _np_array utility routine
- WIP: x64: make Jax default dtypes 32-bit
- Change onp/np to np/jnp in docs & notebooks
- Cleanup: convert uses of 'import numpy as onp' in tests
- Cleanup: convert uses of `import numpy as onp` in library code
- lax_numpy: rename arguments to match numpy
- Fix compilation bug in histogram_bin_edges
- Remove unused private helper function
- Some pull requests not shown.
- Add jnp.modf(); improve dtype & test coverage for related functions
- refine population_count type check
- Replicating sort_complex functionality from np.sort_complex to jax.numpy
- Add erfcx
- jax.random: use correct x32/x64 default dtypes.
- Enable int{8,16} and uint{8,16} tests in lax_test and lax_numpy_test.
- DeviceArray.__iter__ returns DeviceArrays, without host sync
- Address nan issue in zeta
- tweak jnp.repeat not to use jnp on shapes
- Fix `jax.image._resize` function
- add clarification about jit inside indexing error message
- Error message and docstring updates RE: dynamic_slice
- Note in pmap docs that pmap compiles like jit.
- Don't move batch dimensions to start in jnp.einsum.
- Relax test tolerance for conv_general_dilated gradients.
- Use keyword arguments in einsum.
- Update scipy.ndimage.map_coordinates docstring
- Fix bug in jnp.repeat where for array repeats with zero entries (e.g. [1,0,3]) and axis=0 the zero entry was not skipped).
- Fix low probability (4/1000) flake of batching_test on GPU.
- A compilable version of np.repeat.
- Make jnp.take work for empty slices of empty arrays.
- Adding np.sort_complex within lax_numpy.py
- Add pshuffle to docs
- Improve np.intersect1d
- Implement complex convolutions on CPU and GPU.
- Some pull request reviews not shown.
Created an issue in google/jax that received 2 comments
lax silently zeros large integers outside x64 mode
Repro, with jax_enable_x64=False:
>>> from jax import lax
>>> lax.neg(1 << 32)
DeviceArray(0, dtype=int32)
If the input is too large for int64, it …
2
comments