""" .. module: lemur.database :platform: Unix :synopsis: This module contains all of the database related methods needed for lemur to interact with a datastore :copyright: (c) 2018 by Netflix Inc., see AUTHORS for more :license: Apache, see LICENSE for more details. .. moduleauthor:: Kevin Glisson """ from inflection import underscore from sqlalchemy import exc, func, distinct from sqlalchemy.orm import make_transient, lazyload from sqlalchemy.sql import and_, or_ from lemur.exceptions import AttrNotFound, DuplicateError from lemur.extensions import db def filter_none(kwargs): """ Remove all `None` values froma given dict. SQLAlchemy does not like to have values that are None passed to it. :param kwargs: Dict to filter :return: Dict without any 'None' values """ n_kwargs = {} for k, v in kwargs.items(): if v: n_kwargs[k] = v return n_kwargs def session_query(model): """ Returns a SQLAlchemy query object for the specified `model`. If `model` has a ``query`` attribute already, that object will be returned. Otherwise a query will be created and returned based on `session`. :param model: sqlalchemy model :return: query object for model """ return model.query if hasattr(model, "query") else db.session.query(model) def create_query(model, kwargs): """ Returns a SQLAlchemy query object for specified `model`. Model filtered by the kwargs passed. :param model: :param kwargs: :return: """ s = session_query(model) return s.filter_by(**kwargs) def commit(): """ Helper to commit the current session. """ db.session.commit() def add(model): """ Helper to add a `model` to the current session. :param model: :return: """ db.session.add(model) def get_model_column(model, field): if field in getattr(model, "sensitive_fields", ()): raise AttrNotFound(field) column = model.__table__.columns._data.get(field, None) if column is None: raise AttrNotFound(field) return column def find_all(query, model, kwargs): """ Returns a query object that ensures that all kwargs are present. :param query: :param model: :param kwargs: :return: """ conditions = [] kwargs = filter_none(kwargs) for attr, value in kwargs.items(): if not isinstance(value, list): value = value.split(",") conditions.append(get_model_column(model, attr).in_(value)) return query.filter(and_(*conditions)) def find_any(query, model, kwargs): """ Returns a query object that allows any kwarg to be present. :param query: :param model: :param kwargs: :return: """ or_args = [] for attr, value in kwargs.items(): or_args.append(or_(get_model_column(model, attr) == value)) exprs = or_(*or_args) return query.filter(exprs) def get(model, value, field="id"): """ Returns one object filtered by the field and value. :param model: :param value: :param field: :return: """ query = session_query(model) return query.filter(get_model_column(model, field) == value).scalar() def get_all(model, value, field="id"): """ Returns query object with the fields and value filtered. :param model: :param value: :param field: :return: """ query = session_query(model) return query.filter(get_model_column(model, field) == value) def create(model): """ Helper that attempts to create a new instance of an object. :param model: :return: :raise IntegrityError: """ try: db.session.add(model) commit() except exc.IntegrityError as e: raise DuplicateError(e.orig.diag.message_detail) db.session.refresh(model) return model def update(model): """ Helper that attempts to update a model. :param model: :return: """ commit() db.session.refresh(model) return model def delete(model): """ Helper that attempts to delete a model. :param model: """ if model: db.session.delete(model) db.session.commit() def filter(query, model, terms): """ Helper that searched for 'like' strings in column values. :param query: :param model: :param terms: :return: """ column = get_model_column(model, underscore(terms[0])) return query.filter(column.ilike("%{}%".format(terms[1]))) def sort(query, model, field, direction): """ Returns objects of the specified `model` in the field and direction given :param query: :param model: :param field: :param direction: """ column = get_model_column(model, underscore(field)) return query.order_by(column.desc() if direction == "desc" else column.asc()) def paginate(query, page, count): """ Returns the items given the count and page specified :param query: :param page: :param count: """ total = get_count(query) items = query.paginate(page, count).items return dict(items=items, total=total, current=len(items)) def update_list(model, model_attr, item_model, items): """ Helper that correctly updates a models items depending on what has changed :param model_attr: :param item_model: :param items: :param model: :return: """ ids = [] for i in getattr(model, model_attr): if i.id not in ids: getattr(model, model_attr).remove(i) for i in items: for item in getattr(model, model_attr): if item.id == i["id"]: break else: getattr(model, model_attr).append(get(item_model, i["id"])) return model def clone(model): """ Clones the given model and removes it's primary key :param model: :return: """ db.session.expunge(model) make_transient(model) model.id = None return model def get_count(q): """ Count the number of rows in a table. More efficient than count(*) :param q: :return: """ disable_group_by = False if len(q._entities) > 1: # currently support only one entity raise Exception("only one entity is supported for get_count, got: %s" % q) entity = q._entities[0] if hasattr(entity, "column"): # _ColumnEntity has column attr - on case: query(Model.column)... col = entity.column if q._group_by and q._distinct: # which query can have both? raise NotImplementedError if q._group_by or q._distinct: col = distinct(col) if q._group_by: # need to disable group_by and enable distinct - we can do this because we have only 1 entity disable_group_by = True count_func = func.count(col) else: # _MapperEntity doesn't have column attr - on case: query(Model)... count_func = func.count() if q._group_by and not disable_group_by: count_func = count_func.over(None) count_q = ( q.options(lazyload("*")) .statement.with_only_columns([count_func]) .order_by(None) ) if disable_group_by: count_q = count_q.group_by(None) count = q.session.execute(count_q).scalar() return count def sort_and_page(query, model, args): """ Helper that allows us to combine sorting and paging :param query: :param model: :param args: :return: """ sort_by = args.pop("sort_by") sort_dir = args.pop("sort_dir") page = args.pop("page") count = args.pop("count") if args.get("user"): user = args.pop("user") query = find_all(query, model, args) if sort_by and sort_dir: query = sort(query, model, sort_by, sort_dir) total = get_count(query) # offset calculated at zero page -= 1 items = query.offset(count * page).limit(count).all() return dict(items=items, total=total)