From cc6dd3efe33033bfd62180c16fae5e9bdc0c1639 Mon Sep 17 00:00:00 2001 From: Emmanuel Garette Date: Tue, 10 Mar 2020 14:03:37 +0100 Subject: [PATCH] can have multi domain locally --- src/risotto/controller.py | 22 +++---- src/risotto/http.py | 14 +++-- src/risotto/message/message.py | 100 ++++++++++++------------------- src/risotto/register.py | 68 +++++++-------------- src/risotto/services/__init__.py | 33 ++++++++++ 5 files changed, 116 insertions(+), 121 deletions(-) create mode 100644 src/risotto/services/__init__.py diff --git a/src/risotto/controller.py b/src/risotto/controller.py index e770570..1ac8223 100644 --- a/src/risotto/controller.py +++ b/src/risotto/controller.py @@ -2,6 +2,7 @@ from .config import get_config from .dispatcher import dispatcher from .context import Context from .remote import remote +from .services import list_modules class Controller: @@ -9,17 +10,17 @@ class Controller: """ def __init__(self, test: bool): - self.submodule = get_config()['global']['module_name'] + self.risotto_modules = list_modules() async def call(self, uri: str, risotto_context: Context, **kwargs): """ a wrapper to dispatcher's call""" - version, submodule, message = uri.split('.', 2) - uri = submodule + '.' + message - if submodule != self.submodule: - return await remote.call_or_publish(submodule, + version, module, message = uri.split('.', 2) + uri = module + '.' + message + if module not in self.risotto_modules: + return await remote.call_or_publish(module, version, message, kwargs) @@ -33,14 +34,15 @@ class Controller: risotto_context: Context, **kwargs): """ a wrapper to dispatcher's publish""" - version, submodule, uri = uri.split('.', 2) - if submodule != self.submodule: - await remote.call_or_publish(submodule, + version, module, submessage = uri.split('.', 2) + version, message = uri.split('.', 1) + if module not in self.risotto_modules: + await remote.call_or_publish(module, version, - message, + submessage, kwargs) await dispatcher.publish(version, - uri, + message, risotto_context, **kwargs) diff --git a/src/risotto/http.py b/src/risotto/http.py index 58ead74..5a366ff 100644 --- a/src/risotto/http.py +++ b/src/risotto/http.py @@ -11,12 +11,15 @@ from .error import CallError, NotAllowedError, RegistrationError from .message import get_messages from .logger import log from .config import get_config -from .services import load_services +from .services import list_modules, load_submodules extra_routes = {} +RISOTTO_MODULES = list_modules() + + def create_context(request): risotto_context = Context() risotto_context.username = request.match_info.get('username', @@ -88,8 +91,8 @@ async def handle(request): if get_config()['global']['debug']: print_exc() raise HTTPInternalServerError(reason=str(err)) - return Response(text=dumps({'response': text}, - content_type='application/json')) + return Response(text=dumps({'response': text}), + content_type='application/json') async def api(request, @@ -108,7 +111,8 @@ async def api(request, WHERE RoleURI.URIId = URI.URIId ''' uris = [uri['uriname'] for uri in await connection.fetch(sql)] - async with await Config(get_messages(load_shortarg=True, + async with await Config(get_messages(current_module_names=RISOTTO_MODULES, + load_shortarg=True, current_version=risotto_context.version, uris=uris)[1]) as config: await config.property.read_write() @@ -120,7 +124,7 @@ async def get_app(loop): """ build all routes """ global extra_routes - load_services() + load_submodules(dispatcher) app = Application(loop=loop) routes = [] default_storage.engine('dictionary') diff --git a/src/risotto/message/message.py b/src/risotto/message/message.py index 25ce419..2f003a5 100644 --- a/src/risotto/message/message.py +++ b/src/risotto/message/message.py @@ -13,7 +13,6 @@ from ..config import get_config from ..utils import _ MESSAGE_ROOT_PATH = get_config()['global']['message_root_path'] -MODULE_NAME = get_config()['global']['module_name'] CUSTOMTYPES = {} groups.addgroup('message') @@ -229,12 +228,12 @@ def _parse_parameters(raw_defs, def get_message(uri: str, - current_module_name: str): + current_module_names: str): try: version, message = uri.split('.', 1) path = get_message_file_path(version, message, - current_module_name) + current_module_names) with open(path, "r") as message_file: return MessageDefinition(load(message_file.read(), Loader=SafeLoader), version, @@ -247,22 +246,21 @@ def get_message(uri: str, def get_message_file_path(version, message, - current_module_name): + current_module_names): module_name, filename = message.split('.', 1) - if current_module_name and module_name != current_module_name: - raise Exception(f'should only load message for {current_module_name}, not {message}') + if current_module_names and module_name not in current_module_names: + raise Exception(f'should only load message for {current_module_names}, not {message}') return join(MESSAGE_ROOT_PATH, version, module_name, 'messages', filename + '.yml') def list_messages(uris, - current_module_name, + current_module_names, current_version): - def get_module_paths(): - if current_module_name is not None: - yield current_module_name, join(MESSAGE_ROOT_PATH, version, current_module_name, 'messages') - else: - for module_name in listdir(join(MESSAGE_ROOT_PATH, version)): - yield module_name, join(MESSAGE_ROOT_PATH, version, module_name, 'messages') + def get_module_paths(current_module_names): + if current_module_names is None: + current_module_names = listdir(join(MESSAGE_ROOT_PATH, version)) + for module_name in current_module_names: + yield module_name, join(MESSAGE_ROOT_PATH, version, module_name, 'messages') if current_version: versions = [current_version] @@ -270,7 +268,7 @@ def list_messages(uris, versions = listdir(join(MESSAGE_ROOT_PATH)) versions.sort() for version in versions: - for module_name, message_path in get_module_paths(): + for module_name, message_path in get_module_paths(current_module_names): for message in listdir(message_path): if message.endswith('.yml'): uri = version + '.' + module_name + '.' + message.rsplit('.', 1)[0] @@ -390,47 +388,28 @@ class CustomType: return self.title -def load_customtypes(current_module_name: str) -> None: +def load_customtypes(current_module_names: str) -> None: versions = listdir(MESSAGE_ROOT_PATH) versions.sort() - def convert_properties(customtype: str, - version: str) -> None: - """ if properties include an other customtype, replace it - """ - properties = {} - for key, value in customtype.properties.items(): - type_ = value.type - if type_.startswith('[]'): - if type_ in CUSTOMTYPES[version]: - raise Exception(_('cannot have []CustomType')) - properties[key] = value - else: - if type_ in CUSTOMTYPES[version]: - print('====== ca existe') - properties[key] = CUSTOMTYPES[version][ttype_] - else: - properties[key] = value - customtype.properties = properties for version in versions: if version not in CUSTOMTYPES: CUSTOMTYPES[version] = {} - types_path = join(MESSAGE_ROOT_PATH, - version, - current_module_name, - 'types') - for message in listdir(types_path): - if message.endswith('.yml'): - path = join(types_path, message) - # remove extension - message = message.rsplit('.', 1)[0] - with open(path, "r") as message_file: - try: - custom_type = CustomType(load(message_file, Loader=SafeLoader)) - convert_properties(custom_type, - version) - CUSTOMTYPES[version][custom_type.getname()] = custom_type - except Exception as err: - raise Exception(_(f'enable to load type {err}: {message}')) + for current_module_name in current_module_names: + types_path = join(MESSAGE_ROOT_PATH, + version, + current_module_name, + 'types') + for message in listdir(types_path): + if message.endswith('.yml'): + path = join(types_path, message) + # remove extension + message = message.rsplit('.', 1)[0] + with open(path, "r") as message_file: + try: + custom_type = CustomType(load(message_file, Loader=SafeLoader)) + CUSTOMTYPES[version][custom_type.getname()] = custom_type + except Exception as err: + raise Exception(_(f'enable to load type {err}: {message}')) @@ -612,10 +591,10 @@ def _get_root_option(select_option, optiondescriptions): return OptionDescription('root', 'root', options_obj) -def get_messages(load_shortarg=False, +def get_messages(current_module_names, + load_shortarg=False, current_version=None, - uris=None, - current_module_name=MODULE_NAME): + uris=None): """generate description from yml files """ global CUSTOMTYPES @@ -623,7 +602,7 @@ def get_messages(load_shortarg=False, optiondescriptions_info = {} needs = {} messages = list(list_messages(uris, - current_module_name, + current_module_names, current_version)) messages.sort() optiondescriptions_name = [message_name.split('.', 1)[1] for message_name in messages] @@ -631,18 +610,17 @@ def get_messages(load_shortarg=False, 'Nom du message.', tuple(optiondescriptions_name), properties=frozenset(['mandatory', 'positional'])) - if current_module_name is None: + if current_module_names is None: CUSTOMTYPES = {} if not CUSTOMTYPES: - if current_module_name is None: + if current_module_names is None: for version in listdir(MESSAGE_ROOT_PATH): - for module_name in listdir(join(MESSAGE_ROOT_PATH, version)): - load_customtypes(module_name) + load_customtypes(listdir(join(MESSAGE_ROOT_PATH, version))) else: - load_customtypes(current_module_name) + load_customtypes(current_module_names) for message_name in messages: message_def = get_message(message_name, - current_module_name) + current_module_names) optiondescriptions_info[message_def.uri] = {'pattern': message_def.pattern, 'default_roles': message_def.default_roles, 'version': message_name.split('.')[0]} @@ -659,6 +637,6 @@ def get_messages(load_shortarg=False, load_shortarg) root = _get_root_option(select_option, optiondescriptions) - if current_module_name is None: + if current_module_names is None: CUSTOMTYPES = {} return optiondescriptions_info, root diff --git a/src/risotto/register.py b/src/risotto/register.py index 9b348f7..b3af396 100644 --- a/src/risotto/register.py +++ b/src/risotto/register.py @@ -10,6 +10,7 @@ from .message import get_messages from .context import Context from .config import get_config from .logger import log +from .services import list_modules def register(uris: str, @@ -36,22 +37,32 @@ class RegisterDispatcher: # postgresql pool self.pool = None # load tiramisu objects - messages, self.option = get_messages() + self.risotto_modules = list_modules() + messages, self.option = get_messages(self.risotto_modules) # list of uris with informations: {"v1": {"module_name.xxxxx": yyyyyy}} - version = 'v1' self.messages = {} for tiramisu_message, obj in messages.items(): version = obj['version'] if version not in self.messages: self.messages[version] = {} self.messages[version][tiramisu_message] = obj - self.risotto_module = get_config()['global']['module_name'] def get_function_args(self, function: Callable): - # remove self - first_argument_index = 1 - return [param.name for param in list(signature(function).parameters.values())[first_argument_index:]] + # remove self and risotto_context + first_argument_index = 2 + return {param.name for param in list(signature(function).parameters.values())[first_argument_index:]} + + async def get_message_args(self, + message: str): + # load config + async with await Config(self.option) as config: + await config.property.read_write() + # set message to the message name + await config.option('message').value.set(message) + # get message argument + dico = await config.option(message).value.dict() + return set(dico.keys()) async def valid_rpc_params(self, version: str, @@ -60,26 +71,10 @@ class RegisterDispatcher: module_name: str): """ parameters function must have strictly all arguments with the correct name """ - async def get_message_args(): - # load config - async with await Config(self.option) as config: - await config.property.read_write() - # set message to the uri name - await config.option('message').value.set(message) - # get message argument - dico = await config.option(message).value.dict() - return set(dico.keys()) - - def get_function_args(): - function_args = self.get_function_args(function) - # risotto_context is a special argument, remove it - function_args = function_args[1:] - return set(function_args) - # get message arguments - message_args = await get_message_args() + message_args = await self.get_message_args(message) # get function arguments - function_args = get_function_args() + function_args = self.get_function_args(function) # compare message arguments with function parameter # it must not have more or less arguments if message_args != function_args: @@ -102,26 +97,10 @@ class RegisterDispatcher: module_name: str): """ parameters function validation for event messages """ - async def get_message_args(): - # load config - async with await Config(self.option) as config: - await config.property.read_write() - # set message to the message name - await config.option('message').value.set(message) - # get message argument - dico = await config.option(message).value.dict() - return set(dico.keys()) - - def get_function_args(): - function_args = self.get_function_args(function) - # risotto_context is a special argument, remove it - function_args = function_args[1:] - return set(function_args) - # get message arguments - message_args = await get_message_args() + message_args = await self.get_message_args(message) # get function arguments - function_args = get_function_args() + function_args = self.get_function_args(function) # compare message arguments with function parameter # it can have less arguments but not more extra_function_args = function_args - message_args @@ -148,14 +127,13 @@ class RegisterDispatcher: module_name = function.__module__.split('.')[-2] message_namespace = message.split('.', 1)[0] message_risotto_module, message_namespace, message_name = message.split('.', 2) - if message_risotto_module != self.risotto_module: - raise RegistrationError(_(f'cannot registered the "{message}" is not "{self.risotto_module}"')) + if message_risotto_module not in self.risotto_modules: + raise RegistrationError(_(f'cannot registered the "{message}" is not "{self.risotto_modules}"')) if self.messages[version][message]['pattern'] == 'rpc' and message_namespace != module_name: raise RegistrationError(_(f'cannot registered the "{message}" message in module "{module_name}"')) # True if first argument is the risotto_context function_args = self.get_function_args(function) - function_args.pop(0) # check if already register if 'function' in self.messages[version][message]: diff --git a/src/risotto/services/__init__.py b/src/risotto/services/__init__.py new file mode 100644 index 0000000..0cb4af0 --- /dev/null +++ b/src/risotto/services/__init__.py @@ -0,0 +1,33 @@ +from os import listdir +from os.path import isdir, isfile, dirname, abspath, basename, join +from importlib import import_module +# from ..dispatcher import dispatcher + + +def list_modules(): + abs_here = dirname(abspath(__file__)) + here = basename(abs_here) + return [name for name in listdir(abs_here) if not name.startswith('__') and isdir(join(abs_here, name))] + + +def load_submodules(dispatcher, + modules=None, + validate: bool=True, + test: bool=False): + abs_here = dirname(abspath(__file__)) + here = basename(abs_here) + module = basename(dirname(abs_here)) + if not modules: + modules = listdir(abs_here) + for module in modules: + absmodule = join(abs_here, module) + if isdir(absmodule): + for submodule in listdir(absmodule): + absfilename = join(absmodule, submodule) + if isdir(absfilename) and isfile(join(absfilename, '__init__.py')): + dispatcher.set_module(submodule, + import_module(f'.{here}.{module}.{submodule}', + f'risotto'), + test) + if validate: + dispatcher.validate()