Fix filtering on boolean columns, broken with SQLAlchemy 1.2 upgrade

SQLAlchemy 1.2 does not allow comparing string values to boolean
columns. This caused errors like:

    sqlalchemy.exc.StatementError: (builtins.TypeError) Not a boolean value: 'true'

For more details see http://docs.sqlalchemy.org/en/latest/changelog/migration_12.html#boolean-datatype-now-enforces-strict-true-false-none-values
This commit is contained in:
Marti Raudsepp 2018-04-02 18:33:51 +03:00
parent b4b9a913b3
commit 8e2b2123f1
7 changed files with 30 additions and 14 deletions

View File

@ -9,6 +9,7 @@
""" """
from lemur import database from lemur import database
from lemur.common.utils import truthiness
from lemur.extensions import metrics from lemur.extensions import metrics
from lemur.authorities.models import Authority from lemur.authorities.models import Authority
from lemur.roles import service as role_service from lemur.roles import service as role_service
@ -170,8 +171,8 @@ def render(args):
if filt: if filt:
terms = filt.split(';') terms = filt.split(';')
if 'active' in filt: # this is really weird but strcmp seems to not work here?? if 'active' in filt:
query = query.filter(Authority.active == terms[1]) query = query.filter(Authority.active == truthiness(terms[1]))
else: else:
query = database.filter(query, Authority, terms) query = database.filter(query, Authority, terms)

View File

@ -8,7 +8,7 @@
import arrow import arrow
from flask import current_app from flask import current_app
from sqlalchemy import func, or_, not_, cast, Boolean, Integer from sqlalchemy import func, or_, not_, cast, Integer
from cryptography import x509 from cryptography import x509
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
@ -17,7 +17,7 @@ from cryptography.hazmat.primitives import hashes, serialization
from lemur import database from lemur import database
from lemur.extensions import metrics, signals from lemur.extensions import metrics, signals
from lemur.plugins.base import plugins from lemur.plugins.base import plugins
from lemur.common.utils import generate_private_key from lemur.common.utils import generate_private_key, truthiness
from lemur.roles.models import Role from lemur.roles.models import Role
from lemur.domains.models import Domain from lemur.domains.models import Domain
@ -319,9 +319,9 @@ def render(args):
elif 'destination' in terms: elif 'destination' in terms:
query = query.filter(Certificate.destinations.any(Destination.id == terms[1])) query = query.filter(Certificate.destinations.any(Destination.id == terms[1]))
elif 'notify' in filt: elif 'notify' in filt:
query = query.filter(Certificate.notify == cast(terms[1], Boolean)) query = query.filter(Certificate.notify == truthiness(terms[1]))
elif 'active' in filt: elif 'active' in filt:
query = query.filter(Certificate.active == terms[1]) query = query.filter(Certificate.active == truthiness(terms[1]))
elif 'cn' in terms: elif 'cn' in terms:
query = query.filter( query = query.filter(
or_( or_(

View File

@ -175,3 +175,9 @@ def windowed_query(q, column, windowsize):
column, windowsize): column, windowsize):
for row in q.filter(whereclause).order_by(column): for row in q.filter(whereclause).order_by(column):
yield row yield row
def truthiness(s):
"""If input string resembles something truthy then return True, else False."""
return s.lower() in ('true', 'yes', 'on', 't', '1')

View File

@ -13,6 +13,7 @@ import arrow
from sqlalchemy import func from sqlalchemy import func
from lemur import database from lemur import database
from lemur.common.utils import truthiness
from lemur.endpoints.models import Endpoint, Policy, Cipher from lemur.endpoints.models import Endpoint, Policy, Cipher
from lemur.extensions import metrics from lemur.extensions import metrics
@ -142,7 +143,7 @@ def render(args):
if filt: if filt:
terms = filt.split(';') terms = filt.split(';')
if 'active' in filt: # this is really weird but strcmp seems to not work here?? if 'active' in filt: # this is really weird but strcmp seems to not work here??
query = query.filter(Endpoint.active == terms[1]) query = query.filter(Endpoint.active == truthiness(terms[1]))
elif 'port' in filt: elif 'port' in filt:
if terms[1] != 'null': # ng-table adds 'null' if a number is removed if terms[1] != 'null': # ng-table adds 'null' if a number is removed
query = query.filter(Endpoint.port == terms[1]) query = query.filter(Endpoint.port == terms[1])

View File

@ -12,6 +12,7 @@ from flask import current_app
from lemur import database from lemur import database
from lemur.certificates.models import Certificate from lemur.certificates.models import Certificate
from lemur.common.utils import truthiness
from lemur.notifications.models import Notification from lemur.notifications.models import Notification
@ -169,10 +170,8 @@ def render(args):
if filt: if filt:
terms = filt.split(';') terms = filt.split(';')
if terms[0] == 'active' and terms[1] == 'false': if terms[0] == 'active':
query = query.filter(Notification.active == False) # noqa query = query.filter(Notification.active == truthiness(terms[1]))
elif terms[0] == 'active' and terms[1] == 'true':
query = query.filter(Notification.active == True) # noqa
else: else:
query = database.filter(query, Notification, terms) query = database.filter(query, Notification, terms)

View File

@ -5,9 +5,10 @@
""" """
import arrow import arrow
from sqlalchemy import or_, cast, Boolean, Integer from sqlalchemy import or_, cast, Integer
from lemur import database from lemur import database
from lemur.common.utils import truthiness
from lemur.plugins.base import plugins from lemur.plugins.base import plugins
from lemur.roles.models import Role from lemur.roles.models import Role
@ -181,9 +182,9 @@ def render(args):
elif 'destination' in terms: elif 'destination' in terms:
query = query.filter(PendingCertificate.destinations.any(Destination.id == terms[1])) query = query.filter(PendingCertificate.destinations.any(Destination.id == terms[1]))
elif 'notify' in filt: elif 'notify' in filt:
query = query.filter(PendingCertificate.notify == cast(terms[1], Boolean)) query = query.filter(PendingCertificate.notify == truthiness(terms[1]))
elif 'active' in filt: elif 'active' in filt:
query = query.filter(PendingCertificate.active == terms[1]) query = query.filter(PendingCertificate.active == truthiness(terms[1]))
elif 'cn' in terms: elif 'cn' in terms:
query = query.filter( query = query.filter(
or_( or_(

View File

@ -717,3 +717,11 @@ def test_certificates_upload_patch(client, token, status):
def test_sensitive_sort(client): def test_sensitive_sort(client):
resp = client.get(api.url_for(CertificatesList) + '?sortBy=private_key&sortDir=asc', headers=VALID_ADMIN_HEADER_TOKEN) resp = client.get(api.url_for(CertificatesList) + '?sortBy=private_key&sortDir=asc', headers=VALID_ADMIN_HEADER_TOKEN)
assert "'private_key' is not sortable or filterable" in resp.json['message'] assert "'private_key' is not sortable or filterable" in resp.json['message']
def test_boolean_filter(client):
resp = client.get(api.url_for(CertificatesList) + '?filter=notify;true', headers=VALID_ADMIN_HEADER_TOKEN)
assert resp.status_code == 200
# Also don't crash with invalid input (we currently treat that as false)
resp = client.get(api.url_for(CertificatesList) + '?filter=notify;whatisthis', headers=VALID_ADMIN_HEADER_TOKEN)
assert resp.status_code == 200