Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scipy.stats.mode #8978

Open
gianlucadetommaso opened this issue Dec 16, 2021 · 11 comments
Open

scipy.stats.mode #8978

gianlucadetommaso opened this issue Dec 16, 2021 · 11 comments
Labels
contributions welcome enhancement good first issue

Comments

@gianlucadetommaso
Copy link

@gianlucadetommaso gianlucadetommaso commented Dec 16, 2021

It would be great if you could add the JAX equivalent of scipy.stats.mode, which is currently unavailable.

A use case may be a ML classification task with ensembles, where multiple models give different predictions, and we are interested in finding the most common one.

As an example, consider predictions to be a two-dimensional DeviceArray of predictions, with
shape = (number of models, number of data points).

It would be nice if one could compute

mode_prediction = jax.scipy.stats.mode(predictions, axis=0),

returning the most common prediction for every data point.

@gianlucadetommaso gianlucadetommaso added the enhancement label Dec 16, 2021
@jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Dec 16, 2021

Here's a quick jit-compatible implementation if you want to make use of it (not thoroughly tested):

import jax.numpy as jnp
from jax import vmap

def mode(x, axis=None):
  def _mode(x):
    vals, counts = jnp.unique(x, return_counts=True, size=x.size)
    return vals[jnp.argmax(counts)]
  if axis is None:
    return _mode(x)
  else:
    x = jnp.moveaxis(x, axis, 0)
    return vmap(_mode, in_axes=(1,))(x.reshape(x.shape[0], -1)).reshape(x.shape[1:])

x = jnp.array([1, 1, 2, 2, 2, 3, 4])
print(mode(x))
# 2

y = jnp.array([[1, 1, 2, 3],
               [4, 5, 6, 5]])
print(mode(y, axis=1))
# [1, 5]

@gianlucadetommaso
Copy link
Author

@gianlucadetommaso gianlucadetommaso commented Dec 16, 2021

Great, thanks!

@jakevdp jakevdp added contributions welcome good first issue labels Dec 16, 2021
@jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Dec 16, 2021

It's definitely in scope to add this to jax.scipy.stats. If anyone wants to tackle it, you could take the example implementation from above, modify it to make sure it returns the same outputs/shapes as scipy.stats (I think scipy.stats.mode returns the counts as well), and add an appropriate test in tests/scipy_stats_test.py

@jayendra13
Copy link

@jayendra13 jayendra13 commented Dec 24, 2021

Let me take this up, can anyone suggest a good location/file for this function to put ?
For Scipy it's defined here and here.

I am starting this for now by putting it inside the init.py

@Yash621
Copy link

@Yash621 Yash621 commented Jan 3, 2022

@gianlucadetommaso can i work on this ?

@gianlucadetommaso
Copy link
Author

@gianlucadetommaso gianlucadetommaso commented Jan 3, 2022

@Yash621 please go ahead, I haven't had much time to take this up myself.

@Yash621
Copy link

@Yash621 Yash621 commented Jan 4, 2022

@gianlucadetommaso can you suggest the location where i should place this function?

@jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Jan 4, 2022

You could put the implementation in a new submodule, maybe something like jax/_src/scipy/stats/_core.py

@elusenji
Copy link

@elusenji elusenji commented Apr 17, 2022

I'd love to jump in and collaborate with @Yash621 on this, if that's okay.

@parthpatel9414
Copy link

@parthpatel9414 parthpatel9414 commented Apr 17, 2022

Would love to collaborate with @Yash621 and @elusenji, if this issue is still open?

@jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Apr 17, 2022

I think the PR in #9141 is quite close - it just needs a bit more work before it can be merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributions welcome enhancement good first issue
Projects
None yet
Development

No branches or pull requests

6 participants