diff --git a/sql/risotto.sql b/sql/risotto.sql new file mode 100644 index 0000000..ad58260 --- /dev/null +++ b/sql/risotto.sql @@ -0,0 +1,8 @@ +CREATE TABLE log( + Msg VARCHAR(255) NOT NULL, + Level VARCHAR(10) NOT NULL, + Path VARCHAR(255), + Username VARCHAR(100) NOT NULL, + Data JSON, + Date timestamp DEFAULT current_timestamp +); diff --git a/src/risotto/config.py b/src/risotto/config.py index 5cd933e..bacd6b3 100644 --- a/src/risotto/config.py +++ b/src/risotto/config.py @@ -1,21 +1,81 @@ from os import environ +from os.path import isfile +from configobj import ConfigObj -CONFIGURATION_DIR = environ.get('CONFIGURATION_DIR', '/srv/risotto/configurations') -PROVIDER_FACTORY_CONFIG_DIR = environ.get('PROVIDER_FACTORY_CONFIG_DIR', '/srv/factory') -TMP_DIR = '/tmp' -DEFAULT_USER = environ.get('DEFAULT_USER', 'Anonymous') -RISOTTO_DB_NAME = environ.get('RISOTTO_DB_NAME', 'risotto') -RISOTTO_DB_PASSWORD = environ.get('RISOTTO_DB_PASSWORD', 'risotto') -RISOTTO_DB_USER = environ.get('RISOTTO_DB_USER', 'risotto') -TIRAMISU_DB_NAME = environ.get('TIRAMISU_DB_NAME', 'tiramisu') -TIRAMISU_DB_PASSWORD = environ.get('TIRAMISU_DB_PASSWORD', 'tiramisu') -TIRAMISU_DB_USER = environ.get('TIRAMISU_DB_USER', 'tiramisu') -DB_ADDRESS = environ.get('DB_ADDRESS', 'localhost') -MESSAGE_PATH = environ.get('MESSAGE_PATH', '/root/risotto-message/messages') -SQL_DIR = environ.get('SQL_DIR', './sql') -CACHE_ROOT_PATH = environ.get('CACHE_ROOT_PATH', '/var/cache/risotto') -SRV_SEED_PATH = environ.get('SRV_SEED_PATH', '/srv/seed') +CONFIG_FILE = environ.get('CONFIG_FILE', '/etc/risotto/risotto.conf') + + +if isfile(CONFIG_FILE): + config = ConfigObj(CONFIG_FILE) +else: + config = {} + + +if 'RISOTTO_PORT' in environ: + RISOTTO_PORT = environ['RISOTTO_PORT'] +else: + RISOTTO_PORT = config.get('RISOTTO_PORT', 8080) +if 'CONFIGURATION_DIR' in environ: + CONFIGURATION_DIR = environ['CONFIGURATION_DIR'] +else: + CONFIGURATION_DIR = config.get('CONFIGURATION_DIR', '/srv/risotto/configurations') +if 'PROVIDER_FACTORY_CONFIG_DIR' in environ: + PROVIDER_FACTORY_CONFIG_DIR = environ['PROVIDER_FACTORY_CONFIG_DIR'] +else: + PROVIDER_FACTORY_CONFIG_DIR = config.get('PROVIDER_FACTORY_CONFIG_DIR', '/srv/factory') +if 'DEFAULT_USER' in environ: + DEFAULT_USER = environ['DEFAULT_USER'] +else: + DEFAULT_USER = config.get('DEFAULT_USER', 'Anonymous') +if 'RISOTTO_DB_NAME' in environ: + RISOTTO_DB_NAME = environ['RISOTTO_DB_NAME'] +else: + RISOTTO_DB_NAME = config.get('RISOTTO_DB_NAME', 'risotto') +if 'RISOTTO_DB_PASSWORD' in environ: + RISOTTO_DB_PASSWORD = environ['RISOTTO_DB_PASSWORD'] +else: + RISOTTO_DB_PASSWORD = config.get('RISOTTO_DB_PASSWORD', 'risotto') +if 'RISOTTO_DB_USER' in environ: + RISOTTO_DB_USER = environ['RISOTTO_DB_USER'] +else: + RISOTTO_DB_USER = config.get('RISOTTO_DB_USER', 'risotto') +if 'TIRAMISU_DB_NAME' in environ: + TIRAMISU_DB_NAME = environ['TIRAMISU_DB_NAME'] +else: + TIRAMISU_DB_NAME = config.get('TIRAMISU_DB_NAME', 'tiramisu') +if 'TIRAMISU_DB_PASSWORD' in environ: + TIRAMISU_DB_PASSWORD = environ['TIRAMISU_DB_PASSWORD'] +else: + TIRAMISU_DB_PASSWORD = config.get('TIRAMISU_DB_PASSWORD', 'tiramisu') +if 'TIRAMISU_DB_USER' in environ: + TIRAMISU_DB_USER = environ['TIRAMISU_DB_USER'] +else: + TIRAMISU_DB_USER = config.get('TIRAMISU_DB_USER', 'tiramisu') +if 'DB_ADDRESS' in environ: + DB_ADDRESS = environ['DB_ADDRESS'] +else: + DB_ADDRESS = config.get('DB_ADDRESS', 'localhost') +if 'MESSAGE_PATH' in environ: + MESSAGE_PATH = environ['MESSAGE_PATH'] +else: + MESSAGE_PATH = config.get('MESSAGE_PATH', '/root/risotto-message/messages') +if 'SQL_DIR' in environ: + SQL_DIR = environ['SQL_DIR'] +else: + SQL_DIR = config.get('SQL_DIR', './sql') +if 'CACHE_ROOT_PATH' in environ: + CACHE_ROOT_PATH = environ['CACHE_ROOT_PATH'] +else: + CACHE_ROOT_PATH = config.get('CACHE_ROOT_PATH', '/var/cache/risotto') +if 'SRV_SEED_PATH' in environ: + SRV_SEED_PATH = environ['SRV_SEED_PATH'] +else: + SRV_SEED_PATH = config.get('SRV_SEED_PATH', '/srv/seed') +if 'TMP_DIR' in environ: + TMP_DIR = environ['TMP_DIR'] +else: + TMP_DIR = config.get('TMP_DIR', '/tmp') def dsn_factory(database, user, password, address=DB_ADDRESS): @@ -23,26 +83,29 @@ def dsn_factory(database, user, password, address=DB_ADDRESS): return f'postgres:///{database}?host={mangled_address}/&user={user}&password={password}' +_config = {'database': {'dsn': dsn_factory(RISOTTO_DB_NAME, RISOTTO_DB_USER, RISOTTO_DB_PASSWORD), + 'tiramisu_dsn': dsn_factory(TIRAMISU_DB_NAME, TIRAMISU_DB_USER, TIRAMISU_DB_PASSWORD), + }, + 'http_server': {'port': RISOTTO_PORT, + 'default_user': DEFAULT_USER}, + 'global': {'message_root_path': MESSAGE_PATH, + 'configurations_dir': CONFIGURATION_DIR, + 'debug': True, + 'internal_user': '_internal', + 'check_role': True, + 'admin_user': DEFAULT_USER, + 'sql_dir': SQL_DIR, + 'tmp_dir': TMP_DIR, + }, + 'cache': {'root_path': CACHE_ROOT_PATH}, + 'servermodel': {'internal_source_path': SRV_SEED_PATH, + 'internal_source': 'internal'}, + 'submodule': {'allow_insecure_https': False, + 'pki': '192.168.56.112'}, + 'provider': {'factory_configuration_dir': PROVIDER_FACTORY_CONFIG_DIR, + 'factory_configuration_filename': 'infra.json'}, + } + + def get_config(): - return {'database': {'dsn': dsn_factory(RISOTTO_DB_NAME, RISOTTO_DB_USER, RISOTTO_DB_PASSWORD), - 'tiramisu_dsn': dsn_factory(TIRAMISU_DB_NAME, TIRAMISU_DB_USER, TIRAMISU_DB_PASSWORD), - }, - 'http_server': {'port': 8080, - 'default_user': DEFAULT_USER}, - 'global': {'message_root_path': MESSAGE_PATH, - 'configurations_dir': CONFIGURATION_DIR, - 'debug': True, - 'internal_user': 'internal', - 'check_role': True, - 'admin_user': DEFAULT_USER, - 'sql_dir': SQL_DIR}, - 'cache': {'root_path': CACHE_ROOT_PATH}, - 'servermodel': {'internal_source_path': SRV_SEED_PATH, - 'internal_source': 'internal', - 'internal_distribution': 'last', - 'internal_release_name': 'none'}, - 'submodule': {'allow_insecure_https': False, - 'pki': '192.168.56.112'}, - 'provider': {'factory_configuration_dir': PROVIDER_FACTORY_CONFIG_DIR, - 'factory_configuration_filename': 'infra.json'}, - } + return _config diff --git a/src/risotto/controller.py b/src/risotto/controller.py index cc0a67c..c4b3296 100644 --- a/src/risotto/controller.py +++ b/src/risotto/controller.py @@ -1,8 +1,5 @@ -from .config import get_config from .dispatcher import dispatcher from .context import Context -from .remote import remote -from . import services from .utils import _ @@ -10,50 +7,48 @@ class Controller: """Common controller used to add a service in Risotto """ def __init__(self, - test: bool): - self.risotto_modules = services.get_services_list() + test: bool, + ): + pass async def call(self, uri: str, risotto_context: Context, *args, - **kwargs): + **kwargs, + ): """ a wrapper to dispatcher's call""" - version, module, message = uri.split('.', 2) - uri = module + '.' + message if args: raise ValueError(_(f'the URI "{uri}" can only be called with keyword arguments')) - if module not in self.risotto_modules: - return await remote.remote_call(module, - version, - message, - kwargs) + current_uri = risotto_context.paths[-1] + current_module = risotto_context.module + version, message = uri.split('.', 1) + module = message.split('.', 1)[0] + if current_module != module: + raise ValueError(_(f'cannot call to external module ("{module}") to the URI "{uri}" from "{current_module}"')) return await dispatcher.call(version, - uri, + message, risotto_context, - **kwargs) + **kwargs, + ) async def publish(self, uri: str, risotto_context: Context, - *args, - **kwargs): + *args, + **kwargs, + ): """ a wrapper to dispatcher's publish""" - version, module, submessage = uri.split('.', 2) version, message = uri.split('.', 1) if args: raise ValueError(_(f'the URI "{uri}" can only be published with keyword arguments')) - if module not in self.risotto_modules: - await remote.remote_call(module, - version, - submessage, - kwargs) - else: - await dispatcher.publish(version, - message, - risotto_context, - **kwargs) + await dispatcher.publish(version, + message, + risotto_context, + **kwargs, + ) async def on_join(self, - risotto_context): + risotto_context, + ): pass diff --git a/src/risotto/dispatcher.py b/src/risotto/dispatcher.py index 92c400b..08b2343 100644 --- a/src/risotto/dispatcher.py +++ b/src/risotto/dispatcher.py @@ -1,7 +1,9 @@ try: from tiramisu3 import Config + from tiramisu3.error import ValueOptionError except: from tiramisu import Config + from tiramisu.error import ValueOptionError from traceback import print_exc from copy import copy from typing import Dict, Callable, List, Optional @@ -13,8 +15,7 @@ from .logger import log from .config import get_config from .context import Context from . import register -#from .remote import Remote -import asyncpg +from .remote import Remote class CallDispatcher: @@ -42,28 +43,28 @@ class CallDispatcher: raise Exception('hu?') else: for ret in returns: - async with await Config(response, display_name=lambda self, dyn_name: self.impl_getname()) as config: + async with await Config(response, display_name=lambda self, dyn_name, suffix: self.impl_getname()) as config: await config.property.read_write() try: for key, value in ret.items(): await config.option(key).value.set(value) except AttributeError: - err = _(f'function {module_name}.{function_name} return the unknown parameter "{key}"') + err = _(f'function {module_name}.{function_name} return the unknown parameter "{key}" for the uri "{risotto_context.version}.{risotto_context.message}"') await log.error_msg(risotto_context, kwargs, err) raise CallError(str(err)) except ValueError: - err = _(f'function {module_name}.{function_name} return the parameter "{key}" with an unvalid value "{value}"') + err = _(f'function {module_name}.{function_name} return the parameter "{key}" with an unvalid value "{value}" for the uri "{risotto_context.version}.{risotto_context.message}"') await log.error_msg(risotto_context, kwargs, err) raise CallError(str(err)) await config.property.read_only() mandatories = await config.value.mandatory() if mandatories: mand = [mand.split('.')[-1] for mand in mandatories] - raise ValueError(_(f'missing parameters in response: {mand} in message "{risotto_context.message}"')) + raise ValueError(_(f'missing parameters in response of the uri "{risotto_context.version}.{risotto_context.message}": {mand} in message')) try: await config.value.dict() except Exception as err: - err = _(f'function {module_name}.{function_name} return an invalid response {err}') + err = _(f'function {module_name}.{function_name} return an invalid response {err} for the uri "{risotto_context.version}.{risotto_context.message}"') await log.error_msg(risotto_context, kwargs, err) raise CallError(str(err)) @@ -72,14 +73,21 @@ class CallDispatcher: message: str, old_risotto_context: Context, check_role: bool=False, - **kwargs): + internal: bool=True, + **kwargs, + ): """ execute the function associate with specified uri arguments are validate before """ risotto_context = self.build_new_context(old_risotto_context, version, message, - 'rpc') + 'rpc', + ) + if version not in self.messages: + raise CallError(_(f'cannot find version of message "{version}"')) + if message not in self.messages[version]: + raise CallError(_(f'cannot find message "{version}.{message}"')) function_objs = [self.messages[version][message]] # do not start a new database connection if hasattr(old_risotto_context, 'connection'): @@ -89,7 +97,9 @@ class CallDispatcher: risotto_context, check_role, kwargs, - function_objs) + function_objs, + internal, + ) else: try: async with self.pool.acquire() as connection: @@ -106,7 +116,9 @@ class CallDispatcher: risotto_context, check_role, kwargs, - function_objs) + function_objs, + internal, + ) except CallError as err: raise err except Exception as err: @@ -132,26 +144,38 @@ class PublishDispatcher: message: str, old_risotto_context: Context, check_role: bool=False, - **kwargs) -> None: + internal: bool=True, + **kwargs, + ) -> None: risotto_context = self.build_new_context(old_risotto_context, version, message, - 'event') + 'event', + ) try: function_objs = self.messages[version][message].get('functions', []) except KeyError: raise ValueError(_(f'cannot find message {version}.{message}')) # do not start a new database connection if hasattr(old_risotto_context, 'connection'): + # publish to remove + remote_kw = dumps({'kwargs': kwargs, + 'context': risotto_context.__dict__, + }) risotto_context.connection = old_risotto_context.connection + # FIXME should be better :/ + remote_kw = remote_kw.replace("'", "''") + await risotto_context.connection.execute(f'NOTIFY "{version}.{message}", \'{remote_kw}\'') return await self.launch(version, message, risotto_context, check_role, kwargs, - function_objs) - try: - async with self.pool.acquire() as connection: + function_objs, + internal, + ) + async with self.pool.acquire() as connection: + try: await connection.set_type_codec( 'json', encoder=dumps, @@ -165,28 +189,29 @@ class PublishDispatcher: risotto_context, check_role, kwargs, - function_objs) - except CallError as err: - raise err - except Exception as err: - # if there is a problem with arguments, just send an error and do nothing - if get_config()['global']['debug']: - print_exc() - async with self.pool.acquire() as connection: - await connection.set_type_codec( - 'json', - encoder=dumps, - decoder=loads, - schema='pg_catalog' - ) - risotto_context.connection = connection - async with connection.transaction(): - await log.error_msg(risotto_context, kwargs, err) - raise err + function_objs, + internal, + ) + except CallError as err: + pass + except Exception as err: + # if there is a problem with arguments, log and do nothing + if get_config()['global']['debug']: + print_exc() + async with self.pool.acquire() as connection: + await connection.set_type_codec( + 'json', + encoder=dumps, + decoder=loads, + schema='pg_catalog' + ) + risotto_context.connection = connection + async with connection.transaction(): + await log.error_msg(risotto_context, kwargs, err) class Dispatcher(register.RegisterDispatcher, -# Remote, + Remote, CallDispatcher, PublishDispatcher): """ Manage message (call or publish) @@ -196,7 +221,8 @@ class Dispatcher(register.RegisterDispatcher, old_risotto_context: Context, version: str, message: str, - type: str): + type: str, + ) -> Context: """ This is a new call or a new publish, so create a new context """ uri = version + '.' + message @@ -212,7 +238,8 @@ class Dispatcher(register.RegisterDispatcher, async def check_message_type(self, risotto_context: Context, - kwargs: Dict): + kwargs: Dict, + ) -> None: if self.messages[risotto_context.version][risotto_context.message]['pattern'] != risotto_context.type: msg = _(f'{risotto_context.uri} is not a {risotto_context.type} message') await log.error_msg(risotto_context, kwargs, msg) @@ -222,23 +249,31 @@ class Dispatcher(register.RegisterDispatcher, risotto_context: Context, uri: str, kwargs: Dict, - check_role: bool): + check_role: bool, + internal: bool, + ): """ create a new Config et set values to it """ # create a new config async with await Config(self.option) as config: await config.property.read_write() # set message's option - await config.option('message').value.set(risotto_context.message) + await config.option('message').value.set(uri) # store values - subconfig = config.option(risotto_context.message) + subconfig = config.option(uri) + extra_parameters = {} for key, value in kwargs.items(): - try: - await subconfig.option(key).value.set(value) - except AttributeError: - if get_config()['global']['debug']: - print_exc() - raise ValueError(_(f'unknown parameter in "{uri}": "{key}"')) + if not internal or not key.startswith('_'): + try: + await subconfig.option(key).value.set(value) + except AttributeError: + if get_config()['global']['debug']: + print_exc() + raise ValueError(_(f'unknown parameter in "{uri}": "{key}"')) + except ValueOptionError as err: + raise ValueError(_(f'invalid parameter in "{uri}": {err}')) + else: + extra_parameters[key] = value # check mandatories options if check_role and get_config().get('global').get('check_role'): await self.check_role(subconfig, @@ -250,7 +285,10 @@ class Dispatcher(register.RegisterDispatcher, mand = [mand.split('.')[-1] for mand in mandatories] raise ValueError(_(f'missing parameters in "{uri}": {mand}')) # return complete an validated kwargs - return await subconfig.value.dict() + parameters = await subconfig.value.dict() + if extra_parameters: + parameters.update(extra_parameters) + return parameters def get_service(self, name: str): @@ -265,7 +303,7 @@ class Dispatcher(register.RegisterDispatcher, # Verify if user exists and get ID sql = ''' SELECT UserId - FROM RisottoUser + FROM UserUser WHERE UserLogin = $1 ''' user_id = await connection.fetchval(sql, @@ -283,8 +321,8 @@ class Dispatcher(register.RegisterDispatcher, # Check role select_role_uri = ''' SELECT RoleName - FROM URI, RoleURI - WHERE URI.URIName = $1 AND RoleURI.URIId = URI.URIId + FROM UserURI, UserRoleURI + WHERE UserURI.URIName = $1 AND UserRoleURI.URIId = UserURI.URIId ''' select_role_user = ''' SELECT RoleAttribute, RoleAttributeValue @@ -309,19 +347,24 @@ class Dispatcher(register.RegisterDispatcher, risotto_context: Context, check_role: bool, kwargs: Dict, - function_objs: List) -> Optional[Dict]: + function_objs: List, + internal: bool, + ) -> Optional[Dict]: await self.check_message_type(risotto_context, kwargs) config_arguments = await self.load_kwargs_to_config(risotto_context, f'{version}.{message}', kwargs, - check_role) + check_role, + internal, + ) # config is ok, so send the message for function_obj in function_objs: function = function_obj['function'] - module_name = function.__module__.split('.')[-2] + submodule_name = function_obj['module'] function_name = function.__name__ - info_msg = _(f'in module {module_name}.{function_name}') + risotto_context.module = submodule_name.split('.', 1)[0] + info_msg = _(f'in module {submodule_name}.{function_name}') # build argument for this function if risotto_context.type == 'rpc': kw = config_arguments diff --git a/src/risotto/http.py b/src/risotto/http.py index cbfdac7..dae48d8 100644 --- a/src/risotto/http.py +++ b/src/risotto/http.py @@ -20,13 +20,11 @@ from . import services extra_routes = {} -RISOTTO_MODULES = services.get_services_list() - - def create_context(request): risotto_context = Context() risotto_context.username = request.match_info.get('username', - get_config()['http_server']['default_user']) + get_config()['http_server']['default_user'], + ) return risotto_context @@ -53,8 +51,9 @@ class extra_route_handler: function_name = cls.function.__module__ # if not 'api' function if function_name != 'risotto.http': - module_name = function_name.split('.')[-2] - kwargs['self'] = dispatcher.injected_self[module_name] + risotto_module_name, submodule_name = function_name.split('.', 2)[:-1] + module_name = risotto_module_name.split('_')[-1] + kwargs['self'] = dispatcher.injected_self[module_name + '.' + submodule_name] try: returns = await cls.function(**kwargs) except NotAllowedError as err: @@ -85,7 +84,9 @@ async def handle(request): message, risotto_context, check_role=True, - **kwargs) + internal=False, + **kwargs, + ) except NotAllowedError as err: raise HTTPNotFound(reason=str(err)) except CallError as err: @@ -100,8 +101,8 @@ async def handle(request): async def api(request, risotto_context): - global tiramisu - if not tiramisu: + global TIRAMISU + if not TIRAMISU: # check all URI that have an associated role # all URI without role is concidered has a private URI uris = [] @@ -109,18 +110,21 @@ async def api(request, async with connection.transaction(): # Check role with ACL sql = ''' - SELECT URI.URIName - FROM URI, RoleURI - WHERE RoleURI.URIId = URI.URIId + SELECT UserURI.URIName + FROM UserURI, UserRoleURI + WHERE UserRoleURI.URIId = UserURI.URIId ''' uris = [uri['uriname'] for uri in await connection.fetch(sql)] - async with await Config(get_messages(current_module_names=RISOTTO_MODULES, + risotto_modules = services.get_services_list() + async with await Config(get_messages(current_module_names=risotto_modules, load_shortarg=True, current_version=risotto_context.version, - uris=uris)[1]) as config: + uris=uris, + )[1], + display_name=lambda self, dyn_name, suffix: self.impl_getname()) as config: await config.property.read_write() - tiramisu = await config.option.dict(remotable='none') - return tiramisu + TIRAMISU = await config.option.dict(remotable='none') + return TIRAMISU async def get_app(loop): @@ -138,9 +142,9 @@ async def get_app(loop): versions.append(version) print() print(_('======== Registered messages ========')) - for message in messages: + for message, message_infos in messages.items(): web_message = f'/api/{version}/{message}' - pattern = dispatcher.messages[version][message]['pattern'] + pattern = message_infos['pattern'] print(f' - {web_message} ({pattern})') routes.append(post(web_message, handle)) print() @@ -152,6 +156,9 @@ async def get_app(loop): extra_handler = type(api_route['path'], (extra_route_handler,), api_route) routes.append(get(api_route['path'], extra_handler)) print(f' - {api_route["path"]} (http_get)') + # last version is default version + routes.append(get('/api', extra_handler)) + print(f' - /api (http_get)') print() if extra_routes: print(_('======== Registered extra routes ========')) @@ -162,11 +169,12 @@ async def get_app(loop): extra_handler = type(path, (extra_route_handler,), extra) routes.append(get(path, extra_handler)) print(f' - {path} (http_get)') - print() del extra_routes app.router.add_routes(routes) + await dispatcher.register_remote() + print() await dispatcher.on_join() return await loop.create_server(app.make_handler(), '*', get_config()['http_server']['port']) -tiramisu = None +TIRAMISU = None diff --git a/src/risotto/message.py b/src/risotto/message.py index da6dea6..e4453c9 100644 --- a/src/risotto/message.py +++ b/src/risotto/message.py @@ -248,7 +248,8 @@ def get_message_file_path(version, def list_messages(uris, current_module_names, - current_version): + current_version, + ): def get_module_paths(current_module_names): if current_module_names is None: current_module_names = listdir(join(MESSAGE_ROOT_PATH, version)) @@ -412,7 +413,7 @@ def load_customtypes() -> None: custom_type = CustomType(load(message_file, Loader=SafeLoader)) ret[version][custom_type.getname()] = custom_type except Exception as err: - raise Exception(_(f'enable to load type {err}: {message}')) + raise Exception(_(f'enable to load type "{message}": {err}')) return ret @@ -431,9 +432,9 @@ def _get_description(description, def _get_option(name, arg, - file_path, + uri, select_option, - optiondescription): + ): """generate option """ props = [] @@ -443,7 +444,7 @@ def _get_option(name, props.append(Calculation(calc_value, Params(ParamValue('disabled'), kwargs={'condition': ParamOption(select_option, todict=True), - 'expected': ParamValue(optiondescription), + 'expected': ParamValue(uri), 'reverse_condition': ParamValue(True)}), calc_value_property_help)) @@ -472,25 +473,25 @@ def _get_option(name, elif type_ == 'Float': obj = FloatOption(**kwargs) else: - raise Exception('unsupported type {} in {}'.format(type_, file_path)) + raise Exception('unsupported type {} in {}'.format(type_, uri)) obj.impl_set_information('ref', arg.ref) return obj def get_options(message_def, - file_path, + uri, select_option, - optiondescription, - load_shortarg): + load_shortarg, + ): """build option with args/kwargs """ options =[] for name, arg in message_def.parameters.items(): current_opt = _get_option(name, arg, - file_path, + uri, select_option, - optiondescription) + ) options.append(current_opt) if hasattr(arg, 'shortarg') and arg.shortarg and load_shortarg: options.append(SymLinkOption(arg.shortarg, current_opt)) @@ -498,17 +499,18 @@ def get_options(message_def, def _parse_responses(message_def, - file_path): + uri, + ): """build option with returns """ if message_def.response.parameters is None: - raise Exception('message "{}" did not returned any valid parameters.'.format(message_def.message)) + raise Exception(f'message "{message_def.message}" did not returned any valid parameters') options = [] names = [] for name, obj in message_def.response.parameters.items(): if name in names: - raise Exception('multi response with name {} in {}'.format(name, file_path)) + raise Exception(f'multi response with name "{name}" in "{uri}"') names.append(name) kwargs = {'name': name, @@ -531,15 +533,17 @@ def _parse_responses(message_def, else: kwargs['properties'] = ('mandatory',) options.append(option(**kwargs)) - od = OptionDescription(message_def.message, + od = OptionDescription(uri, message_def.response.description, - options) + options, + ) od.impl_set_information('multi', message_def.response.multi) return od def _get_root_option(select_option, - optiondescriptions): + optiondescriptions, + ): """get root option """ def _get_od(curr_ods): @@ -581,19 +585,21 @@ def _get_root_option(select_option, def get_messages(current_module_names, load_shortarg=False, current_version=None, - uris=None): + uris=None, + ): """generate description from yml files """ optiondescriptions = {} optiondescriptions_info = {} messages = list(list_messages(uris, current_module_names, - current_version)) + current_version, + )) messages.sort() - optiondescriptions_name = [message_name.split('.', 1)[1] for message_name in messages] + # optiondescriptions_name = [message_name.split('.', 1)[1] for message_name in messages] select_option = ChoiceOption('message', 'Nom du message.', - tuple(optiondescriptions_name), + tuple(messages), properties=frozenset(['mandatory', 'positional'])) for uri in messages: message_def = get_message(uri, @@ -601,23 +607,26 @@ def get_messages(current_module_names, ) optiondescriptions_info[message_def.message] = {'pattern': message_def.pattern, 'default_roles': message_def.default_roles, - 'version': message_def.version} + 'version': message_def.version, + } if message_def.pattern == 'rpc': if not message_def.response: raise Exception(f'rpc without response is not allowed {uri}') optiondescriptions_info[message_def.message]['response'] = _parse_responses(message_def, - uri) + uri, + ) elif message_def.response: raise Exception(f'response is not allowed for {uri}') message_def.options = get_options(message_def, uri, select_option, - message_def.message, - load_shortarg) - optiondescriptions[message_def.message] = (message_def.description, message_def.options) + load_shortarg, + ) + optiondescriptions[uri] = (message_def.description, message_def.options) root = _get_root_option(select_option, - optiondescriptions) + optiondescriptions, + ) return optiondescriptions_info, root diff --git a/src/risotto/register.py b/src/risotto/register.py index 1f16f4a..c2db5cb 100644 --- a/src/risotto/register.py +++ b/src/risotto/register.py @@ -3,9 +3,10 @@ try: except: from tiramisu import Config from inspect import signature -from typing import Callable, Optional -import asyncpg +from typing import Callable, Optional, List +from asyncpg import create_pool from json import dumps, loads +from pkg_resources import iter_entry_points import risotto from .utils import _ from .error import RegistrationError @@ -13,7 +14,7 @@ from .message import get_messages from .context import Context from .config import get_config from .logger import log -from pkg_resources import iter_entry_points + class Services(): services = {} @@ -25,11 +26,14 @@ class Services(): self.services.setdefault(entry_point.name, []) self.services_loaded = True - def load_modules(self): + def load_modules(self, + limit_services: Optional[List[str]]=None, + ) -> None: for entry_point in iter_entry_points(group='risotto_modules'): service_name, module_name = entry_point.name.split('.') - setattr(self, module_name, entry_point.load()) - self.services[service_name].append(module_name) + if limit_services is None or service_name in limit_services: + setattr(self, module_name, entry_point.load()) + self.services[service_name].append(module_name) self.modules_loaded = True def get_services(self): @@ -37,10 +41,12 @@ class Services(): self.load_services() return [(s, getattr(self, s)) for s in self.services] - def get_modules(self): + def get_modules(self, + limit_services: Optional[List[str]]=None, + ) -> List[str]: if not self.modules_loaded: - self.load_modules() - return [(m, getattr(self, m)) for s in self.services.values() for m in s] + self.load_modules(limit_services=limit_services) + return [(module + '.' + submodule, getattr(self, submodule)) for module, submodules in self.services.items() for submodule in submodules] def get_services_list(self): return self.services.keys() @@ -52,11 +58,13 @@ class Services(): dispatcher, validate: bool=True, test: bool=False, + limit_services: Optional[List[str]]=None, ): - for module_name, module in self.get_modules(): - dispatcher.set_module(module_name, + for submodule_name, module in self.get_modules(limit_services=limit_services): + dispatcher.set_module(submodule_name, module, - test) + test, + ) if validate: dispatcher.validate() @@ -65,8 +73,10 @@ services = Services() services.load_services() setattr(risotto, 'services', services) + def register(uris: str, - notification: str=None): + notification: str=None, + ) -> None: """ Decorator to register function to the dispatcher """ if not isinstance(uris, list): @@ -106,29 +116,38 @@ class RegisterDispatcher: return {param.name for param in list(signature(function).parameters.values())[first_argument_index:]} async def get_message_args(self, - message: str): + message: str, + version: str, + ): # load config - async with await Config(self.option) as config: + async with await Config(self.option, display_name=lambda self, dyn_name, suffix: self.impl_getname()) as config: + uri = f'{version}.{message}' await config.property.read_write() # set message to the message name - await config.option('message').value.set(message) + await config.option('message').value.set(uri) # get message argument - dico = await config.option(message).value.dict() + dico = await config.option(uri).value.dict() return set(dico.keys()) async def valid_rpc_params(self, version: str, message: str, function: Callable, - module_name: str): + module_name: str, + ): """ parameters function must have strictly all arguments with the correct name """ # get message arguments - message_args = await self.get_message_args(message) + message_args = await self.get_message_args(message, + version, + ) # get function arguments function_args = self.get_function_args(function) # compare message arguments with function parameter # it must not have more or less arguments + for arg in function_args - message_args: + if arg.startswith('_'): + message_args.add(arg) if message_args != function_args: # raise if arguments are not equal msg = [] @@ -146,11 +165,14 @@ class RegisterDispatcher: version: str, message: str, function: Callable, - module_name: str): + module_name: str, + ): """ parameters function validation for event messages """ # get message arguments - message_args = await self.get_message_args(message) + message_args = await self.get_message_args(message, + version, + ) # get function arguments function_args = self.get_function_args(function) # compare message arguments with function parameter @@ -166,7 +188,8 @@ class RegisterDispatcher: version: str, message: str, notification: str, - function: Callable): + function: Callable, + ): """ register a function to an URI URI is a message """ @@ -175,14 +198,16 @@ class RegisterDispatcher: if message not in self.messages[version]: raise RegistrationError(_(f'the message {message} not exists')) - # xxx module can only be register with v1.xxxx..... message - module_name = function.__module__.split('.')[-2] - message_namespace = message.split('.', 1)[0] - message_risotto_module, message_namespace, message_name = message.split('.', 2) - if message_risotto_module not in self.risotto_modules: + # xxx submodule can only be register with v1.yyy.xxx..... message + risotto_module_name, submodule_name = function.__module__.split('.')[-3:-1] + module_name = risotto_module_name.split('_')[-1] + message_module, message_submodule, message_name = message.split('.', 2) + if message_module not in self.risotto_modules: raise RegistrationError(_(f'cannot registered the "{message}" is not "{self.risotto_modules}"')) - if self.messages[version][message]['pattern'] == 'rpc' and message_namespace != module_name: - raise RegistrationError(_(f'cannot registered the "{message}" message in module "{module_name}"')) + if self.messages[version][message]['pattern'] == 'rpc' and \ + module_name != message_module and \ + message_submodule != submodule_name: + raise RegistrationError(_(f'cannot registered the "{message}" message in submodule "{module_name}.{submodule_name}"')) # True if first argument is the risotto_context function_args = self.get_function_args(function) @@ -198,10 +223,11 @@ class RegisterDispatcher: register = self.register_event register(version, message, - module_name, + f'{module_name}.{submodule_name}', function, function_args, - notification) + notification, + ) def register_rpc(self, version: str, @@ -209,7 +235,8 @@ class RegisterDispatcher: module_name: str, function: Callable, function_args: list, - notification: Optional[str]): + notification: Optional[str], + ): self.messages[version][message]['module'] = module_name self.messages[version][message]['function'] = function self.messages[version][message]['arguments'] = function_args @@ -222,7 +249,8 @@ class RegisterDispatcher: module_name: str, function: Callable, function_args: list, - notification: Optional[str]): + notification: Optional[str], + ): if 'functions' not in self.messages[version][message]: self.messages[version][message]['functions'] = [] @@ -233,13 +261,17 @@ class RegisterDispatcher: dico['notification'] = notification self.messages[version][message]['functions'].append(dico) - def set_module(self, module_name, module, test): + def set_module(self, + submodule_name, + module, + test, + ): """ register and instanciate a new module """ try: - self.injected_self[module_name] = module.Risotto(test) + self.injected_self[submodule_name] = module.Risotto(test) except AttributeError as err: - raise RegistrationError(_(f'unable to register the module {module_name}, this module must have Risotto class')) + raise RegistrationError(_(f'unable to register the module {submodule_name}, this module must have Risotto class')) def validate(self): """ check if all messages have a function @@ -255,7 +287,9 @@ class RegisterDispatcher: if missing_messages: raise RegistrationError(_(f'no matching function for uri {missing_messages}')) - async def on_join(self): + async def on_join(self, + truncate: bool=False, + ) -> None: internal_user = get_config()['global']['internal_user'] async with self.pool.acquire() as connection: await connection.set_type_codec( @@ -264,14 +298,18 @@ class RegisterDispatcher: decoder=loads, schema='pg_catalog' ) + if truncate: + async with connection.transaction(): + await connection.execute('TRUNCATE InfraServer, InfraSite, InfraZone, Log, ProviderDeployment, ProviderFactoryCluster, ProviderFactoryClusterNode, SettingApplicationservice, SettingApplicationServiceDependency, SettingRelease, SettingServer, SettingServermodel, SettingSource, UserRole, UserRoleURI, UserURI, UserUser, InfraServermodel, ProviderZone, ProviderServer, ProviderSource, ProviderApplicationservice ProviderServermodel') async with connection.transaction(): - for module_name, module in self.injected_self.items(): + for submodule_name, module in self.injected_self.items(): risotto_context = Context() risotto_context.username = internal_user - risotto_context.paths.append(f'{module_name}.on_join') + risotto_context.paths.append(f'internal.{submodule_name}.on_join') risotto_context.type = None risotto_context.connection = connection - info_msg = _(f'in module {module_name}.on_join') + risotto_context.module = submodule_name.split('.', 1)[0] + info_msg = _(f'in module risotto_{submodule_name}.on_join') await log.info_msg(risotto_context, None, info_msg) @@ -280,18 +318,20 @@ class RegisterDispatcher: async def load(self): # valid function's arguments db_conf = get_config()['database']['dsn'] - self.pool = await asyncpg.create_pool(db_conf) + self.pool = await create_pool(db_conf) async with self.pool.acquire() as connection: async with connection.transaction(): for version, messages in self.messages.items(): for message, message_infos in messages.items(): if message_infos['pattern'] == 'rpc': - module_name = message_infos['module'] - function = message_infos['function'] - await self.valid_rpc_params(version, - message, - function, - module_name) + # module not available during test + if 'module' in message_infos: + module_name = message_infos['module'] + function = message_infos['function'] + await self.valid_rpc_params(version, + message, + function, + module_name) elif 'functions' in message_infos: # event with functions for function_infos in message_infos['functions']: diff --git a/src/risotto/remote.py b/src/risotto/remote.py index a3c2aad..dc5d358 100644 --- a/src/risotto/remote.py +++ b/src/risotto/remote.py @@ -1,61 +1,42 @@ -from aiohttp import ClientSession -from requests import get, post -from json import dumps -#from tiramisu_api import Config +from asyncio import get_event_loop, ensure_future +from json import loads +from .context import Context from .config import get_config from .utils import _ -# -# -# ALLOW_INSECURE_HTTPS = get_config()['module']['allow_insecure_https'] class Remote: - submodules = {} + async def register_remote(self) -> None: + print() + print(_('======== Registered remote event ========')) + self.listened_connection = await self.pool.acquire() + for version, messages in self.messages.items(): + for message, message_infos in messages.items(): + # event not emit locally + if message_infos['pattern'] == 'event': + module, submodule, submessage = message.split('.', 2) + if f'{module}.{submodule}' not in self.injected_self: + uri = f'{version}.{message}' + print(f' - {uri}') + await self.listened_connection.add_listener(uri, self.to_async_publish) - async def _get_config(self, - module: str, - url: str) -> None: - if module not in self.submodules: - session = ClientSession() - async with session.get(url) as resp: - if resp.status != 200: - try: - json = await resp.json() - err = json['error']['kwargs']['reason'] - except: - err = await resp.text() - raise Exception(err) - json = await resp.json() - self.submodules[module] = json - return Config(self.submodules[module]) - - async def remote_call(self, - module: str, - version: str, - submessage: str, - payload) -> dict: - try: - domain_name = get_config()['module'][module] - except KeyError: - raise ValueError(_(f'cannot find information of remote module "{module}" to access to "{version}.{module}.{submessage}"')) - remote_url = f'http://{domain_name}:8080/api/{version}' - message_url = f'{remote_url}/{submessage}' - - config = await self._get_config(module, - remote_url) - for key, value in payload.items(): - path = submessage + '.' + key - config.option(path).value.set(value) - session = ClientSession() - async with session.post(message_url, data=dumps(payload)) as resp: - response = await resp.json() - if 'error' in response: - if 'reason' in response['error']['kwargs']: - raise Exception("{}".format(response['error']['kwargs']['reason'])) - raise Exception('erreur inconnue') - return response['response'] - - -remote = Remote() + def to_async_publish(self, + con: 'asyncpg.connection.Connection', + pid: int, + uri: str, + payload: str, + ) -> None: + version, message = uri.split('.', 1) + loop = get_event_loop() + remote_kw = loads(payload) + context = Context() + for key, value in remote_kw['context'].items(): + setattr(context, key, value) + callback = lambda: ensure_future(self.publish(version, + message, + context, + **remote_kw['kwargs'], + )) + loop.call_soon(callback) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/storage.py b/tests/storage.py index 8a47d7b..fdb4d73 100644 --- a/tests/storage.py +++ b/tests/storage.py @@ -1,5 +1,15 @@ -from tiramisu import Storage -from risotto.config import DATABASE_DIR +try: + from tiramisu3 import Storage +except: + from tiramisu import Storage +from os.path import isfile as _isfile +import os as _os +_envfile = '/etc/risotto/risotto.conf' +if _isfile(_envfile): + with open(_envfile, 'r') as fh_env: + for line in fh_env.readlines(): + key, value = line.strip().split('=') + _os.environ[key] = value -STORAGE = Storage(engine='sqlite3', dir_database=DATABASE_DIR, name='test') +STORAGE = Storage(engine='sqlite3') diff --git a/tests/test_config.py b/tests/test_config.py index d340bc9..1105616 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,20 +1,29 @@ from importlib import import_module import pytest -from tiramisu import list_sessions, delete_session +try: + from tiramisu3 import list_sessions, delete_session as _delete_session +except: + from tiramisu import list_sessions, delete_session as _delete_session from .storage import STORAGE +from risotto import services from risotto.context import Context -from risotto.services import load_services +#from risotto.services import load_services from risotto.dispatcher import dispatcher +SOURCE_NAME = 'test' +SERVERMODEL_NAME = 'sm1' + + def setup_module(module): - load_services(['config'], - validate=False) +# load_services(['config'], +# validate=False) + services.link_to_dispatcher(dispatcher, limit_services=['setting'], validate=False) config_module = dispatcher.get_service('config') config_module.save_storage = STORAGE - dispatcher.set_module('server', import_module(f'.server', 'fake_services'), True) - dispatcher.set_module('servermodel', import_module(f'.servermodel', 'fake_services'), True) + #dispatcher.set_module('server', import_module(f'.server', 'fake_services'), True) + #dispatcher.set_module('servermodel', import_module(f'.servermodel', 'fake_services'), True) def setup_function(function): @@ -23,11 +32,11 @@ def setup_function(function): config_module.servermodel = {} -def teardown_function(function): +async def delete_session(): # delete all sessions config_module = dispatcher.get_service('config') - for session in list_sessions(storage=config_module.save_storage): - delete_session(storage=config_module.save_storage, session_id=session) + for session in await list_sessions(storage=config_module.save_storage): + await _delete_session(storage=config_module.save_storage, session_id=session) def get_fake_context(module_name): @@ -38,127 +47,166 @@ def get_fake_context(module_name): return risotto_context -@pytest.mark.asyncio -async def test_on_join(): +async def onjoin(source=True): config_module = dispatcher.get_service('config') assert config_module.servermodel == {} assert config_module.server == {} + await delete_session() # + #config_module.cache_root_path = 'tests/data' + await dispatcher.load() + await dispatcher.on_join(truncate=True) + if source: + fake_context = get_fake_context('config') + await dispatcher.call('v1', + 'setting.source.create', + fake_context, + source_name=SOURCE_NAME, + source_directory='tests/data', + ) + + +INTERNAL_SOURCE = {'source_name': 'internal', 'source_directory': '/srv/risotto/seed/internal'} +TEST_SOURCE = {'source_name': 'test', 'source_directory': 'tests/data'} + + +############################################################################################################################## +# Source / Release +############################################################################################################################## +@pytest.mark.asyncio +async def test_source_on_join(): + # onjoin must create internal source + sources = [INTERNAL_SOURCE] + await onjoin(False) fake_context = get_fake_context('config') - config_module.cache_root_path = 'tests/data' - await config_module.on_join(fake_context) - assert list(config_module.servermodel.keys()) == [1, 2] - assert list(config_module.server) == [3] - assert set(config_module.server[3]) == {'server', 'server_to_deploy', 'funcs_file'} - assert config_module.server[3]['funcs_file'] == 'tests/data/1/funcs.py' + assert await dispatcher.call('v1', + 'setting.source.list', + fake_context, + ) == sources + await delete_session() @pytest.mark.asyncio -async def test_server_created(): +async def test_source_create(): + sources = [INTERNAL_SOURCE, TEST_SOURCE] + await onjoin() config_module = dispatcher.get_service('config') + assert list(config_module.servermodel.keys()) == ['last_base'] + assert list(config_module.server) == [] fake_context = get_fake_context('config') - config_module.cache_root_path = 'tests/data' - await config_module.on_join(fake_context) - # - assert list(config_module.server) == [3] - await dispatcher.publish('v1', - 'server.created', - fake_context, - server_id=4, - server_name='name3', - server_description='description3', - server_servermodel_id=2) - assert list(config_module.server) == [3, 4] - assert set(config_module.server[4]) == {'server', 'server_to_deploy', 'funcs_file'} - assert config_module.server[4]['funcs_file'] == 'tests/data/2/funcs.py' + assert await dispatcher.call('v1', + 'setting.source.list', + fake_context, + ) == sources + await delete_session() @pytest.mark.asyncio -async def test_server_deleted(): - config_module = dispatcher.get_service('config') +async def test_source_describe(): + await onjoin() fake_context = get_fake_context('config') - config_module.cache_root_path = 'tests/data' - await config_module.on_join(fake_context) - # - assert list(config_module.server) == [3] - await dispatcher.publish('v1', - 'server.created', - fake_context, - server_id=4, - server_name='name4', - server_description='description4', - server_servermodel_id=2) - assert list(config_module.server) == [3, 4] - await dispatcher.publish('v1', - 'server.deleted', - fake_context, - server_id=4) - assert list(config_module.server) == [3] + assert await dispatcher.call('v1', + 'setting.source.describe', + fake_context, + source_name='internal', + ) == INTERNAL_SOURCE + assert await dispatcher.call('v1', + 'setting.source.describe', + fake_context, + source_name=SOURCE_NAME, + ) == TEST_SOURCE + await delete_session() + + +@pytest.mark.asyncio +async def test_release_internal_list(): + releases = [{'release_distribution': 'last', + 'release_name': 'none', + 'source_name': 'internal'}] + + await onjoin() + fake_context = get_fake_context('config') + assert await dispatcher.call('v1', + 'setting.source.release.list', + fake_context, + source_name='internal', + ) == releases + await delete_session() + + +@pytest.mark.asyncio +async def test_release_list(): + releases = [{'release_distribution': 'last', + 'release_name': '1', + 'source_name': 'test'}] + + await onjoin() + fake_context = get_fake_context('config') + assert await dispatcher.call('v1', + 'setting.source.release.list', + fake_context, + source_name='test', + ) == releases + await delete_session() + + +@pytest.mark.asyncio +async def test_release_describe(): + + await onjoin() + fake_context = get_fake_context('config') + assert await dispatcher.call('v1', + 'setting.source.release.describe', + fake_context, + source_name='internal', + release_distribution='last', + ) == {'release_distribution': 'last', + 'release_name': 'none', + 'source_name': 'internal'} + assert await dispatcher.call('v1', + 'setting.source.release.describe', + fake_context, + source_name='test', + release_distribution='last', + ) == {'release_distribution': 'last', + 'release_name': '1', + 'source_name': 'test'} + await delete_session() + + +############################################################################################################################## +# Servermodel +############################################################################################################################## +async def create_servermodel(name=SERVERMODEL_NAME, + parents_name=['base'], + ): + fake_context = get_fake_context('config') + await dispatcher.call('v1', + 'setting.servermodel.create', + fake_context, + servermodel_name=name, + servermodel_description='servermodel 1', + parents_name=parents_name, + source_name=SOURCE_NAME, + release_distribution='last', + ) @pytest.mark.asyncio async def test_servermodel_created(): + await onjoin() config_module = dispatcher.get_service('config') - fake_context = get_fake_context('config') - config_module.cache_root_path = 'tests/data' - await config_module.on_join(fake_context) # - assert list(config_module.servermodel) == [1, 2] - servermodel = {'servermodeid': 3, - 'servermodelname': 'name3'} - await dispatcher.publish('v1', - 'servermodel.created', - fake_context, - servermodel_id=3, - servermodel_description='name3', - release_id=1, - servermodel_name='name3') - assert list(config_module.servermodel) == [1, 2, 3] - assert not list(await config_module.servermodel[3].config.parents()) - - -@pytest.mark.asyncio -async def test_servermodel_herited_created(): - config_module = dispatcher.get_service('config') - fake_context = get_fake_context('config') - config_module.cache_root_path = 'tests/data' - await config_module.on_join(fake_context) - # - assert list(config_module.servermodel) == [1, 2] - await dispatcher.publish('v1', - 'servermodel.created', - fake_context, - servermodel_id=3, - servermodel_name='name3', - release_id=1, - servermodel_description='name3', - servermodel_parents_id=[1]) - assert list(config_module.servermodel) == [1, 2, 3] - assert len(list(await config_module.servermodel[3].config.parents())) == 1 - - -@pytest.mark.asyncio -async def test_servermodel_multi_herited_created(): - config_module = dispatcher.get_service('config') - fake_context = get_fake_context('config') - config_module.cache_root_path = 'tests/data' - await config_module.on_join(fake_context) - # - assert list(config_module.servermodel) == [1, 2] - await dispatcher.publish('v1', - 'servermodel.created', - fake_context, - servermodel_id=3, - servermodel_name='name3', - release_id=1, - servermodel_description='name3', - servermodel_parents_id=[1, 2]) - assert list(config_module.servermodel) == [1, 2, 3] - assert len(list(await config_module.servermodel[3].config.parents())) == 2 - - + assert list(config_module.servermodel) == ['last_base'] + await create_servermodel() + assert list(config_module.servermodel) == ['last_base', 'last_sm1'] + assert not list(await config_module.servermodel['last_base'].config.parents()) + assert len(list(await config_module.servermodel['last_sm1'].config.parents())) == 1 + await delete_session() +# +# #@pytest.mark.asyncio -#async def test_servermodel_updated_not_exists(): +#async def test_servermodel_herited_created(): # config_module = dispatcher.get_service('config') # fake_context = get_fake_context('config') # config_module.cache_root_path = 'tests/data' @@ -166,7 +214,28 @@ async def test_servermodel_multi_herited_created(): # # # assert list(config_module.servermodel) == [1, 2] # await dispatcher.publish('v1', -# 'servermodel.updated', +# 'servermodel.created', +# fake_context, +# servermodel_id=3, +# servermodel_name='name3', +# release_id=1, +# servermodel_description='name3', +# servermodel_parents_id=[1]) +# assert list(config_module.servermodel) == [1, 2, 3] +# assert len(list(await config_module.servermodel[3].config.parents())) == 1 +# await delete_session() +# +# +#@pytest.mark.asyncio +#async def test_servermodel_multi_herited_created(): +# config_module = dispatcher.get_service('config') +# fake_context = get_fake_context('config') +# config_module.cache_root_path = 'tests/data' +# await config_module.on_join(fake_context) +# # +# assert list(config_module.servermodel) == [1, 2] +# await dispatcher.publish('v1', +# 'servermodel.created', # fake_context, # servermodel_id=3, # servermodel_name='name3', @@ -175,164 +244,315 @@ async def test_servermodel_multi_herited_created(): # servermodel_parents_id=[1, 2]) # assert list(config_module.servermodel) == [1, 2, 3] # assert len(list(await config_module.servermodel[3].config.parents())) == 2 +# await delete_session() # # -# @pytest.mark.asyncio -# async def test_servermodel_updated1(): -# config_module = dispatcher.get_service('config') -# fake_context = get_fake_context('config') -# config_module.cache_root_path = 'tests/data' -# await config_module.on_join(fake_context) -# # -# assert list(config_module.servermodel) == [1, 2] -# metaconfig1 = config_module.servermodel[1] -# metaconfig2 = config_module.servermodel[2] -# mixconfig1 = (await metaconfig1.config.list())[0] -# mixconfig2 = (await metaconfig2.config.list())[0] -# assert len(list(await metaconfig1.config.parents())) == 0 -# assert len(list(await metaconfig2.config.parents())) == 1 -# assert len(list(await mixconfig1.config.list())) == 1 -# assert len(list(await mixconfig2.config.list())) == 0 -# # -# await dispatcher.publish('v1', -# 'servermodel.updated', -# fake_context, -# servermodel_id=1, -# servermodel_name='name1-1', -# release_id=1, -# servermodel_description='name1-1') -# assert set(config_module.servermodel) == {1, 2} -# assert config_module.servermodel[1].information.get('servermodel_name') == 'name1-1' -# assert metaconfig1 != config_module.servermodel[1] -# assert metaconfig2 == config_module.servermodel[2] -# metaconfig1 = config_module.servermodel[1] -# assert mixconfig1 != next(metaconfig1.config.list()) -# mixconfig1 = next(metaconfig1.config.list()) -# # -# assert len(list(await metaconfig1.config.parents())) == 0 -# assert len(list(await metaconfig2.config.parents())) == 1 -# assert len(list(await mixconfig1.config.list())) == 1 -# assert len(list(await mixconfig2.config.list())) == 0 -# -# -# @pytest.mark.asyncio -# async def test_servermodel_updated2(): -# config_module = dispatcher.get_service('config') -# fake_context = get_fake_context('config') -# config_module.cache_root_path = 'tests/data' -# await config_module.on_join(fake_context) -# # create a new servermodel -# assert list(config_module.servermodel) == [1, 2] -# mixconfig1 = next(config_module.servermodel[1].config.list()) -# mixconfig2 = next(config_module.servermodel[2].config.list()) -# assert len(list(mixconfig1.config.list())) == 1 -# assert len(list(mixconfig2.config.list())) == 0 -# await dispatcher.publish('v1', -# 'servermodel.created', -# fake_context, -# servermodel_id=3, -# servermodel_name='name3', -# release_id=1, -# servermodel_description='name3', -# servermodel_parents_id=[1]) -# assert list(config_module.servermodel) == [1, 2, 3] -# assert len(list(await config_module.servermodel[3].config.parents())) == 1 -# assert await config_module.servermodel[3].information.get('servermodel_name') == 'name3' -# assert len(list(await mixconfig1.config.list())) == 2 -# assert len(list(await mixconfig2.config.list())) == 0 -# # -# await dispatcher.publish('v1', -# 'servermodel.updated', -# fake_context, -# servermodel_id=3, -# servermodel_name='name3-1', -# release_id=1, -# servermodel_description='name3-1', -# servermodel_parents_id=[1, 2]) -# assert list(config_module.servermodel) == [1, 2, 3] -# assert config_module.servermodel[3].information.get('servermodel_name') == 'name3-1' -# assert len(list(mixconfig1.config.list())) == 2 -# assert len(list(mixconfig2.config.list())) == 1 -# -# -# @pytest.mark.asyncio -# async def test_servermodel_updated_config(): -# config_module = dispatcher.get_service('config') -# fake_context = get_fake_context('config') -# config_module.cache_root_path = 'tests/data' -# await config_module.on_join(fake_context) -# # -# config_module.servermodel[1].property.read_write() -# assert config_module.servermodel[1].option('creole.general.mode_conteneur_actif').value.get() == 'non' -# config_module.servermodel[1].option('creole.general.mode_conteneur_actif').value.set('oui') -# assert config_module.servermodel[1].option('creole.general.mode_conteneur_actif').value.get() == 'oui' -# # -# await dispatcher.publish('v1', -# 'servermodel.updated', -# fake_context, -# servermodel_id=1, -# servermodel_name='name1-1', -# release_id=1, -# servermodel_description='name1-1') -# assert config_module.servermodel[1].option('creole.general.mode_conteneur_actif').value.get() == 'oui' +##@pytest.mark.asyncio +##async def test_servermodel_updated_not_exists(): +## config_module = dispatcher.get_service('config') +## fake_context = get_fake_context('config') +## config_module.cache_root_path = 'tests/data' +## await config_module.on_join(fake_context) +## # +## assert list(config_module.servermodel) == [1, 2] +## await dispatcher.publish('v1', +## 'servermodel.updated', +## fake_context, +## servermodel_id=3, +## servermodel_name='name3', +## release_id=1, +## servermodel_description='name3', +## servermodel_parents_id=[1, 2]) +## assert list(config_module.servermodel) == [1, 2, 3] +## assert len(list(await config_module.servermodel[3].config.parents())) == 2 +## await delete_session() +## +## +## @pytest.mark.asyncio +## async def test_servermodel_updated1(): +## config_module = dispatcher.get_service('config') +## fake_context = get_fake_context('config') +## config_module.cache_root_path = 'tests/data' +## await config_module.on_join(fake_context) +## # +## assert list(config_module.servermodel) == [1, 2] +## metaconfig1 = config_module.servermodel[1] +## metaconfig2 = config_module.servermodel[2] +## mixconfig1 = (await metaconfig1.config.list())[0] +## mixconfig2 = (await metaconfig2.config.list())[0] +## assert len(list(await metaconfig1.config.parents())) == 0 +## assert len(list(await metaconfig2.config.parents())) == 1 +## assert len(list(await mixconfig1.config.list())) == 1 +## assert len(list(await mixconfig2.config.list())) == 0 +## # +## await dispatcher.publish('v1', +## 'servermodel.updated', +## fake_context, +## servermodel_id=1, +## servermodel_name='name1-1', +## release_id=1, +## servermodel_description='name1-1') +## assert set(config_module.servermodel) == {1, 2} +## assert config_module.servermodel[1].information.get('servermodel_name') == 'name1-1' +## assert metaconfig1 != config_module.servermodel[1] +## assert metaconfig2 == config_module.servermodel[2] +## metaconfig1 = config_module.servermodel[1] +## assert mixconfig1 != next(metaconfig1.config.list()) +## mixconfig1 = next(metaconfig1.config.list()) +## # +## assert len(list(await metaconfig1.config.parents())) == 0 +## assert len(list(await metaconfig2.config.parents())) == 1 +## assert len(list(await mixconfig1.config.list())) == 1 +## assert len(list(await mixconfig2.config.list())) == 0 +## await delete_session() +## +## +## @pytest.mark.asyncio +## async def test_servermodel_updated2(): +## config_module = dispatcher.get_service('config') +## fake_context = get_fake_context('config') +## config_module.cache_root_path = 'tests/data' +## await config_module.on_join(fake_context) +## # create a new servermodel +## assert list(config_module.servermodel) == [1, 2] +## mixconfig1 = next(config_module.servermodel[1].config.list()) +## mixconfig2 = next(config_module.servermodel[2].config.list()) +## assert len(list(mixconfig1.config.list())) == 1 +## assert len(list(mixconfig2.config.list())) == 0 +## await dispatcher.publish('v1', +## 'servermodel.created', +## fake_context, +## servermodel_id=3, +## servermodel_name='name3', +## release_id=1, +## servermodel_description='name3', +## servermodel_parents_id=[1]) +## assert list(config_module.servermodel) == [1, 2, 3] +## assert len(list(await config_module.servermodel[3].config.parents())) == 1 +## assert await config_module.servermodel[3].information.get('servermodel_name') == 'name3' +## assert len(list(await mixconfig1.config.list())) == 2 +## assert len(list(await mixconfig2.config.list())) == 0 +## # +## await dispatcher.publish('v1', +## 'servermodel.updated', +## fake_context, +## servermodel_id=3, +## servermodel_name='name3-1', +## release_id=1, +## servermodel_description='name3-1', +## servermodel_parents_id=[1, 2]) +## assert list(config_module.servermodel) == [1, 2, 3] +## assert config_module.servermodel[3].information.get('servermodel_name') == 'name3-1' +## assert len(list(mixconfig1.config.list())) == 2 +## assert len(list(mixconfig2.config.list())) == 1 +## await delete_session() +## +## +## @pytest.mark.asyncio +## async def test_servermodel_updated_config(): +## config_module = dispatcher.get_service('config') +## fake_context = get_fake_context('config') +## config_module.cache_root_path = 'tests/data' +## await config_module.on_join(fake_context) +## # +## config_module.servermodel[1].property.read_write() +## assert config_module.servermodel[1].option('configuration.general.mode_conteneur_actif').value.get() == 'non' +## config_module.servermodel[1].option('configuration.general.mode_conteneur_actif').value.set('oui') +## assert config_module.servermodel[1].option('configuration.general.mode_conteneur_actif').value.get() == 'oui' +## # +## await dispatcher.publish('v1', +## 'servermodel.updated', +## fake_context, +## servermodel_id=1, +## servermodel_name='name1-1', +## release_id=1, +## servermodel_description='name1-1') +## assert config_module.servermodel[1].option('configuration.general.mode_conteneur_actif').value.get() == 'oui' +## await delete_session() + + +############################################################################################################################## +# Server +############################################################################################################################## +@pytest.mark.asyncio +async def test_server_created_base(): + await onjoin() + config_module = dispatcher.get_service('config') + fake_context = get_fake_context('config') + # + assert list(config_module.server) == [] + await dispatcher.on_join(truncate=True) + server_name = 'dns.test.lan' + await dispatcher.publish('v1', + 'infra.server.created', + fake_context, + server_name=server_name, + server_description='description_created', + servermodel_name='base', + release_distribution='last', + site_name='site_1', + zones_name=['zones'], + zones_ip=['1.1.1.1'], + ) + assert list(config_module.server) == [server_name] + assert set(config_module.server[server_name]) == {'server', 'server_to_deploy', 'funcs_file'} + assert config_module.server[server_name]['funcs_file'] == '/var/cache/risotto/servermodel/last/base/funcs.py' + await delete_session() + + +@pytest.mark.asyncio +async def test_server_created_own_sm(): + await onjoin() + config_module = dispatcher.get_service('config') + fake_context = get_fake_context('config') + await create_servermodel() + # + assert list(config_module.server) == [] + await dispatcher.on_join(truncate=True) + server_name = 'dns.test.lan' + await dispatcher.publish('v1', + 'infra.server.created', + fake_context, + server_name=server_name, + server_description='description_created', + servermodel_name=SERVERMODEL_NAME, + source_name=SOURCE_NAME, + release_distribution='last', + site_name='site_1', + zones_name=['zones'], + zones_ip=['1.1.1.1'], + ) + assert list(config_module.server) == [server_name] + assert set(config_module.server[server_name]) == {'server', 'server_to_deploy', 'funcs_file'} + assert config_module.server[server_name]['funcs_file'] == '/var/cache/risotto/servermodel/last/sm1/funcs.py' + await delete_session() + + +#@pytest.mark.asyncio +#async def test_server_deleted(): +# config_module = dispatcher.get_service('config') +# config_module.cache_root_path = 'tests/data' +# await config_module.on_join(fake_context) +# # +# assert list(config_module.server) == [3] +# await dispatcher.publish('v1', +# 'server.created', +# fake_context, +# server_id=4, +# server_name='name4', +# server_description='description4', +# server_servermodel_id=2) +# assert list(config_module.server) == [3, 4] +# await dispatcher.publish('v1', +# 'server.deleted', +# fake_context, +# server_id=4) +# assert list(config_module.server) == [3] +# await delete_session() @pytest.mark.asyncio async def test_server_configuration_get(): + await onjoin() config_module = dispatcher.get_service('config') fake_context = get_fake_context('config') - config_module.cache_root_path = 'tests/data' - await config_module.on_join(fake_context) + await create_servermodel() + await dispatcher.on_join(truncate=True) + server_name = 'dns.test.lan' + await dispatcher.publish('v1', + 'infra.server.created', + fake_context, + server_name=server_name, + server_description='description_created', + servermodel_name=SERVERMODEL_NAME, + source_name=SOURCE_NAME, + release_distribution='last', + site_name='site_1', + zones_name=['zones'], + zones_ip=['1.1.1.1'], + ) # - await config_module.server[3]['server_to_deploy'].property.read_write() - assert await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.get() == 'non' - await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.set('oui') - assert await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.get() == 'oui' - assert await config_module.server[3]['server'].option('creole.general.mode_conteneur_actif').value.get() == 'non' + await config_module.server[server_name]['server'].property.read_write() + assert await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.get() == 1 + await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.set(2) + assert await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.get() == 2 + assert await config_module.server[server_name]['server_to_deploy'].option('configuration.general.number_of_interfaces').value.get() == 1 # + configuration = {'server_name': server_name, + 'deployed': False, + 'configuration': {'configuration.general.number_of_interfaces': 1, + 'configuration.general.interfaces_list': [0], + 'configuration.interface_0.domain_name_eth0': 'dns.test.lan' + } + } values = await dispatcher.call('v1', - 'config.configuration.server.get', + 'setting.config.configuration.server.get', fake_context, - server_id=3) - configuration = {'configuration': - {'creole.general.mode_conteneur_actif': 'non', - 'creole.general.master.master': [], - 'creole.general.master.slave1': [], - 'creole.general.master.slave2': [], - 'containers.container0.files.file0.mkdir': False, - 'containers.container0.files.file0.name': '/etc/mailname', - 'containers.container0.files.file0.rm': False, - 'containers.container0.files.file0.source': 'mailname', - 'containers.container0.files.file0.activate': True}, - 'server_id': 3, - 'deployed': True} + server_name=server_name, + deployed=False, + ) + assert values == configuration # - values = await dispatcher.call('v1', - 'config.configuration.server.get', - fake_context, - server_id=3, - deployed=False) - configuration['configuration']['creole.general.mode_conteneur_actif'] = 'oui' - configuration['deployed'] = False - assert values == configuration + await delete_session() @pytest.mark.asyncio -async def test_config_deployed(): +async def test_server_configuration_deployed(): + await onjoin() config_module = dispatcher.get_service('config') fake_context = get_fake_context('config') - config_module.cache_root_path = 'tests/data' - await config_module.on_join(fake_context) + await create_servermodel() + await dispatcher.on_join(truncate=True) + server_name = 'dns.test.lan' + await dispatcher.publish('v1', + 'infra.server.created', + fake_context, + server_name=server_name, + server_description='description_created', + servermodel_name=SERVERMODEL_NAME, + source_name=SOURCE_NAME, + release_distribution='last', + site_name='site_1', + zones_name=['zones'], + zones_ip=['1.1.1.1'], + ) # - await config_module.server[3]['server_to_deploy'].property.read_write() - assert await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.get() == 'non' - await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.set('oui') - assert await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.get() == 'oui' - assert await config_module.server[3]['server'].option('creole.general.mode_conteneur_actif').value.get() == 'non' - values = await dispatcher.publish('v1', - 'config.configuration.server.deploy', - fake_context, - server_id=3) - assert await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.get() == 'oui' - assert await config_module.server[3]['server'].option('creole.general.mode_conteneur_actif').value.get() == 'oui' + await config_module.server[server_name]['server'].property.read_write() + assert await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.get() == 1 + await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.set(2) + assert await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.get() == 2 + assert await config_module.server[server_name]['server_to_deploy'].option('configuration.general.number_of_interfaces').value.get() == 1 + # + configuration = {'server_name': server_name, + 'deployed': False, + 'configuration': {'configuration.general.number_of_interfaces': 1, + 'configuration.general.interfaces_list': [0], + 'configuration.interface_0.domain_name_eth0': 'dns.test.lan' + } + } + try: + await dispatcher.call('v1', + 'setting.config.configuration.server.get', + fake_context, + server_name=server_name, + ) + except: + pass + else: + raise Exception('should raise propertyerror') + + values = await dispatcher.call('v1', + 'setting.config.configuration.server.deploy', + fake_context, + server_name=server_name, + ) + assert values == {'server_name': 'dns.test.lan', 'deployed': True} + await dispatcher.call('v1', + 'setting.config.configuration.server.get', + fake_context, + server_name=server_name, + ) + + # + await delete_session() diff --git a/tests/test_session.py b/tests/test_session.py index e3663ea..a801166 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -2,7 +2,7 @@ from importlib import import_module import pytest from .storage import STORAGE from risotto.context import Context -from risotto.services import load_services +#from risotto.services import load_services from risotto.dispatcher import dispatcher from risotto.services.session.storage import storage_server, storage_servermodel @@ -16,9 +16,9 @@ def get_fake_context(module_name): def setup_module(module): - load_services(['config', 'session'], - validate=False, - test=True) + #load_services(['config', 'session'], + # validate=False, + # test=True) config_module = dispatcher.get_service('config') config_module.save_storage = STORAGE dispatcher.set_module('server', import_module(f'.server', 'fake_services'), True)