import typing
from ..service import library_uri, Library
from ..collection.tensor import einsum, Dense, Sparse, Tensor
from ..decorators import closure, get as get_op, post
from ..error import BadRequest
from ..generic import Map, Tuple
from ..scalar.number import Number, UInt, F32, Int
from ..scalar.ref import after, cond, while_loop, Get
from ..scalar.value import Value, Version
from ..uri import URI
from .base import product
from .constants import NS
# from "Numerical Recipes in C" p. 65
EPS = F32(10**-6)
[docs]def diagonal(tensor):
"""Return the diagonal of the given `tensor` of matrices"""
rtype = type(tensor) if isinstance(tensor, Tensor) else Tensor
op = Get(URI(tensor, "diagonal"))
return rtype(form=op)
[docs]def with_diagonal(matrix, diag):
"""Construct a view of the given `matrix` with the diagonal set to `diag`."""
eye = Sparse.eye(matrix.shape[-2])
return eye.cond(diag, matrix)
# TODO: replace this helper class with a `typing.TypedDict`
[docs]class PLUFactorization(Map):
"""PLU factorization of a given `[N, N]` matrix."""
@property
def p(self) -> Tensor:
"""Permutation matrix as an `[N, N]` `Tensor`"""
return Tensor(self['p'])
@property
def l(self) -> Tensor:
"""Lower-triangular matrix as an `[N, N]` `Tensor`"""
return Tensor(self['l'])
@property
def u(self) -> Tensor:
"""Upper-triangular matrix as an `[N, N]` `Tensor`"""
return Tensor(self['u'])
@property
def num_permutations(self) -> UInt:
"""The number of permutations calculated during the factorization"""
return UInt(self['num_permutations'])
# TODO: replace Tuple.range with a Stream after re-implementing Stream
[docs]class LinearAlgebra(Library):
NAME = "linalg"
VERSION = Version("0.0.0")
__uri__ = library_uri(None, NS, NAME, VERSION)
# TODO: vectorize to support a `Tensor` containing a batch of matrices
@post
def householder(self, cxt, x: Tensor) -> typing.Tuple[Tensor, Tensor]:
"""Compute the Householder vector of the given column vector `a`."""
cxt.alpha = x[0]
cxt.s = cond(x.shape[0] > 1, (x[1:]**2).sum(), 0.0)
cxt.t = (cxt.alpha**2 + cxt.s)**0.5
cxt.v = x.copy()
cxt.v_zero = cond(cxt.alpha <= 0, cxt.alpha - cxt.t, -cxt.s / (cxt.alpha + cxt.t))
tau = cond(cxt.s.abs() < EPS, 0, 2 * cxt.v_zero**2 / (cxt.s + cxt.v_zero**2))
v = cond(Int(x.shape[0]) > 1, after(cxt.v[0].write(cxt.v_zero), cxt.v / cxt.v_zero), cxt.v)
return v, tau
# TODO: vectorize to support a `Tensor` containing a batch of matrices
@post
def qr(self, txn, a: Tensor) -> typing.Tuple[Tensor, Tensor]:
"""Compute the QR decomposition of the given matrix `a`"""
txn.shape = a.shape
txn.n, txn.m = txn.shape.unpack(2)
txn.q_init = Dense.zeros([txn.n, txn.n])
txn.u_init = Dense.zeros([txn.n, txn.n])
txn.u = after(txn.u_init[:, 0].write(a[:, 0]), txn.u_init).copy()
txn.q = after(
after(txn.u, txn.q_init[:, 0].write(txn.u_init[:, 0] / norm(txn.u_init[:, 0]))),
txn.q_init
).copy()
@closure(a, txn.q, txn.u)
@get_op
def q_step(cxt, i: UInt) -> Tensor:
@closure(a, i, txn.q, txn.u)
@get_op
def u_step(j: UInt) -> Map:
col = txn.u[:, i].copy() - (txn.q[:, j] * (a[:, i] * txn.q[:, j]).sum())
return txn.u[:, i].write(col)
cxt.update_u = after(txn.u[:, i].write(a[:, i]), Tuple.range(i).for_each(u_step))
cxt.update_q = after(
cxt.update_u,
txn.q[:, i].write(txn.u[:, i] / norm(txn.u[:, i])))
return after(cxt.update_q, {})
n = cond(txn.n <= txn.m, txn.n, txn.m)
txn.update_q = Tuple.range((1, n)).for_each(q_step)
txn._q = Tensor(after(txn.update_q, txn.q))
txn._r = Dense.zeros([txn.n, txn.m])
@closure(txn._r, a, txn._q, txn.m)
@get_op
def r_step(i: UInt):
@closure(txn._r, a, txn._q, i, txn.m)
@get_op
def r_step_inner(j: UInt):
return txn._r[i, j].write((a[:, j] * txn._q[:, i]).sum())
return Tuple.range((i, txn.m)).for_each(r_step_inner)
return after(Tuple.range(txn.n).for_each(r_step), (txn._q, txn._r))
@post
def plu(self, txn, x: Tensor) -> PLUFactorization:
"""Compute the PLU factorization of the given matrix `x`.
Args:
`x`: a matrix with shape `[N, N]`
Returns `(p, l, u)` where
`p` is the permutation matrix,
`l` is lower triangular with unit diagonal elements, and
`u` is upper triangular.
"""
# TODO: use a TypedDict as the return annotation
@post
def permute_rows(x: Tensor, p: Tensor, start_from: UInt) -> Map:
@closure(start_from)
@post
def step(p: Tensor, x: Tensor, k: UInt) -> Map:
p_k, p_kp1 = p[start_from].copy(), p[k + 1].copy()
x_k, x_kp1 = x[start_from].copy(), x[k + 1].copy()
return after(
[
p[start_from].write(p_kp1),
p[k + 1].write(p_k),
x[start_from].write(x_kp1),
x[k + 1].write(x_k),
],
{'p': p, 'x': x, 'k': k + 1}
)
@post
def while_cond(cxt, x: Tensor, k: UInt):
cxt.valid_k = k < (x.shape[0] - 1)
cxt.valid_x_k_k = x[k, k].abs() < 1e-3
return cxt.valid_k.logical_and(cxt.valid_x_k_k)
return while_loop(while_cond, step, {
'p': p.copy(),
'x': x.copy(),
'k': start_from
})
txn.permute_rows = permute_rows
@closure(txn.permute_rows)
@post
def step(p: Tensor, l: Tensor, u: Tensor, i: UInt, num_permutations: UInt) -> Map:
pu = txn.permute_rows(p=p, x=u, start_from=i)
u = Tensor(pu['x'])
p = Tensor(pu['p'])
n = UInt(pu['k']) - i
factor = Tensor(u[i + 1:, i] / u[i, i])
return after(
when=[
l[i + 1:, i].write(factor),
u[i + 1:].write(u[i + 1:] - factor.expand_dims() * u[i]),
],
then=Map(p=p, l=l, u=u, i=i + 1, num_permutations=num_permutations + n))
@post
def factor_cond(u: Tensor, i: UInt):
return i < UInt(u.shape[0]) - 1
txn.factorization = while_loop(factor_cond, step, {
'p': Sparse.eye(x.shape[0]).cast(F32).as_dense().copy(),
'l': Sparse.eye(x.shape[0]).cast(F32).as_dense().copy(),
'u': x.copy(),
'i': 0,
"num_permutations": 0,
})
return cond(
_is_square(x),
txn.factorization,
BadRequest("PLU decomposition requires a square matrix, not {{x}}", x=x))
@post
def det(self, x: Tensor) -> F32:
"""Computes the determinant of a square matrix `x`.
Args:
`x`: a matrix with shape `[N, N]`
Returns:
The determinant for `x`
"""
plu_result = self.plu(x=x)
sign = Int(-1).pow(plu_result.num_permutations)
determinant = diagonal(plu_result.u).product() * sign
return cond(
_is_square(x),
determinant,
BadRequest("determinant requires a square matrix, not {{x}}", x=x))
@post
def slogdet(self, cxt, x: Dense) -> typing.Tuple[Tensor, Tensor]:
"""
Compute the sign and log of the absolute value of the determinant of one or more square matrices.
Args:
`x`: a `Tensor` of square matrices with shape `[..., M, M]`
Returns:
`(sign, logdet)` where:
`sign` is a `Tensor` of signs of determinants `{-1, +1}` with shape `[...]`
`logdet` is a `Tensor` of the natural log of the absolute values of determinants with shape `[...]`
"""
cxt.batch_shape = x.shape[:-2]
cxt.batch_size = product(cxt.batch_shape)
cxt.sign_result = Dense.create([cxt.batch_size])
cxt.logdet_result = Dense.create([cxt.batch_size])
cxt.copy = x.reshape(Tuple([cxt.batch_size]) + x.shape[-2:]).copy()
@closure(cxt.copy, cxt.sign_result, cxt.logdet_result)
@get_op
def step(i: UInt):
d = self.det(x=cxt.copy[i])
logdet = F32(d.abs().log())
sign = cond(d > 0, 1, -1) * 1
return [
cxt.sign_result[i].write(sign),
cxt.logdet_result[i].write(logdet),
]
sign, determinants = after(
Tuple.range((0, cxt.batch_size)).for_each(step),
[cxt.sign_result, cxt.logdet_result])
return Tensor(sign).reshape(cxt.batch_shape), Tensor(determinants).reshape(cxt.batch_shape)
@post
def svd_matrix(self, cxt, A: Tensor, l=UInt(0), epsilon=EPS, max_iter=UInt(30)) -> typing.Tuple[Tensor, Tensor, Tensor]:
"""
Compute the singular value decomposition of the given matrix `A`
Returns:
`(U, s, V)`: :class:`Tensor` s such that `A` ~= `u * (identity([P, P]) * s) * v`, where `P = min(N, M)`.
"""
cxt.shape = A.shape
cxt.n_orig, cxt.m_orig = cxt.shape.unpack(2)
k = cond(l == 0, Value.min(cxt.n_orig, cxt.m_orig), l)
cxt.A_orig = A.copy()
cxt.A1, n, m = cond(
cxt.n_orig > cxt.m_orig,
[A.transpose() @ A, Tensor(A).shape[1], Tensor(A).shape[1]],
cond(
cxt.n_orig < cxt.m_orig,
[A @ Tensor(A).transpose(), A.shape[0], A.shape[0]],
[A, cxt.n_orig, cxt.m_orig]
),
).unpack(3) # TODO: this call to `unpack` should not be necessary
Q, R = self.qr(a=Dense.random_uniform([n, k]).abs())
@closure(cxt.A1)
@post
def step(i: UInt, Q_prev: Tensor, Q: Tensor):
Z = Tensor(cxt.A1) @ Q
_Q, _R = self.qr(a=Z)
_err = _Q.sub(Q_prev).pow(2).sum()
_Q_prev = _Q.copy()
return Map(i=i + 1, Q_prev=_Q_prev, Q=_Q, R=_R, err=_err)
cxt.step = step
@closure(epsilon, max_iter)
@post
def while_cond(i: UInt, err: F32):
return (F32(err).abs() > epsilon).logical_and(i < max_iter)
cxt.cond = while_cond
result_loop = while_loop(cxt.cond, cxt.step, Map(
i=UInt(0),
Q_prev=Tensor(Q).copy(),
Q=Tensor(Q).copy(),
R=Tensor(R),
err=F32(1.0)))
Q, R = Tensor(result_loop['Q']), Tensor(result_loop['R'])
singular_values = diagonal(R).pow(0.5)
cxt.eye = Sparse.eye(singular_values.shape[0]).as_dense().copy()
cxt.inv_matrix = (cxt.eye * singular_values.pow(-1))
cxt.Q_T = Q.transpose()
cxt.vec_sing_values_upd = cond(
cxt.n_orig == cxt.m_orig,
Map(left_vecs=cxt.Q_T, right_vecs=cxt.Q_T, singular_values=singular_values.pow(2)),
Map(
left_vecs=einsum('ij,jk->ik', [einsum('ij,jk->ik', [cxt.A_orig, Q]), cxt.inv_matrix]),
right_vecs=cxt.Q_T,
singular_values=singular_values))
vec_sing_values = cond(
cxt.n_orig < cxt.m_orig,
Map(
left_vecs=cxt.Q_T,
right_vecs=einsum('ij,jk->ik', [einsum('ij,jk->ik', [cxt.inv_matrix, Q]), cxt.A_orig]),
singular_values=singular_values),
cxt.vec_sing_values_upd)
return vec_sing_values['left_vecs'], vec_sing_values['singular_values'], vec_sing_values['right_vecs']
# TODO: update to support `Tensor` (not just `Dense`) after `Sparse.concatenate` is implemented
@post
def svd_parallel(self, txn, A: Tensor, l=UInt(0), epsilon=EPS, max_iter=UInt(30)) -> typing.Tuple[Tensor, Tensor, Tensor]:
"""
Given a `Tensor` of `matrices`, return the singular value decomposition `(s, u, v)` of each matrix.
Currently only implemented for `Dense` matrices.
"""
txn.N, txn.M = A.shape[-2:].unpack(2)
txn.batch_shape = A.shape[:-2]
txn.num_matrices = product(txn.batch_shape)
txn.matrices = A.reshape([txn.num_matrices, txn.N, txn.M]).copy()
@closure(txn.matrices, l, epsilon, max_iter)
@get_op
def matrix_svd(i: UInt) -> typing.Tuple[Tensor, Tensor, Tensor]:
return self.svd_matrix(A=txn.matrices[i], l=l, epsilon=epsilon, max_iter=max_iter)
txn.indices = Tuple.range(txn.num_matrices)
txn.UsV_tuples = txn.indices.map(matrix_svd)
def getter(j):
@closure(txn.UsV_tuples)
@get_op
def getter(i: UInt) -> Dense:
tensor = Tensor(Tuple(txn.UsV_tuples[i])[j])
return tensor.expand_dims(0)
return getter
txn.U = Dense.concatenate(txn.indices.map(getter(0)), axis=0)
txn.s = Dense.concatenate(txn.indices.map(getter(1)), axis=0)
txn.V = Dense.concatenate(txn.indices.map(getter(2)), axis=0)
return (
txn.U.reshape(Tuple.concatenate(txn.batch_shape, txn.U.shape[1:])).copy(),
txn.s.reshape(Tuple.concatenate(txn.batch_shape, [Number.min(txn.N, txn.M)])).copy(),
txn.V.reshape(Tuple.concatenate(txn.batch_shape, txn.V.shape[1:])).copy(),
)
@post
def svd(self, A: Tensor, l=UInt(0), epsilon=EPS, max_iter=UInt(30)) -> typing.Tuple[Tensor, Tensor, Tensor]:
"""
Computes `svd_matrix` for each matrix in `A`.
For `A` with shape `[..., N, M]`, `svd` returns a tuple `(U, s, V)` such that
`A[...]` ~= `u[...] * (identity([P, P]) * s[...]) * v[...]`, where `P = min(N, M)`.
"""
return cond(
A.ndim == 2,
self.svd_matrix(A=A, l=l, epsilon=epsilon, max_iter=max_iter),
self.svd_parallel(A=A, l=l, epsilon=epsilon, max_iter=max_iter))
[docs]def norm(x):
"""Helper function to compute the norm of a matrix `x`"""
return (x**2).sum()**0.5
def _is_square(x):
return (x.ndim == 2).logical_and(x.shape[0] == x.shape[1])