New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Chebyshev class and functions #11093
base: main
Are you sure you want to change the base?
Conversation
|
Hi - thanks for the PR! In the past when we've discussed the polynomial interface, we've made the choice to avoid adding polynomial class types in favor of the more direct functional API (e.g. #70 (comment)). One concern I have here is that when you start to define class-based interfaces, you have to be much more careful to ensure that they're compatible with JAX's transformation model. Is there any functionality here that cannot be expressed directly via chebyshev functions, as opposed to a full chebyshev class? |
|
I believe the Chebyshev functions capture all of the functionality, so getting rid of the full class shouldn't cause any problems and would honestly simplify the interface. Would you like me to remove it? |
|
I would lean toward keeping only the simplest functional APIs in JAX, partly because it makes long-term maintenance more sustainable (the class definition in this PR is a gigantic API surface to test and maintain). But I'm happy to let other team members weigh-in if they disagree. |
Remove Chebyshev class Implement Chebyshev functions
|
Alright, I removed the Chebyshev class and its abstract base class to streamline the API. If it's decided later that they are worth including, I have local copies and can add them back in. |
OK, a bunch of comments below. First-off, there is a lot in here that is incompatible or inefficent with JAX's computation model (for example, loop-based implementations, object arrays, value-dependent shapes)
Before moving forward with this, we should probably think about whether those pieces of the implementation are important; if so, I'd probably lean toward not including these APIs in JAX, because the user experience will not be good (incompatibility with JAX transforms is not great for JAX library code).
On the other hand, if the APIs can be changed in a way that they will become compatible with JAX, then we should think about making those changes before diving-in to a review at the level of granularity of my comments below.
What do you think?
| i = 0 | ||
| j = dlen | ||
|
|
||
| while i < j: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How long are typical series? If it's more than a few, we should avoid using the Python while loop here because we could end up with very long compile times.
| c0 = jnp.asarray(c[-2]) | ||
| c1 = jnp.asarray(c[-1]) | ||
|
|
||
| for i in range(n-1, 1, -1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here about Python control flow when the size gets large. We may want to use lax.fori_loop instead unless n is bounded by a small number.
| [pol] = pu.as_series([pol]) | ||
| deg = len(pol) - 1 | ||
| res = 0 | ||
| for i in range(deg, -1, -1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also here: consider fori_loop
|
Alright, I've cleaned up a lot of the code following your suggestions and to better match the existing polynomial functions in JAX. I've also been thinking more about JAX compatibility, and I wanted to ask a couple questions. The object arrays and value-based shapes were unnecessary, but loop implementations are necessary (or at least the natural way to implement many functions) because of the relationships between terms in a Chebyshev series. So, what do you consider "large" for a series or loop parameter? If a typical series is large, would replacing Python control flow with JAX control flow primitives resolve the incompatibility/inefficiency issues? If so, I will make those changes and then push everything. |
|
The issue with Python loops is that they are flattened by JAX tracing, and then passed in an unrolled state to the XLA compiler. The mechanics of the compiler are complicated, but a good rule of thumb is that compilation time will be roughly quadratic in the number of operations. So if you write a Python |
|
Yes, that's very helpful, thank you! |
Adds a Chebyshev convenience class and cheb* functions that act on arrays, as described in #11055.