Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

> JAX can automatically differentiate native Python ... functions

sympy can differentiate functions, but they have to be set up properly. How can JAX differentiate native functions?

(Or do they mean numerical differentiation, like a finite difference estimation?)



Automatic Differentiation is an algorithm to efficiently compute the value of the derivative of a function implemented by some arbitrary code. It does not use numerical approximation; it combines algorithmic cleverness with a table of analytic derivatives for elementary functions. Despite reliance on analytic derivatives, it does not compute the analytic derivative of the function–it just computes the _value_ of the derivative for some particular input.

In order to differentiate native functions you have to accept a fairly loose definition of "derivative" that basically assumes you can ignore discontinuities in functions. For example, in the function below we can say that the derivative is piecewise continuous with a value of 0 for all x < 0, and value 1 for all x >= 0. AD extends this idea to "each output produced by a function follows a single path, so the corresponding derivative follows the same path".

``` def f(x): if x >= 0: return x else: return 0 ```

AD cannot magically determine the derivative of `lambda x: exp(x)` or similar–it needs a lookup table for elementary functions, for example: https://github.com/HIPS/autograd/blob/304552b1ba42e42bce97f0... However, AD does support differentiating through program flow control including function calls, loops, conditionals, etc. subject to the caveats above, which is much more difficult to do analytically.

The idea has now been generalized to all linear operators in functional programming https://www.youtube.com/watch?v=MmkNSsGAZhw&feature=youtu.be and the associated paper https://dl.acm.org/citation.cfm?doid=3243631.3236765


After checking the wiki link: Does it internally construct a kind of AST for the symbolic derivative, but only returns particular values, not the derivative itself? i.e. Since the new function returns derivative values, it seems it must itself be the derivative...

Looking at your github link, multiplication (product rule) doesn't seem to be handled there (only `def_linear(anp.multiply)`).

Would an implementation be something really straightforward, like:

    *(f(x),g(x)): f(x)*g'(x) + f'(x)*g(x)
i.e. it internally constructs an AST of the derivative, but only returns results at specific points, not the AST itself.

(Actually, the wiki eg for Forward Accumulation (https://wikipedia.org/wiki/Automatic_differentiation#Forward...) does include a product differentiated in this way, so I guess I got it right).


> multiplication (product rule) doesn't seem to be handled there (only `def_linear(anp.multiply)`)

def_linear is what handles the product rule. Other product operations (like anp.cross, anp.inner etc.) are implemented the same way. It's called "linear" because products are multi-linear functions, i.e. they are linear in each individual argument and you can get the derivative with respect to each parameter by simple substitution. (x+Δx)y - xy = Δxy. Together with the chain-rule for multi-parameter functions, the classic product rule falls out for free.


Neither, they mean automatic differentiation [1], which is not symbolic like SymPy or Mathematica, but also is not finite differences.

Often, the interface for doing this is to just extend an existing numerical or array type and overload all the arithmetic operations to keep track of the gradients. Then ordinary code for numerical computations will "just work". That's basically what they've done here, except with a sophisticated compiler.

1: https://en.wikipedia.org/wiki/Automatic_differentiation


How's AD different from simply calling the function with a slightly perturbed input and observing how much the output changes? I assume it's more efficient with multivariable functions because the aforementioned method requires one call per parameter?


It is indeed much more efficient. In general, you can evaluate a scalar function and its gradient with less than twice the effort required to compute the scalar function alone -- regardless of the number of parameters.


> How's AD different from simply calling the function with a slightly perturbed input and observing how much the output changes?

AD gives a more precise result, as AD calculates the value of the derivative at the point you want without any perturbed input. AD is also faster, as calculating the value of the derivative usually requires about the same number of elementary operations as calculating the value of the function. Whereas with finite difference you need to calculate the function twice.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: