339 lines
7.9 KiB
Python
339 lines
7.9 KiB
Python
"""
|
|
.. module: lemur.database
|
|
:platform: Unix
|
|
:synopsis: This module contains all of the database related methods
|
|
needed for lemur to interact with a datastore
|
|
|
|
:copyright: (c) 2018 by Netflix Inc., see AUTHORS for more
|
|
:license: Apache, see LICENSE for more details.
|
|
|
|
.. moduleauthor:: Kevin Glisson <kglisson@netflix.com>
|
|
"""
|
|
from inflection import underscore
|
|
from sqlalchemy import exc, func, distinct
|
|
from sqlalchemy.orm import make_transient, lazyload
|
|
from sqlalchemy.sql import and_, or_
|
|
|
|
from lemur.exceptions import AttrNotFound, DuplicateError
|
|
from lemur.extensions import db
|
|
|
|
|
|
def filter_none(kwargs):
|
|
"""
|
|
Remove all `None` values froma given dict. SQLAlchemy does not
|
|
like to have values that are None passed to it.
|
|
|
|
:param kwargs: Dict to filter
|
|
:return: Dict without any 'None' values
|
|
"""
|
|
n_kwargs = {}
|
|
for k, v in kwargs.items():
|
|
if v:
|
|
n_kwargs[k] = v
|
|
return n_kwargs
|
|
|
|
|
|
def session_query(model):
|
|
"""
|
|
Returns a SQLAlchemy query object for the specified `model`.
|
|
|
|
If `model` has a ``query`` attribute already, that object will be returned.
|
|
Otherwise a query will be created and returned based on `session`.
|
|
|
|
:param model: sqlalchemy model
|
|
:return: query object for model
|
|
"""
|
|
return model.query if hasattr(model, "query") else db.session.query(model)
|
|
|
|
|
|
def create_query(model, kwargs):
|
|
"""
|
|
Returns a SQLAlchemy query object for specified `model`. Model
|
|
filtered by the kwargs passed.
|
|
|
|
:param model:
|
|
:param kwargs:
|
|
:return:
|
|
"""
|
|
s = session_query(model)
|
|
return s.filter_by(**kwargs)
|
|
|
|
|
|
def commit():
|
|
"""
|
|
Helper to commit the current session.
|
|
"""
|
|
db.session.commit()
|
|
|
|
|
|
def add(model):
|
|
"""
|
|
Helper to add a `model` to the current session.
|
|
|
|
:param model:
|
|
:return:
|
|
"""
|
|
db.session.add(model)
|
|
|
|
|
|
def get_model_column(model, field):
|
|
if field in getattr(model, "sensitive_fields", ()):
|
|
raise AttrNotFound(field)
|
|
column = model.__table__.columns._data.get(field, None)
|
|
if column is None:
|
|
raise AttrNotFound(field)
|
|
|
|
return column
|
|
|
|
|
|
def find_all(query, model, kwargs):
|
|
"""
|
|
Returns a query object that ensures that all kwargs
|
|
are present.
|
|
|
|
:param query:
|
|
:param model:
|
|
:param kwargs:
|
|
:return:
|
|
"""
|
|
conditions = []
|
|
kwargs = filter_none(kwargs)
|
|
for attr, value in kwargs.items():
|
|
if not isinstance(value, list):
|
|
value = value.split(",")
|
|
|
|
conditions.append(get_model_column(model, attr).in_(value))
|
|
|
|
return query.filter(and_(*conditions))
|
|
|
|
|
|
def find_any(query, model, kwargs):
|
|
"""
|
|
Returns a query object that allows any kwarg
|
|
to be present.
|
|
|
|
:param query:
|
|
:param model:
|
|
:param kwargs:
|
|
:return:
|
|
"""
|
|
or_args = []
|
|
for attr, value in kwargs.items():
|
|
or_args.append(or_(get_model_column(model, attr) == value))
|
|
exprs = or_(*or_args)
|
|
return query.filter(exprs)
|
|
|
|
|
|
def get(model, value, field="id"):
|
|
"""
|
|
Returns one object filtered by the field and value.
|
|
|
|
:param model:
|
|
:param value:
|
|
:param field:
|
|
:return:
|
|
"""
|
|
query = session_query(model)
|
|
return query.filter(get_model_column(model, field) == value).scalar()
|
|
|
|
|
|
def get_all(model, value, field="id"):
|
|
"""
|
|
Returns query object with the fields and value filtered.
|
|
|
|
:param model:
|
|
:param value:
|
|
:param field:
|
|
:return:
|
|
"""
|
|
query = session_query(model)
|
|
return query.filter(get_model_column(model, field) == value)
|
|
|
|
|
|
def create(model):
|
|
"""
|
|
Helper that attempts to create a new instance of an object.
|
|
|
|
:param model:
|
|
:return: :raise IntegrityError:
|
|
"""
|
|
try:
|
|
db.session.add(model)
|
|
commit()
|
|
except exc.IntegrityError as e:
|
|
raise DuplicateError(e.orig.diag.message_detail)
|
|
|
|
db.session.refresh(model)
|
|
return model
|
|
|
|
|
|
def update(model):
|
|
"""
|
|
Helper that attempts to update a model.
|
|
|
|
:param model:
|
|
:return:
|
|
"""
|
|
commit()
|
|
db.session.refresh(model)
|
|
return model
|
|
|
|
|
|
def delete(model):
|
|
"""
|
|
Helper that attempts to delete a model.
|
|
|
|
:param model:
|
|
"""
|
|
if model:
|
|
db.session.delete(model)
|
|
db.session.commit()
|
|
|
|
|
|
def filter(query, model, terms):
|
|
"""
|
|
Helper that searched for 'like' strings in column values.
|
|
|
|
:param query:
|
|
:param model:
|
|
:param terms:
|
|
:return:
|
|
"""
|
|
column = get_model_column(model, underscore(terms[0]))
|
|
return query.filter(column.ilike("%{}%".format(terms[1])))
|
|
|
|
|
|
def sort(query, model, field, direction):
|
|
"""
|
|
Returns objects of the specified `model` in the field and direction
|
|
given
|
|
|
|
:param query:
|
|
:param model:
|
|
:param field:
|
|
:param direction:
|
|
"""
|
|
column = get_model_column(model, underscore(field))
|
|
return query.order_by(column.desc() if direction == "desc" else column.asc())
|
|
|
|
|
|
def paginate(query, page, count):
|
|
"""
|
|
Returns the items given the count and page specified
|
|
|
|
:param query:
|
|
:param page:
|
|
:param count:
|
|
"""
|
|
total = get_count(query)
|
|
items = query.paginate(page, count).items
|
|
return dict(items=items, total=total, current=len(items))
|
|
|
|
|
|
def update_list(model, model_attr, item_model, items):
|
|
"""
|
|
Helper that correctly updates a models items
|
|
depending on what has changed
|
|
|
|
:param model_attr:
|
|
:param item_model:
|
|
:param items:
|
|
:param model:
|
|
:return:
|
|
"""
|
|
ids = []
|
|
|
|
for i in getattr(model, model_attr):
|
|
if i.id not in ids:
|
|
getattr(model, model_attr).remove(i)
|
|
|
|
for i in items:
|
|
for item in getattr(model, model_attr):
|
|
if item.id == i["id"]:
|
|
break
|
|
else:
|
|
getattr(model, model_attr).append(get(item_model, i["id"]))
|
|
|
|
return model
|
|
|
|
|
|
def clone(model):
|
|
"""
|
|
Clones the given model and removes it's primary key
|
|
:param model:
|
|
:return:
|
|
"""
|
|
db.session.expunge(model)
|
|
make_transient(model)
|
|
model.id = None
|
|
return model
|
|
|
|
|
|
def get_count(q):
|
|
"""
|
|
Count the number of rows in a table. More efficient than count(*)
|
|
:param q:
|
|
:return:
|
|
"""
|
|
disable_group_by = False
|
|
if len(q._entities) > 1:
|
|
# currently support only one entity
|
|
raise Exception("only one entity is supported for get_count, got: %s" % q)
|
|
entity = q._entities[0]
|
|
if hasattr(entity, "column"):
|
|
# _ColumnEntity has column attr - on case: query(Model.column)...
|
|
col = entity.column
|
|
if q._group_by and q._distinct:
|
|
# which query can have both?
|
|
raise NotImplementedError
|
|
if q._group_by or q._distinct:
|
|
col = distinct(col)
|
|
if q._group_by:
|
|
# need to disable group_by and enable distinct - we can do this because we have only 1 entity
|
|
disable_group_by = True
|
|
count_func = func.count(col)
|
|
else:
|
|
# _MapperEntity doesn't have column attr - on case: query(Model)...
|
|
count_func = func.count()
|
|
if q._group_by and not disable_group_by:
|
|
count_func = count_func.over(None)
|
|
count_q = (
|
|
q.options(lazyload("*"))
|
|
.statement.with_only_columns([count_func])
|
|
.order_by(None)
|
|
)
|
|
if disable_group_by:
|
|
count_q = count_q.group_by(None)
|
|
count = q.session.execute(count_q).scalar()
|
|
return count
|
|
|
|
|
|
def sort_and_page(query, model, args):
|
|
"""
|
|
Helper that allows us to combine sorting and paging
|
|
|
|
:param query:
|
|
:param model:
|
|
:param args:
|
|
:return:
|
|
"""
|
|
sort_by = args.pop("sort_by")
|
|
sort_dir = args.pop("sort_dir")
|
|
page = args.pop("page")
|
|
count = args.pop("count")
|
|
|
|
if args.get("user"):
|
|
user = args.pop("user")
|
|
|
|
query = find_all(query, model, args)
|
|
|
|
if sort_by and sort_dir:
|
|
query = sort(query, model, sort_by, sort_dir)
|
|
|
|
total = get_count(query)
|
|
|
|
# offset calculated at zero
|
|
page -= 1
|
|
items = query.offset(count * page).limit(count).all()
|
|
return dict(items=items, total=total)
|