update tests

This commit is contained in:
Emmanuel Garette 2020-08-26 10:56:34 +02:00
parent c309ebbd56
commit ca101cf094
5 changed files with 151 additions and 47 deletions

View File

@ -59,7 +59,7 @@ class CallDispatcher:
mandatories = await config.value.mandatory() mandatories = await config.value.mandatory()
if mandatories: if mandatories:
mand = [mand.split('.')[-1] for mand in mandatories] mand = [mand.split('.')[-1] for mand in mandatories]
raise ValueError(_(f'missing parameters in response: {mand} in message "{risotto_context.message}"')) raise ValueError(_(f'missing parameters in response of the message "{risotto_context.version}.{risotto_context.message}": {mand} in message'))
try: try:
await config.value.dict() await config.value.dict()
except Exception as err: except Exception as err:
@ -72,7 +72,9 @@ class CallDispatcher:
message: str, message: str,
old_risotto_context: Context, old_risotto_context: Context,
check_role: bool=False, check_role: bool=False,
**kwargs): internal: bool=True,
**kwargs,
):
""" execute the function associate with specified uri """ execute the function associate with specified uri
arguments are validate before arguments are validate before
""" """
@ -80,6 +82,10 @@ class CallDispatcher:
version, version,
message, message,
'rpc') 'rpc')
if version not in self.messages:
raise CallError(_(f'cannot find version of message "{version}"'))
if message not in self.messages[version]:
raise CallError(_(f'cannot find message "{version}.{message}"'))
function_objs = [self.messages[version][message]] function_objs = [self.messages[version][message]]
# do not start a new database connection # do not start a new database connection
if hasattr(old_risotto_context, 'connection'): if hasattr(old_risotto_context, 'connection'):
@ -89,7 +95,9 @@ class CallDispatcher:
risotto_context, risotto_context,
check_role, check_role,
kwargs, kwargs,
function_objs) function_objs,
internal,
)
else: else:
try: try:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
@ -106,7 +114,9 @@ class CallDispatcher:
risotto_context, risotto_context,
check_role, check_role,
kwargs, kwargs,
function_objs) function_objs,
internal,
)
except CallError as err: except CallError as err:
raise err raise err
except Exception as err: except Exception as err:
@ -132,7 +142,9 @@ class PublishDispatcher:
message: str, message: str,
old_risotto_context: Context, old_risotto_context: Context,
check_role: bool=False, check_role: bool=False,
**kwargs) -> None: internal: bool=True,
**kwargs,
) -> None:
risotto_context = self.build_new_context(old_risotto_context, risotto_context = self.build_new_context(old_risotto_context,
version, version,
message, message,
@ -149,7 +161,9 @@ class PublishDispatcher:
risotto_context, risotto_context,
check_role, check_role,
kwargs, kwargs,
function_objs) function_objs,
internal,
)
try: try:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
await connection.set_type_codec( await connection.set_type_codec(
@ -165,7 +179,9 @@ class PublishDispatcher:
risotto_context, risotto_context,
check_role, check_role,
kwargs, kwargs,
function_objs) function_objs,
internal,
)
except CallError as err: except CallError as err:
raise err raise err
except Exception as err: except Exception as err:
@ -222,7 +238,9 @@ class Dispatcher(register.RegisterDispatcher,
risotto_context: Context, risotto_context: Context,
uri: str, uri: str,
kwargs: Dict, kwargs: Dict,
check_role: bool): check_role: bool,
internal: bool,
):
""" create a new Config et set values to it """ create a new Config et set values to it
""" """
# create a new config # create a new config
@ -232,13 +250,17 @@ class Dispatcher(register.RegisterDispatcher,
await config.option('message').value.set(risotto_context.message) await config.option('message').value.set(risotto_context.message)
# store values # store values
subconfig = config.option(risotto_context.message) subconfig = config.option(risotto_context.message)
extra_parameters = {}
for key, value in kwargs.items(): for key, value in kwargs.items():
if not internal or not key.startswith('_'):
try: try:
await subconfig.option(key).value.set(value) await subconfig.option(key).value.set(value)
except AttributeError: except AttributeError:
if get_config()['global']['debug']: if get_config()['global']['debug']:
print_exc() print_exc()
raise ValueError(_(f'unknown parameter in "{uri}": "{key}"')) raise ValueError(_(f'unknown parameter in "{uri}": "{key}"'))
else:
extra_parameters[key] = value
# check mandatories options # check mandatories options
if check_role and get_config().get('global').get('check_role'): if check_role and get_config().get('global').get('check_role'):
await self.check_role(subconfig, await self.check_role(subconfig,
@ -250,7 +272,10 @@ class Dispatcher(register.RegisterDispatcher,
mand = [mand.split('.')[-1] for mand in mandatories] mand = [mand.split('.')[-1] for mand in mandatories]
raise ValueError(_(f'missing parameters in "{uri}": {mand}')) raise ValueError(_(f'missing parameters in "{uri}": {mand}'))
# return complete an validated kwargs # return complete an validated kwargs
return await subconfig.value.dict() parameters = await subconfig.value.dict()
if extra_parameters:
parameters.update(extra_parameters)
return parameters
def get_service(self, def get_service(self,
name: str): name: str):
@ -309,13 +334,17 @@ class Dispatcher(register.RegisterDispatcher,
risotto_context: Context, risotto_context: Context,
check_role: bool, check_role: bool,
kwargs: Dict, kwargs: Dict,
function_objs: List) -> Optional[Dict]: function_objs: List,
internal: bool,
) -> Optional[Dict]:
await self.check_message_type(risotto_context, await self.check_message_type(risotto_context,
kwargs) kwargs)
config_arguments = await self.load_kwargs_to_config(risotto_context, config_arguments = await self.load_kwargs_to_config(risotto_context,
f'{version}.{message}', f'{version}.{message}',
kwargs, kwargs,
check_role) check_role,
internal,
)
# config is ok, so send the message # config is ok, so send the message
for function_obj in function_objs: for function_obj in function_objs:
function = function_obj['function'] function = function_obj['function']

View File

@ -82,7 +82,9 @@ async def handle(request):
message, message,
risotto_context, risotto_context,
check_role=True, check_role=True,
**kwargs) internal=False,
**kwargs,
)
except NotAllowedError as err: except NotAllowedError as err:
raise HTTPNotFound(reason=str(err)) raise HTTPNotFound(reason=str(err))
except CallError as err: except CallError as err:

View File

@ -136,6 +136,9 @@ class RegisterDispatcher:
function_args = self.get_function_args(function) function_args = self.get_function_args(function)
# compare message arguments with function parameter # compare message arguments with function parameter
# it must not have more or less arguments # it must not have more or less arguments
for arg in function_args - message_args:
if arg.startswith('_'):
message_args.add(arg)
if message_args != function_args: if message_args != function_args:
# raise if arguments are not equal # raise if arguments are not equal
msg = [] msg = []

0
tests/__init__.py Normal file
View File

View File

@ -66,22 +66,117 @@ async def onjoin(source=True):
) )
INTERNAL_SOURCE = {'source_name': 'internal', 'source_directory': '/srv/risotto/seed/internal'}
TEST_SOURCE = {'source_name': 'test', 'source_directory': 'tests/data'}
##############################################################################################################################
# Source / Release
##############################################################################################################################
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_join(): async def test_source_on_join():
# onjoin must create internal source
sources = [INTERNAL_SOURCE]
await onjoin(False) await onjoin(False)
fake_context = get_fake_context('config')
assert await dispatcher.call('v1',
'setting.source.list',
fake_context,
) == sources
await delete_session() await delete_session()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_source_create(): async def test_source_create():
sources = [INTERNAL_SOURCE, TEST_SOURCE]
await onjoin() await onjoin()
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
assert list(config_module.servermodel.keys()) == ['last_base'] assert list(config_module.servermodel.keys()) == ['last_base']
assert list(config_module.server) == [] assert list(config_module.server) == []
fake_context = get_fake_context('config')
assert await dispatcher.call('v1',
'setting.source.list',
fake_context,
) == sources
await delete_session() await delete_session()
# FIXME {source|release}.list {source|release}.describe {source|release}.delete, ...
@pytest.mark.asyncio
async def test_source_describe():
await onjoin()
fake_context = get_fake_context('config')
assert await dispatcher.call('v1',
'setting.source.describe',
fake_context,
source_name='internal',
) == INTERNAL_SOURCE
assert await dispatcher.call('v1',
'setting.source.describe',
fake_context,
source_name=SOURCE_NAME,
) == TEST_SOURCE
await delete_session()
@pytest.mark.asyncio
async def test_release_internal_list():
releases = [{'release_distribution': 'last',
'release_name': 'none',
'source_name': 'internal'}]
await onjoin()
fake_context = get_fake_context('config')
assert await dispatcher.call('v1',
'setting.source.release.list',
fake_context,
source_name='internal',
) == releases
await delete_session()
@pytest.mark.asyncio
async def test_release_list():
releases = [{'release_distribution': 'last',
'release_name': '1',
'source_name': 'test'}]
await onjoin()
fake_context = get_fake_context('config')
assert await dispatcher.call('v1',
'setting.source.release.list',
fake_context,
source_name='test',
) == releases
await delete_session()
@pytest.mark.asyncio
async def test_release_describe():
await onjoin()
fake_context = get_fake_context('config')
assert await dispatcher.call('v1',
'setting.source.release.describe',
fake_context,
source_name='internal',
release_distribution='last',
) == {'release_distribution': 'last',
'release_name': 'none',
'source_name': 'internal'}
assert await dispatcher.call('v1',
'setting.source.release.describe',
fake_context,
source_name='test',
release_distribution='last',
) == {'release_distribution': 'last',
'release_name': '1',
'source_name': 'test'}
await delete_session()
##############################################################################################################################
# Servermodel
##############################################################################################################################
async def create_servermodel(name=SERVERMODEL_NAME, async def create_servermodel(name=SERVERMODEL_NAME,
parents_name=['base'], parents_name=['base'],
): ):
@ -97,9 +192,6 @@ async def create_servermodel(name=SERVERMODEL_NAME,
) )
###################################################################################################################################
# Servermodel
###################################################################################################################################
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_servermodel_created(): async def test_servermodel_created():
await onjoin() await onjoin()
@ -111,28 +203,6 @@ async def test_servermodel_created():
assert not list(await config_module.servermodel['last_base'].config.parents()) assert not list(await config_module.servermodel['last_base'].config.parents())
assert len(list(await config_module.servermodel['last_sm1'].config.parents())) == 1 assert len(list(await config_module.servermodel['last_sm1'].config.parents())) == 1
await delete_session() await delete_session()
@pytest.mark.asyncio
async def test_servermodel_created():
await onjoin()
config_module = dispatcher.get_service('config')
fake_context = get_fake_context('config')
#
assert list(config_module.servermodel) == ['last_base']
await dispatcher.call('v1',
'setting.servermodel.create',
fake_context,
servermodel_name='sm1',
servermodel_description='servermodel 1',
parents_name=['base'],
source_name=SOURCE_NAME,
release_distribution='last',
)
assert list(config_module.servermodel) == ['last_base', 'last_sm1']
assert not list(await config_module.servermodel['last_base'].config.parents())
assert len(list(await config_module.servermodel['last_sm1'].config.parents())) == 1
await delete_session()
# #
# #
#@pytest.mark.asyncio #@pytest.mark.asyncio
@ -301,9 +371,9 @@ async def test_servermodel_created():
## await delete_session() ## await delete_session()
################################################################################################################################### ##############################################################################################################################
# Server # Server
################################################################################################################################### ##############################################################################################################################
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_server_created_base(): async def test_server_created_base():
await onjoin() await onjoin()