I have a numpy array that has dimensions (x, y, z) = (5, 50, 4). For every (x, y) pair, I want to find the index of the maximum value along the z axis. This index is in range(4). I want to select all of these "maximum" elements and set them to 1. Then, I want to select all other elements and set them to zero.
To explain it another way, I want to look at all vectors in the "z" direction (there are x*y of these vectors total). I want to set the maximum element to 1 and all other elements to 0. For example, the vector (0.25, 0.1, 0.5, 0.15) will become (0, 0, 1, 0).
I've tried many different things. The argmax function seems like it should help. But how do I use it to select elements correctly? I have tried things like...
x = data
i = x.argmax(axis = 2)
x[i] # shape = (5, 50, 50, 4)
x[:,:,i] # shape = (5, 50, 5, 50)
x[np.unravel_index(i), x.shape] # shape = (5, 50)
The last one, which uses np.unravel_index, has the correct shape, but the selected indices are NOT the maximum values along the z axis. So I'm having some trouble. If anyone could help at all, it would be really awesome. Thanks!
Edit: Here is a way I have found to do this. But if anyone has anything that is faster, please let me know!
def fix_vector(a):
i = a.argmax()
a = a*0
a[i] = 1
return a
y = np.apply_along_axis(fix_vector, axis=2, arr=x)
I would really like to optimize this if possible, since I call this function MANY times.
Edit: Thanks DSM for a nice solution. Here is a small example dataset, as requested in the comments.
data = np.random.random((3,5,4))
desired_output = np.apply_along_axis(fix_vector, axis=2, arr=data)
This uses the fix_vector function I posted above, but DSM's solution is faster. Thanks again!