diff --git a/src/risotto/dispatcher.py b/src/risotto/dispatcher.py index ac29dfb..bba0446 100644 --- a/src/risotto/dispatcher.py +++ b/src/risotto/dispatcher.py @@ -59,7 +59,7 @@ class CallDispatcher: mandatories = await config.value.mandatory() if mandatories: mand = [mand.split('.')[-1] for mand in mandatories] - raise ValueError(_(f'missing parameters in response: {mand} in message "{risotto_context.message}"')) + raise ValueError(_(f'missing parameters in response of the message "{risotto_context.version}.{risotto_context.message}": {mand} in message')) try: await config.value.dict() except Exception as err: @@ -72,7 +72,9 @@ class CallDispatcher: message: str, old_risotto_context: Context, check_role: bool=False, - **kwargs): + internal: bool=True, + **kwargs, + ): """ execute the function associate with specified uri arguments are validate before """ @@ -80,6 +82,10 @@ class CallDispatcher: version, message, 'rpc') + if version not in self.messages: + raise CallError(_(f'cannot find version of message "{version}"')) + if message not in self.messages[version]: + raise CallError(_(f'cannot find message "{version}.{message}"')) function_objs = [self.messages[version][message]] # do not start a new database connection if hasattr(old_risotto_context, 'connection'): @@ -89,7 +95,9 @@ class CallDispatcher: risotto_context, check_role, kwargs, - function_objs) + function_objs, + internal, + ) else: try: async with self.pool.acquire() as connection: @@ -106,7 +114,9 @@ class CallDispatcher: risotto_context, check_role, kwargs, - function_objs) + function_objs, + internal, + ) except CallError as err: raise err except Exception as err: @@ -132,7 +142,9 @@ class PublishDispatcher: message: str, old_risotto_context: Context, check_role: bool=False, - **kwargs) -> None: + internal: bool=True, + **kwargs, + ) -> None: risotto_context = self.build_new_context(old_risotto_context, version, message, @@ -149,7 +161,9 @@ class PublishDispatcher: risotto_context, check_role, kwargs, - function_objs) + function_objs, + internal, + ) try: async with self.pool.acquire() as connection: await connection.set_type_codec( @@ -165,7 +179,9 @@ class PublishDispatcher: risotto_context, check_role, kwargs, - function_objs) + function_objs, + internal, + ) except CallError as err: raise err except Exception as err: @@ -222,7 +238,9 @@ class Dispatcher(register.RegisterDispatcher, risotto_context: Context, uri: str, kwargs: Dict, - check_role: bool): + check_role: bool, + internal: bool, + ): """ create a new Config et set values to it """ # create a new config @@ -232,13 +250,17 @@ class Dispatcher(register.RegisterDispatcher, await config.option('message').value.set(risotto_context.message) # store values subconfig = config.option(risotto_context.message) + extra_parameters = {} for key, value in kwargs.items(): - try: - await subconfig.option(key).value.set(value) - except AttributeError: - if get_config()['global']['debug']: - print_exc() - raise ValueError(_(f'unknown parameter in "{uri}": "{key}"')) + if not internal or not key.startswith('_'): + try: + await subconfig.option(key).value.set(value) + except AttributeError: + if get_config()['global']['debug']: + print_exc() + raise ValueError(_(f'unknown parameter in "{uri}": "{key}"')) + else: + extra_parameters[key] = value # check mandatories options if check_role and get_config().get('global').get('check_role'): await self.check_role(subconfig, @@ -250,7 +272,10 @@ class Dispatcher(register.RegisterDispatcher, mand = [mand.split('.')[-1] for mand in mandatories] raise ValueError(_(f'missing parameters in "{uri}": {mand}')) # return complete an validated kwargs - return await subconfig.value.dict() + parameters = await subconfig.value.dict() + if extra_parameters: + parameters.update(extra_parameters) + return parameters def get_service(self, name: str): @@ -309,13 +334,17 @@ class Dispatcher(register.RegisterDispatcher, risotto_context: Context, check_role: bool, kwargs: Dict, - function_objs: List) -> Optional[Dict]: + function_objs: List, + internal: bool, + ) -> Optional[Dict]: await self.check_message_type(risotto_context, kwargs) config_arguments = await self.load_kwargs_to_config(risotto_context, f'{version}.{message}', kwargs, - check_role) + check_role, + internal, + ) # config is ok, so send the message for function_obj in function_objs: function = function_obj['function'] diff --git a/src/risotto/http.py b/src/risotto/http.py index eedd3ce..cdbd3ff 100644 --- a/src/risotto/http.py +++ b/src/risotto/http.py @@ -82,7 +82,9 @@ async def handle(request): message, risotto_context, check_role=True, - **kwargs) + internal=False, + **kwargs, + ) except NotAllowedError as err: raise HTTPNotFound(reason=str(err)) except CallError as err: diff --git a/src/risotto/register.py b/src/risotto/register.py index 4c7cd92..6a85a1d 100644 --- a/src/risotto/register.py +++ b/src/risotto/register.py @@ -136,6 +136,9 @@ class RegisterDispatcher: function_args = self.get_function_args(function) # compare message arguments with function parameter # it must not have more or less arguments + for arg in function_args - message_args: + if arg.startswith('_'): + message_args.add(arg) if message_args != function_args: # raise if arguments are not equal msg = [] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_config.py b/tests/test_config.py index 83eda7a..1105616 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -66,22 +66,117 @@ async def onjoin(source=True): ) +INTERNAL_SOURCE = {'source_name': 'internal', 'source_directory': '/srv/risotto/seed/internal'} +TEST_SOURCE = {'source_name': 'test', 'source_directory': 'tests/data'} + + +############################################################################################################################## +# Source / Release +############################################################################################################################## @pytest.mark.asyncio -async def test_on_join(): +async def test_source_on_join(): + # onjoin must create internal source + sources = [INTERNAL_SOURCE] await onjoin(False) + fake_context = get_fake_context('config') + assert await dispatcher.call('v1', + 'setting.source.list', + fake_context, + ) == sources await delete_session() @pytest.mark.asyncio async def test_source_create(): + sources = [INTERNAL_SOURCE, TEST_SOURCE] await onjoin() config_module = dispatcher.get_service('config') assert list(config_module.servermodel.keys()) == ['last_base'] assert list(config_module.server) == [] + fake_context = get_fake_context('config') + assert await dispatcher.call('v1', + 'setting.source.list', + fake_context, + ) == sources await delete_session() -# FIXME {source|release}.list {source|release}.describe {source|release}.delete, ... +@pytest.mark.asyncio +async def test_source_describe(): + await onjoin() + fake_context = get_fake_context('config') + assert await dispatcher.call('v1', + 'setting.source.describe', + fake_context, + source_name='internal', + ) == INTERNAL_SOURCE + assert await dispatcher.call('v1', + 'setting.source.describe', + fake_context, + source_name=SOURCE_NAME, + ) == TEST_SOURCE + await delete_session() + + +@pytest.mark.asyncio +async def test_release_internal_list(): + releases = [{'release_distribution': 'last', + 'release_name': 'none', + 'source_name': 'internal'}] + + await onjoin() + fake_context = get_fake_context('config') + assert await dispatcher.call('v1', + 'setting.source.release.list', + fake_context, + source_name='internal', + ) == releases + await delete_session() + + +@pytest.mark.asyncio +async def test_release_list(): + releases = [{'release_distribution': 'last', + 'release_name': '1', + 'source_name': 'test'}] + + await onjoin() + fake_context = get_fake_context('config') + assert await dispatcher.call('v1', + 'setting.source.release.list', + fake_context, + source_name='test', + ) == releases + await delete_session() + + +@pytest.mark.asyncio +async def test_release_describe(): + + await onjoin() + fake_context = get_fake_context('config') + assert await dispatcher.call('v1', + 'setting.source.release.describe', + fake_context, + source_name='internal', + release_distribution='last', + ) == {'release_distribution': 'last', + 'release_name': 'none', + 'source_name': 'internal'} + assert await dispatcher.call('v1', + 'setting.source.release.describe', + fake_context, + source_name='test', + release_distribution='last', + ) == {'release_distribution': 'last', + 'release_name': '1', + 'source_name': 'test'} + await delete_session() + + +############################################################################################################################## +# Servermodel +############################################################################################################################## async def create_servermodel(name=SERVERMODEL_NAME, parents_name=['base'], ): @@ -97,9 +192,6 @@ async def create_servermodel(name=SERVERMODEL_NAME, ) -################################################################################################################################### -# Servermodel -################################################################################################################################### @pytest.mark.asyncio async def test_servermodel_created(): await onjoin() @@ -111,28 +203,6 @@ async def test_servermodel_created(): assert not list(await config_module.servermodel['last_base'].config.parents()) assert len(list(await config_module.servermodel['last_sm1'].config.parents())) == 1 await delete_session() - - -@pytest.mark.asyncio -async def test_servermodel_created(): - await onjoin() - config_module = dispatcher.get_service('config') - fake_context = get_fake_context('config') - # - assert list(config_module.servermodel) == ['last_base'] - await dispatcher.call('v1', - 'setting.servermodel.create', - fake_context, - servermodel_name='sm1', - servermodel_description='servermodel 1', - parents_name=['base'], - source_name=SOURCE_NAME, - release_distribution='last', - ) - assert list(config_module.servermodel) == ['last_base', 'last_sm1'] - assert not list(await config_module.servermodel['last_base'].config.parents()) - assert len(list(await config_module.servermodel['last_sm1'].config.parents())) == 1 - await delete_session() # # #@pytest.mark.asyncio @@ -301,9 +371,9 @@ async def test_servermodel_created(): ## await delete_session() -################################################################################################################################### +############################################################################################################################## # Server -################################################################################################################################### +############################################################################################################################## @pytest.mark.asyncio async def test_server_created_base(): await onjoin()