diff --git a/lemur/certificates/service.py b/lemur/certificates/service.py index 0bd50694..c8a5365b 100644 --- a/lemur/certificates/service.py +++ b/lemur/certificates/service.py @@ -54,7 +54,7 @@ def get_by_name(name): def get_by_serial(serial): """ - Retrieves certificate by it's Serial. + Retrieves certificate(s) by serial number. :param serial: :return: """ @@ -64,6 +64,22 @@ def get_by_serial(serial): return Certificate.query.filter(Certificate.serial == serial).all() +def get_by_attributes(conditions): + """ + Retrieves certificate(s) by conditions given in a hash of given key=>value pairs. + :param serial: + :return: + """ + # Ensure that each of the given conditions corresponds to actual columns + # if not, silently remove it + for attr in conditions.keys(): + if attr not in Certificate.__table__.columns: + conditions.pop(attr) + + query = database.session_query(Certificate) + return database.find_all(query, Certificate, conditions).all() + + def delete(cert_id): """ Delete's a certificate. diff --git a/lemur/sources/service.py b/lemur/sources/service.py index 227f1bce..5002041c 100644 --- a/lemur/sources/service.py +++ b/lemur/sources/service.py @@ -116,7 +116,12 @@ def sync_certificates(source, user): for certificate in certificates: exists = False - if certificate.get('name'): + + if certificate.get('search', None): + conditions = certificate.pop('search') + exists = certificate_service.get_by_attributes(conditions) + + if not exists and certificate.get('name'): result = certificate_service.get_by_name(certificate['name']) if result: exists = [result] diff --git a/lemur/tests/test_certificates.py b/lemur/tests/test_certificates.py index 1a4d644b..0f46e4a5 100644 --- a/lemur/tests/test_certificates.py +++ b/lemur/tests/test_certificates.py @@ -41,6 +41,89 @@ def test_get_or_increase_name(session, certificate): assert get_or_increase_name('certificate1', int(serial, 16)) == 'certificate1-{}-1'.format(serial) +def test_get_all_certs(session, certificate): + from lemur.certificates.service import get_all_certs + assert len(get_all_certs()) > 1 + + +def test_get_by_name(session, certificate): + from lemur.certificates.service import get_by_name + + found = get_by_name(certificate.name) + + assert found + + +def test_get_by_serial(session, certificate): + from lemur.certificates.service import get_by_serial + + found = get_by_serial(certificate.serial) + + assert found + + +def test_delete_cert(session): + from lemur.certificates.service import delete, get + from lemur.tests.factories import CertificateFactory + + delete_this = CertificateFactory(name='DELETEME') + session.commit() + + cert_exists = get(delete_this.id) + + # it needs to exist first + assert cert_exists + + delete(delete_this.id) + cert_exists = get(delete_this.id) + + # then not exist after delete + assert not cert_exists + + +def test_get_by_attributes(session, certificate): + from lemur.certificates.service import get_by_attributes + + # Should get one cert + certificate1 = get_by_attributes({ + 'name': 'SAN-san.example.org-LemurTrustUnittestsClass1CA2018-20171231-20471231' + }) + + # Should get one cert using multiple attrs + certificate2 = get_by_attributes({ + 'name': 'test-cert-11111111-1', + 'cn': 'san.example.org' + }) + + # Should get multiple certs + multiple = get_by_attributes({ + 'cn': 'LemurTrust Unittests Class 1 CA 2018', + 'issuer': 'LemurTrustUnittestsRootCA2018' + }) + + assert len(certificate1) == 1 + assert len(certificate2) == 1 + assert len(multiple) > 1 + + +def test_find_duplicates(session): + from lemur.certificates.service import find_duplicates + + cert = { + 'body': SAN_CERT_STR, + 'chain': INTERMEDIATE_CERT_STR + } + + dups1 = find_duplicates(cert) + + cert['chain'] = '' + + dups2 = find_duplicates(cert) + + assert len(dups1) > 0 + assert len(dups2) > 0 + + def test_get_certificate_primitives(certificate): from lemur.certificates.service import get_certificate_primitives