from tiramisu import Config from inspect import signature from traceback import print_exc from copy import copy from typing import Dict, Callable from .utils import undefined, _ from .error import RegistrationError, CallError, NotAllowedError from .message import get_messages from .logger import log from .config import DEBUG from .context import Context def register(uri: str, notification: str=undefined): """ Decorator to register function to the dispatcher """ version, uri = uri.split('.', 1) def decorator(function): dispatcher.set_function(version, uri, notification, function) return decorator class RegisterDispatcher: def get_function_args(self, function): # remove self first_argument_index = 1 return [param.name for param in list(signature(function).parameters.values())[first_argument_index:]] def _valid_rpc_params(self, version, uri, function, module_name): """ parameters function must have strictly all arguments with the correct name """ def get_message_args(): # load config config = Config(self.option) config.property.read_write() # set message to the uri name config.option('message').value.set(uri) # get message argument subconfig = config.option(uri) return set(config.option(uri).value.dict().keys()) def get_function_args(): function_args = self.get_function_args(function) # risotto_context is a special argument, remove it if function_args[0] == 'risotto_context': function_args = function_args[1:] return set(function_args) # get message arguments message_args = get_message_args() # get function arguments function_args = get_function_args() # compare message arguments with function parameter # it must not have more or less arguments if message_args != function_args: # raise if arguments are not equal msg = [] missing_function_args = message_args - function_args if missing_function_args: msg.append(_(f'missing arguments: {missing_function_args}')) extra_function_args = function_args - message_args if extra_function_args: msg.append(_(f'extra arguments: {extra_function_args}')) function_name = function.__name__ msg = _(' and ').join(msg) raise RegistrationError(_(f'error with {module_name}.{function_name} arguments: {msg}')) def _valid_event_params(self, version, uri, function, module_name): """ parameters function validation for event messages """ def get_message_args(): # load config config = Config(self.option) config.property.read_write() # set message to the uri name config.option('message').value.set(uri) # get message argument subconfig = config.option(uri) return set(config.option(uri).value.dict().keys()) def get_function_args(): function_args = self.get_function_args(function) # risotto_context is a special argument, remove it if function_args[0] == 'risotto_context': function_args = function_args[1:] return set(function_args) # get message arguments message_args = get_message_args() # get function arguments function_args = get_function_args() # compare message arguments with function parameter # it can have less arguments but not more extra_function_args = function_args - message_args if extra_function_args: # raise if too many arguments function_name = function.__name__ msg = _(f'extra arguments: {extra_function_args}') raise RegistrationError(_(f'error with {module_name}.{function_name} arguments: {msg}')) def set_function(self, version: str, uri: str, notification: str, function: Callable): """ register a function to an URI URI is a message """ # xxx module can only be register with v1.xxxx..... message module_name = function.__module__.split('.')[-2] uri_namespace = uri.split('.', 1)[0] if uri_namespace != module_name: raise RegistrationError(_(f'cannot registered to {uri} message in module {module_name}')) # check if message exists try: if not Config(self.option).option(uri).option.type() == 'message': raise RegistrationError(_(f'{uri} is not a valid message')) except AttributeError: raise RegistrationError(_(f'{uri} is not a valid message')) # create an uris' version if needed if version not in self.uris: self.uris[version] = {} self.function_names[version] = {} # valid function is unique per module if module_name not in self.function_names[version]: self.function_names[version][module_name] = [] function_name = function.__name__ if function_name in self.function_names[version][module_name]: raise RegistrationError(_(f'multiple registration of {module_name}.{function_name} function')) self.function_names[version][module_name].append(function_name) # True if first argument is the risotto_context function_args = self.get_function_args(function) if function_args[0] == 'risotto_context': inject_risotto_context = True function_args.pop(0) else: inject_risotto_context = False if self.messages[uri]['pattern'] == 'rpc': # check if a RPC function is already register for this uri if uri in self.uris[version]: raise RegistrationError(_(f'uri {uri} already registered')) # valid function's arguments self._valid_rpc_params(version, uri, function, module_name) # register this function dico = {'module': module_name, 'function': function, 'risotto_context': inject_risotto_context} if notification is undefined: raise RegistrationError(_('notification is mandatory when registered {uri} with {module_name}.{function_name} even if you set None')) if notification: dico['notification'] = notification self.uris[version][uri] = dico else: # if event if notification and notification is not undefined: raise RegistrationError(_(f'notification not supported yet')) # valid function's arguments self._valid_event_params(version, uri, function, module_name) # register this function if uri not in self.uris[version]: self.uris[version][uri] = [] self.uris[version][uri].append({'module': module_name, 'function': function, 'arguments': function_args, 'risotto_context': inject_risotto_context}) def set_module(self, module_name, module): """ register and instanciate a new module """ try: self.injected_self[module_name] = module.Risotto() except AttributeError as err: raise RegistrationError(_(f'unable to register the module {module_name}, this module must have Risotto class')) def validate(self): """ check if all messages have a function """ # FIXME only v1 supported missing_messages = set(self.messages.keys()) - set(self.uris['v1'].keys()) if missing_messages: raise RegistrationError(_(f'missing uri {missing_messages}')) class Dispatcher(RegisterDispatcher): """ Manage message (call or publish) so launch a function when a message is called """ def __init__(self): # reference to instanciate module (to inject self in method): {"module_name": instance_of_module} self.injected_self = {} # list of uris with informations: {"v1": {"module_name.xxxxx": yyyyyy}} self.uris = {} # all function for a module, to avoid conflict name {"v1": {"module_name": ["function_name"]}} self.function_names = {} self.messages, self.option = get_messages() def new_context(self, context: Context, version: str, uri: str): new_context = Context() new_context.paths = copy(context.paths) new_context.paths.append(version + '.' + uri) new_context.username = context.username return new_context def check_public_function(self, version: str, uri: str, context: Context, kwargs: Dict, public_only: bool): if public_only and not self.messages[uri]['public']: msg = _(f'the message {version}.{uri} is private') log.error_msg(version, uri, context, kwargs, 'call', msg) raise NotAllowedError(msg) def check_pattern(self, version: str, uri: str, type: str, context: Context, kwargs: Dict): if self.messages[uri]['pattern'] != type: msg = _(f'{version}.{uri} is not a {type} message') log.error_msg(version, uri, context, kwargs, 'call', msg) raise CallError(msg) def set_config(self, uri: str, kwargs: Dict): """ create a new Config et set values to it """ # create a new config config = Config(self.option) config.property.read_write() # set message option config.option('message').value.set(uri) # store values subconfig = config.option(uri) for key, value in kwargs.items(): try: subconfig.option(key).value.set(value) except AttributeError: raise AttributeError(_(f'unknown parameter "{key}"')) # check mandatories options config.property.read_only() mandatories = list(config.value.mandatory()) if mandatories: mand = [mand.split('.')[-1] for mand in mandatories] raise ValueError(_(f'missing parameters: {mand}')) # return the config return config def valid_call_returns(self, function: Callable, returns: Dict, version: str, uri:str, context: Context, kwargs: Dict): if not isinstance(returns, dict): module_name = function.__module__.split('.')[-2] function_name = function.__name__ err = _(f'function {module_name}.{function_name} has to return a dict') log.error_msg(version, uri, context, kwargs, 'call', err) raise CallError(str(err)) response = self.messages[uri]['response'] if response is None: raise Exception('hu?') else: config = Config(response, display_name=lambda self, dyn_name: self.impl_getname()) config.property.read_write() try: for key, value in returns.items(): config.option(key).value.set(value) except AttributeError: module_name = function.__module__.split('.')[-2] function_name = function.__name__ err = _(f'function {module_name}.{function_name} return the unknown parameter "{key}"') log.error_msg(version, uri, context, kwargs, 'call', err) raise CallError(str(err)) except ValueError: module_name = function.__module__.split('.')[-2] function_name = function.__name__ err = _(f'function {module_name}.{function_name} return the parameter "{key}" with an unvalid value "{value}"') log.error_msg(version, uri, context, kwargs, 'call', err) raise CallError(str(err)) config.property.read_only() try: config.value.dict() except Exception as err: module_name = function.__module__.split('.')[-2] function_name = function.__name__ err = _(f'function {module_name}.{function_name} return an invalid response {err}') log.error_msg(version, uri, context, kwargs, 'call', err) raise CallError(str(err)) async def call(self, version, uri, risotto_context, public_only=False, **kwargs): """ execute the function associate with specified uri arguments are validate before """ new_context = self.new_context(risotto_context, version, uri) self.check_public_function(version, uri, new_context, kwargs, public_only) self.check_pattern(version, uri, 'rpc', new_context, kwargs) try: config = self.set_config(uri, kwargs) obj = self.uris[version][uri] kw = config.option(uri).value.dict() if obj['risotto_context']: kw['risotto_context'] = new_context returns = await obj['function'](self.injected_self[obj['module']], **kw) except CallError as err: raise err except Exception as err: if DEBUG: print_exc() log.error_msg(version, uri, new_context, kwargs, 'call', err) raise CallError(str(err)) # valid returns self.valid_call_returns(obj['function'], returns, version, uri, new_context, kwargs) # log the success log.info_msg(version, uri, new_context, kwargs, 'call', _(f'returns {returns}')) # notification if obj.get('notification'): notif_version, notif_message = obj['notification'].split('.', 1) await self.publish(notif_version, notif_message, new_context, **returns) return returns async def publish(self, version, uri, risotto_context, public_only=False, **kwargs): new_context = self.new_context(risotto_context, version, uri) self.check_pattern(version, uri, 'event', new_context, kwargs) try: config = self.set_config(uri, kwargs) config_arguments = config.option(uri).value.dict() except CallError as err: return except Exception as err: # if there is a problem with arguments, just send an error et do nothing if DEBUG: print_exc() log.error_msg(version, uri, new_context, kwargs, 'publish', err) return # config is ok, so publish the message for function_obj in self.uris[version][uri]: function = function_obj['function'] module_name = function.__module__.split('.')[-2] function_name = function.__name__ info_msg = _(f'in module {module_name}.{function_name}') try: # build argument for this function kw = {} for key, value in config_arguments.items(): if key in function_obj['arguments']: kw[key] = value if function_obj['risotto_context']: kw['risotto_context'] = new_context # send event await function(self.injected_self[function_obj['module']], **kw) except Exception as err: if DEBUG: print_exc() log.error_msg(version, uri, new_context, kwargs, 'publish', err, info_msg) else: module_name = function.__module__.split('.')[-2] function_name = function.__name__ log.info_msg(version, uri, new_context, kwargs,'publish', info_msg) dispatcher = Dispatcher()