from __future__ import annotations
from typing import TYPE_CHECKING, Type
from ..error import BadRequest
from ..generic import Map, Tuple
from ..json import to_json
from ..scalar.bound import Range
from ..scalar.number import Bool, UInt
from ..scalar.ref import If, Ref, form_of
from ..state import State
from ..uri import URI
from .base import Collection, Column
if TYPE_CHECKING:
from ..service import Model
[docs]class Schema(object):
"""A `Table` schema which comprises a primary key and value :class:`Column` s."""
def __init__(self, key, values=[]):
self.key = key
self.values = values
self.indices = []
def __json__(self):
return to_json([[self.key, self.values], Tuple(self.indices)])
def columns(self):
return self.key + self.values
def create_index(self, name, columns):
self.indices.append((name, columns))
return self
[docs]class Table(Collection):
"""A `Table` defined by a primary key, values, and optional indices."""
__uri__ = URI(Collection) + "/table"
def __getitem__(self, key):
"""Return the row with the given key, or a :class:`NotFound` error."""
return self._get("", key, rtype=Map)
# TODO: re-enable this functionality after implementing Graph slicing
[docs] def aggregate(self, columns, fn):
"""
Apply the given callback to slices of this `Table` grouped by the given columns.
Returns a stream of tuples of the form (<unique column values>, <callback result>).
Example: `orders.aggregate(["customer_id", "product_id"], Table.count)`
"""
raise NotImplementedError("Table.aggregate has been temporarily disabled")
[docs] def contains(self, key):
"""Return `True` if this `Table` contains the given key."""
return self._get("contains", key, rtype=Bool)
[docs] def columns(self):
"""Return the column schema of this `Table` as a :class:`Tuple`."""
return self._get("columns", rtype=Tuple)
[docs] def count(self):
"""Return the number of rows in the given slice of this `Table`."""
return self._get("count", rtype=UInt)
[docs] def delete(self, key):
"""Delete the row of this `Table` with the given `key`."""
return self._delete("", key)
[docs] def insert(self, key, values=[]):
"""
Insert the given row into this `Table`.
If the key is already present, this will raise a :class:`BadRequest` error.
"""
return If(
self.contains(key),
BadRequest("cannot insert: key already exists"),
self.upsert(key, values))
[docs] def is_empty(self):
"""Return `True` if this table contains no rows."""
return self._get("is_empty", rtype=Bool)
[docs] def key_columns(self):
"""Return the schema of the key columns of this `Table`."""
return self._get("key_columns", rtype=Tuple)
[docs] def key_names(self):
"""Return the `Id` s of the key columns of this `Table`."""
return self._get("key_names", rtype=Tuple)
[docs] def limit(self, limit):
"""Limit the number of rows returned from this `Table`."""
return self._get("limit", limit, Table)
[docs] def order_by(self, columns, reverse=False):
"""
Set the order in which this `Table`'s rows will be iterated over.
If no index supports the given order, this will raise a :class:`BadRequest` error.
"""
return self._get("order", (columns, reverse), Table)
[docs] def select(self, columns):
"""Return a `Table` containing only the specified columns."""
return self._get("select", columns, Table)
[docs] def truncate(self):
"""Delete all rows in this :class:`Table`."""
return self.delete("")
[docs] def update(self, **values):
"""Update the rows of this table with the given `values`."""
return self._put("", values)
[docs] def upsert(self, key, values):
"""
Insert the given row into this `Table`.
If the row is already present, it will be updated with the given `values`.
"""
return self._put("", key, values)
[docs] def where(self, **bounds):
"""
Return a slice of this `Table` whose column values fall within the specified range.
If there is no index which supports the given range, this will raise a :class:`BadRequest` error.
"""
if not bounds:
return self
parent = self
bounds = handle_bounds(bounds)
class WriteableView(Table):
def delete(self, key):
return RuntimeError(f"cannot delete the row at {key} from a slice {self} of a table {parent}")
def update(self, **values):
return parent._put("", [(col, bounds[col]) for col in bounds], values)
def upsert(self, key, values):
return RuntimeError(f"cannot upsert ({key}, {values}) into a slice {self} of a table {parent}")
def truncate(self):
return parent._delete("", [(col, bounds[col]) for col in bounds])
return self._get("", [(col, bounds[col]) for col in bounds], WriteableView)
def handle_bounds(bounds):
if bounds is None:
return {}
elif isinstance(bounds, State):
return handle_bounds(form_of(bounds))
elif isinstance(bounds, (Ref, URI)):
return bounds
return {
col: Range.from_slice(bounds[col]) if isinstance(bounds[col], slice) else bounds[col]
for col in bounds
}
# TODO: move to the graph package
[docs]def create_schema(modelclass: Type[Model]) -> Schema:
"""
Create a table schema for the given model.
A key for the table is auto generated using the `class_name` function, then suffixed with '_id'.
Each attribute of the model will be considered as a column if it is of type :class:`Column` or :class:`Model`.
"""
values = []
indices = []
base_attributes = set()
for b in modelclass.__bases__:
base_attributes |= set(dir(b))
for f in base_attributes ^ set(dir(modelclass)):
attr = getattr(modelclass, f)
if isinstance(attr, Column):
values.append(attr)
else:
try:
from ..service import Model, class_name
assert issubclass(attr, Model)
values.append(*attr.key())
indices.append((class_name(attr), [attr.key()[0].name]))
except (TypeError, AssertionError):
continue
schema = Schema(modelclass.key(), values)
for i in indices:
schema.create_index(*i)
return schema