try: from tiramisu3 import Config except: from tiramisu import Config from inspect import signature from typing import Callable, Optional, List from asyncpg import create_pool from json import dumps, loads from pkg_resources import iter_entry_points from traceback import print_exc import risotto from .utils import _ from .error import RegistrationError from .message import get_messages from .context import Context from .config import get_config from .logger import log class Services(): services = {} modules_loaded = False services_loaded = False def load_services(self): for entry_point in iter_entry_points(group='risotto_services'): self.services.setdefault(entry_point.name, {}) self.services_loaded = True def load_modules(self, limit_services: Optional[List[str]]=None, ) -> None: for entry_point in iter_entry_points(group='risotto_modules'): service_name, module_name = entry_point.name.split('.') if limit_services is None or service_name in limit_services: self.services[service_name][module_name] = entry_point.load() self.modules_loaded = True # # def get_services(self): # if not self.services_loaded: # self.load_services() # return [(service, getattr(self, service)) for service in self.services] def get_modules(self, limit_services: Optional[List[str]]=None, ) -> List[str]: if not self.modules_loaded: self.load_modules(limit_services=limit_services) return [(module + '.' + submodule, entry_point) for module, submodules in self.services.items() for submodule, entry_point in submodules.items()] def get_services_list(self): return self.services.keys() def get_modules_list(self): return [m for s in self.services for m in self.services[s]] def link_to_dispatcher(self, dispatcher, validate: bool=True, test: bool=False, limit_services: Optional[List[str]]=None, ): for submodule_name, module in self.get_modules(limit_services=limit_services): dispatcher.set_module(submodule_name, module, test, ) if validate: dispatcher.validate() services = Services() services.load_services() setattr(risotto, 'services', services) def register(uris: str, notification: str=None, ) -> None: """ Decorator to register function to the dispatcher """ if not isinstance(uris, list): uris = [uris] def decorator(function): try: for uri in uris: dispatcher.set_function(uri, notification, function, function.__module__ ) except NameError: # if you when register uri, please use get_dispatcher before registered uri pass return decorator class RegisterDispatcher: def __init__(self): # reference to instanciate module (to inject self in method): {"module_name": instance_of_module} self.injected_self = {} # postgresql pool self.pool = None # load tiramisu objects self.risotto_modules = services.get_services_list() messages, self.option = get_messages(self.risotto_modules) # list of uris with informations: {"v1": {"module_name.xxxxx": yyyyyy}} self.messages = {} for tiramisu_message, obj in messages.items(): version = obj['version'] if version not in self.messages: self.messages[version] = {} obj['message'] = tiramisu_message self.messages[version][tiramisu_message] = obj def get_function_args(self, function: Callable): # remove self and risotto_context first_argument_index = 2 return {param.name for param in list(signature(function).parameters.values())[first_argument_index:]} async def get_message_args(self, message: str, version: str, ): # load config async with await Config(self.option, display_name=lambda self, dyn_name, suffix: self.impl_getname()) as config: uri = f'{version}.{message}' await config.property.read_write() # set message to the message name await config.option('message').value.set(uri) # get message argument dico = await config.option(uri).value.dict() return set(dico.keys()) async def valid_rpc_params(self, version: str, message: str, function: Callable, module_name: str, ): """ parameters function must have strictly all arguments with the correct name """ # get message arguments message_args = await self.get_message_args(message, version, ) # get function arguments function_args = self.get_function_args(function) # compare message arguments with function parameter # it must not have more or less arguments for arg in function_args - message_args: if arg.startswith('_'): message_args.add(arg) if message_args != function_args: # raise if arguments are not equal msg = [] missing_function_args = message_args - function_args if missing_function_args: msg.append(_(f'missing arguments: {missing_function_args}')) extra_function_args = function_args - message_args if extra_function_args: msg.append(_(f'extra arguments: {extra_function_args}')) function_name = function.__name__ msg = _(' and ').join(msg) raise RegistrationError(_(f'error with {module_name}.{function_name} arguments: {msg}')) async def valid_event_params(self, version: str, message: str, function: Callable, module_name: str, ): """ parameters function validation for event messages """ # get message arguments message_args = await self.get_message_args(message, version, ) # get function arguments function_args = self.get_function_args(function) # compare message arguments with function parameter # it can have less arguments but not more extra_function_args = function_args - message_args if extra_function_args: # raise if too many arguments function_name = function.__name__ msg = _(f'extra arguments: {extra_function_args}') raise RegistrationError(_(f'error with {module_name}.{function_name} arguments: {msg}')) def set_function(self, uri: str, notification: str, function: Callable, full_module_name: str, ): """ register a function to an URI URI is a message """ version, message = uri.split('.', 1) # check if message exists if message not in self.messages[version]: raise RegistrationError(_(f'the message {message} not exists')) # xxx submodule can only be register with v1.yyy.xxx..... message risotto_module_name, submodule_name = full_module_name.split('.')[-3:-1] module_name = risotto_module_name.split('_')[-1] message_module, message_submodule, message_name = message.split('.', 2) if message_module not in self.risotto_modules: raise RegistrationError(_(f'cannot registered the "{message}" is not "{self.risotto_modules}"')) if self.messages[version][message]['pattern'] == 'rpc' and \ module_name != message_module and \ message_submodule != submodule_name: raise RegistrationError(_(f'cannot registered the "{message}" message in submodule "{module_name}.{submodule_name}"')) # True if first argument is the risotto_context function_args = self.get_function_args(function) # check if already register if 'function' in self.messages[version][message]: raise RegistrationError(_(f'uri {uri} already registered')) # register if self.messages[version][message]['pattern'] == 'rpc': register = self.register_rpc else: register = self.register_event register(version, message, f'{module_name}.{submodule_name}', full_module_name, function, function_args, notification, ) def register_rpc(self, version: str, message: str, module_name: str, full_module_name: str, function: Callable, function_args: list, notification: Optional[str], ): self.messages[version][message]['module'] = module_name self.messages[version][message]['full_module_name'] = full_module_name self.messages[version][message]['function'] = function self.messages[version][message]['arguments'] = function_args if notification: self.messages[version][message]['notification'] = notification def register_event(self, version: str, message: str, module_name: str, full_module_name: str, function: Callable, function_args: list, notification: Optional[str], ): if 'functions' not in self.messages[version][message]: self.messages[version][message]['functions'] = [] dico = {'module': module_name, 'full_module_name': full_module_name, 'function': function, 'arguments': function_args, } if notification and notification: dico['notification'] = notification self.messages[version][message]['functions'].append(dico) def set_module(self, submodule_name, module, test, ): """ register and instanciate a new module """ try: self.injected_self[submodule_name] = module.Risotto(test) except AttributeError as err: print(_(f'unable to register the module {submodule_name}, this module must have Risotto class')) def validate(self): """ check if all messages have a function """ missing_messages = [] for version, messages in self.messages.items(): for message, message_obj in messages.items(): if not 'functions' in message_obj and not 'function' in message_obj: if message_obj['pattern'] == 'event': print(f'{version}.{message} prĂȘche dans le dĂ©sert') else: missing_messages.append(f'{version}.{message}') if missing_messages: raise RegistrationError(_(f'no matching function for uri {missing_messages}')) async def on_join(self, truncate: bool=False, ) -> None: internal_user = get_config()['global']['internal_user'] async with self.pool.acquire() as connection: await connection.set_type_codec( 'json', encoder=dumps, decoder=loads, schema='pg_catalog' ) if truncate: async with connection.transaction(): await connection.execute('TRUNCATE InfraServer, InfraSite, InfraZone, Log, ProviderDeployment, ProviderFactoryCluster, ProviderFactoryClusterNode, SettingApplicationservice, SettingApplicationServiceDependency, SettingRelease, SettingServer, SettingServermodel, SettingSource, UserRole, UserRoleURI, UserURI, UserUser, InfraServermodel, ProviderZone, ProviderServer, ProviderSource, ProviderApplicationservice, ProviderServermodel') async with connection.transaction(): for submodule_name, module in self.injected_self.items(): risotto_context = Context() risotto_context.username = internal_user risotto_context.paths.append(f'internal.{submodule_name}.on_join') risotto_context.type = None risotto_context.pool = self.pool risotto_context.connection = connection risotto_context.module = submodule_name.split('.', 1)[0] info_msg = _(f'in function risotto_{submodule_name}.on_join') await log.info_msg(risotto_context, None, info_msg) try: await module.on_join(risotto_context) except Exception as err: if get_config()['global']['debug']: print_exc() msg = _(f'on_join returns an error in module {submodule_name}: {err}') await log.error_msg(risotto_context, {}, msg) async def load(self): # valid function's arguments db_conf = get_config()['database']['dsn'] self.pool = await create_pool(db_conf) async with self.pool.acquire() as connection: async with connection.transaction(): for version, messages in self.messages.items(): for message, message_infos in messages.items(): if message_infos['pattern'] == 'rpc': # module not available during test if 'module' in message_infos: module_name = message_infos['module'] function = message_infos['function'] await self.valid_rpc_params(version, message, function, module_name) elif 'functions' in message_infos: # event with functions for function_infos in message_infos['functions']: module_name = function_infos['module'] function = function_infos['function'] await self.valid_event_params(version, message, function, module_name)