Ninjax#
General Modules for JAX
Basics#
- class nj.Module(*args, name=None, **kwargs)[source]#
Base class for users to inherit their modules from. Provides automatic name scoping via the meta class and helper functions for accessing state.
- property path#
The unique name scope of this module instance as a string.
- property name#
The name of this module instance as a string.
Transforms#
- nj.grad(fun, keys, has_aux=False)[source]#
Compute the gradient of an impure function with respect to the specified state entries or modules. The transformed function returns a tuple containing the computed value, selected state entries, their gradients, and if applicable auxiliary outputs of the function.
Control Flow#
Advanced#
- nj.context()[source]#
Access and modify the global context from within an impure function. For advanced users only. Prefer to use module methods to access and modify state and rng() to get the next RNG key.
Integrations#
Reference#
|
Base class for users to inherit their modules from. |
|
|
|
Wrap an impure function that uses global state to explicitly pass the state in and out. |
|
Split the global RNG key and return a new local key. |
|
Compute the gradient of an impure function with respect to the specified state entries or modules. |
|
Compiles a pure function for fast execution. |
|
Compiles n pure function for fast execution across multiple devices. |
|
|
|
|
|
Access and modify the global context from within an impure function. |
|
Indicates whether the program is currently allowed to create state entries. |
|
Enter a relative or absolute name scope. |
|
|
|
|
|