from tiramisu import Config from traceback import print_exc from copy import copy from typing import Dict, Callable from .utils import _ from .error import CallError, NotAllowedError from .logger import log from .config import DEBUG from .config import get_config from .context import Context from . import register import asyncpg class CallDispatcher: def valid_public_function(self, risotto_context: Context, kwargs: Dict, public_only: bool): if public_only and not self.messages[risotto_context.version][risotto_context.message]['public']: msg = _(f'the message {risotto_context.message} is private') log.error_msg(risotto_context, kwargs, msg) raise NotAllowedError(msg) def valid_call_returns(self, risotto_context: Context, returns: Dict, kwargs: Dict): response = self.messages[risotto_context.version][risotto_context.message]['response'] module_name = risotto_context.function.__module__.split('.')[-2] function_name = risotto_context.function.__name__ if response.impl_get_information('multi'): if not isinstance(returns, list): err = _(f'function {module_name}.{function_name} has to return a list') log.error_msg(risotto_context, kwargs, err) raise CallError(str(err)) else: if not isinstance(returns, dict): log.error_msg(risotto_context, kwargs, returns) err = _(f'function {module_name}.{function_name} has to return a dict') log.error_msg(risotto_context, kwargs, err) raise CallError(str(err)) returns = [returns] if response is None: raise Exception('hu?') else: for ret in returns: config = Config(response, display_name=lambda self, dyn_name: self.impl_getname()) try: for key, value in ret.items(): config.option(key).value.set(value) except AttributeError: err = _(f'function {module_name}.{function_name} return the unknown parameter "{key}"') log.error_msg(risotto_context, kwargs, err) raise CallError(str(err)) except ValueError: err = _(f'function {module_name}.{function_name} return the parameter "{key}" with an unvalid value "{value}"') log.error_msg(risotto_context, kwargs, err) raise CallError(str(err)) mandatories = list(config.value.mandatory()) if mandatories: mand = [mand.split('.')[-1] for mand in mandatories] raise ValueError(_(f'missing parameters in response: {mand} in message "{risotto_context.message}"')) try: config.value.dict() except Exception as err: err = _(f'function {module_name}.{function_name} return an invalid response {err}') log.error_msg(risotto_context, kwargs, err) raise CallError(str(err)) async def call(self, version: str, message: str, old_risotto_context: Context, public_only: bool=False, **kwargs): """ execute the function associate with specified uri arguments are validate before """ risotto_context = self.build_new_context(old_risotto_context, version, message, 'rpc') self.valid_public_function(risotto_context, kwargs, public_only) self.check_message_type(risotto_context, kwargs) try: tiramisu_config = self.load_kwargs_to_config(risotto_context, kwargs) obj = self.messages[version][message] kw = tiramisu_config.option(message).value.dict() risotto_context.function = obj['function'] if obj['risotto_context']: kw['risotto_context'] = risotto_context if 'database' in obj and obj['database']: db_conf = get_config().get('database') pool = await asyncpg.create_pool(database=db_conf.get('dbname'), user=db_conf.get('user')) async with pool.acquire() as connection: risotto_context.connection = connection async with connection.transaction(): returns = await risotto_context.function(self.injected_self[obj['module']], **kw) else: returns = await risotto_context.function(self.injected_self[obj['module']], **kw) except CallError as err: raise err except Exception as err: if get_config().get('global').get('debug'): print_exc() log.error_msg(risotto_context, kwargs, err) raise CallError(str(err)) # valid returns self.valid_call_returns(risotto_context, returns, kwargs) # log the success log.info_msg(risotto_context, kwargs, _(f'returns {returns}')) # notification if obj.get('notification'): notif_version, notif_message = obj['notification'].split('.', 1) if not isinstance(returns, list): send_returns = [returns] else: send_returns = returns for ret in send_returns: await self.publish(notif_version, notif_message, risotto_context, **ret) return returns class PublishDispatcher: async def publish(self, version, message, old_risotto_context, public_only=False, **kwargs): risotto_context = self.build_new_context(old_risotto_context, version, message, 'event') self.check_message_type(risotto_context, kwargs) try: config = self.load_kwargs_to_config(risotto_context, kwargs) config_arguments = config.option(message).value.dict() except CallError as err: return except Exception as err: # if there is a problem with arguments, just send an error et do nothing if DEBUG: print_exc() log.error_msg(risotto_context, kwargs, err) return # config is ok, so publish the message for function_obj in self.messages[version][message]['functions']: function = function_obj['function'] module_name = function.__module__.split('.')[-2] function_name = function.__name__ info_msg = _(f'in module {module_name}.{function_name}') try: # build argument for this function kw = {} for key, value in config_arguments.items(): if key in function_obj['arguments']: kw[key] = value if function_obj['risotto_context']: kw['risotto_context'] = risotto_context # send event await function(self.injected_self[function_obj['module']], **kw) except Exception as err: if DEBUG: print_exc() log.error_msg(risotto_context, kwargs, err, info_msg) else: log.info_msg(risotto_context, kwargs, info_msg) # notification if obj.get('notification'): notif_version, notif_message = obj['notification'].split('.', 1) await self.publish(notif_version, notif_message, risotto_context, **returns) class Dispatcher(register.RegisterDispatcher, CallDispatcher, PublishDispatcher): """ Manage message (call or publish) so launch a function when a message is called """ def build_new_context(self, old_risotto_context: Context, version: str, message: str, type: str): """ This is a new call or a new publish, so create a new context """ uri = version + '.' + message risotto_context = Context() risotto_context.username = old_risotto_context.username risotto_context.paths = copy(old_risotto_context.paths) risotto_context.paths.append(uri) risotto_context.uri = uri risotto_context.type = type risotto_context.message = message risotto_context.version = version return risotto_context def check_message_type(self, risotto_context: Context, kwargs: Dict): if self.messages[risotto_context.version][risotto_context.message]['pattern'] != risotto_context.type: msg = _(f'{risotto_context.uri} is not a {risotto_context.type} message') log.error_msg(risotto_context, kwargs, msg) raise CallError(msg) def load_kwargs_to_config(self, risotto_context: Context, kwargs: Dict): """ create a new Config et set values to it """ # create a new config config = Config(self.option) # set message's option config.option('message').value.set(risotto_context.message) # store values subconfig = config.option(risotto_context.message) for key, value in kwargs.items(): try: subconfig.option(key).value.set(value) except AttributeError: if DEBUG: print_exc() raise AttributeError(_(f'unknown parameter "{key}"')) # check mandatories options mandatories = list(config.value.mandatory()) if mandatories: mand = [mand.split('.')[-1] for mand in mandatories] raise ValueError(_(f'missing parameters: {mand}')) # return the config return config def get_service(self, name: str): return self.injected_self[name] dispatcher = Dispatcher() register.dispatcher = dispatcher