import contextlib
import functools
import inspect
import re
import threading
from functools import partial as bind
import jax
import jax.numpy as jnp
__version__ = '1.2.1'
###############################################################################
# State
###############################################################################
# When running an impure function that accesses state, it will find the state
# in this global variable. The pure() wrapper populates this global variable
# with the provided state, calls the inner function, and then the takes the
# resulting state out of the global variable to return it back to the user.
# To allow multi-threaded programs to use impure functions in parallel, the
# context is a dictionary with a slot for each thread identifier.
CONTEXT = {}
class Context(dict):
def __init__(self, entries, rng, create, modify, ignore, reserve, name):
super().__init__(entries)
self.create = create # Allow creating new state entries.
self.modify = modify # Allow modifying existing state entries.
self.ignore = ignore # Ignore modifications to existing state entries.
self.rng = rng
self.reserve = reserve
self.name = name
def update(self, entries):
for key, value in dict(entries).items():
self[key] = value
def __setitem__(self, key, value):
if self.ignore and key in self:
return # Do not overwrite existing entries.
if not self.create and key not in self:
raise RuntimeError(
'Can only create state entries during first call. ' +
f'You were trying to set {key} to shape {value.shape} and ' +
f'dtype {value.dtype}.')
if not self.modify:
raise RuntimeError(
'Cannot modify state entries here. If you want to modify '
'state inside of scan() set modify=True. ' +
f'You were trying to set {key} to shape {value.shape} and ' +
f'dtype {value.dtype}.')
super().__setitem__(key, value)
[docs]def pure(fun, nested=False):
"""Wrap an impure function that uses global state to explicitly pass the
state in and out. The result is a pure function that is composable with JAX
transformation. The pure function can be used as follows:
`out, state = fun(state, rng, *args, **kwargs)`."""
@functools.wraps(fun)
def purified(
state, rng, *args, create=None, modify=None, ignore=None, **kwargs):
context = CONTEXT.get(threading.get_ident(), None)
if context:
create = create if create is not None else context.create
modify = modify if modify is not None else context.modify
ignore = ignore if ignore is not None else context.ignore
assert context.create or not create, 'Parent context disabled create.'
assert context.modify or not modify, 'Parent context disabled modify.'
assert not context.ignore or ignore, 'Parent context enabled ignore.'
else:
create = create if create is not None else True
modify = modify if modify is not None else True
ignore = ignore if ignore is not None else False
if not isinstance(state, dict):
raise ValueError('Must provide a dict as state.')
if context and (not nested):
raise RuntimeError(
f'You are trying to call pure {fun.__name__}() inside pure '
f'{context.name}(). Is that intentional? If you want to nest pure '
f'functions, use pure(..., nested=True) for the inner function.')
before = context
try:
name = fun.__name__
context = Context(state.copy(), rng, create, modify, ignore, [], name)
CONTEXT[threading.get_ident()] = context
out = fun(*args, **kwargs)
state = dict(context)
return out, state
finally:
CONTEXT[threading.get_ident()] = before
return purified
[docs]def context():
"""Access and modify the global context from within an impure function. For
advanced users only. Prefer to use module methods to access and modify state
and rng() to get the next RNG key."""
context = CONTEXT.get(threading.get_ident(), None)
if context is None:
raise RuntimeError('Wrap impure functions in pure() before running them.')
return context
[docs]@jax.named_scope('rng')
def rng(amount=None, reserve=16):
"""Split the global RNG key and return a new local key."""
ctx = context()
if amount:
keys = jax.random.split(ctx.rng, amount + 1)
ctx.rng = keys[0]
return keys[1:]
else:
if not ctx.reserve:
keys = jax.random.split(ctx.rng, reserve)
ctx.rng = keys[0]
ctx.reserve = list(keys[1:])
return ctx.reserve.pop(0)
[docs]def creating():
"""Indicates whether the program is currently allowed to create state
entries. Can use used for initialization logic that should be excluded from
compiled functions."""
return context().create
###############################################################################
# Transformations
###############################################################################
[docs]@jax.named_scope('grad')
def grad(fun, keys, has_aux=False):
"""Compute the gradient of an impure function with respect to the specified
state entries or modules. The transformed function returns a tuple containing
the computed value, selected state entries, their gradients, and if
applicable auxiliary outputs of the function."""
keys = keys if hasattr(keys, '__len__') else (keys,)
if not has_aux:
fun = lambda *args, _fun=fun, **kwargs: (_fun(*args, *kwargs), {})
fun = pure(fun, nested=True)
def forward(x1, x2, rng, *args, **kwargs):
(y, aux), state = fun({**x1, **x2}, rng, *args, create=False, **kwargs)
return y, (aux, state)
backward = jax.value_and_grad(forward, has_aux=True)
@functools.wraps(backward)
def wrapper(*args, **kwargs):
_prerun(fun, *args, **kwargs)
assert all(isinstance(x, (str, Module)) for x in keys)
strs = [x for x in keys if isinstance(x, str)]
mods = [x for x in keys if isinstance(x, Module)]
for mod in mods:
strs += mod.getm()
x1 = {k: v for k, v in context().items() if k in strs}
x2 = {k: v for k, v in context().items() if k not in strs}
(y, (aux, state)), dx = backward(x1, x2, rng(), *args, **kwargs)
context().update(state)
return (y, x1, dx, aux) if has_aux else (y, x1, dx)
return wrapper
[docs]def jit(fun, static=(), donate=(), **jit_kwargs):
"""Compiles a pure function for fast execution. Only the first call of the
function is allowed to create state entries."""
jit_kwargs['static_argnums'] = [0]
jit_kwargs['donate_argnums'] = [1]
@bind(jax.jit, **jit_kwargs)
def _init(_static, _donate, *args, **kwargs):
_static = dict(_static)
_donate = {k: v for k, v in zip(donate, _donate)}
return fun({}, *args, ignore=True, **_static, **_donate, **kwargs)[1]
@bind(jax.jit, **jit_kwargs)
def _apply(_static, _donate, *args, **kwargs):
_static = dict(_static)
_donate = {k: v for k, v in zip(donate, _donate)}
return fun(*args, create=False, **_static, **_donate, **kwargs)
@functools.wraps(fun)
def wrapper(state, *args, init=True, apply=True, **kwargs):
_static = tuple((k, kwargs.pop(k)) for k in static if k in kwargs)
_donate = tuple(kwargs.pop(k) for k in donate)
if not hasattr(wrapper, 'keys'):
if init:
created = _init(_static, _donate, *args, **kwargs)
wrapper.keys = set(created.keys())
state = {**created, **state}
else:
wrapper.keys = set(state.keys())
if not apply:
return state
selected = {k: v for k, v in state.items() if k in wrapper.keys}
out, updated = _apply(_static, _donate, selected, *args, **kwargs)
return out, {**state, **updated}
return wrapper
[docs]def pmap(fun, axis_name=None, static=(), donate=(), **pmap_kwargs):
"""Compiles n pure function for fast execution across multiple devices. Only
the first call of the function is allowed to create state entries."""
pmap_kwargs['axis_name'] = axis_name
pmap_kwargs['static_broadcasted_argnums'] = [0]
pmap_kwargs['donate_argnums'] = [1]
@bind(jax.pmap, **pmap_kwargs)
def _init(_static, _donate, *args, **kwargs):
_static = dict(_static)
_donate = {k: v for k, v in zip(donate, _donate)}
return fun({}, *args, ignore=True, **_static, **_donate, **kwargs)[1]
@bind(jax.pmap, **pmap_kwargs)
def _apply(_static, _donate, *args, **kwargs):
_static = dict(_static)
_donate = {k: v for k, v in zip(donate, _donate)}
return fun(*args, create=False, **_static, **_donate, **kwargs)
@functools.wraps(fun)
def wrapper(state, *args, init=True, apply=True, **kwargs):
_static = tuple((k, kwargs.pop(k)) for k in static if k in kwargs)
_donate = tuple(kwargs.pop(k) for k in donate)
if not hasattr(wrapper, 'keys'):
if init:
created = _init(_static, _donate, *args, **kwargs)
wrapper.keys = set(created.keys())
state = {**created, **state}
else:
wrapper.keys = set(state.keys())
if not apply:
return state
selected = {k: v for k, v in state.items() if k in wrapper.keys}
out, updated = _apply(_static, _donate, selected, *args, **kwargs)
return out, {**state, **updated}
return wrapper
[docs]@jax.named_scope('cond')
def cond(pred, true_fun, false_fun, *operands):
true_fun = pure(true_fun, nested=True)
false_fun = pure(false_fun, nested=True)
_prerun(true_fun, *operands)
_prerun(false_fun, *operands)
out, state = jax.lax.cond(
pred,
lambda state, rng1, rng2, *args: true_fun(state, rng1, *args),
lambda state, rng1, rng2, *args: false_fun(state, rng2, *args),
dict(context()), *rng(2), *operands)
context().update(state)
return out
[docs]@jax.named_scope('scan')
def scan(fun, carry, xs, reverse=False, unroll=1, modify=False):
fun = pure(fun, nested=True)
_prerun(fun, carry, jax.tree_util.tree_map(lambda x: x[0], xs))
length = len(jax.tree_util.tree_leaves(xs)[0])
rngs = rng(length)
if modify:
def inner(carry, x):
carry, state = carry
x, rng = x
(carry, y), state = fun(state, rng, carry, x, create=False)
return (carry, state), y
(carry, state), ys = jax.lax.scan(
inner, (carry, dict(context())), (xs, rngs), length, reverse, unroll)
context().update(state)
else:
def inner(carry, x):
x, rng = x
(carry, y), state = fun(
dict(context()), rng, carry, x, create=False, modify=False)
return carry, y
carry, ys = jax.lax.scan(inner, carry, (xs, rngs), length, reverse, unroll)
return carry, ys
@jax.named_scope('_prerun')
def _prerun(fun, *args, **kwargs):
if not context().create:
return
discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)
context().update(state)
###############################################################################
# Modules
###############################################################################
SCOPE = ''
[docs]@contextlib.contextmanager
def scope(name, absolute=False):
"""Enter a relative or absolute name scope. Name scopes are used to make
names of state entries unique."""
global SCOPE
if SCOPE is None:
raise RuntimeError(
'Purify stateful functions with fn = pure(fn) before running them.')
outside = SCOPE
if absolute:
SCOPE = name
elif SCOPE == '':
SCOPE = name
else:
SCOPE = outside + '/' + name
try:
yield SCOPE
finally:
SCOPE = outside
class ModuleMeta(type):
"""Meta class that creates a unique path for each module instance and wraps
the methods and properties of the module to enter the name scope."""
def __new__(mcs, name, bases, clsdict):
"""This runs once per user module class definition. It wraps the methods of
the module class to automatically enter the name scope of the module."""
method_names = []
for key, value in clsdict.items():
if key.startswith('__') and key != '__call__':
continue
elif isinstance(value, property):
clsdict[key] = property(
value.fget if not value.fget else _scope_method(value.fget),
value.fset if not value.fset else _scope_method(value.fset),
value.fdel if not value.fdel else _scope_method(value.fdel),
doc=value.__doc__)
elif inspect.isfunction(value):
method_names.append(key)
cls = super(ModuleMeta, mcs).__new__(mcs, name, bases, clsdict)
for method_name in method_names:
method = getattr(cls, method_name)
method = _scope_method(method)
setattr(cls, method_name, method)
return cls
def __call__(cls, *args, name=None, **kwargs):
"""This runs once per use module instance creation. It derives a unique
name and path for the module instance."""
if not isinstance(name, str):
raise KeyError(
"Please provide a module name via Module(..., name='example').")
if not re.match(r'^[A-Za-z0-9_]+$', name):
raise KeyError(
'Only letters, numbers, and underscores are allowed in scope names.')
obj = cls.__new__(cls)
with scope(name) as path:
obj._path = path
obj._submodules = {}
init = _scope_method(cls.__init__)
init(obj, *args, **kwargs)
return obj
def _scope_method(method):
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
with scope(self._path, absolute=True):
with jax.named_scope(self._path.split('/')[-1]):
return method(self, *args, **kwargs)
return wrapper
[docs]class Module(object, metaclass=ModuleMeta):
"""Base class for users to inherit their modules from. Provides automatic
name scoping via the meta class and helper functions for accessing state."""
def __repr__(self):
return f'{self.__class__.__name__}({self.path})'
@property
def path(self):
"""The unique name scope of this module instance as a string."""
return self._path
@property
def name(self):
"""The name of this module instance as a string."""
return self._path.split('/')[-1]
[docs] def get(self, name, *args, **kwargs):
"""Retrieve or create a state entry that belongs to this module."""
assert '{' not in name, 'Did you forget to format a string?'
path = self.path + '/' + name
if name in self._submodules:
return self._submodules[name]
if path in context():
return context()[path]
ctor, *args = args
if 'name' in inspect.signature(ctor).parameters:
kwargs['name'] = name
value = ctor(*args, **kwargs)
flat, _ = jax.tree_util.tree_flatten(value)
if all(isinstance(x, jnp.ndarray) for x in flat):
context()[path] = value
else:
self._submodules[name] = value
return value
[docs] def put(self, name, value):
"""Update or create a single state entry that belongs to this module."""
self.putm({self.path + '/' + name: value})
return value
[docs] def getm(self, pattern=r'.*', allow_empty=False):
"""Read the state entries of this module, optionally filtered by regex."""
pattern = re.compile(pattern)
prefix = self.path + '/'
results = {}
for key, value in context().items():
if not key.startswith(prefix):
continue
if pattern.match(key[len(prefix):]):
results[key] = value
if not allow_empty and not results:
raise KeyError(f'Pattern {pattern} matched no state keys.')
return results
[docs] def putm(self, mapping):
"""Update or create multiple state entries that belong to this module."""
prefix = self.path + '/'
for key in mapping:
if not key.startswith(prefix):
raise KeyError(f'Key {key} does not belong to module {self.path}.')
context().update(mapping)
[docs]class Variable(Module):
def __init__(self, ctor, *args, **kwargs):
self.ctor = ctor
self.args = args
self.kwargs = kwargs
[docs] def read(self):
return self.get('value', self.ctor, *self.args, **self.kwargs)
[docs] def write(self, value):
return self.put('value', value)
###############################################################################
# Integrations
###############################################################################
[docs]class HaikuModule(Module):
def __init__(self, ctor, *args, **kwargs):
import haiku as hk
def net(*args_, **kwargs_):
return ctor(*args, **kwargs)(*args_, **kwargs_)
self.transformed = hk.transform(net)
def __call__(self, *args, **kwargs):
state = self.get('state', self.transformed.init, rng(), *args, **kwargs)
return self.transformed.apply(state, rng(), *args, **kwargs)
[docs]class FlaxModule(Module):
def __init__(self, ctor, *args, **kwargs):
self.module = ctor(*args, **kwargs)
def __call__(self, *args, **kwargs):
state = self.get('state', self.module.init, rng(), *args, **kwargs)
return self.module.apply(state, *args, **kwargs)
[docs]class OptaxModule(Module):
def __init__(self, ctor, *args, **kwargs):
self.opt = ctor(*args, **kwargs)
def __call__(self, loss, keys, *args, **kwargs):
import optax
loss, params, grads = grad(loss, keys)(*args, **kwargs)
optstate = self.get('state', self.opt.init, params)
updates, optstate = self.opt.update(grads, optstate)
self.put('state', optstate)
context().update(optax.apply_updates(params, updates))
return {'loss': loss.mean(), 'grad_norm': optax.global_norm(grads)}