""" .. module: lemur.schemas :platform: unix :copyright: (c) 2018 by Netflix Inc., see AUTHORS for more :license: Apache, see LICENSE for more details. .. moduleauthor:: Kevin Glisson """ from sqlalchemy.orm.exc import NoResultFound from marshmallow import fields, post_load, pre_load, post_dump from marshmallow.exceptions import ValidationError from lemur.common import validators from lemur.common.schema import LemurSchema, LemurInputSchema, LemurOutputSchema from lemur.common.fields import ( KeyUsageExtension, ExtendedKeyUsageExtension, BasicConstraintsExtension, SubjectAlternativeNameExtension, ) from lemur.plugins import plugins from lemur.plugins.utils import get_plugin_option from lemur.roles.models import Role from lemur.users.models import User from lemur.authorities.models import Authority from lemur.dns_providers.models import DnsProvider from lemur.policies.models import RotationPolicy from lemur.certificates.models import Certificate from lemur.destinations.models import Destination from lemur.notifications.models import Notification def validate_options(options): """ Ensures that the plugin options are valid. :param options: :return: """ interval = get_plugin_option("interval", options) unit = get_plugin_option("unit", options) if not interval and not unit: return if unit == "month": interval *= 30 elif unit == "week": interval *= 7 if interval > 90: raise ValidationError( "Notification cannot be more than 90 days into the future." ) def get_object_attribute(data, many=False): if many: ids = [d.get("id") for d in data] names = [d.get("name") for d in data] if None in ids: if None in names: raise ValidationError("Associated object require a name or id.") else: return "name" return "id" else: if data.get("id"): return "id" elif data.get("name"): return "name" else: raise ValidationError("Associated object require a name or id.") def fetch_objects(model, data, many=False): attr = get_object_attribute(data, many=many) if many: values = [v[attr] for v in data] items = model.query.filter(getattr(model, attr).in_(values)).all() found = [getattr(i, attr) for i in items] diff = set(values).symmetric_difference(set(found)) if diff: raise ValidationError( "Unable to locate {model} with {attr} {diff}".format( model=model, attr=attr, diff=",".join(list(diff)) ) ) return items else: try: return model.query.filter(getattr(model, attr) == data[attr]).one() except NoResultFound: raise ValidationError( "Unable to find {model} with {attr}: {data}".format( model=model, attr=attr, data=data[attr] ) ) class AssociatedAuthoritySchema(LemurInputSchema): id = fields.Int() name = fields.String() @post_load def get_object(self, data, many=False): return fetch_objects(Authority, data, many=many) class AssociatedDnsProviderSchema(LemurInputSchema): id = fields.Int() name = fields.String() @post_load def get_object(self, data, many=False): return fetch_objects(DnsProvider, data, many=many) class AssociatedRoleSchema(LemurInputSchema): id = fields.Int() name = fields.String() @post_load def get_object(self, data, many=False): return fetch_objects(Role, data, many=many) class AssociatedDestinationSchema(LemurInputSchema): id = fields.Int() name = fields.String() @post_load def get_object(self, data, many=False): return fetch_objects(Destination, data, many=many) class AssociatedNotificationSchema(LemurInputSchema): id = fields.Int() name = fields.String() @post_load def get_object(self, data, many=False): return fetch_objects(Notification, data, many=many) class AssociatedCertificateSchema(LemurInputSchema): id = fields.Int() name = fields.String() @post_load def get_object(self, data, many=False): return fetch_objects(Certificate, data, many=many) class AssociatedUserSchema(LemurInputSchema): id = fields.Int() name = fields.String() @post_load def get_object(self, data, many=False): return fetch_objects(User, data, many=many) class AssociatedRotationPolicySchema(LemurInputSchema): id = fields.Int() name = fields.String() @post_load def get_object(self, data, many=False): return fetch_objects(RotationPolicy, data, many=many) class PluginInputSchema(LemurInputSchema): plugin_options = fields.List(fields.Dict(), validate=validate_options) slug = fields.String(required=True) title = fields.String() description = fields.String() @post_load def get_object(self, data, many=False): try: data["plugin_object"] = plugins.get(data["slug"]) # parse any sub-plugins for option in data.get("plugin_options", []): if "plugin" in option.get("type", []): sub_data, errors = PluginInputSchema().load(option["value"]) option["value"] = sub_data return data except Exception as e: raise ValidationError( "Unable to find plugin. Slug: {0} Reason: {1}".format(data["slug"], e) ) class PluginOutputSchema(LemurOutputSchema): id = fields.Integer() label = fields.String() description = fields.String() active = fields.Boolean() options = fields.List(fields.Dict(), dump_to="pluginOptions") slug = fields.String() title = fields.String() plugins_output_schema = PluginOutputSchema(many=True) plugin_output_schema = PluginOutputSchema class BaseExtensionSchema(LemurSchema): @pre_load(pass_many=True) def preprocess(self, data, many): return self.under(data, many=many) @post_dump(pass_many=True) def post_process(self, data, many): if data: data = self.camel(data, many=many) return data class AuthorityKeyIdentifierSchema(BaseExtensionSchema): use_key_identifier = fields.Boolean() use_authority_cert = fields.Boolean() class CertificateInfoAccessSchema(BaseExtensionSchema): include_aia = fields.Boolean() @post_dump def handle_keys(self, data): return {"includeAIA": data["include_aia"]} class CRLDistributionPointsSchema(BaseExtensionSchema): include_crl_dp = fields.String() @post_dump def handle_keys(self, data): return {"includeCRLDP": data["include_crl_dp"]} class SubjectKeyIdentifierSchema(BaseExtensionSchema): include_ski = fields.Boolean() @post_dump def handle_keys(self, data): return {"includeSKI": data["include_ski"]} class CustomOIDSchema(BaseExtensionSchema): oid = fields.String() encoding = fields.String(validate=validators.encoding) value = fields.String() is_critical = fields.Boolean() class NamesSchema(BaseExtensionSchema): names = SubjectAlternativeNameExtension() class ExtensionSchema(BaseExtensionSchema): basic_constraints = ( BasicConstraintsExtension() ) # some devices balk on default basic constraints key_usage = KeyUsageExtension() extended_key_usage = ExtendedKeyUsageExtension() subject_key_identifier = fields.Nested(SubjectKeyIdentifierSchema) sub_alt_names = fields.Nested(NamesSchema) authority_key_identifier = fields.Nested(AuthorityKeyIdentifierSchema) certificate_info_access = fields.Nested(CertificateInfoAccessSchema) crl_distribution_points = fields.Nested( CRLDistributionPointsSchema, dump_to="cRL_distribution_points" ) # FIXME: Convert custom OIDs to a custom field in fields.py like other Extensions # FIXME: Remove support in UI for Critical custom extensions https://github.com/Netflix/lemur/issues/665 custom = fields.List(fields.Nested(CustomOIDSchema)) class EndpointNestedOutputSchema(LemurOutputSchema): __envelope__ = False id = fields.Integer() description = fields.String() name = fields.String() dnsname = fields.String() owner = fields.Email() type = fields.String() active = fields.Boolean()