Merge pull request #2099 from castrapel/count_accurate

More accurate db count functionality
This commit is contained in:
Curtis 2018-11-13 09:26:14 -08:00 committed by GitHub
commit d5bf85b3b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 27 additions and 3 deletions

View File

@ -10,8 +10,8 @@
.. moduleauthor:: Kevin Glisson <kglisson@netflix.com> .. moduleauthor:: Kevin Glisson <kglisson@netflix.com>
""" """
from inflection import underscore from inflection import underscore
from sqlalchemy import exc, func from sqlalchemy import exc, func, distinct
from sqlalchemy.orm import make_transient from sqlalchemy.orm import make_transient, lazyload
from sqlalchemy.sql import and_, or_ from sqlalchemy.sql import and_, or_
from lemur.exceptions import AttrNotFound, DuplicateError from lemur.exceptions import AttrNotFound, DuplicateError
@ -273,7 +273,31 @@ def get_count(q):
:param q: :param q:
:return: :return:
""" """
count_q = q.statement.with_only_columns([func.count()]).group_by(None).order_by(None) disable_group_by = False
if len(q._entities) > 1:
# currently support only one entity
raise Exception('only one entity is supported for get_count, got: %s' % q)
entity = q._entities[0]
if hasattr(entity, 'column'):
# _ColumnEntity has column attr - on case: query(Model.column)...
col = entity.column
if q._group_by and q._distinct:
# which query can have both?
raise NotImplementedError
if q._group_by or q._distinct:
col = distinct(col)
if q._group_by:
# need to disable group_by and enable distinct - we can do this because we have only 1 entity
disable_group_by = True
count_func = func.count(col)
else:
# _MapperEntity doesn't have column attr - on case: query(Model)...
count_func = func.count()
if q._group_by and not disable_group_by:
count_func = count_func.over(None)
count_q = q.options(lazyload('*')).statement.with_only_columns([count_func]).order_by(None)
if disable_group_by:
count_q = count_q.group_by(None)
count = q.session.execute(count_q).scalar() count = q.session.execute(count_q).scalar()
return count return count