from tiramisu import Config from inspect import signature from typing import Callable, Optional import asyncpg from json import dumps, loads from .utils import _ from .error import RegistrationError from .message import get_messages from .context import Context from .config import INTERNAL_USER, get_config def register(uris: str, notification: str=None): """ Decorator to register function to the dispatcher """ if not isinstance(uris, list): uris = [uris] def decorator(function): for uri in uris: version, message = uri.split('.', 1) dispatcher.set_function(version, message, notification, function) return decorator class RegisterDispatcher: def __init__(self): # reference to instanciate module (to inject self in method): {"module_name": instance_of_module} self.injected_self = {} # postgresql pool self.pool = None # list of uris with informations: {"v1": {"module_name.xxxxx": yyyyyy}} self.messages = {} # load tiramisu objects messages, self.option = get_messages() #FIXME version = 'v1' self.messages[version] = {} for tiramisu_message, obj in messages.items(): self.messages[version][tiramisu_message] = obj def get_function_args(self, function: Callable): # remove self first_argument_index = 1 return [param.name for param in list(signature(function).parameters.values())[first_argument_index:]] async def valid_rpc_params(self, version: str, message: str, function: Callable, module_name: str): """ parameters function must have strictly all arguments with the correct name """ async def get_message_args(): # load config config = await Config(self.option) await config.property.read_write() # set message to the uri name await config.option('message').value.set(message) # get message argument dico = await config.option(message).value.dict() return set(dico.keys()) def get_function_args(): function_args = self.get_function_args(function) # risotto_context is a special argument, remove it function_args = function_args[1:] return set(function_args) # get message arguments message_args = await get_message_args() # get function arguments function_args = get_function_args() # compare message arguments with function parameter # it must not have more or less arguments if message_args != function_args: # raise if arguments are not equal msg = [] missing_function_args = message_args - function_args if missing_function_args: msg.append(_(f'missing arguments: {missing_function_args}')) extra_function_args = function_args - message_args if extra_function_args: msg.append(_(f'extra arguments: {extra_function_args}')) function_name = function.__name__ msg = _(' and ').join(msg) raise RegistrationError(_(f'error with {module_name}.{function_name} arguments: {msg}')) async def valid_event_params(self, version: str, message: str, function: Callable, module_name: str): """ parameters function validation for event messages """ async def get_message_args(): # load config config = await Config(self.option) await config.property.read_write() # set message to the message name await config.option('message').value.set(message) # get message argument dico = await config.option(message).value.dict() return set(dico.keys()) def get_function_args(): function_args = self.get_function_args(function) # risotto_context is a special argument, remove it function_args = function_args[1:] return set(function_args) # get message arguments message_args = await get_message_args() # get function arguments function_args = get_function_args() # compare message arguments with function parameter # it can have less arguments but not more extra_function_args = function_args - message_args if extra_function_args: # raise if too many arguments function_name = function.__name__ msg = _(f'extra arguments: {extra_function_args}') raise RegistrationError(_(f'error with {module_name}.{function_name} arguments: {msg}')) def set_function(self, version: str, message: str, notification: str, function: Callable): """ register a function to an URI URI is a message """ # check if message exists if message not in self.messages[version]: raise RegistrationError(_(f'the message {message} not exists')) # xxx module can only be register with v1.xxxx..... message module_name = function.__module__.split('.')[-2] message_namespace = message.split('.', 1)[0] if self.messages[version][message]['pattern'] == 'rpc' and message_namespace != module_name: raise RegistrationError(_(f'cannot registered the "{message}" message in module "{module_name}"')) # True if first argument is the risotto_context function_args = self.get_function_args(function) function_args.pop(0) # check if already register if 'function' in self.messages[version][message]: raise RegistrationError(_(f'uri {version}.{message} already registered')) # register if self.messages[version][message]['pattern'] == 'rpc': register = self.register_rpc else: register = self.register_event register(version, message, module_name, function, function_args, notification) def register_rpc(self, version: str, message: str, module_name: str, function: Callable, function_args: list, notification: Optional[str]): self.messages[version][message]['module'] = module_name self.messages[version][message]['function'] = function self.messages[version][message]['arguments'] = function_args if notification: self.messages[version][message]['notification'] = notification def register_event(self, version: str, message: str, module_name: str, function: Callable, function_args: list, notification: Optional[str]): if 'functions' not in self.messages[version][message]: self.messages[version][message]['functions'] = [] dico = {'module': module_name, 'function': function, 'arguments': function_args} if notification and notification: dico['notification'] = notification self.messages[version][message]['functions'].append(dico) def set_module(self, module_name, module, test): """ register and instanciate a new module """ try: self.injected_self[module_name] = module.Risotto(test) except AttributeError as err: raise RegistrationError(_(f'unable to register the module {module_name}, this module must have Risotto class')) def validate(self): """ check if all messages have a function """ missing_messages = [] for version, messages in self.messages.items(): for message, message_obj in messages.items(): if not 'functions' in message_obj and not 'function' in message_obj: if message_obj['pattern'] == 'event': print(f'{message} prĂȘche dans le dĂ©sert') else: missing_messages.append(message) if missing_messages: raise RegistrationError(_(f'missing uri {missing_messages}')) async def on_join(self): async with self.pool.acquire() as connection: await connection.set_type_codec( 'json', encoder=dumps, decoder=loads, schema='pg_catalog' ) async with connection.transaction(): for module_name, module in self.injected_self.items(): risotto_context = Context() risotto_context.username = INTERNAL_USER risotto_context.paths.append(f'{module_name}.on_join') risotto_context.type = None risotto_context.connection = connection await module.on_join(risotto_context) async def insert_message(self, connection, uri): sql = """INSERT INTO URI(URIName) VALUES ($1) ON CONFLICT (URIName) DO NOTHING """ await connection.fetchval(sql, uri) async def load(self): # valid function's arguments db_conf = get_config().get('database') engine = db_conf.get('engine') host = db_conf.get('host') dbname = db_conf.get('dbname') dbuser = db_conf.get('user') dbpassword = db_conf.get('password') dbport = db_conf.get('port') cfg = "{}://{}:{}@{}:{}/{}".format(engine, dbuser, dbpassword, host, dbport, dbname) self.pool = await asyncpg.create_pool(cfg) async with self.pool.acquire() as connection: async with connection.transaction(): for version, messages in self.messages.items(): for message, message_infos in messages.items(): if message_infos['pattern'] == 'rpc': module_name = message_infos['module'] function = message_infos['function'] await self.valid_rpc_params(version, message, function, module_name) else: if 'functions' in message_infos: for function_infos in message_infos['functions']: module_name = function_infos['module'] function = function_infos['function'] await self.valid_event_params(version, message, function, module_name) await self.insert_message(connection, f'{version}.{message}')