Ninjax#

General Modules for JAX

Basics#

class nj.Module(*args, name=None, **kwargs)[source]#

Base class for users to inherit their modules from. Provides automatic name scoping via the meta class and helper functions for accessing state.

property path#

The unique name scope of this module instance as a string.

property name#

The name of this module instance as a string.

get(name, *args, **kwargs)[source]#

Retrieve or create a state entry that belongs to this module.

put(name, value)[source]#

Update or create a single state entry that belongs to this module.

getm(pattern='.*', allow_empty=False)[source]#

Read the state entries of this module, optionally filtered by regex.

putm(mapping)[source]#

Update or create multiple state entries that belong to this module.

class nj.Variable(*args, name=None, **kwargs)[source]#
read()[source]#
write(value)[source]#
nj.pure(fun, nested=False)[source]#

Wrap an impure function that uses global state to explicitly pass the state in and out. The result is a pure function that is composable with JAX transformation. The pure function can be used as follows: out, state = fun(state, rng, *args, **kwargs).

nj.rng(amount=None, reserve=16)[source]#

Split the global RNG key and return a new local key.

Transforms#

nj.grad(fun, keys, has_aux=False)[source]#

Compute the gradient of an impure function with respect to the specified state entries or modules. The transformed function returns a tuple containing the computed value, selected state entries, their gradients, and if applicable auxiliary outputs of the function.

nj.jit(fun, static=(), donate=(), **jit_kwargs)[source]#

Compiles a pure function for fast execution. Only the first call of the function is allowed to create state entries.

nj.pmap(fun, axis_name=None, static=(), donate=(), **pmap_kwargs)[source]#

Compiles n pure function for fast execution across multiple devices. Only the first call of the function is allowed to create state entries.

Control Flow#

nj.cond(pred, true_fun, false_fun, *operands)[source]#
nj.scan(fun, carry, xs, reverse=False, unroll=1, modify=False)[source]#

Advanced#

nj.context()[source]#

Access and modify the global context from within an impure function. For advanced users only. Prefer to use module methods to access and modify state and rng() to get the next RNG key.

nj.creating()[source]#

Indicates whether the program is currently allowed to create state entries. Can use used for initialization logic that should be excluded from compiled functions.

nj.scope(name, absolute=False)[source]#

Enter a relative or absolute name scope. Name scopes are used to make names of state entries unique.

Integrations#

class nj.HaikuModule(*args, name=None, **kwargs)[source]#
class nj.FlaxModule(*args, name=None, **kwargs)[source]#
class nj.OptaxModule(*args, name=None, **kwargs)[source]#

Reference#

Module(*args[, name])

Base class for users to inherit their modules from.

Variable(*args[, name])

pure(fun[, nested])

Wrap an impure function that uses global state to explicitly pass the state in and out.

rng([amount, reserve])

Split the global RNG key and return a new local key.

grad(fun, keys[, has_aux])

Compute the gradient of an impure function with respect to the specified state entries or modules.

jit(fun[, static, donate])

Compiles a pure function for fast execution.

pmap(fun[, axis_name, static, donate])

Compiles n pure function for fast execution across multiple devices.

cond(pred, true_fun, false_fun, *operands)

scan(fun, carry, xs[, reverse, unroll, modify])

context()

Access and modify the global context from within an impure function.

creating()

Indicates whether the program is currently allowed to create state entries.

scope(name[, absolute])

Enter a relative or absolute name scope.

HaikuModule(*args[, name])

FlaxModule(*args[, name])

OptaxModule(*args[, name])