lemur/lemur/database.py

309 lines
6.7 KiB
Python
Raw Normal View History

2015-06-22 22:47:27 +02:00
"""
.. module: lemur.database
:platform: Unix
:synopsis: This module contains all of the database related methods
needed for lemur to interact with a datastore
:copyright: (c) 2018 by Netflix Inc., see AUTHORS for more
2015-06-22 22:47:27 +02:00
:license: Apache, see LICENSE for more details.
.. moduleauthor:: Kevin Glisson <kglisson@netflix.com>
"""
from inflection import underscore
2018-05-24 21:55:52 +02:00
from sqlalchemy import exc, func
2015-06-22 22:47:27 +02:00
from sqlalchemy.sql import and_, or_
from sqlalchemy.orm import make_transient
2015-06-22 22:47:27 +02:00
from lemur.extensions import db
2015-07-21 22:06:13 +02:00
from lemur.exceptions import AttrNotFound, DuplicateError
2015-06-22 22:47:27 +02:00
def filter_none(kwargs):
"""
Remove all `None` values froma given dict. SQLAlchemy does not
like to have values that are None passed to it.
:param kwargs: Dict to filter
:return: Dict without any 'None' values
"""
n_kwargs = {}
for k, v in kwargs.items():
if v:
n_kwargs[k] = v
return n_kwargs
def session_query(model):
"""
Returns a SQLAlchemy query object for the specified `model`.
If `model` has a ``query`` attribute already, that object will be returned.
Otherwise a query will be created and returned based on `session`.
:param model: sqlalchemy model
:return: query object for model
"""
return model.query if hasattr(model, 'query') else db.session.query(model)
def create_query(model, kwargs):
"""
Returns a SQLAlchemy query object for specified `model`. Model
filtered by the kwargs passed.
:param model:
:param kwargs:
:return:
"""
s = session_query(model)
return s.filter_by(**kwargs)
def commit():
"""
Helper to commit the current session.
"""
db.session.commit()
def add(model):
"""
Helper to add a `model` to the current session.
:param model:
:return:
"""
db.session.add(model)
def get_model_column(model, field):
if field in getattr(model, 'sensitive_fields', ()):
raise AttrNotFound(field)
column = model.__table__.columns._data.get(field, None)
if column is None:
raise AttrNotFound(field)
return column
2015-06-22 22:47:27 +02:00
def find_all(query, model, kwargs):
"""
Returns a query object that ensures that all kwargs
are present.
:param query:
:param model:
:param kwargs:
:return:
"""
conditions = []
kwargs = filter_none(kwargs)
for attr, value in kwargs.items():
if not isinstance(value, list):
value = value.split(',')
conditions.append(get_model_column(model, attr).in_(value))
2015-06-22 22:47:27 +02:00
return query.filter(and_(*conditions))
def find_any(query, model, kwargs):
"""
Returns a query object that allows any kwarg
to be present.
:param query:
:param model:
:param kwargs:
:return:
"""
or_args = []
for attr, value in kwargs.items():
or_args.append(or_(get_model_column(model, attr) == value))
2015-06-22 22:47:27 +02:00
exprs = or_(*or_args)
return query.filter(exprs)
def get(model, value, field="id"):
"""
Returns one object filtered by the field and value.
:param model:
:param value:
:param field:
:return:
"""
query = session_query(model)
return query.filter(get_model_column(model, field) == value).scalar()
2015-06-22 22:47:27 +02:00
def get_all(model, value, field="id"):
"""
Returns query object with the fields and value filtered.
:param model:
:param value:
:param field:
:return:
"""
query = session_query(model)
return query.filter(get_model_column(model, field) == value)
2015-06-22 22:47:27 +02:00
def create(model):
"""
Helper that attempts to create a new instance of an object.
:param model:
:return: :raise IntegrityError:
"""
try:
db.session.add(model)
commit()
except exc.IntegrityError as e:
2015-06-25 01:48:40 +02:00
raise DuplicateError(e.orig.diag.message_detail)
db.session.refresh(model)
2015-06-22 22:47:27 +02:00
return model
def update(model):
"""
Helper that attempts to update a model.
:param model:
:return:
"""
commit()
db.session.refresh(model)
return model
def delete(model):
"""
Helper that attempts to delete a model.
:param model:
"""
2015-07-23 05:32:29 +02:00
if model:
db.session.delete(model)
db.session.commit()
2015-06-22 22:47:27 +02:00
def filter(query, model, terms):
"""
Helper that searched for 'like' strings in column values.
:param query:
:param model:
:param terms:
:return:
"""
column = get_model_column(model, underscore(terms[0]))
return query.filter(column.ilike('%{}%'.format(terms[1])))
2015-06-22 22:47:27 +02:00
def sort(query, model, field, direction):
"""
Returns objects of the specified `model` in the field and direction
given
:param query:
:param model:
:param field:
:param direction:
"""
column = get_model_column(model, underscore(field))
return query.order_by(column.desc() if direction == 'desc' else column.asc())
2015-06-22 22:47:27 +02:00
def paginate(query, page, count):
"""
Returns the items given the count and page specified
:param query:
:param page:
:param count:
"""
return query.paginate(page, count)
def update_list(model, model_attr, item_model, items):
"""
Helper that correctly updates a models items
depending on what has changed
:param model_attr:
:param item_model:
:param items:
:param model:
:return:
"""
ids = []
for i in getattr(model, model_attr):
if i.id not in ids:
getattr(model, model_attr).remove(i)
for i in items:
for item in getattr(model, model_attr):
if item.id == i['id']:
break
else:
getattr(model, model_attr).append(get(item_model, i['id']))
return model
def clone(model):
"""
Clones the given model and removes it's primary key
:param model:
:return:
"""
db.session.expunge(model)
make_transient(model)
model.id = None
return model
2018-05-24 21:55:52 +02:00
def get_count(q):
2018-05-25 00:21:38 +02:00
"""
Count the number of rows in a table. More efficient than count(*)
:param q:
:return:
"""
2018-05-24 21:55:52 +02:00
count_q = q.statement.with_only_columns([func.count()]).order_by(None)
count = q.session.execute(count_q).scalar()
return count
2015-06-22 22:47:27 +02:00
def sort_and_page(query, model, args):
"""
Helper that allows us to combine sorting and paging
:param query:
:param model:
:param args:
:return:
"""
sort_by = args.pop('sort_by')
sort_dir = args.pop('sort_dir')
page = args.pop('page')
count = args.pop('count')
if args.get('user'):
user = args.pop('user')
2015-06-22 22:47:27 +02:00
query = find_all(query, model, args)
if sort_by and sort_dir:
query = sort(query, model, sort_by, sort_dir)
2018-05-24 21:55:52 +02:00
total = get_count(query)
# offset calculated at zero
page -= 1
items = query.offset(count * page).limit(count).all()
return dict(items=items, total=total)