Source code for tinychain.math.operator

import logging
import typing

from ..json import to_json
from ..scalar.ref import deref, is_literal, same_as, is_op_ref, Op
from ..scalar.value import Id

from .base import is_numeric
from .interface import Boolean, Numeric, Trigonometric


[docs]class Gradients(dict): def __add__(self, other): grads = Gradients() grads.update(self) grads.update(other) return grads def __setitem__(self, key: Id, value: Numeric): if key in self: dict.__setitem__(self, key, self[key] + value) else: dict.__setitem__(self, key, value)
[docs] def update(self, __m: typing.Mapping[Id, Numeric], **kwargs: Numeric) -> None: for var_id in __m: self[var_id] = __m[var_id] for var_id in kwargs: self[var_id] = __m[var_id]
[docs]class Operator(Op): """A differentiable operator like addition, multiplication, exponentiation, etc.""" def __init__(self, subject, args): if not is_numeric(subject): logging.info(f"{subject} is the the subject of a differentiable Operator but does not implement Numeric") Op.__init__(self, subject, args) def __args__(self): return self.subject, self.args def __json__(self): return to_json(self.forward()) def __ns__(self, cxt, name_hint): cxt.deanonymize(self.subject, name_hint + "_subject") cxt.deanonymize(self.args, name_hint + "_args") if is_literal(self.subject) or is_op_ref(self.subject): cxt.assign(self.subject, name_hint + "_subject") if is_op_ref(self.args): cxt.assign(self.args, name_hint + "_args") def __repr__(self): raise NotImplementedError(f"human-readable string representation of {self.__class__.__name__}") def __same__(self, other): other = operator(other) if not other: return False return type(self) is type(other) and same_as(self.subject, other.subject) and same_as(self.args, other.args) @property def shape(self): """The shape of the result of this operation""" raise NotImplementedError(f"{self.__class__.__name__}.shape")
[docs] def forward(self): """Return the result of evaluating this `Operator`""" raise NotImplementedError(f"{self.__class__}.forward")
[docs] def backward(self, variable=None): """ Return the derivative of this :class:`Operator` (may be a numeric constant or itself an :class:`Operator`). If a `variable` is specified, this will be the partial derivative w/r/t the given `variable`. """ raise NotImplementedError(f"{self.__class__}.backward")
[docs] def gradients(self, loss): """ Return the :class:`Gradients` this :class:`Operator` with respect to the given `variables`. If no `variables` are specified, this will return the `Gradients` of each :class:`Variable` that this `Operator` depends on. """ raise NotImplementedError(f"{self.__class__}.gradients")
[docs] def simplify(self): """ Return a simplified but logically equivalent version of this :class:`Operator`, if possible. For example, `Mul(2, 1).simplify()` will return 2. IMPORTANT: don't call `simplify` until after constructing an entire operator graph. This is because `simplify` may discard parts of the operator graph needed to apply the chain rule correctly. """ return self
[docs]class Unary(Operator): def __init__(self, subject): if not is_numeric(subject): raise ValueError(f"Unary operator requires a Numeric subject, not {subject}") Operator.__init__(self, subject, None) def __args__(self): return self.subject, def __ns__(self, cxt, name_hint): assert self.args is None cxt.deanonymize(self.subject, name_hint + "_subject") if is_op_ref(self.subject): cxt.assign(self.subject, name_hint + "_subject")
[docs]class Custom(Unary): """A custom operator""" def __init__(self, subject): Operator.__init__(self, subject, None) self._op = self.forward() def __json__(self): return to_json(self._op) def __ns__(self, cxt, name_hint): Unary.__ns__(self, cxt, name_hint) cxt.deanonymize(self._op, name_hint + "_custom_op") @property def shape(self): return self._op.shape
# TODO: Tensor.log(base!=None)
[docs]class Abs(Unary): def __repr__(self): return f"abs({self.subject})" @property def shape(self): return self.subject.shape
[docs] def forward(self): return Numeric.abs(self.subject)
[docs] def backward(self, _variable=None): return self.subject / self.subject.abs()
[docs] def gradients(self, loss): return gradients(self.subject, loss * self.backward())
[docs] def simplify(self): subject = simplify(self.subject) return Abs(subject)
[docs]class Exp(Unary): def __repr__(self): return f"e**({self.subject})" @property def shape(self): return self.subject.shape
[docs] def forward(self): return Numeric.exp(self.subject)
[docs] def backward(self, variable=None): if same_as(variable, self.subject): return derivative_of(self.subject, variable).exp() elif operator(self.subject): # this operator always has exactly one argument, so it's safe to hard-code the single-variable chain rule return derivative_of(self.subject) * self.subject.exp() else: return self.subject.exp()
[docs] def gradients(self, loss): return gradients(self.subject, loss * self.subject.exp())
[docs] def simplify(self): subject = simplify(self.subject) if is_one(subject): return 1 elif is_zero(subject): return 0 else: return Exp(subject)
[docs]class LogicalNot(Unary): def __repr__(self): return f"NOT ({self.subject})" @property def shape(self): return self.subject.shape
[docs] def forward(self): return Boolean.logical_not(self.subject)
[docs] def simplify(self): subject = simplify(self.subject) return LogicalNot(subject)
[docs]class Trig(Unary): @property def shape(self): return self.subject.shape
[docs] def simplify(self): subject = simplify(self.subject) return type(self)(subject)
[docs]class Sin(Trig): def __repr__(self): return f"sin({self.subject})"
[docs] def forward(self): return Trigonometric.sin(self.subject)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable) if same_as(self.subject, variable) else self.subject return subject.cos()
[docs] def gradients(self, loss): return gradients(self.subject, loss * self.backward())
[docs]class Cos(Trig): def __repr__(self): return f"cos({self.subject})"
[docs] def forward(self): return Trigonometric.cos(self.subject)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable) return subject - self.subject.sin()
[docs] def gradients(self, loss): return gradients(self.subject, loss * -self.subject.sin())
[docs]class Asin(Trig): def __repr__(self): return f"asin({self.subject})"
[docs] def forward(self): return Trigonometric.asin(self.subject)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable) if same_as(self.subject, variable) else self.subject return (1 - (subject**2))**-0.5
[docs] def gradients(self, loss): return gradients(self.subject, loss * self.backward())
[docs]class Acos(Trig): def __repr__(self): return f"acos({self.subject})"
[docs] def forward(self): return Trigonometric.acos(self.subject)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable) if same_as(self.subject, variable) else self.subject return -((1 - subject**2)**-0.5)
[docs] def gradients(self, loss): return gradients(self.subject, loss * self.backward())
[docs]class Sinh(Trig): def __repr__(self): return f"sinh({self.subject})"
[docs] def forward(self): return Trigonometric.sinh(self.subject)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable) if same_as(self.subject, variable) else self.subject return subject.cosh()
[docs] def gradients(self, loss): return gradients(self.subject, loss * self.backward())
[docs]class Cosh(Trig): def __repr__(self): return f"cosh({self.subject})"
[docs] def forward(self): return Trigonometric.cosh(self.subject)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable) if same_as(self.subject, variable) else self.subject return subject.sinh()
[docs] def gradients(self, loss): return gradients(self.subject, loss * self.backward())
[docs]class Asinh(Trig): def __repr__(self): return f"asinh({self.subject})"
[docs] def forward(self): return Trigonometric.asinh(self.subject)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable) if same_as(self.subject, variable) else self.subject return (subject**2 + 1)**-0.5
[docs] def gradients(self, loss): return gradients(self.subject, loss * self.backward())
[docs]class Acosh(Trig): def __repr__(self): return f"acosh({self.subject})"
[docs] def forward(self): return Trigonometric.acosh(self.subject)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable) if same_as(self.subject, variable) else self.subject return ((subject**2) - 1)**-0.5
[docs] def gradients(self, loss): return gradients(self.subject, loss * self.backward())
[docs]class Tan(Trig): def __repr__(self): return f"tan({self.subject})"
[docs] def forward(self): return Trigonometric.tan(self.subject)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable) if same_as(self.subject, variable) else self.subject return 1 / (subject.cos()**2)
[docs] def gradients(self, loss): return gradients(self.subject, loss * self.backward())
[docs]class Tanh(Trig): def __repr__(self): return f"tanh({self.subject})"
[docs] def forward(self): return Trigonometric.tanh(self.subject)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable) if same_as(self.subject, variable) else self.subject return 1 - subject.tanh()**2
[docs] def gradients(self, loss): return gradients(self.subject, loss * self.backward())
[docs]class Atan(Trig): def __repr__(self): return f"atan({self.subject})"
[docs] def forward(self): return Trigonometric.atan(self.subject)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable) if same_as(self.subject, variable) else self.subject return 1 / (subject**2 + 1)
[docs] def gradients(self, loss): return gradients(self.subject, loss * (self.subject**2 + 1)**(-1))
[docs]class Atanh(Trig): def __repr__(self): return f"atanh({self.subject})"
[docs] def forward(self): return Trigonometric.atanh(self.subject)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable) if same_as(self.subject, variable) else self.subject return 1 / (1 - (subject**2))
[docs] def gradients(self, loss): return gradients(self.subject, loss / (1 - self.subject**2))
[docs]class Cond(Operator): """A boolean condition""" def __repr__(self): then, or_else = self.args return f"cond({self.subject}, {then}, {or_else})"
[docs] def forward(self): from ..collection.tensor import NDArray return NDArray.cond(self.subject, *self.args)
[docs] def backward(self, variable=None): def _derivative_of(operand): return derivative_of(operand, variable) if same_as(operand, variable) else operand subject = _derivative_of(self.subject) then = _derivative_of(self.args[0]) or_else = _derivative_of(self.args[1]) return subject.cond(then, or_else)
[docs] def gradients(self, loss): return gradients(self.subject, loss)
[docs]class Dual(Operator): """A differentiable operator with two arguments""" def __init__(self, subject, args): if not is_numeric(subject): raise ValueError(f"{self.__class__.__name__} requires a Numeric subject, not {subject}") if not is_numeric(args): raise ValueError(f"{self.__class__.__name__} requires Numeric args, not {args}") Operator.__init__(self, subject, args)
[docs]class Log(Operator): def __repr__(self): return f"log({self.subject})" @property def shape(self): return self.subject.shape
[docs] def forward(self): return Numeric.log(self.subject)
[docs] def backward(self, variable=None): if variable is None: return chain_rule(self, self.subject, self.args) / self.subject elif same_as(self.subject, variable): return 1 / derivative_of(self.subject, variable) else: return 1 / self.subject
[docs] def gradients(self, loss): return gradients(self.subject, loss / self.subject)
[docs] def simplify(self): subject = simplify(self.subject) args = simplify(self.args) return Log(subject, args)
[docs]class MatMul(Dual): def __repr__(self): return f"({self.subject}) @ ({self.args})" @property def shape(self): from ..shape import Shape return Shape(self.subject.shape[:-2]) + Shape((self.subject.shape[-1], self.args.shape[-2]))
[docs] def forward(self): from ..collection.tensor import NDArray return NDArray.__matmul__(self.subject, self.args)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable, keepdims=True) arg = derivative_of(self.args, variable, keepdims=True) return (subject @ self.args) + (self.subject @ arg)
[docs] def gradients(self, loss): # TODO: don't assume that self.subject.ndim == 2 and self.args.ndim == 2 return (gradients(self.subject, loss @ self.args.transpose([1, 0])) + gradients(self.args, self.subject.transpose([1, 0]) @ loss))
[docs] def simplify(self): subject = simplify(self.subject) args = simplify(self.args) from ..collection.tensor import NDArray if is_zero(subject) or is_zero(args): return zeros_like(self) elif isinstance(subject, NDArray) and isinstance(args, NDArray): return MatMul(subject, args) else: return self
[docs]class Pow(Dual): def __repr__(self): return f"({self.subject})**({self.args})" @property def shape(self): return self.subject.shape
[docs] def forward(self): return Numeric.pow(self.subject, self.args)
[docs] def backward(self, variable=None): if same_as(self.args, variable): return (self.subject**self.args) * self.subject.log() elif variable is None: return chain_rule(self, self.subject, self.args) * self.args * (self.subject**(self.args - 1)) else: return self.args * (self.subject ** (self.args - 1))
[docs] def gradients(self, loss): subject_grad = loss * self.args * self.subject**(self.args - 1) args_grad = loss * self.subject.log() * self.subject**self.args return gradients(self.subject, subject_grad) + gradients(self.args, args_grad)
[docs] def simplify(self): subject = simplify(self.subject) args = simplify(self.args) if is_one(subject) or is_zero(args): return 1 elif is_one(args): return subject return Pow(subject, args)
[docs]class DualBroadcast(Operator): @property def shape(self): if is_literal(self.subject): return self.args.shape elif is_literal(self.args): return self.subject.shape return self.subject.shape.broadcast(self.args.shape)
[docs]class LogicalAnd(DualBroadcast): def __repr__(self): return f"({self.subject}) AND ({self.args})"
[docs] def forward(self): return Boolean.logical_and(self.subject, self.args)
[docs]class LogicalOr(DualBroadcast): def __repr__(self): return f"({self.subject}) OR ({self.args})"
[docs] def forward(self): return Boolean.logical_or(self.subject, self.args)
[docs]class LogicalXor(DualBroadcast): def __repr__(self): return f"({self.subject}) XOR ({self.args})"
[docs] def forward(self): return Boolean.logical_xor(self.subject, self.args)
[docs]class Add(DualBroadcast): def __repr__(self): return f"({self.subject}) + ({self.args})"
[docs] def forward(self): return Numeric.add(self.subject, self.args)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable) arg = derivative_of(self.args, variable) return subject + arg
[docs] def gradients(self, loss): return gradients(self.subject, loss) + gradients(self.args, loss)
[docs] def simplify(self): subject = simplify(self.subject) args = simplify(self.args) if is_zero(subject) and is_zero(args): return 0 if is_zero(subject): return args elif is_zero(args): return subject else: return Add(subject, args)
[docs]class Mul(DualBroadcast): def __repr__(self): return f"({self.subject}) * ({self.args})"
[docs] def forward(self): return Numeric.mul(self.subject, self.args)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable) arg = derivative_of(self.args, variable) return (subject * self.args) + (self.subject * arg)
[docs] def gradients(self, loss): return gradients(self.subject, self.args * loss) + gradients(self.args, self.subject * loss)
[docs] def simplify(self): subject = simplify(self.subject) args = simplify(self.args) if is_zero(subject) or is_zero(args): return 0 elif is_one(subject) and is_one(args): return 1 elif is_one(subject): return args elif is_one(args): return subject else: return Mul(subject, args)
[docs]class Sub(DualBroadcast): def __repr__(self): return f"({self.subject}) - ({self.args})"
[docs] def forward(self): return Numeric.sub(self.subject, self.args)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable) arg = derivative_of(self.args, variable) return subject - arg
[docs] def gradients(self, loss): return gradients(self.subject, loss) + gradients(self.args, -loss)
[docs] def simplify(self): subject = simplify(self.subject) args = simplify(self.args) if is_zero(args): return subject elif same_as(subject, args): return 0 else: return Sub(subject, args)
[docs]class Div(DualBroadcast): def __repr__(self): return f"({self.subject}) / ({self.args})" def __init__(self, subject, args): if same_as(args, 0): raise ValueError(f"cannot divide {subject} by {args}") DualBroadcast.__init__(self, subject, args)
[docs] def forward(self): return Numeric.div(self.subject, self.args)
[docs] def backward(self, variable=None): subject = derivative_of(self.subject, variable) arg = derivative_of(self.args, variable) return ((subject * self.args) - (self.subject * arg)) / (self.args**2)
[docs] def gradients(self, loss): return gradients(self.subject, loss / self.args) + gradients(self.args, -self.subject * loss / self.args**2)
[docs] def simplify(self): subject = simplify(self.subject) args = simplify(self.args) if is_zero(subject): return 0 elif is_one(args): return subject else: return Div(subject, args)
[docs]def chain_rule(op, *args): """ Compute the chain rule coefficient of the given :class:`Operator`. This function will return `1` if the given `numeric` is constant, or has only constant inputs. """ if operator(op): op = operator(op) else: raise ValueError(f"cannot apply the chain rule to a constant {op}") args = [arg for arg in args if operator(arg)] if not args: return 1 elif len(args) == 1: return derivative_of(args[0]) else: return sum(derivative_of(op, arg) * derivative_of(arg) for arg in args)
[docs]def constant(numeric): """Return the given `numeric` state as a constant, i.e. not the result of a differentiable :class:`Operator`.""" if is_literal(numeric): return numeric rtype = type(numeric) if not is_numeric(numeric): raise ValueError(f"a non-numeric state {numeric} (type {rtype}) cannot be a numeric constant") while operator(numeric): numeric = rtype(form=operator(numeric).forward()) return numeric
[docs]def derivative_of(state_or_function, variable=None, keepdims=False): """ Find the derivative of the given `state_or_function`. If a differentiable state is given, this will construct a new op to calculate it derivative, which can be a partial derivative if a `variable` is specified. If a differentiable function is given, a new callable function will be returned which computes its derivative. If `state_or_function` is not differentiable, a `TypeError` will be raised. """ if callable(state_or_function): function = state_or_function if hasattr(function, "derivative"): return function.derivative(variable) else: raise ValueError(f"not a differentiable function: {function}") state = state_or_function if same_as(state, variable): # it's a partial derivative and this is the free variable return ones_like(state, keepdims) # TODO: it doesn't make sense to import from the ML package in the math package from ..ml.variable import Variable if isinstance(state, Variable): if variable is None: # it's not a partial derivative return ones_like(state, keepdims) else: # it's a partial derivative and this variable is held constant return zeros_like(state, keepdims) if is_constant(state): return zeros_like(state, keepdims) elif operator(state): d = operator(state).backward(variable) if keepdims: if same_as(d, 0): return zeros_like(state) elif same_as(d, 1): return ones_like(state) return d else: raise ValueError(f"the derivative of {state} is not defined")
[docs]def gradients(numeric, loss, variables=None): """ Return the gradient of a `numeric` state with respect to the given `loss`. If one variable is given, one gradient will be returned, or a `KeyError` will be raised if not present in the graph. If a list of variables is given, a corresponding list of gradients will be returned. If no variables are given, a :class:`Gradients` object whose keys are the inputs of the graph will be returned. """ if is_literal(numeric): grads = Gradients() elif operator(numeric): grads = operator(numeric).gradients(loss) elif is_constant(numeric): grads = Gradients({numeric: loss}) elif is_numeric(numeric): raise ValueError(f"cannot compute gradients of {numeric} w/r/t {loss}") else: raise ValueError(f"not a numeric state: {numeric}") if variables is None: return grads if not isinstance(variables, (list, tuple)): if variables not in grads: raise KeyError(f"{variables} is not reachable from operator {numeric}") return grads[variables] missing = [var for var in variables if var not in grads] if missing: raise KeyError(f"not reachable by traversing the operator graph {numeric}: {missing}") return [grads[var] for var in variables]
[docs]def is_constant(numeric): """ Return `False` if the given `numeric` state is the result of an :class:`Operator`, i.e. a differentiable function. """ return operator(numeric) is None
[docs]def is_one(numeric): """Return `True` if the given `numeric` state is a constant with value one.""" if same_as(numeric, 1): return True from ..collection.tensor import Dense, NDArray, Transform while isinstance(operator(numeric), Transform): numeric = operator(numeric).subject if same_as(numeric, 1): return True elif isinstance(numeric, NDArray) and same_as(numeric, Dense.ones_like(numeric)): return True return False
[docs]def is_zero(numeric): """Return `True` if the given `numeric` state is a constant with value zero.""" if same_as(numeric, 0): return True from ..collection.tensor import Sparse, NDArray, Transform while isinstance(operator(numeric), Transform): numeric = operator(numeric).subject if same_as(numeric, 0): return True elif isinstance(numeric, NDArray) and same_as(numeric, Sparse.zeros_like(numeric)): return True return False
[docs]def operator(state_or_ref): """Return the `Operator` instance which produces the given `state_or_ref`, if any""" if isinstance(state_or_ref, Operator): return state_or_ref elif deref(state_or_ref) is not state_or_ref: return operator(deref(state_or_ref))
[docs]def simplify(state): """ Simplify the given operator graph, if possible. For example, `simplify(Add(0, 2))` will return `2`. """ if is_literal(state): return state if not is_numeric(state): raise TypeError(f"cannot simplify a non-numeric state: {state}") rtype = type(state) while operator(state): simplified = operator(state).simplify() if same_as(simplified, state): break state = simplified if is_literal(state): return state else: return rtype(form=state)
[docs]def ones_like(numeric, keepdims=True): """Construct a constant with each element equal to one, with the same shape and data type as `numeric`.""" from ..collection.tensor import Dense from ..scalar.number import Number if isinstance(numeric, Number): return type(numeric)(form=1) elif is_literal(numeric) or not keepdims: return Number(1) else: cls = type(numeric) if hasattr(type(numeric), "ones_like") else Dense return cls.ones_like(numeric)
[docs]def zeros_like(state, keepdims=True): """Construct a constant with each element equal to zero, with the same shape and data type as `numeric`.""" from ..collection.tensor import Sparse from ..scalar.number import Number if isinstance(state, Number): return type(state)(form=0) elif is_literal(state) or not keepdims: return Number(0) else: cls = type(state) if hasattr(type(state), "zeros_like") else Sparse return cls.zeros_like(state)
def _debug_shape(numeric): if not hasattr(numeric, "shape"): raise ValueError(f"{numeric} has no shape") if is_literal(numeric.shape): print(f"the shape of {numeric} is {numeric.shape}") return print(f"{numeric} does not have a literal shape") op = operator(numeric) if not op: return from ..collection.tensor import NDArray if isinstance(op.subject, NDArray): _debug_shape(op.subject) if isinstance(op.args, NDArray): _debug_shape(op.args)