New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scipy.stats.mode #8978
Comments
|
Here's a quick jit-compatible implementation if you want to make use of it (not thoroughly tested): import jax.numpy as jnp
from jax import vmap
def mode(x, axis=None):
def _mode(x):
vals, counts = jnp.unique(x, return_counts=True, size=x.size)
return vals[jnp.argmax(counts)]
if axis is None:
return _mode(x)
else:
x = jnp.moveaxis(x, axis, 0)
return vmap(_mode, in_axes=(1,))(x.reshape(x.shape[0], -1)).reshape(x.shape[1:])
x = jnp.array([1, 1, 2, 2, 2, 3, 4])
print(mode(x))
# 2
y = jnp.array([[1, 1, 2, 3],
[4, 5, 6, 5]])
print(mode(y, axis=1))
# [1, 5] |
|
Great, thanks! |
|
It's definitely in scope to add this to |
|
@gianlucadetommaso can i work on this ? |
|
@Yash621 please go ahead, I haven't had much time to take this up myself. |
|
@gianlucadetommaso can you suggest the location where i should place this function? |
|
You could put the implementation in a new submodule, maybe something like |
|
I'd love to jump in and collaborate with @Yash621 on this, if that's okay. |
|
I think the PR in #9141 is quite close - it just needs a bit more work before it can be merged. |
It would be great if you could add the JAX equivalent of
scipy.stats.mode, which is currently unavailable.A use case may be a ML classification task with ensembles, where multiple models give different predictions, and we are interested in finding the most common one.
As an example, consider
predictionsto be a two-dimensional DeviceArray of predictions, withshape = (number of models, number of data points).
It would be nice if one could compute
mode_prediction = jax.scipy.stats.mode(predictions, axis=0),returning the most common prediction for every data point.
The text was updated successfully, but these errors were encountered: