Merge branch 'develop' into dist/risotto/risotto-2.8.0/develop

This commit is contained in:
Emmanuel Garette 2020-09-16 08:16:33 +02:00
commit 5653de1e99
12 changed files with 902 additions and 525 deletions

8
sql/risotto.sql Normal file
View File

@ -0,0 +1,8 @@
CREATE TABLE log(
Msg VARCHAR(255) NOT NULL,
Level VARCHAR(10) NOT NULL,
Path VARCHAR(255),
Username VARCHAR(100) NOT NULL,
Data JSON,
Date timestamp DEFAULT current_timestamp
);

View File

@ -1,21 +1,81 @@
from os import environ from os import environ
from os.path import isfile
from configobj import ConfigObj
CONFIGURATION_DIR = environ.get('CONFIGURATION_DIR', '/srv/risotto/configurations') CONFIG_FILE = environ.get('CONFIG_FILE', '/etc/risotto/risotto.conf')
PROVIDER_FACTORY_CONFIG_DIR = environ.get('PROVIDER_FACTORY_CONFIG_DIR', '/srv/factory')
TMP_DIR = '/tmp'
DEFAULT_USER = environ.get('DEFAULT_USER', 'Anonymous') if isfile(CONFIG_FILE):
RISOTTO_DB_NAME = environ.get('RISOTTO_DB_NAME', 'risotto') config = ConfigObj(CONFIG_FILE)
RISOTTO_DB_PASSWORD = environ.get('RISOTTO_DB_PASSWORD', 'risotto') else:
RISOTTO_DB_USER = environ.get('RISOTTO_DB_USER', 'risotto') config = {}
TIRAMISU_DB_NAME = environ.get('TIRAMISU_DB_NAME', 'tiramisu')
TIRAMISU_DB_PASSWORD = environ.get('TIRAMISU_DB_PASSWORD', 'tiramisu')
TIRAMISU_DB_USER = environ.get('TIRAMISU_DB_USER', 'tiramisu') if 'RISOTTO_PORT' in environ:
DB_ADDRESS = environ.get('DB_ADDRESS', 'localhost') RISOTTO_PORT = environ['RISOTTO_PORT']
MESSAGE_PATH = environ.get('MESSAGE_PATH', '/root/risotto-message/messages') else:
SQL_DIR = environ.get('SQL_DIR', './sql') RISOTTO_PORT = config.get('RISOTTO_PORT', 8080)
CACHE_ROOT_PATH = environ.get('CACHE_ROOT_PATH', '/var/cache/risotto') if 'CONFIGURATION_DIR' in environ:
SRV_SEED_PATH = environ.get('SRV_SEED_PATH', '/srv/seed') CONFIGURATION_DIR = environ['CONFIGURATION_DIR']
else:
CONFIGURATION_DIR = config.get('CONFIGURATION_DIR', '/srv/risotto/configurations')
if 'PROVIDER_FACTORY_CONFIG_DIR' in environ:
PROVIDER_FACTORY_CONFIG_DIR = environ['PROVIDER_FACTORY_CONFIG_DIR']
else:
PROVIDER_FACTORY_CONFIG_DIR = config.get('PROVIDER_FACTORY_CONFIG_DIR', '/srv/factory')
if 'DEFAULT_USER' in environ:
DEFAULT_USER = environ['DEFAULT_USER']
else:
DEFAULT_USER = config.get('DEFAULT_USER', 'Anonymous')
if 'RISOTTO_DB_NAME' in environ:
RISOTTO_DB_NAME = environ['RISOTTO_DB_NAME']
else:
RISOTTO_DB_NAME = config.get('RISOTTO_DB_NAME', 'risotto')
if 'RISOTTO_DB_PASSWORD' in environ:
RISOTTO_DB_PASSWORD = environ['RISOTTO_DB_PASSWORD']
else:
RISOTTO_DB_PASSWORD = config.get('RISOTTO_DB_PASSWORD', 'risotto')
if 'RISOTTO_DB_USER' in environ:
RISOTTO_DB_USER = environ['RISOTTO_DB_USER']
else:
RISOTTO_DB_USER = config.get('RISOTTO_DB_USER', 'risotto')
if 'TIRAMISU_DB_NAME' in environ:
TIRAMISU_DB_NAME = environ['TIRAMISU_DB_NAME']
else:
TIRAMISU_DB_NAME = config.get('TIRAMISU_DB_NAME', 'tiramisu')
if 'TIRAMISU_DB_PASSWORD' in environ:
TIRAMISU_DB_PASSWORD = environ['TIRAMISU_DB_PASSWORD']
else:
TIRAMISU_DB_PASSWORD = config.get('TIRAMISU_DB_PASSWORD', 'tiramisu')
if 'TIRAMISU_DB_USER' in environ:
TIRAMISU_DB_USER = environ['TIRAMISU_DB_USER']
else:
TIRAMISU_DB_USER = config.get('TIRAMISU_DB_USER', 'tiramisu')
if 'DB_ADDRESS' in environ:
DB_ADDRESS = environ['DB_ADDRESS']
else:
DB_ADDRESS = config.get('DB_ADDRESS', 'localhost')
if 'MESSAGE_PATH' in environ:
MESSAGE_PATH = environ['MESSAGE_PATH']
else:
MESSAGE_PATH = config.get('MESSAGE_PATH', '/root/risotto-message/messages')
if 'SQL_DIR' in environ:
SQL_DIR = environ['SQL_DIR']
else:
SQL_DIR = config.get('SQL_DIR', './sql')
if 'CACHE_ROOT_PATH' in environ:
CACHE_ROOT_PATH = environ['CACHE_ROOT_PATH']
else:
CACHE_ROOT_PATH = config.get('CACHE_ROOT_PATH', '/var/cache/risotto')
if 'SRV_SEED_PATH' in environ:
SRV_SEED_PATH = environ['SRV_SEED_PATH']
else:
SRV_SEED_PATH = config.get('SRV_SEED_PATH', '/srv/seed')
if 'TMP_DIR' in environ:
TMP_DIR = environ['TMP_DIR']
else:
TMP_DIR = config.get('TMP_DIR', '/tmp')
def dsn_factory(database, user, password, address=DB_ADDRESS): def dsn_factory(database, user, password, address=DB_ADDRESS):
@ -23,26 +83,29 @@ def dsn_factory(database, user, password, address=DB_ADDRESS):
return f'postgres:///{database}?host={mangled_address}/&user={user}&password={password}' return f'postgres:///{database}?host={mangled_address}/&user={user}&password={password}'
def get_config(): _config = {'database': {'dsn': dsn_factory(RISOTTO_DB_NAME, RISOTTO_DB_USER, RISOTTO_DB_PASSWORD),
return {'database': {'dsn': dsn_factory(RISOTTO_DB_NAME, RISOTTO_DB_USER, RISOTTO_DB_PASSWORD),
'tiramisu_dsn': dsn_factory(TIRAMISU_DB_NAME, TIRAMISU_DB_USER, TIRAMISU_DB_PASSWORD), 'tiramisu_dsn': dsn_factory(TIRAMISU_DB_NAME, TIRAMISU_DB_USER, TIRAMISU_DB_PASSWORD),
}, },
'http_server': {'port': 8080, 'http_server': {'port': RISOTTO_PORT,
'default_user': DEFAULT_USER}, 'default_user': DEFAULT_USER},
'global': {'message_root_path': MESSAGE_PATH, 'global': {'message_root_path': MESSAGE_PATH,
'configurations_dir': CONFIGURATION_DIR, 'configurations_dir': CONFIGURATION_DIR,
'debug': True, 'debug': True,
'internal_user': 'internal', 'internal_user': '_internal',
'check_role': True, 'check_role': True,
'admin_user': DEFAULT_USER, 'admin_user': DEFAULT_USER,
'sql_dir': SQL_DIR}, 'sql_dir': SQL_DIR,
'tmp_dir': TMP_DIR,
},
'cache': {'root_path': CACHE_ROOT_PATH}, 'cache': {'root_path': CACHE_ROOT_PATH},
'servermodel': {'internal_source_path': SRV_SEED_PATH, 'servermodel': {'internal_source_path': SRV_SEED_PATH,
'internal_source': 'internal', 'internal_source': 'internal'},
'internal_distribution': 'last',
'internal_release_name': 'none'},
'submodule': {'allow_insecure_https': False, 'submodule': {'allow_insecure_https': False,
'pki': '192.168.56.112'}, 'pki': '192.168.56.112'},
'provider': {'factory_configuration_dir': PROVIDER_FACTORY_CONFIG_DIR, 'provider': {'factory_configuration_dir': PROVIDER_FACTORY_CONFIG_DIR,
'factory_configuration_filename': 'infra.json'}, 'factory_configuration_filename': 'infra.json'},
} }
def get_config():
return _config

View File

@ -1,8 +1,5 @@
from .config import get_config
from .dispatcher import dispatcher from .dispatcher import dispatcher
from .context import Context from .context import Context
from .remote import remote
from . import services
from .utils import _ from .utils import _
@ -10,50 +7,48 @@ class Controller:
"""Common controller used to add a service in Risotto """Common controller used to add a service in Risotto
""" """
def __init__(self, def __init__(self,
test: bool): test: bool,
self.risotto_modules = services.get_services_list() ):
pass
async def call(self, async def call(self,
uri: str, uri: str,
risotto_context: Context, risotto_context: Context,
*args, *args,
**kwargs): **kwargs,
):
""" a wrapper to dispatcher's call""" """ a wrapper to dispatcher's call"""
version, module, message = uri.split('.', 2)
uri = module + '.' + message
if args: if args:
raise ValueError(_(f'the URI "{uri}" can only be called with keyword arguments')) raise ValueError(_(f'the URI "{uri}" can only be called with keyword arguments'))
if module not in self.risotto_modules: current_uri = risotto_context.paths[-1]
return await remote.remote_call(module, current_module = risotto_context.module
version, version, message = uri.split('.', 1)
message, module = message.split('.', 1)[0]
kwargs) if current_module != module:
raise ValueError(_(f'cannot call to external module ("{module}") to the URI "{uri}" from "{current_module}"'))
return await dispatcher.call(version, return await dispatcher.call(version,
uri, message,
risotto_context, risotto_context,
**kwargs) **kwargs,
)
async def publish(self, async def publish(self,
uri: str, uri: str,
risotto_context: Context, risotto_context: Context,
*args, *args,
**kwargs): **kwargs,
):
""" a wrapper to dispatcher's publish""" """ a wrapper to dispatcher's publish"""
version, module, submessage = uri.split('.', 2)
version, message = uri.split('.', 1) version, message = uri.split('.', 1)
if args: if args:
raise ValueError(_(f'the URI "{uri}" can only be published with keyword arguments')) raise ValueError(_(f'the URI "{uri}" can only be published with keyword arguments'))
if module not in self.risotto_modules:
await remote.remote_call(module,
version,
submessage,
kwargs)
else:
await dispatcher.publish(version, await dispatcher.publish(version,
message, message,
risotto_context, risotto_context,
**kwargs) **kwargs,
)
async def on_join(self, async def on_join(self,
risotto_context): risotto_context,
):
pass pass

View File

@ -1,7 +1,9 @@
try: try:
from tiramisu3 import Config from tiramisu3 import Config
from tiramisu3.error import ValueOptionError
except: except:
from tiramisu import Config from tiramisu import Config
from tiramisu.error import ValueOptionError
from traceback import print_exc from traceback import print_exc
from copy import copy from copy import copy
from typing import Dict, Callable, List, Optional from typing import Dict, Callable, List, Optional
@ -13,8 +15,7 @@ from .logger import log
from .config import get_config from .config import get_config
from .context import Context from .context import Context
from . import register from . import register
#from .remote import Remote from .remote import Remote
import asyncpg
class CallDispatcher: class CallDispatcher:
@ -42,28 +43,28 @@ class CallDispatcher:
raise Exception('hu?') raise Exception('hu?')
else: else:
for ret in returns: for ret in returns:
async with await Config(response, display_name=lambda self, dyn_name: self.impl_getname()) as config: async with await Config(response, display_name=lambda self, dyn_name, suffix: self.impl_getname()) as config:
await config.property.read_write() await config.property.read_write()
try: try:
for key, value in ret.items(): for key, value in ret.items():
await config.option(key).value.set(value) await config.option(key).value.set(value)
except AttributeError: except AttributeError:
err = _(f'function {module_name}.{function_name} return the unknown parameter "{key}"') err = _(f'function {module_name}.{function_name} return the unknown parameter "{key}" for the uri "{risotto_context.version}.{risotto_context.message}"')
await log.error_msg(risotto_context, kwargs, err) await log.error_msg(risotto_context, kwargs, err)
raise CallError(str(err)) raise CallError(str(err))
except ValueError: except ValueError:
err = _(f'function {module_name}.{function_name} return the parameter "{key}" with an unvalid value "{value}"') err = _(f'function {module_name}.{function_name} return the parameter "{key}" with an unvalid value "{value}" for the uri "{risotto_context.version}.{risotto_context.message}"')
await log.error_msg(risotto_context, kwargs, err) await log.error_msg(risotto_context, kwargs, err)
raise CallError(str(err)) raise CallError(str(err))
await config.property.read_only() await config.property.read_only()
mandatories = await config.value.mandatory() mandatories = await config.value.mandatory()
if mandatories: if mandatories:
mand = [mand.split('.')[-1] for mand in mandatories] mand = [mand.split('.')[-1] for mand in mandatories]
raise ValueError(_(f'missing parameters in response: {mand} in message "{risotto_context.message}"')) raise ValueError(_(f'missing parameters in response of the uri "{risotto_context.version}.{risotto_context.message}": {mand} in message'))
try: try:
await config.value.dict() await config.value.dict()
except Exception as err: except Exception as err:
err = _(f'function {module_name}.{function_name} return an invalid response {err}') err = _(f'function {module_name}.{function_name} return an invalid response {err} for the uri "{risotto_context.version}.{risotto_context.message}"')
await log.error_msg(risotto_context, kwargs, err) await log.error_msg(risotto_context, kwargs, err)
raise CallError(str(err)) raise CallError(str(err))
@ -72,14 +73,21 @@ class CallDispatcher:
message: str, message: str,
old_risotto_context: Context, old_risotto_context: Context,
check_role: bool=False, check_role: bool=False,
**kwargs): internal: bool=True,
**kwargs,
):
""" execute the function associate with specified uri """ execute the function associate with specified uri
arguments are validate before arguments are validate before
""" """
risotto_context = self.build_new_context(old_risotto_context, risotto_context = self.build_new_context(old_risotto_context,
version, version,
message, message,
'rpc') 'rpc',
)
if version not in self.messages:
raise CallError(_(f'cannot find version of message "{version}"'))
if message not in self.messages[version]:
raise CallError(_(f'cannot find message "{version}.{message}"'))
function_objs = [self.messages[version][message]] function_objs = [self.messages[version][message]]
# do not start a new database connection # do not start a new database connection
if hasattr(old_risotto_context, 'connection'): if hasattr(old_risotto_context, 'connection'):
@ -89,7 +97,9 @@ class CallDispatcher:
risotto_context, risotto_context,
check_role, check_role,
kwargs, kwargs,
function_objs) function_objs,
internal,
)
else: else:
try: try:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
@ -106,7 +116,9 @@ class CallDispatcher:
risotto_context, risotto_context,
check_role, check_role,
kwargs, kwargs,
function_objs) function_objs,
internal,
)
except CallError as err: except CallError as err:
raise err raise err
except Exception as err: except Exception as err:
@ -132,26 +144,38 @@ class PublishDispatcher:
message: str, message: str,
old_risotto_context: Context, old_risotto_context: Context,
check_role: bool=False, check_role: bool=False,
**kwargs) -> None: internal: bool=True,
**kwargs,
) -> None:
risotto_context = self.build_new_context(old_risotto_context, risotto_context = self.build_new_context(old_risotto_context,
version, version,
message, message,
'event') 'event',
)
try: try:
function_objs = self.messages[version][message].get('functions', []) function_objs = self.messages[version][message].get('functions', [])
except KeyError: except KeyError:
raise ValueError(_(f'cannot find message {version}.{message}')) raise ValueError(_(f'cannot find message {version}.{message}'))
# do not start a new database connection # do not start a new database connection
if hasattr(old_risotto_context, 'connection'): if hasattr(old_risotto_context, 'connection'):
# publish to remove
remote_kw = dumps({'kwargs': kwargs,
'context': risotto_context.__dict__,
})
risotto_context.connection = old_risotto_context.connection risotto_context.connection = old_risotto_context.connection
# FIXME should be better :/
remote_kw = remote_kw.replace("'", "''")
await risotto_context.connection.execute(f'NOTIFY "{version}.{message}", \'{remote_kw}\'')
return await self.launch(version, return await self.launch(version,
message, message,
risotto_context, risotto_context,
check_role, check_role,
kwargs, kwargs,
function_objs) function_objs,
try: internal,
)
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
try:
await connection.set_type_codec( await connection.set_type_codec(
'json', 'json',
encoder=dumps, encoder=dumps,
@ -165,11 +189,13 @@ class PublishDispatcher:
risotto_context, risotto_context,
check_role, check_role,
kwargs, kwargs,
function_objs) function_objs,
internal,
)
except CallError as err: except CallError as err:
raise err pass
except Exception as err: except Exception as err:
# if there is a problem with arguments, just send an error and do nothing # if there is a problem with arguments, log and do nothing
if get_config()['global']['debug']: if get_config()['global']['debug']:
print_exc() print_exc()
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
@ -182,11 +208,10 @@ class PublishDispatcher:
risotto_context.connection = connection risotto_context.connection = connection
async with connection.transaction(): async with connection.transaction():
await log.error_msg(risotto_context, kwargs, err) await log.error_msg(risotto_context, kwargs, err)
raise err
class Dispatcher(register.RegisterDispatcher, class Dispatcher(register.RegisterDispatcher,
# Remote, Remote,
CallDispatcher, CallDispatcher,
PublishDispatcher): PublishDispatcher):
""" Manage message (call or publish) """ Manage message (call or publish)
@ -196,7 +221,8 @@ class Dispatcher(register.RegisterDispatcher,
old_risotto_context: Context, old_risotto_context: Context,
version: str, version: str,
message: str, message: str,
type: str): type: str,
) -> Context:
""" This is a new call or a new publish, so create a new context """ This is a new call or a new publish, so create a new context
""" """
uri = version + '.' + message uri = version + '.' + message
@ -212,7 +238,8 @@ class Dispatcher(register.RegisterDispatcher,
async def check_message_type(self, async def check_message_type(self,
risotto_context: Context, risotto_context: Context,
kwargs: Dict): kwargs: Dict,
) -> None:
if self.messages[risotto_context.version][risotto_context.message]['pattern'] != risotto_context.type: if self.messages[risotto_context.version][risotto_context.message]['pattern'] != risotto_context.type:
msg = _(f'{risotto_context.uri} is not a {risotto_context.type} message') msg = _(f'{risotto_context.uri} is not a {risotto_context.type} message')
await log.error_msg(risotto_context, kwargs, msg) await log.error_msg(risotto_context, kwargs, msg)
@ -222,23 +249,31 @@ class Dispatcher(register.RegisterDispatcher,
risotto_context: Context, risotto_context: Context,
uri: str, uri: str,
kwargs: Dict, kwargs: Dict,
check_role: bool): check_role: bool,
internal: bool,
):
""" create a new Config et set values to it """ create a new Config et set values to it
""" """
# create a new config # create a new config
async with await Config(self.option) as config: async with await Config(self.option) as config:
await config.property.read_write() await config.property.read_write()
# set message's option # set message's option
await config.option('message').value.set(risotto_context.message) await config.option('message').value.set(uri)
# store values # store values
subconfig = config.option(risotto_context.message) subconfig = config.option(uri)
extra_parameters = {}
for key, value in kwargs.items(): for key, value in kwargs.items():
if not internal or not key.startswith('_'):
try: try:
await subconfig.option(key).value.set(value) await subconfig.option(key).value.set(value)
except AttributeError: except AttributeError:
if get_config()['global']['debug']: if get_config()['global']['debug']:
print_exc() print_exc()
raise ValueError(_(f'unknown parameter in "{uri}": "{key}"')) raise ValueError(_(f'unknown parameter in "{uri}": "{key}"'))
except ValueOptionError as err:
raise ValueError(_(f'invalid parameter in "{uri}": {err}'))
else:
extra_parameters[key] = value
# check mandatories options # check mandatories options
if check_role and get_config().get('global').get('check_role'): if check_role and get_config().get('global').get('check_role'):
await self.check_role(subconfig, await self.check_role(subconfig,
@ -250,7 +285,10 @@ class Dispatcher(register.RegisterDispatcher,
mand = [mand.split('.')[-1] for mand in mandatories] mand = [mand.split('.')[-1] for mand in mandatories]
raise ValueError(_(f'missing parameters in "{uri}": {mand}')) raise ValueError(_(f'missing parameters in "{uri}": {mand}'))
# return complete an validated kwargs # return complete an validated kwargs
return await subconfig.value.dict() parameters = await subconfig.value.dict()
if extra_parameters:
parameters.update(extra_parameters)
return parameters
def get_service(self, def get_service(self,
name: str): name: str):
@ -265,7 +303,7 @@ class Dispatcher(register.RegisterDispatcher,
# Verify if user exists and get ID # Verify if user exists and get ID
sql = ''' sql = '''
SELECT UserId SELECT UserId
FROM RisottoUser FROM UserUser
WHERE UserLogin = $1 WHERE UserLogin = $1
''' '''
user_id = await connection.fetchval(sql, user_id = await connection.fetchval(sql,
@ -283,8 +321,8 @@ class Dispatcher(register.RegisterDispatcher,
# Check role # Check role
select_role_uri = ''' select_role_uri = '''
SELECT RoleName SELECT RoleName
FROM URI, RoleURI FROM UserURI, UserRoleURI
WHERE URI.URIName = $1 AND RoleURI.URIId = URI.URIId WHERE UserURI.URIName = $1 AND UserRoleURI.URIId = UserURI.URIId
''' '''
select_role_user = ''' select_role_user = '''
SELECT RoleAttribute, RoleAttributeValue SELECT RoleAttribute, RoleAttributeValue
@ -309,19 +347,24 @@ class Dispatcher(register.RegisterDispatcher,
risotto_context: Context, risotto_context: Context,
check_role: bool, check_role: bool,
kwargs: Dict, kwargs: Dict,
function_objs: List) -> Optional[Dict]: function_objs: List,
internal: bool,
) -> Optional[Dict]:
await self.check_message_type(risotto_context, await self.check_message_type(risotto_context,
kwargs) kwargs)
config_arguments = await self.load_kwargs_to_config(risotto_context, config_arguments = await self.load_kwargs_to_config(risotto_context,
f'{version}.{message}', f'{version}.{message}',
kwargs, kwargs,
check_role) check_role,
internal,
)
# config is ok, so send the message # config is ok, so send the message
for function_obj in function_objs: for function_obj in function_objs:
function = function_obj['function'] function = function_obj['function']
module_name = function.__module__.split('.')[-2] submodule_name = function_obj['module']
function_name = function.__name__ function_name = function.__name__
info_msg = _(f'in module {module_name}.{function_name}') risotto_context.module = submodule_name.split('.', 1)[0]
info_msg = _(f'in module {submodule_name}.{function_name}')
# build argument for this function # build argument for this function
if risotto_context.type == 'rpc': if risotto_context.type == 'rpc':
kw = config_arguments kw = config_arguments

View File

@ -20,13 +20,11 @@ from . import services
extra_routes = {} extra_routes = {}
RISOTTO_MODULES = services.get_services_list()
def create_context(request): def create_context(request):
risotto_context = Context() risotto_context = Context()
risotto_context.username = request.match_info.get('username', risotto_context.username = request.match_info.get('username',
get_config()['http_server']['default_user']) get_config()['http_server']['default_user'],
)
return risotto_context return risotto_context
@ -53,8 +51,9 @@ class extra_route_handler:
function_name = cls.function.__module__ function_name = cls.function.__module__
# if not 'api' function # if not 'api' function
if function_name != 'risotto.http': if function_name != 'risotto.http':
module_name = function_name.split('.')[-2] risotto_module_name, submodule_name = function_name.split('.', 2)[:-1]
kwargs['self'] = dispatcher.injected_self[module_name] module_name = risotto_module_name.split('_')[-1]
kwargs['self'] = dispatcher.injected_self[module_name + '.' + submodule_name]
try: try:
returns = await cls.function(**kwargs) returns = await cls.function(**kwargs)
except NotAllowedError as err: except NotAllowedError as err:
@ -85,7 +84,9 @@ async def handle(request):
message, message,
risotto_context, risotto_context,
check_role=True, check_role=True,
**kwargs) internal=False,
**kwargs,
)
except NotAllowedError as err: except NotAllowedError as err:
raise HTTPNotFound(reason=str(err)) raise HTTPNotFound(reason=str(err))
except CallError as err: except CallError as err:
@ -100,8 +101,8 @@ async def handle(request):
async def api(request, async def api(request,
risotto_context): risotto_context):
global tiramisu global TIRAMISU
if not tiramisu: if not TIRAMISU:
# check all URI that have an associated role # check all URI that have an associated role
# all URI without role is concidered has a private URI # all URI without role is concidered has a private URI
uris = [] uris = []
@ -109,18 +110,21 @@ async def api(request,
async with connection.transaction(): async with connection.transaction():
# Check role with ACL # Check role with ACL
sql = ''' sql = '''
SELECT URI.URIName SELECT UserURI.URIName
FROM URI, RoleURI FROM UserURI, UserRoleURI
WHERE RoleURI.URIId = URI.URIId WHERE UserRoleURI.URIId = UserURI.URIId
''' '''
uris = [uri['uriname'] for uri in await connection.fetch(sql)] uris = [uri['uriname'] for uri in await connection.fetch(sql)]
async with await Config(get_messages(current_module_names=RISOTTO_MODULES, risotto_modules = services.get_services_list()
async with await Config(get_messages(current_module_names=risotto_modules,
load_shortarg=True, load_shortarg=True,
current_version=risotto_context.version, current_version=risotto_context.version,
uris=uris)[1]) as config: uris=uris,
)[1],
display_name=lambda self, dyn_name, suffix: self.impl_getname()) as config:
await config.property.read_write() await config.property.read_write()
tiramisu = await config.option.dict(remotable='none') TIRAMISU = await config.option.dict(remotable='none')
return tiramisu return TIRAMISU
async def get_app(loop): async def get_app(loop):
@ -138,9 +142,9 @@ async def get_app(loop):
versions.append(version) versions.append(version)
print() print()
print(_('======== Registered messages ========')) print(_('======== Registered messages ========'))
for message in messages: for message, message_infos in messages.items():
web_message = f'/api/{version}/{message}' web_message = f'/api/{version}/{message}'
pattern = dispatcher.messages[version][message]['pattern'] pattern = message_infos['pattern']
print(f' - {web_message} ({pattern})') print(f' - {web_message} ({pattern})')
routes.append(post(web_message, handle)) routes.append(post(web_message, handle))
print() print()
@ -152,6 +156,9 @@ async def get_app(loop):
extra_handler = type(api_route['path'], (extra_route_handler,), api_route) extra_handler = type(api_route['path'], (extra_route_handler,), api_route)
routes.append(get(api_route['path'], extra_handler)) routes.append(get(api_route['path'], extra_handler))
print(f' - {api_route["path"]} (http_get)') print(f' - {api_route["path"]} (http_get)')
# last version is default version
routes.append(get('/api', extra_handler))
print(f' - /api (http_get)')
print() print()
if extra_routes: if extra_routes:
print(_('======== Registered extra routes ========')) print(_('======== Registered extra routes ========'))
@ -162,11 +169,12 @@ async def get_app(loop):
extra_handler = type(path, (extra_route_handler,), extra) extra_handler = type(path, (extra_route_handler,), extra)
routes.append(get(path, extra_handler)) routes.append(get(path, extra_handler))
print(f' - {path} (http_get)') print(f' - {path} (http_get)')
print()
del extra_routes del extra_routes
app.router.add_routes(routes) app.router.add_routes(routes)
await dispatcher.register_remote()
print()
await dispatcher.on_join() await dispatcher.on_join()
return await loop.create_server(app.make_handler(), '*', get_config()['http_server']['port']) return await loop.create_server(app.make_handler(), '*', get_config()['http_server']['port'])
tiramisu = None TIRAMISU = None

View File

@ -248,7 +248,8 @@ def get_message_file_path(version,
def list_messages(uris, def list_messages(uris,
current_module_names, current_module_names,
current_version): current_version,
):
def get_module_paths(current_module_names): def get_module_paths(current_module_names):
if current_module_names is None: if current_module_names is None:
current_module_names = listdir(join(MESSAGE_ROOT_PATH, version)) current_module_names = listdir(join(MESSAGE_ROOT_PATH, version))
@ -412,7 +413,7 @@ def load_customtypes() -> None:
custom_type = CustomType(load(message_file, Loader=SafeLoader)) custom_type = CustomType(load(message_file, Loader=SafeLoader))
ret[version][custom_type.getname()] = custom_type ret[version][custom_type.getname()] = custom_type
except Exception as err: except Exception as err:
raise Exception(_(f'enable to load type {err}: {message}')) raise Exception(_(f'enable to load type "{message}": {err}'))
return ret return ret
@ -431,9 +432,9 @@ def _get_description(description,
def _get_option(name, def _get_option(name,
arg, arg,
file_path, uri,
select_option, select_option,
optiondescription): ):
"""generate option """generate option
""" """
props = [] props = []
@ -443,7 +444,7 @@ def _get_option(name,
props.append(Calculation(calc_value, props.append(Calculation(calc_value,
Params(ParamValue('disabled'), Params(ParamValue('disabled'),
kwargs={'condition': ParamOption(select_option, todict=True), kwargs={'condition': ParamOption(select_option, todict=True),
'expected': ParamValue(optiondescription), 'expected': ParamValue(uri),
'reverse_condition': ParamValue(True)}), 'reverse_condition': ParamValue(True)}),
calc_value_property_help)) calc_value_property_help))
@ -472,25 +473,25 @@ def _get_option(name,
elif type_ == 'Float': elif type_ == 'Float':
obj = FloatOption(**kwargs) obj = FloatOption(**kwargs)
else: else:
raise Exception('unsupported type {} in {}'.format(type_, file_path)) raise Exception('unsupported type {} in {}'.format(type_, uri))
obj.impl_set_information('ref', arg.ref) obj.impl_set_information('ref', arg.ref)
return obj return obj
def get_options(message_def, def get_options(message_def,
file_path, uri,
select_option, select_option,
optiondescription, load_shortarg,
load_shortarg): ):
"""build option with args/kwargs """build option with args/kwargs
""" """
options =[] options =[]
for name, arg in message_def.parameters.items(): for name, arg in message_def.parameters.items():
current_opt = _get_option(name, current_opt = _get_option(name,
arg, arg,
file_path, uri,
select_option, select_option,
optiondescription) )
options.append(current_opt) options.append(current_opt)
if hasattr(arg, 'shortarg') and arg.shortarg and load_shortarg: if hasattr(arg, 'shortarg') and arg.shortarg and load_shortarg:
options.append(SymLinkOption(arg.shortarg, current_opt)) options.append(SymLinkOption(arg.shortarg, current_opt))
@ -498,17 +499,18 @@ def get_options(message_def,
def _parse_responses(message_def, def _parse_responses(message_def,
file_path): uri,
):
"""build option with returns """build option with returns
""" """
if message_def.response.parameters is None: if message_def.response.parameters is None:
raise Exception('message "{}" did not returned any valid parameters.'.format(message_def.message)) raise Exception(f'message "{message_def.message}" did not returned any valid parameters')
options = [] options = []
names = [] names = []
for name, obj in message_def.response.parameters.items(): for name, obj in message_def.response.parameters.items():
if name in names: if name in names:
raise Exception('multi response with name {} in {}'.format(name, file_path)) raise Exception(f'multi response with name "{name}" in "{uri}"')
names.append(name) names.append(name)
kwargs = {'name': name, kwargs = {'name': name,
@ -531,15 +533,17 @@ def _parse_responses(message_def,
else: else:
kwargs['properties'] = ('mandatory',) kwargs['properties'] = ('mandatory',)
options.append(option(**kwargs)) options.append(option(**kwargs))
od = OptionDescription(message_def.message, od = OptionDescription(uri,
message_def.response.description, message_def.response.description,
options) options,
)
od.impl_set_information('multi', message_def.response.multi) od.impl_set_information('multi', message_def.response.multi)
return od return od
def _get_root_option(select_option, def _get_root_option(select_option,
optiondescriptions): optiondescriptions,
):
"""get root option """get root option
""" """
def _get_od(curr_ods): def _get_od(curr_ods):
@ -581,19 +585,21 @@ def _get_root_option(select_option,
def get_messages(current_module_names, def get_messages(current_module_names,
load_shortarg=False, load_shortarg=False,
current_version=None, current_version=None,
uris=None): uris=None,
):
"""generate description from yml files """generate description from yml files
""" """
optiondescriptions = {} optiondescriptions = {}
optiondescriptions_info = {} optiondescriptions_info = {}
messages = list(list_messages(uris, messages = list(list_messages(uris,
current_module_names, current_module_names,
current_version)) current_version,
))
messages.sort() messages.sort()
optiondescriptions_name = [message_name.split('.', 1)[1] for message_name in messages] # optiondescriptions_name = [message_name.split('.', 1)[1] for message_name in messages]
select_option = ChoiceOption('message', select_option = ChoiceOption('message',
'Nom du message.', 'Nom du message.',
tuple(optiondescriptions_name), tuple(messages),
properties=frozenset(['mandatory', 'positional'])) properties=frozenset(['mandatory', 'positional']))
for uri in messages: for uri in messages:
message_def = get_message(uri, message_def = get_message(uri,
@ -601,23 +607,26 @@ def get_messages(current_module_names,
) )
optiondescriptions_info[message_def.message] = {'pattern': message_def.pattern, optiondescriptions_info[message_def.message] = {'pattern': message_def.pattern,
'default_roles': message_def.default_roles, 'default_roles': message_def.default_roles,
'version': message_def.version} 'version': message_def.version,
}
if message_def.pattern == 'rpc': if message_def.pattern == 'rpc':
if not message_def.response: if not message_def.response:
raise Exception(f'rpc without response is not allowed {uri}') raise Exception(f'rpc without response is not allowed {uri}')
optiondescriptions_info[message_def.message]['response'] = _parse_responses(message_def, optiondescriptions_info[message_def.message]['response'] = _parse_responses(message_def,
uri) uri,
)
elif message_def.response: elif message_def.response:
raise Exception(f'response is not allowed for {uri}') raise Exception(f'response is not allowed for {uri}')
message_def.options = get_options(message_def, message_def.options = get_options(message_def,
uri, uri,
select_option, select_option,
message_def.message, load_shortarg,
load_shortarg) )
optiondescriptions[message_def.message] = (message_def.description, message_def.options) optiondescriptions[uri] = (message_def.description, message_def.options)
root = _get_root_option(select_option, root = _get_root_option(select_option,
optiondescriptions) optiondescriptions,
)
return optiondescriptions_info, root return optiondescriptions_info, root

View File

@ -3,9 +3,10 @@ try:
except: except:
from tiramisu import Config from tiramisu import Config
from inspect import signature from inspect import signature
from typing import Callable, Optional from typing import Callable, Optional, List
import asyncpg from asyncpg import create_pool
from json import dumps, loads from json import dumps, loads
from pkg_resources import iter_entry_points
import risotto import risotto
from .utils import _ from .utils import _
from .error import RegistrationError from .error import RegistrationError
@ -13,7 +14,7 @@ from .message import get_messages
from .context import Context from .context import Context
from .config import get_config from .config import get_config
from .logger import log from .logger import log
from pkg_resources import iter_entry_points
class Services(): class Services():
services = {} services = {}
@ -25,9 +26,12 @@ class Services():
self.services.setdefault(entry_point.name, []) self.services.setdefault(entry_point.name, [])
self.services_loaded = True self.services_loaded = True
def load_modules(self): def load_modules(self,
limit_services: Optional[List[str]]=None,
) -> None:
for entry_point in iter_entry_points(group='risotto_modules'): for entry_point in iter_entry_points(group='risotto_modules'):
service_name, module_name = entry_point.name.split('.') service_name, module_name = entry_point.name.split('.')
if limit_services is None or service_name in limit_services:
setattr(self, module_name, entry_point.load()) setattr(self, module_name, entry_point.load())
self.services[service_name].append(module_name) self.services[service_name].append(module_name)
self.modules_loaded = True self.modules_loaded = True
@ -37,10 +41,12 @@ class Services():
self.load_services() self.load_services()
return [(s, getattr(self, s)) for s in self.services] return [(s, getattr(self, s)) for s in self.services]
def get_modules(self): def get_modules(self,
limit_services: Optional[List[str]]=None,
) -> List[str]:
if not self.modules_loaded: if not self.modules_loaded:
self.load_modules() self.load_modules(limit_services=limit_services)
return [(m, getattr(self, m)) for s in self.services.values() for m in s] return [(module + '.' + submodule, getattr(self, submodule)) for module, submodules in self.services.items() for submodule in submodules]
def get_services_list(self): def get_services_list(self):
return self.services.keys() return self.services.keys()
@ -52,11 +58,13 @@ class Services():
dispatcher, dispatcher,
validate: bool=True, validate: bool=True,
test: bool=False, test: bool=False,
limit_services: Optional[List[str]]=None,
): ):
for module_name, module in self.get_modules(): for submodule_name, module in self.get_modules(limit_services=limit_services):
dispatcher.set_module(module_name, dispatcher.set_module(submodule_name,
module, module,
test) test,
)
if validate: if validate:
dispatcher.validate() dispatcher.validate()
@ -65,8 +73,10 @@ services = Services()
services.load_services() services.load_services()
setattr(risotto, 'services', services) setattr(risotto, 'services', services)
def register(uris: str, def register(uris: str,
notification: str=None): notification: str=None,
) -> None:
""" Decorator to register function to the dispatcher """ Decorator to register function to the dispatcher
""" """
if not isinstance(uris, list): if not isinstance(uris, list):
@ -106,29 +116,38 @@ class RegisterDispatcher:
return {param.name for param in list(signature(function).parameters.values())[first_argument_index:]} return {param.name for param in list(signature(function).parameters.values())[first_argument_index:]}
async def get_message_args(self, async def get_message_args(self,
message: str): message: str,
version: str,
):
# load config # load config
async with await Config(self.option) as config: async with await Config(self.option, display_name=lambda self, dyn_name, suffix: self.impl_getname()) as config:
uri = f'{version}.{message}'
await config.property.read_write() await config.property.read_write()
# set message to the message name # set message to the message name
await config.option('message').value.set(message) await config.option('message').value.set(uri)
# get message argument # get message argument
dico = await config.option(message).value.dict() dico = await config.option(uri).value.dict()
return set(dico.keys()) return set(dico.keys())
async def valid_rpc_params(self, async def valid_rpc_params(self,
version: str, version: str,
message: str, message: str,
function: Callable, function: Callable,
module_name: str): module_name: str,
):
""" parameters function must have strictly all arguments with the correct name """ parameters function must have strictly all arguments with the correct name
""" """
# get message arguments # get message arguments
message_args = await self.get_message_args(message) message_args = await self.get_message_args(message,
version,
)
# get function arguments # get function arguments
function_args = self.get_function_args(function) function_args = self.get_function_args(function)
# compare message arguments with function parameter # compare message arguments with function parameter
# it must not have more or less arguments # it must not have more or less arguments
for arg in function_args - message_args:
if arg.startswith('_'):
message_args.add(arg)
if message_args != function_args: if message_args != function_args:
# raise if arguments are not equal # raise if arguments are not equal
msg = [] msg = []
@ -146,11 +165,14 @@ class RegisterDispatcher:
version: str, version: str,
message: str, message: str,
function: Callable, function: Callable,
module_name: str): module_name: str,
):
""" parameters function validation for event messages """ parameters function validation for event messages
""" """
# get message arguments # get message arguments
message_args = await self.get_message_args(message) message_args = await self.get_message_args(message,
version,
)
# get function arguments # get function arguments
function_args = self.get_function_args(function) function_args = self.get_function_args(function)
# compare message arguments with function parameter # compare message arguments with function parameter
@ -166,7 +188,8 @@ class RegisterDispatcher:
version: str, version: str,
message: str, message: str,
notification: str, notification: str,
function: Callable): function: Callable,
):
""" register a function to an URI """ register a function to an URI
URI is a message URI is a message
""" """
@ -175,14 +198,16 @@ class RegisterDispatcher:
if message not in self.messages[version]: if message not in self.messages[version]:
raise RegistrationError(_(f'the message {message} not exists')) raise RegistrationError(_(f'the message {message} not exists'))
# xxx module can only be register with v1.xxxx..... message # xxx submodule can only be register with v1.yyy.xxx..... message
module_name = function.__module__.split('.')[-2] risotto_module_name, submodule_name = function.__module__.split('.')[-3:-1]
message_namespace = message.split('.', 1)[0] module_name = risotto_module_name.split('_')[-1]
message_risotto_module, message_namespace, message_name = message.split('.', 2) message_module, message_submodule, message_name = message.split('.', 2)
if message_risotto_module not in self.risotto_modules: if message_module not in self.risotto_modules:
raise RegistrationError(_(f'cannot registered the "{message}" is not "{self.risotto_modules}"')) raise RegistrationError(_(f'cannot registered the "{message}" is not "{self.risotto_modules}"'))
if self.messages[version][message]['pattern'] == 'rpc' and message_namespace != module_name: if self.messages[version][message]['pattern'] == 'rpc' and \
raise RegistrationError(_(f'cannot registered the "{message}" message in module "{module_name}"')) module_name != message_module and \
message_submodule != submodule_name:
raise RegistrationError(_(f'cannot registered the "{message}" message in submodule "{module_name}.{submodule_name}"'))
# True if first argument is the risotto_context # True if first argument is the risotto_context
function_args = self.get_function_args(function) function_args = self.get_function_args(function)
@ -198,10 +223,11 @@ class RegisterDispatcher:
register = self.register_event register = self.register_event
register(version, register(version,
message, message,
module_name, f'{module_name}.{submodule_name}',
function, function,
function_args, function_args,
notification) notification,
)
def register_rpc(self, def register_rpc(self,
version: str, version: str,
@ -209,7 +235,8 @@ class RegisterDispatcher:
module_name: str, module_name: str,
function: Callable, function: Callable,
function_args: list, function_args: list,
notification: Optional[str]): notification: Optional[str],
):
self.messages[version][message]['module'] = module_name self.messages[version][message]['module'] = module_name
self.messages[version][message]['function'] = function self.messages[version][message]['function'] = function
self.messages[version][message]['arguments'] = function_args self.messages[version][message]['arguments'] = function_args
@ -222,7 +249,8 @@ class RegisterDispatcher:
module_name: str, module_name: str,
function: Callable, function: Callable,
function_args: list, function_args: list,
notification: Optional[str]): notification: Optional[str],
):
if 'functions' not in self.messages[version][message]: if 'functions' not in self.messages[version][message]:
self.messages[version][message]['functions'] = [] self.messages[version][message]['functions'] = []
@ -233,13 +261,17 @@ class RegisterDispatcher:
dico['notification'] = notification dico['notification'] = notification
self.messages[version][message]['functions'].append(dico) self.messages[version][message]['functions'].append(dico)
def set_module(self, module_name, module, test): def set_module(self,
submodule_name,
module,
test,
):
""" register and instanciate a new module """ register and instanciate a new module
""" """
try: try:
self.injected_self[module_name] = module.Risotto(test) self.injected_self[submodule_name] = module.Risotto(test)
except AttributeError as err: except AttributeError as err:
raise RegistrationError(_(f'unable to register the module {module_name}, this module must have Risotto class')) raise RegistrationError(_(f'unable to register the module {submodule_name}, this module must have Risotto class'))
def validate(self): def validate(self):
""" check if all messages have a function """ check if all messages have a function
@ -255,7 +287,9 @@ class RegisterDispatcher:
if missing_messages: if missing_messages:
raise RegistrationError(_(f'no matching function for uri {missing_messages}')) raise RegistrationError(_(f'no matching function for uri {missing_messages}'))
async def on_join(self): async def on_join(self,
truncate: bool=False,
) -> None:
internal_user = get_config()['global']['internal_user'] internal_user = get_config()['global']['internal_user']
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
await connection.set_type_codec( await connection.set_type_codec(
@ -264,14 +298,18 @@ class RegisterDispatcher:
decoder=loads, decoder=loads,
schema='pg_catalog' schema='pg_catalog'
) )
if truncate:
async with connection.transaction(): async with connection.transaction():
for module_name, module in self.injected_self.items(): await connection.execute('TRUNCATE InfraServer, InfraSite, InfraZone, Log, ProviderDeployment, ProviderFactoryCluster, ProviderFactoryClusterNode, SettingApplicationservice, SettingApplicationServiceDependency, SettingRelease, SettingServer, SettingServermodel, SettingSource, UserRole, UserRoleURI, UserURI, UserUser, InfraServermodel, ProviderZone, ProviderServer, ProviderSource, ProviderApplicationservice ProviderServermodel')
async with connection.transaction():
for submodule_name, module in self.injected_self.items():
risotto_context = Context() risotto_context = Context()
risotto_context.username = internal_user risotto_context.username = internal_user
risotto_context.paths.append(f'{module_name}.on_join') risotto_context.paths.append(f'internal.{submodule_name}.on_join')
risotto_context.type = None risotto_context.type = None
risotto_context.connection = connection risotto_context.connection = connection
info_msg = _(f'in module {module_name}.on_join') risotto_context.module = submodule_name.split('.', 1)[0]
info_msg = _(f'in module risotto_{submodule_name}.on_join')
await log.info_msg(risotto_context, await log.info_msg(risotto_context,
None, None,
info_msg) info_msg)
@ -280,12 +318,14 @@ class RegisterDispatcher:
async def load(self): async def load(self):
# valid function's arguments # valid function's arguments
db_conf = get_config()['database']['dsn'] db_conf = get_config()['database']['dsn']
self.pool = await asyncpg.create_pool(db_conf) self.pool = await create_pool(db_conf)
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
async with connection.transaction(): async with connection.transaction():
for version, messages in self.messages.items(): for version, messages in self.messages.items():
for message, message_infos in messages.items(): for message, message_infos in messages.items():
if message_infos['pattern'] == 'rpc': if message_infos['pattern'] == 'rpc':
# module not available during test
if 'module' in message_infos:
module_name = message_infos['module'] module_name = message_infos['module']
function = message_infos['function'] function = message_infos['function']
await self.valid_rpc_params(version, await self.valid_rpc_params(version,

View File

@ -1,61 +1,42 @@
from aiohttp import ClientSession from asyncio import get_event_loop, ensure_future
from requests import get, post from json import loads
from json import dumps
#from tiramisu_api import Config
from .context import Context
from .config import get_config from .config import get_config
from .utils import _ from .utils import _
#
#
# ALLOW_INSECURE_HTTPS = get_config()['module']['allow_insecure_https']
class Remote: class Remote:
submodules = {} async def register_remote(self) -> None:
print()
print(_('======== Registered remote event ========'))
self.listened_connection = await self.pool.acquire()
for version, messages in self.messages.items():
for message, message_infos in messages.items():
# event not emit locally
if message_infos['pattern'] == 'event':
module, submodule, submessage = message.split('.', 2)
if f'{module}.{submodule}' not in self.injected_self:
uri = f'{version}.{message}'
print(f' - {uri}')
await self.listened_connection.add_listener(uri, self.to_async_publish)
async def _get_config(self, def to_async_publish(self,
module: str, con: 'asyncpg.connection.Connection',
url: str) -> None: pid: int,
if module not in self.submodules: uri: str,
session = ClientSession() payload: str,
async with session.get(url) as resp: ) -> None:
if resp.status != 200: version, message = uri.split('.', 1)
try: loop = get_event_loop()
json = await resp.json() remote_kw = loads(payload)
err = json['error']['kwargs']['reason'] context = Context()
except: for key, value in remote_kw['context'].items():
err = await resp.text() setattr(context, key, value)
raise Exception(err) callback = lambda: ensure_future(self.publish(version,
json = await resp.json() message,
self.submodules[module] = json context,
return Config(self.submodules[module]) **remote_kw['kwargs'],
))
async def remote_call(self, loop.call_soon(callback)
module: str,
version: str,
submessage: str,
payload) -> dict:
try:
domain_name = get_config()['module'][module]
except KeyError:
raise ValueError(_(f'cannot find information of remote module "{module}" to access to "{version}.{module}.{submessage}"'))
remote_url = f'http://{domain_name}:8080/api/{version}'
message_url = f'{remote_url}/{submessage}'
config = await self._get_config(module,
remote_url)
for key, value in payload.items():
path = submessage + '.' + key
config.option(path).value.set(value)
session = ClientSession()
async with session.post(message_url, data=dumps(payload)) as resp:
response = await resp.json()
if 'error' in response:
if 'reason' in response['error']['kwargs']:
raise Exception("{}".format(response['error']['kwargs']['reason']))
raise Exception('erreur inconnue')
return response['response']
remote = Remote()

0
tests/__init__.py Normal file
View File

View File

@ -1,5 +1,15 @@
try:
from tiramisu3 import Storage
except:
from tiramisu import Storage from tiramisu import Storage
from risotto.config import DATABASE_DIR from os.path import isfile as _isfile
import os as _os
_envfile = '/etc/risotto/risotto.conf'
if _isfile(_envfile):
with open(_envfile, 'r') as fh_env:
for line in fh_env.readlines():
key, value = line.strip().split('=')
_os.environ[key] = value
STORAGE = Storage(engine='sqlite3', dir_database=DATABASE_DIR, name='test') STORAGE = Storage(engine='sqlite3')

View File

@ -1,20 +1,29 @@
from importlib import import_module from importlib import import_module
import pytest import pytest
from tiramisu import list_sessions, delete_session try:
from tiramisu3 import list_sessions, delete_session as _delete_session
except:
from tiramisu import list_sessions, delete_session as _delete_session
from .storage import STORAGE from .storage import STORAGE
from risotto import services
from risotto.context import Context from risotto.context import Context
from risotto.services import load_services #from risotto.services import load_services
from risotto.dispatcher import dispatcher from risotto.dispatcher import dispatcher
SOURCE_NAME = 'test'
SERVERMODEL_NAME = 'sm1'
def setup_module(module): def setup_module(module):
load_services(['config'], # load_services(['config'],
validate=False) # validate=False)
services.link_to_dispatcher(dispatcher, limit_services=['setting'], validate=False)
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
config_module.save_storage = STORAGE config_module.save_storage = STORAGE
dispatcher.set_module('server', import_module(f'.server', 'fake_services'), True) #dispatcher.set_module('server', import_module(f'.server', 'fake_services'), True)
dispatcher.set_module('servermodel', import_module(f'.servermodel', 'fake_services'), True) #dispatcher.set_module('servermodel', import_module(f'.servermodel', 'fake_services'), True)
def setup_function(function): def setup_function(function):
@ -23,11 +32,11 @@ def setup_function(function):
config_module.servermodel = {} config_module.servermodel = {}
def teardown_function(function): async def delete_session():
# delete all sessions # delete all sessions
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
for session in list_sessions(storage=config_module.save_storage): for session in await list_sessions(storage=config_module.save_storage):
delete_session(storage=config_module.save_storage, session_id=session) await _delete_session(storage=config_module.save_storage, session_id=session)
def get_fake_context(module_name): def get_fake_context(module_name):
@ -38,127 +47,166 @@ def get_fake_context(module_name):
return risotto_context return risotto_context
@pytest.mark.asyncio async def onjoin(source=True):
async def test_on_join():
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
assert config_module.servermodel == {} assert config_module.servermodel == {}
assert config_module.server == {} assert config_module.server == {}
await delete_session()
# #
#config_module.cache_root_path = 'tests/data'
await dispatcher.load()
await dispatcher.on_join(truncate=True)
if source:
fake_context = get_fake_context('config') fake_context = get_fake_context('config')
config_module.cache_root_path = 'tests/data' await dispatcher.call('v1',
await config_module.on_join(fake_context) 'setting.source.create',
assert list(config_module.servermodel.keys()) == [1, 2] fake_context,
assert list(config_module.server) == [3] source_name=SOURCE_NAME,
assert set(config_module.server[3]) == {'server', 'server_to_deploy', 'funcs_file'} source_directory='tests/data',
assert config_module.server[3]['funcs_file'] == 'tests/data/1/funcs.py' )
INTERNAL_SOURCE = {'source_name': 'internal', 'source_directory': '/srv/risotto/seed/internal'}
TEST_SOURCE = {'source_name': 'test', 'source_directory': 'tests/data'}
##############################################################################################################################
# Source / Release
##############################################################################################################################
@pytest.mark.asyncio
async def test_source_on_join():
# onjoin must create internal source
sources = [INTERNAL_SOURCE]
await onjoin(False)
fake_context = get_fake_context('config')
assert await dispatcher.call('v1',
'setting.source.list',
fake_context,
) == sources
await delete_session()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_server_created(): async def test_source_create():
sources = [INTERNAL_SOURCE, TEST_SOURCE]
await onjoin()
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
assert list(config_module.servermodel.keys()) == ['last_base']
assert list(config_module.server) == []
fake_context = get_fake_context('config') fake_context = get_fake_context('config')
config_module.cache_root_path = 'tests/data' assert await dispatcher.call('v1',
await config_module.on_join(fake_context) 'setting.source.list',
#
assert list(config_module.server) == [3]
await dispatcher.publish('v1',
'server.created',
fake_context, fake_context,
server_id=4, ) == sources
server_name='name3', await delete_session()
server_description='description3',
server_servermodel_id=2)
assert list(config_module.server) == [3, 4]
assert set(config_module.server[4]) == {'server', 'server_to_deploy', 'funcs_file'}
assert config_module.server[4]['funcs_file'] == 'tests/data/2/funcs.py'
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_server_deleted(): async def test_source_describe():
config_module = dispatcher.get_service('config') await onjoin()
fake_context = get_fake_context('config') fake_context = get_fake_context('config')
config_module.cache_root_path = 'tests/data' assert await dispatcher.call('v1',
await config_module.on_join(fake_context) 'setting.source.describe',
#
assert list(config_module.server) == [3]
await dispatcher.publish('v1',
'server.created',
fake_context, fake_context,
server_id=4, source_name='internal',
server_name='name4', ) == INTERNAL_SOURCE
server_description='description4', assert await dispatcher.call('v1',
server_servermodel_id=2) 'setting.source.describe',
assert list(config_module.server) == [3, 4]
await dispatcher.publish('v1',
'server.deleted',
fake_context, fake_context,
server_id=4) source_name=SOURCE_NAME,
assert list(config_module.server) == [3] ) == TEST_SOURCE
await delete_session()
@pytest.mark.asyncio
async def test_release_internal_list():
releases = [{'release_distribution': 'last',
'release_name': 'none',
'source_name': 'internal'}]
await onjoin()
fake_context = get_fake_context('config')
assert await dispatcher.call('v1',
'setting.source.release.list',
fake_context,
source_name='internal',
) == releases
await delete_session()
@pytest.mark.asyncio
async def test_release_list():
releases = [{'release_distribution': 'last',
'release_name': '1',
'source_name': 'test'}]
await onjoin()
fake_context = get_fake_context('config')
assert await dispatcher.call('v1',
'setting.source.release.list',
fake_context,
source_name='test',
) == releases
await delete_session()
@pytest.mark.asyncio
async def test_release_describe():
await onjoin()
fake_context = get_fake_context('config')
assert await dispatcher.call('v1',
'setting.source.release.describe',
fake_context,
source_name='internal',
release_distribution='last',
) == {'release_distribution': 'last',
'release_name': 'none',
'source_name': 'internal'}
assert await dispatcher.call('v1',
'setting.source.release.describe',
fake_context,
source_name='test',
release_distribution='last',
) == {'release_distribution': 'last',
'release_name': '1',
'source_name': 'test'}
await delete_session()
##############################################################################################################################
# Servermodel
##############################################################################################################################
async def create_servermodel(name=SERVERMODEL_NAME,
parents_name=['base'],
):
fake_context = get_fake_context('config')
await dispatcher.call('v1',
'setting.servermodel.create',
fake_context,
servermodel_name=name,
servermodel_description='servermodel 1',
parents_name=parents_name,
source_name=SOURCE_NAME,
release_distribution='last',
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_servermodel_created(): async def test_servermodel_created():
await onjoin()
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
fake_context = get_fake_context('config')
config_module.cache_root_path = 'tests/data'
await config_module.on_join(fake_context)
# #
assert list(config_module.servermodel) == [1, 2] assert list(config_module.servermodel) == ['last_base']
servermodel = {'servermodeid': 3, await create_servermodel()
'servermodelname': 'name3'} assert list(config_module.servermodel) == ['last_base', 'last_sm1']
await dispatcher.publish('v1', assert not list(await config_module.servermodel['last_base'].config.parents())
'servermodel.created', assert len(list(await config_module.servermodel['last_sm1'].config.parents())) == 1
fake_context, await delete_session()
servermodel_id=3,
servermodel_description='name3',
release_id=1,
servermodel_name='name3')
assert list(config_module.servermodel) == [1, 2, 3]
assert not list(await config_module.servermodel[3].config.parents())
@pytest.mark.asyncio
async def test_servermodel_herited_created():
config_module = dispatcher.get_service('config')
fake_context = get_fake_context('config')
config_module.cache_root_path = 'tests/data'
await config_module.on_join(fake_context)
# #
assert list(config_module.servermodel) == [1, 2]
await dispatcher.publish('v1',
'servermodel.created',
fake_context,
servermodel_id=3,
servermodel_name='name3',
release_id=1,
servermodel_description='name3',
servermodel_parents_id=[1])
assert list(config_module.servermodel) == [1, 2, 3]
assert len(list(await config_module.servermodel[3].config.parents())) == 1
@pytest.mark.asyncio
async def test_servermodel_multi_herited_created():
config_module = dispatcher.get_service('config')
fake_context = get_fake_context('config')
config_module.cache_root_path = 'tests/data'
await config_module.on_join(fake_context)
# #
assert list(config_module.servermodel) == [1, 2]
await dispatcher.publish('v1',
'servermodel.created',
fake_context,
servermodel_id=3,
servermodel_name='name3',
release_id=1,
servermodel_description='name3',
servermodel_parents_id=[1, 2])
assert list(config_module.servermodel) == [1, 2, 3]
assert len(list(await config_module.servermodel[3].config.parents())) == 2
#@pytest.mark.asyncio #@pytest.mark.asyncio
#async def test_servermodel_updated_not_exists(): #async def test_servermodel_herited_created():
# config_module = dispatcher.get_service('config') # config_module = dispatcher.get_service('config')
# fake_context = get_fake_context('config') # fake_context = get_fake_context('config')
# config_module.cache_root_path = 'tests/data' # config_module.cache_root_path = 'tests/data'
@ -166,68 +214,6 @@ async def test_servermodel_multi_herited_created():
# # # #
# assert list(config_module.servermodel) == [1, 2] # assert list(config_module.servermodel) == [1, 2]
# await dispatcher.publish('v1', # await dispatcher.publish('v1',
# 'servermodel.updated',
# fake_context,
# servermodel_id=3,
# servermodel_name='name3',
# release_id=1,
# servermodel_description='name3',
# servermodel_parents_id=[1, 2])
# assert list(config_module.servermodel) == [1, 2, 3]
# assert len(list(await config_module.servermodel[3].config.parents())) == 2
#
#
# @pytest.mark.asyncio
# async def test_servermodel_updated1():
# config_module = dispatcher.get_service('config')
# fake_context = get_fake_context('config')
# config_module.cache_root_path = 'tests/data'
# await config_module.on_join(fake_context)
# #
# assert list(config_module.servermodel) == [1, 2]
# metaconfig1 = config_module.servermodel[1]
# metaconfig2 = config_module.servermodel[2]
# mixconfig1 = (await metaconfig1.config.list())[0]
# mixconfig2 = (await metaconfig2.config.list())[0]
# assert len(list(await metaconfig1.config.parents())) == 0
# assert len(list(await metaconfig2.config.parents())) == 1
# assert len(list(await mixconfig1.config.list())) == 1
# assert len(list(await mixconfig2.config.list())) == 0
# #
# await dispatcher.publish('v1',
# 'servermodel.updated',
# fake_context,
# servermodel_id=1,
# servermodel_name='name1-1',
# release_id=1,
# servermodel_description='name1-1')
# assert set(config_module.servermodel) == {1, 2}
# assert config_module.servermodel[1].information.get('servermodel_name') == 'name1-1'
# assert metaconfig1 != config_module.servermodel[1]
# assert metaconfig2 == config_module.servermodel[2]
# metaconfig1 = config_module.servermodel[1]
# assert mixconfig1 != next(metaconfig1.config.list())
# mixconfig1 = next(metaconfig1.config.list())
# #
# assert len(list(await metaconfig1.config.parents())) == 0
# assert len(list(await metaconfig2.config.parents())) == 1
# assert len(list(await mixconfig1.config.list())) == 1
# assert len(list(await mixconfig2.config.list())) == 0
#
#
# @pytest.mark.asyncio
# async def test_servermodel_updated2():
# config_module = dispatcher.get_service('config')
# fake_context = get_fake_context('config')
# config_module.cache_root_path = 'tests/data'
# await config_module.on_join(fake_context)
# # create a new servermodel
# assert list(config_module.servermodel) == [1, 2]
# mixconfig1 = next(config_module.servermodel[1].config.list())
# mixconfig2 = next(config_module.servermodel[2].config.list())
# assert len(list(mixconfig1.config.list())) == 1
# assert len(list(mixconfig2.config.list())) == 0
# await dispatcher.publish('v1',
# 'servermodel.created', # 'servermodel.created',
# fake_context, # fake_context,
# servermodel_id=3, # servermodel_id=3,
@ -237,102 +223,336 @@ async def test_servermodel_multi_herited_created():
# servermodel_parents_id=[1]) # servermodel_parents_id=[1])
# assert list(config_module.servermodel) == [1, 2, 3] # assert list(config_module.servermodel) == [1, 2, 3]
# assert len(list(await config_module.servermodel[3].config.parents())) == 1 # assert len(list(await config_module.servermodel[3].config.parents())) == 1
# assert await config_module.servermodel[3].information.get('servermodel_name') == 'name3' # await delete_session()
# assert len(list(await mixconfig1.config.list())) == 2
# assert len(list(await mixconfig2.config.list())) == 0
# #
# await dispatcher.publish('v1',
# 'servermodel.updated',
# fake_context,
# servermodel_id=3,
# servermodel_name='name3-1',
# release_id=1,
# servermodel_description='name3-1',
# servermodel_parents_id=[1, 2])
# assert list(config_module.servermodel) == [1, 2, 3]
# assert config_module.servermodel[3].information.get('servermodel_name') == 'name3-1'
# assert len(list(mixconfig1.config.list())) == 2
# assert len(list(mixconfig2.config.list())) == 1
# #
# #
#@pytest.mark.asyncio #@pytest.mark.asyncio
# async def test_servermodel_updated_config(): #async def test_servermodel_multi_herited_created():
# config_module = dispatcher.get_service('config') # config_module = dispatcher.get_service('config')
# fake_context = get_fake_context('config') # fake_context = get_fake_context('config')
# config_module.cache_root_path = 'tests/data' # config_module.cache_root_path = 'tests/data'
# await config_module.on_join(fake_context) # await config_module.on_join(fake_context)
# # # #
# config_module.servermodel[1].property.read_write() # assert list(config_module.servermodel) == [1, 2]
# assert config_module.servermodel[1].option('creole.general.mode_conteneur_actif').value.get() == 'non'
# config_module.servermodel[1].option('creole.general.mode_conteneur_actif').value.set('oui')
# assert config_module.servermodel[1].option('creole.general.mode_conteneur_actif').value.get() == 'oui'
# #
# await dispatcher.publish('v1', # await dispatcher.publish('v1',
# 'servermodel.updated', # 'servermodel.created',
# fake_context, # fake_context,
# servermodel_id=1, # servermodel_id=3,
# servermodel_name='name1-1', # servermodel_name='name3',
# release_id=1, # release_id=1,
# servermodel_description='name1-1') # servermodel_description='name3',
# assert config_module.servermodel[1].option('creole.general.mode_conteneur_actif').value.get() == 'oui' # servermodel_parents_id=[1, 2])
# assert list(config_module.servermodel) == [1, 2, 3]
# assert len(list(await config_module.servermodel[3].config.parents())) == 2
# await delete_session()
#
#
##@pytest.mark.asyncio
##async def test_servermodel_updated_not_exists():
## config_module = dispatcher.get_service('config')
## fake_context = get_fake_context('config')
## config_module.cache_root_path = 'tests/data'
## await config_module.on_join(fake_context)
## #
## assert list(config_module.servermodel) == [1, 2]
## await dispatcher.publish('v1',
## 'servermodel.updated',
## fake_context,
## servermodel_id=3,
## servermodel_name='name3',
## release_id=1,
## servermodel_description='name3',
## servermodel_parents_id=[1, 2])
## assert list(config_module.servermodel) == [1, 2, 3]
## assert len(list(await config_module.servermodel[3].config.parents())) == 2
## await delete_session()
##
##
## @pytest.mark.asyncio
## async def test_servermodel_updated1():
## config_module = dispatcher.get_service('config')
## fake_context = get_fake_context('config')
## config_module.cache_root_path = 'tests/data'
## await config_module.on_join(fake_context)
## #
## assert list(config_module.servermodel) == [1, 2]
## metaconfig1 = config_module.servermodel[1]
## metaconfig2 = config_module.servermodel[2]
## mixconfig1 = (await metaconfig1.config.list())[0]
## mixconfig2 = (await metaconfig2.config.list())[0]
## assert len(list(await metaconfig1.config.parents())) == 0
## assert len(list(await metaconfig2.config.parents())) == 1
## assert len(list(await mixconfig1.config.list())) == 1
## assert len(list(await mixconfig2.config.list())) == 0
## #
## await dispatcher.publish('v1',
## 'servermodel.updated',
## fake_context,
## servermodel_id=1,
## servermodel_name='name1-1',
## release_id=1,
## servermodel_description='name1-1')
## assert set(config_module.servermodel) == {1, 2}
## assert config_module.servermodel[1].information.get('servermodel_name') == 'name1-1'
## assert metaconfig1 != config_module.servermodel[1]
## assert metaconfig2 == config_module.servermodel[2]
## metaconfig1 = config_module.servermodel[1]
## assert mixconfig1 != next(metaconfig1.config.list())
## mixconfig1 = next(metaconfig1.config.list())
## #
## assert len(list(await metaconfig1.config.parents())) == 0
## assert len(list(await metaconfig2.config.parents())) == 1
## assert len(list(await mixconfig1.config.list())) == 1
## assert len(list(await mixconfig2.config.list())) == 0
## await delete_session()
##
##
## @pytest.mark.asyncio
## async def test_servermodel_updated2():
## config_module = dispatcher.get_service('config')
## fake_context = get_fake_context('config')
## config_module.cache_root_path = 'tests/data'
## await config_module.on_join(fake_context)
## # create a new servermodel
## assert list(config_module.servermodel) == [1, 2]
## mixconfig1 = next(config_module.servermodel[1].config.list())
## mixconfig2 = next(config_module.servermodel[2].config.list())
## assert len(list(mixconfig1.config.list())) == 1
## assert len(list(mixconfig2.config.list())) == 0
## await dispatcher.publish('v1',
## 'servermodel.created',
## fake_context,
## servermodel_id=3,
## servermodel_name='name3',
## release_id=1,
## servermodel_description='name3',
## servermodel_parents_id=[1])
## assert list(config_module.servermodel) == [1, 2, 3]
## assert len(list(await config_module.servermodel[3].config.parents())) == 1
## assert await config_module.servermodel[3].information.get('servermodel_name') == 'name3'
## assert len(list(await mixconfig1.config.list())) == 2
## assert len(list(await mixconfig2.config.list())) == 0
## #
## await dispatcher.publish('v1',
## 'servermodel.updated',
## fake_context,
## servermodel_id=3,
## servermodel_name='name3-1',
## release_id=1,
## servermodel_description='name3-1',
## servermodel_parents_id=[1, 2])
## assert list(config_module.servermodel) == [1, 2, 3]
## assert config_module.servermodel[3].information.get('servermodel_name') == 'name3-1'
## assert len(list(mixconfig1.config.list())) == 2
## assert len(list(mixconfig2.config.list())) == 1
## await delete_session()
##
##
## @pytest.mark.asyncio
## async def test_servermodel_updated_config():
## config_module = dispatcher.get_service('config')
## fake_context = get_fake_context('config')
## config_module.cache_root_path = 'tests/data'
## await config_module.on_join(fake_context)
## #
## config_module.servermodel[1].property.read_write()
## assert config_module.servermodel[1].option('configuration.general.mode_conteneur_actif').value.get() == 'non'
## config_module.servermodel[1].option('configuration.general.mode_conteneur_actif').value.set('oui')
## assert config_module.servermodel[1].option('configuration.general.mode_conteneur_actif').value.get() == 'oui'
## #
## await dispatcher.publish('v1',
## 'servermodel.updated',
## fake_context,
## servermodel_id=1,
## servermodel_name='name1-1',
## release_id=1,
## servermodel_description='name1-1')
## assert config_module.servermodel[1].option('configuration.general.mode_conteneur_actif').value.get() == 'oui'
## await delete_session()
##############################################################################################################################
# Server
##############################################################################################################################
@pytest.mark.asyncio
async def test_server_created_base():
await onjoin()
config_module = dispatcher.get_service('config')
fake_context = get_fake_context('config')
#
assert list(config_module.server) == []
await dispatcher.on_join(truncate=True)
server_name = 'dns.test.lan'
await dispatcher.publish('v1',
'infra.server.created',
fake_context,
server_name=server_name,
server_description='description_created',
servermodel_name='base',
release_distribution='last',
site_name='site_1',
zones_name=['zones'],
zones_ip=['1.1.1.1'],
)
assert list(config_module.server) == [server_name]
assert set(config_module.server[server_name]) == {'server', 'server_to_deploy', 'funcs_file'}
assert config_module.server[server_name]['funcs_file'] == '/var/cache/risotto/servermodel/last/base/funcs.py'
await delete_session()
@pytest.mark.asyncio
async def test_server_created_own_sm():
await onjoin()
config_module = dispatcher.get_service('config')
fake_context = get_fake_context('config')
await create_servermodel()
#
assert list(config_module.server) == []
await dispatcher.on_join(truncate=True)
server_name = 'dns.test.lan'
await dispatcher.publish('v1',
'infra.server.created',
fake_context,
server_name=server_name,
server_description='description_created',
servermodel_name=SERVERMODEL_NAME,
source_name=SOURCE_NAME,
release_distribution='last',
site_name='site_1',
zones_name=['zones'],
zones_ip=['1.1.1.1'],
)
assert list(config_module.server) == [server_name]
assert set(config_module.server[server_name]) == {'server', 'server_to_deploy', 'funcs_file'}
assert config_module.server[server_name]['funcs_file'] == '/var/cache/risotto/servermodel/last/sm1/funcs.py'
await delete_session()
#@pytest.mark.asyncio
#async def test_server_deleted():
# config_module = dispatcher.get_service('config')
# config_module.cache_root_path = 'tests/data'
# await config_module.on_join(fake_context)
# #
# assert list(config_module.server) == [3]
# await dispatcher.publish('v1',
# 'server.created',
# fake_context,
# server_id=4,
# server_name='name4',
# server_description='description4',
# server_servermodel_id=2)
# assert list(config_module.server) == [3, 4]
# await dispatcher.publish('v1',
# 'server.deleted',
# fake_context,
# server_id=4)
# assert list(config_module.server) == [3]
# await delete_session()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_server_configuration_get(): async def test_server_configuration_get():
await onjoin()
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
fake_context = get_fake_context('config') fake_context = get_fake_context('config')
config_module.cache_root_path = 'tests/data' await create_servermodel()
await config_module.on_join(fake_context) await dispatcher.on_join(truncate=True)
# server_name = 'dns.test.lan'
await config_module.server[3]['server_to_deploy'].property.read_write() await dispatcher.publish('v1',
assert await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.get() == 'non' 'infra.server.created',
await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.set('oui')
assert await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.get() == 'oui'
assert await config_module.server[3]['server'].option('creole.general.mode_conteneur_actif').value.get() == 'non'
#
values = await dispatcher.call('v1',
'config.configuration.server.get',
fake_context, fake_context,
server_id=3) server_name=server_name,
configuration = {'configuration': server_description='description_created',
{'creole.general.mode_conteneur_actif': 'non', servermodel_name=SERVERMODEL_NAME,
'creole.general.master.master': [], source_name=SOURCE_NAME,
'creole.general.master.slave1': [], release_distribution='last',
'creole.general.master.slave2': [], site_name='site_1',
'containers.container0.files.file0.mkdir': False, zones_name=['zones'],
'containers.container0.files.file0.name': '/etc/mailname', zones_ip=['1.1.1.1'],
'containers.container0.files.file0.rm': False, )
'containers.container0.files.file0.source': 'mailname', #
'containers.container0.files.file0.activate': True}, await config_module.server[server_name]['server'].property.read_write()
'server_id': 3, assert await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.get() == 1
'deployed': True} await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.set(2)
assert await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.get() == 2
assert await config_module.server[server_name]['server_to_deploy'].option('configuration.general.number_of_interfaces').value.get() == 1
#
configuration = {'server_name': server_name,
'deployed': False,
'configuration': {'configuration.general.number_of_interfaces': 1,
'configuration.general.interfaces_list': [0],
'configuration.interface_0.domain_name_eth0': 'dns.test.lan'
}
}
values = await dispatcher.call('v1',
'setting.config.configuration.server.get',
fake_context,
server_name=server_name,
deployed=False,
)
assert values == configuration assert values == configuration
# #
values = await dispatcher.call('v1', await delete_session()
'config.configuration.server.get',
fake_context,
server_id=3,
deployed=False)
configuration['configuration']['creole.general.mode_conteneur_actif'] = 'oui'
configuration['deployed'] = False
assert values == configuration
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_config_deployed(): async def test_server_configuration_deployed():
await onjoin()
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
fake_context = get_fake_context('config') fake_context = get_fake_context('config')
config_module.cache_root_path = 'tests/data' await create_servermodel()
await config_module.on_join(fake_context) await dispatcher.on_join(truncate=True)
# server_name = 'dns.test.lan'
await config_module.server[3]['server_to_deploy'].property.read_write() await dispatcher.publish('v1',
assert await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.get() == 'non' 'infra.server.created',
await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.set('oui')
assert await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.get() == 'oui'
assert await config_module.server[3]['server'].option('creole.general.mode_conteneur_actif').value.get() == 'non'
values = await dispatcher.publish('v1',
'config.configuration.server.deploy',
fake_context, fake_context,
server_id=3) server_name=server_name,
assert await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.get() == 'oui' server_description='description_created',
assert await config_module.server[3]['server'].option('creole.general.mode_conteneur_actif').value.get() == 'oui' servermodel_name=SERVERMODEL_NAME,
source_name=SOURCE_NAME,
release_distribution='last',
site_name='site_1',
zones_name=['zones'],
zones_ip=['1.1.1.1'],
)
#
await config_module.server[server_name]['server'].property.read_write()
assert await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.get() == 1
await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.set(2)
assert await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.get() == 2
assert await config_module.server[server_name]['server_to_deploy'].option('configuration.general.number_of_interfaces').value.get() == 1
#
configuration = {'server_name': server_name,
'deployed': False,
'configuration': {'configuration.general.number_of_interfaces': 1,
'configuration.general.interfaces_list': [0],
'configuration.interface_0.domain_name_eth0': 'dns.test.lan'
}
}
try:
await dispatcher.call('v1',
'setting.config.configuration.server.get',
fake_context,
server_name=server_name,
)
except:
pass
else:
raise Exception('should raise propertyerror')
values = await dispatcher.call('v1',
'setting.config.configuration.server.deploy',
fake_context,
server_name=server_name,
)
assert values == {'server_name': 'dns.test.lan', 'deployed': True}
await dispatcher.call('v1',
'setting.config.configuration.server.get',
fake_context,
server_name=server_name,
)
#
await delete_session()

View File

@ -2,7 +2,7 @@ from importlib import import_module
import pytest import pytest
from .storage import STORAGE from .storage import STORAGE
from risotto.context import Context from risotto.context import Context
from risotto.services import load_services #from risotto.services import load_services
from risotto.dispatcher import dispatcher from risotto.dispatcher import dispatcher
from risotto.services.session.storage import storage_server, storage_servermodel from risotto.services.session.storage import storage_server, storage_servermodel
@ -16,9 +16,9 @@ def get_fake_context(module_name):
def setup_module(module): def setup_module(module):
load_services(['config', 'session'], #load_services(['config', 'session'],
validate=False, # validate=False,
test=True) # test=True)
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
config_module.save_storage = STORAGE config_module.save_storage = STORAGE
dispatcher.set_module('server', import_module(f'.server', 'fake_services'), True) dispatcher.set_module('server', import_module(f'.server', 'fake_services'), True)