Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Chebyshev class and functions #11093

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gderossi
Copy link

@gderossi gderossi commented Jun 14, 2022

Adds a Chebyshev convenience class and cheb* functions that act on arrays, as described in #11055.

@jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Jun 14, 2022

Hi - thanks for the PR! In the past when we've discussed the polynomial interface, we've made the choice to avoid adding polynomial class types in favor of the more direct functional API (e.g. #70 (comment)). One concern I have here is that when you start to define class-based interfaces, you have to be much more careful to ensure that they're compatible with JAX's transformation model.

Is there any functionality here that cannot be expressed directly via chebyshev functions, as opposed to a full chebyshev class?

@gderossi
Copy link
Author

@gderossi gderossi commented Jun 14, 2022

I believe the Chebyshev functions capture all of the functionality, so getting rid of the full class shouldn't cause any problems and would honestly simplify the interface. Would you like me to remove it?

@jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Jun 14, 2022

I would lean toward keeping only the simplest functional APIs in JAX, partly because it makes long-term maintenance more sustainable (the class definition in this PR is a gigantic API surface to test and maintain). But I'm happy to let other team members weigh-in if they disagree.

Remove Chebyshev class

Implement Chebyshev functions
@gderossi
Copy link
Author

@gderossi gderossi commented Jun 14, 2022

Alright, I removed the Chebyshev class and its abstract base class to streamline the API. If it's decided later that they are worth including, I have local copies and can add them back in.

Copy link
Collaborator

@jakevdp jakevdp left a comment

OK, a bunch of comments below. First-off, there is a lot in here that is incompatible or inefficent with JAX's computation model (for example, loop-based implementations, object arrays, value-dependent shapes)

Before moving forward with this, we should probably think about whether those pieces of the implementation are important; if so, I'd probably lean toward not including these APIs in JAX, because the user experience will not be good (incompatibility with JAX transforms is not great for JAX library code).

On the other hand, if the APIs can be changed in a way that they will become compatible with JAX, then we should think about making those changes before diving-in to a review at the level of granularity of my comments below.

What do you think?

i = 0
j = dlen

while i < j:
Copy link
Collaborator

@jakevdp jakevdp Jun 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long are typical series? If it's more than a few, we should avoid using the Python while loop here because we could end up with very long compile times.

jax/_src/numpy/polynomials/chebyshev.py Show resolved Hide resolved
c0 = jnp.asarray(c[-2])
c1 = jnp.asarray(c[-1])

for i in range(n-1, 1, -1):
Copy link
Collaborator

@jakevdp jakevdp Jun 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here about Python control flow when the size gets large. We may want to use lax.fori_loop instead unless n is bounded by a small number.

jax/_src/numpy/polynomials/chebyshev.py Show resolved Hide resolved
[pol] = pu.as_series([pol])
deg = len(pol) - 1
res = 0
for i in range(deg, -1, -1):
Copy link
Collaborator

@jakevdp jakevdp Jun 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here: consider fori_loop

jax/_src/numpy/polynomials/chebyshev.py Show resolved Hide resolved
jax/_src/numpy/polynomials/chebyshev.py Show resolved Hide resolved
jax/_src/numpy/polynomials/polyutils.py Show resolved Hide resolved
jax/_src/numpy/polynomials/polyutils.py Show resolved Hide resolved
jax/_src/numpy/polynomials/polyutils.py Show resolved Hide resolved
@gderossi
Copy link
Author

@gderossi gderossi commented Jun 16, 2022

Alright, I've cleaned up a lot of the code following your suggestions and to better match the existing polynomial functions in JAX. I've also been thinking more about JAX compatibility, and I wanted to ask a couple questions. The object arrays and value-based shapes were unnecessary, but loop implementations are necessary (or at least the natural way to implement many functions) because of the relationships between terms in a Chebyshev series.

So, what do you consider "large" for a series or loop parameter? If a typical series is large, would replacing Python control flow with JAX control flow primitives resolve the incompatibility/inefficiency issues? If so, I will make those changes and then push everything.

@jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Jun 16, 2022

The issue with Python loops is that they are flattened by JAX tracing, and then passed in an unrolled state to the XLA compiler. The mechanics of the compiler are complicated, but a good rule of thumb is that compilation time will be roughly quadratic in the number of operations. So if you write a Python for loop over 100 elements, then compilation will be roughly 10000 times more costly than the equivalent fori_loop implementation. There's not really any hard line regarding what size the loops should be, but if you expect the functions to operate on larger inputs, it's something to keep in mind. Does that make sense?

@gderossi
Copy link
Author

@gderossi gderossi commented Jun 16, 2022

Yes, that's very helpful, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants