Compare commits

..

48 Commits

Author SHA1 Message Date
8e0fe77274 test if a source is loaded 2021-09-13 14:58:11 +02:00
8dca850683 build image is now in risotto 2021-08-28 07:34:31 +02:00
1def6e4e4c add preprocessors function 2021-05-26 20:19:08 +02:00
2ea04e708d fix 2021-05-24 22:24:15 +02:00
4853bb47f0 risotto is now a lib 2021-05-24 20:41:04 +02:00
9d4644bedf fix 2021-05-24 16:53:50 +02:00
c0244eac8c add image file 2021-05-24 16:22:39 +02:00
b6c5dccf17 staticmethod function has no self 2021-05-23 21:39:09 +02:00
ef43b197a1 better lemur integration 2021-05-22 16:37:01 +02:00
94b6563d8f import/export informations 2021-05-18 18:55:33 +02:00
88c2c168ac add v1.user.log.query message 2021-05-12 18:36:59 +02:00
09cd0a4e4c variable to personalise password length 2021-05-11 18:58:43 +02:00
3085bf67d6 variable to personalise password length 2021-05-11 18:28:57 +02:00
1063d2e735 on connection to database to log only 2021-04-25 20:32:02 +02:00
ed51bc483d corrections in log 2021-04-24 17:11:06 +02:00
0442e772c2 support str value 2021-04-24 14:35:51 +02:00
27031dbf0e log_connexion 2021-04-24 14:15:54 +02:00
9ebe79d533 special connexion for log (do not rollback if error) 2021-04-24 12:56:44 +02:00
4c83e6d89d better log support 2021-04-24 10:12:39 +02:00
19240489db add http static support 2021-04-24 10:12:32 +02:00
30a267bf4a add TiramisuController 2021-04-24 10:12:13 +02:00
f88bcef5c0 do not stop daemon when on_join failed 2021-04-16 09:33:25 +02:00
5663b2768b if not Risotto module, do not failed 2021-04-13 10:31:14 +02:00
01834c6ba7 add check_role to dispatcher 2021-04-12 15:11:46 +02:00
8fdc34c4d3 fix 2021-03-27 10:59:10 +01:00
f623feb8a8 add systemd notifier 2020-11-14 19:01:28 +01:00
46f8a4323b add pki informations 2020-11-14 08:12:39 +01:00
6c4bbb3dca add password support 2020-10-14 18:30:05 +02:00
279e3a7c4c better debugging 2020-09-20 21:33:04 +02:00
13c7d5816c update config 2020-09-19 10:33:27 +02:00
a89e512266 update config 2020-09-19 09:18:28 +02:00
7afccab9b1 publish use now postgresql 2020-09-16 17:37:46 +02:00
c84b9435b0 better debug 2020-09-16 08:03:30 +02:00
e664dd6174 add remote support 2020-09-12 16:05:17 +02:00
3823eedd02 update sql file 2020-09-06 09:46:11 +02:00
a12e679b3c Changement de nom de table SQL 2020-09-05 16:41:30 +02:00
32b83f496b simplifier l'API 2020-09-02 09:17:09 +02:00
dc7d081608 better error message 2020-08-26 15:31:54 +02:00
ca101cf094 update tests 2020-08-26 10:56:34 +02:00
c309ebbd56 update tests 2020-08-19 17:15:04 +02:00
a64131fb03 change for tests 2020-08-19 11:36:18 +02:00
e787eb2ef5 explicite error when value in param is invalid 2020-08-19 11:20:46 +02:00
10637656fa explicite error when value in param is invalid 2020-08-13 09:13:00 +02:00
d3a5c99e51 typos 2020-08-12 15:09:31 +02:00
c3d25b4aff separate risotto and python3-risotto package 2020-08-12 10:45:15 +02:00
46ea792c5e import tiramisu3 first 2020-08-12 08:30:03 +02:00
5708cb1ea9 tmp => /tmp 2020-08-05 17:07:44 +02:00
664a2404fa simplify publish function 2020-04-23 07:39:22 +02:00
19 changed files with 2177 additions and 751 deletions

View File

@ -36,12 +36,19 @@ Chacun de ces services documente la structure de la table mais ne se charge pas
La création de la table, selon le schéma fournit dans la documentation, est à la charge de ladministrateur du système. La création de la table, selon le schéma fournit dans la documentation, est à la charge de ladministrateur du système.
# Empty database: # Empty database:
su - postgres
psql -U postgres risotto
drop table log; drop table userrole; drop table release; drop table source; drop table server; drop table servermodel; drop table applicationservice; drop table roleuri; drop table risottouser; drop table uri;
````
psql -U postgres
drop database risotto;
drop user risotto;
\q
reconfigure
```
```
psql -U postgres tiramisu psql -U postgres tiramisu
drop table value; drop table property; drop table permissive; drop table information; drop table session; drop table value; drop table property; drop table permissive; drop table information; drop table session;
```
# Import EOLE # Import EOLE
./script/cucchiaiata source.create -n eole -u http://localhost ./script/cucchiaiata source.create -n eole -u http://localhost

View File

@ -1,13 +1,16 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from sdnotify import SystemdNotifier
from asyncio import get_event_loop from asyncio import get_event_loop
from risotto import get_app from risotto import get_app
if __name__ == '__main__': if __name__ == '__main__':
notifier = SystemdNotifier()
loop = get_event_loop() loop = get_event_loop()
loop.run_until_complete(get_app(loop)) loop.run_until_complete(get_app(loop))
print('HTTP server ready')
notifier.notify("READY=1")
try: try:
print('HTTP server ready')
loop.run_forever() loop.run_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass

View File

@ -4,6 +4,5 @@ setup(
name='risotto', name='risotto',
version='0.1', version='0.1',
packages=['risotto' ], packages=['risotto' ],
scripts=['script/risotto-server'],
package_dir={"": "src"}, package_dir={"": "src"},
) )

16
sql/risotto.sql Normal file
View File

@ -0,0 +1,16 @@
CREATE TABLE RisottoLog(
LogId SERIAL PRIMARY KEY,
ContextId INTEGER,
Msg VARCHAR(255) NOT NULL,
URI VARCHAR(255),
URIS VARCHAR(255),
UserLogin VARCHAR(100) NOT NULL,
Status INTEGER NOT NULL,
Kwargs JSON,
Returns JSON,
StartDate timestamp DEFAULT current_timestamp,
StopDate timestamp
);
CREATE INDEX RisottoLog_ContextId_index ON RisottoLog(ContextId);
CREATE INDEX RisottoLog_Login_index ON RisottoLog(UserLogin);
CREATE INDEX RisottoLog_URI_index ON RisottoLog(URI);

View File

@ -1,48 +1,194 @@
from os import environ from os import environ
from os.path import isfile
from configobj import ConfigObj
from uuid import uuid4
CONFIGURATION_DIR = environ.get('CONFIGURATION_DIR', '/srv/risotto/configurations') CONFIG_FILE = environ.get('CONFIG_FILE', '/etc/risotto/risotto.conf')
PROVIDER_FACTORY_CONFIG_DIR = environ.get('PROVIDER_FACTORY_CONFIG_DIR', '/srv/factory')
TMP_DIR = 'tmp'
DEFAULT_USER = environ.get('DEFAULT_USER', 'Anonymous') if isfile(CONFIG_FILE):
RISOTTO_DB_NAME = environ.get('RISOTTO_DB_NAME', 'risotto') config = ConfigObj(CONFIG_FILE)
RISOTTO_DB_PASSWORD = environ.get('RISOTTO_DB_PASSWORD', 'risotto') else:
RISOTTO_DB_USER = environ.get('RISOTTO_DB_USER', 'risotto') config = {}
TIRAMISU_DB_NAME = environ.get('TIRAMISU_DB_NAME', 'tiramisu')
TIRAMISU_DB_PASSWORD = environ.get('TIRAMISU_DB_PASSWORD', 'tiramisu')
TIRAMISU_DB_USER = environ.get('TIRAMISU_DB_USER', 'tiramisu') if 'RISOTTO_PORT' in environ:
DB_ADDRESS = environ.get('DB_ADDRESS', 'localhost') RISOTTO_PORT = environ['RISOTTO_PORT']
MESSAGE_PATH = environ.get('MESSAGE_PATH', '/root/risotto-message/messages') else:
SQL_DIR = environ.get('SQL_DIR', './sql') RISOTTO_PORT = config.get('RISOTTO_PORT', 8080)
CACHE_ROOT_PATH = environ.get('CACHE_ROOT_PATH', '/var/cache/risotto') if 'RISOTTO_URL' in environ:
SRV_SEED_PATH = environ.get('SRV_SEED_PATH', '/srv/seed') RISOTTO_URL = environ['RISOTTO_URL']
else:
RISOTTO_URL = config.get('RISOTTO_URL', 'http://localhost:8080/')
if 'CONFIGURATION_DIR' in environ:
CONFIGURATION_DIR = environ['CONFIGURATION_DIR']
else:
CONFIGURATION_DIR = config.get('CONFIGURATION_DIR', '/srv/risotto/configurations')
if 'DEFAULT_USER' in environ:
DEFAULT_USER = environ['DEFAULT_USER']
else:
DEFAULT_USER = config.get('DEFAULT_USER', 'Anonymous')
if 'RISOTTO_DB_NAME' in environ:
RISOTTO_DB_NAME = environ['RISOTTO_DB_NAME']
else:
RISOTTO_DB_NAME = config.get('RISOTTO_DB_NAME', 'risotto')
if 'RISOTTO_DB_PASSWORD' in environ:
RISOTTO_DB_PASSWORD = environ['RISOTTO_DB_PASSWORD']
else:
RISOTTO_DB_PASSWORD = config.get('RISOTTO_DB_PASSWORD', 'risotto')
if 'RISOTTO_DB_USER' in environ:
RISOTTO_DB_USER = environ['RISOTTO_DB_USER']
else:
RISOTTO_DB_USER = config.get('RISOTTO_DB_USER', 'risotto')
if 'TIRAMISU_DB_NAME' in environ:
TIRAMISU_DB_NAME = environ['TIRAMISU_DB_NAME']
else:
TIRAMISU_DB_NAME = config.get('TIRAMISU_DB_NAME', 'tiramisu')
if 'TIRAMISU_DB_PASSWORD' in environ:
TIRAMISU_DB_PASSWORD = environ['TIRAMISU_DB_PASSWORD']
else:
TIRAMISU_DB_PASSWORD = config.get('TIRAMISU_DB_PASSWORD', 'tiramisu')
if 'TIRAMISU_DB_USER' in environ:
TIRAMISU_DB_USER = environ['TIRAMISU_DB_USER']
else:
TIRAMISU_DB_USER = config.get('TIRAMISU_DB_USER', 'tiramisu')
if 'CELERYRISOTTO_DB_NAME' in environ:
CELERYRISOTTO_DB_NAME = environ['CELERYRISOTTO_DB_NAME']
else:
CELERYRISOTTO_DB_NAME = config.get('CELERYRISOTTO_DB_NAME', None)
if 'CELERYRISOTTO_DB_PASSWORD' in environ:
CELERYRISOTTO_DB_PASSWORD = environ['CELERYRISOTTO_DB_PASSWORD']
else:
CELERYRISOTTO_DB_PASSWORD = config.get('CELERYRISOTTO_DB_PASSWORD', None)
if 'CELERYRISOTTO_DB_USER' in environ:
CELERYRISOTTO_DB_USER = environ['CELERYRISOTTO_DB_USER']
else:
CELERYRISOTTO_DB_USER = config.get('CELERYRISOTTO_DB_USER', None)
if 'LEMUR_DB_NAME' in environ:
LEMUR_DB_NAME = environ['LEMUR_DB_NAME']
else:
LEMUR_DB_NAME = config.get('LEMUR_DB_NAME', None)
if 'LEMUR_DB_PASSWORD' in environ:
LEMUR_DB_PASSWORD = environ['LEMUR_DB_PASSWORD']
else:
LEMUR_DB_PASSWORD = config.get('LEMUR_DB_PASSWORD', None)
if 'LEMUR_DB_USER' in environ:
LEMUR_DB_USER = environ['LEMUR_DB_USER']
else:
LEMUR_DB_USER = config.get('LEMUR_DB_USER', None)
if 'DB_ADDRESS' in environ:
DB_ADDRESS = environ['DB_ADDRESS']
else:
DB_ADDRESS = config.get('DB_ADDRESS', 'localhost')
if 'MESSAGE_PATH' in environ:
MESSAGE_PATH = environ['MESSAGE_PATH']
else:
MESSAGE_PATH = config.get('MESSAGE_PATH', '/root/risotto-message/messages')
if 'SQL_DIR' in environ:
SQL_DIR = environ['SQL_DIR']
else:
SQL_DIR = config.get('SQL_DIR', './sql')
if 'CACHE_ROOT_PATH' in environ:
CACHE_ROOT_PATH = environ['CACHE_ROOT_PATH']
else:
CACHE_ROOT_PATH = config.get('CACHE_ROOT_PATH', '/var/cache/risotto')
if 'SRV_SEED_PATH' in environ:
SRV_SEED_PATH = environ['SRV_SEED_PATH']
else:
SRV_SEED_PATH = config.get('SRV_SEED_PATH', '/srv/seed')
if 'TMP_DIR' in environ:
TMP_DIR = environ['TMP_DIR']
else:
TMP_DIR = config.get('TMP_DIR', '/tmp')
if 'IMAGE_PATH' in environ:
IMAGE_PATH = environ['IMAGE_PATH']
else:
IMAGE_PATH = config.get('IMAGE_PATH', '/tmp')
if 'PASSWORD_ADMIN_USERNAME' in environ:
PASSWORD_ADMIN_USERNAME = environ['PASSWORD_ADMIN_USERNAME']
else:
PASSWORD_ADMIN_USERNAME = config.get('PASSWORD_ADMIN_USERNAME', 'risotto')
if 'PASSWORD_ADMIN_EMAIL' in environ:
PASSWORD_ADMIN_EMAIL = environ['PASSWORD_ADMIN_EMAIL']
else:
# this parameter is mandatory
PASSWORD_ADMIN_EMAIL = config.get('PASSWORD_ADMIN_EMAIL', 'XXX')
if 'PASSWORD_ADMIN_PASSWORD' in environ:
PASSWORD_ADMIN_PASSWORD = environ['PASSWORD_ADMIN_PASSWORD']
else:
# this parameter is mandatory
PASSWORD_ADMIN_PASSWORD = config.get('PASSWORD_ADMIN_PASSWORD', 'XXX')
if 'PASSWORD_DEVICE_IDENTIFIER' in environ:
PASSWORD_DEVICE_IDENTIFIER = environ['PASSWORD_DEVICE_IDENTIFIER']
else:
PASSWORD_DEVICE_IDENTIFIER = config.get('PASSWORD_DEVICE_IDENTIFIER', uuid4())
if 'PASSWORD_URL' in environ:
PASSWORD_URL = environ['PASSWORD_URL']
else:
PASSWORD_URL = config.get('PASSWORD_URL', 'https://localhost:8001/')
if 'PASSWORD_LENGTH' in environ:
PASSWORD_LENGTH = int(environ['PASSWORD_LENGTH'])
else:
PASSWORD_LENGTH = int(config.get('PASSWORD_LENGTH', 20))
if 'PKI_ADMIN_PASSWORD' in environ:
PKI_ADMIN_PASSWORD = environ['PKI_ADMIN_PASSWORD']
else:
PKI_ADMIN_PASSWORD = config.get('PKI_ADMIN_PASSWORD', 'XXX')
if 'PKI_ADMIN_EMAIL' in environ:
PKI_ADMIN_EMAIL = environ['PKI_ADMIN_EMAIL']
else:
PKI_ADMIN_EMAIL = config.get('PKI_ADMIN_EMAIL', 'XXX')
if 'PKI_URL' in environ:
PKI_URL = environ['PKI_URL']
else:
PKI_URL = config.get('PKI_URL', 'http://localhost:8002')
def dsn_factory(database, user, password, address=DB_ADDRESS): def dsn_factory(database, user, password, address=DB_ADDRESS):
mangled_address = '/var/run/postgresql' if address == 'localhost' else address mangled_address = '/var/run/postgresql' if address == 'localhost' else address
return f'postgres:///{database}?host={mangled_address}/&user={user}&password={password}' return f'postgres:///{database}?host={mangled_address}/&user={user}&password={password}'
_config = {'database': {'dsn': dsn_factory(RISOTTO_DB_NAME, RISOTTO_DB_USER, RISOTTO_DB_PASSWORD),
'tiramisu_dsn': dsn_factory(TIRAMISU_DB_NAME, TIRAMISU_DB_USER, TIRAMISU_DB_PASSWORD),
'celery_dsn': dsn_factory(CELERYRISOTTO_DB_NAME, CELERYRISOTTO_DB_USER, CELERYRISOTTO_DB_PASSWORD),
'lemur_dns': dsn_factory(LEMUR_DB_NAME, LEMUR_DB_USER, LEMUR_DB_PASSWORD),
},
'http_server': {'port': RISOTTO_PORT,
'default_user': DEFAULT_USER,
'url': RISOTTO_URL},
'global': {'message_root_path': MESSAGE_PATH,
'configurations_dir': CONFIGURATION_DIR,
'debug': True,
'internal_user': '_internal',
'check_role': True,
'admin_user': DEFAULT_USER,
'sql_dir': SQL_DIR,
'tmp_dir': TMP_DIR,
},
'password': {'admin_username': PASSWORD_ADMIN_USERNAME,
'admin_email': PASSWORD_ADMIN_EMAIL,
'admin_password': PASSWORD_ADMIN_PASSWORD,
'device_identifier': PASSWORD_DEVICE_IDENTIFIER,
'service_url': PASSWORD_URL,
'length': PASSWORD_LENGTH,
},
'pki': {'admin_password': PKI_ADMIN_PASSWORD,
'owner': PKI_ADMIN_EMAIL,
'url': PKI_URL,
},
'cache': {'root_path': CACHE_ROOT_PATH},
'servermodel': {'internal_source_path': SRV_SEED_PATH,
'internal_source': 'internal'},
'submodule': {'allow_insecure_https': False,
'pki': '192.168.56.112'},
'provider': {'factory_configuration_filename': 'infra.json',
'packer_filename': 'recipe.json',
'risotto_images_dir': IMAGE_PATH},
}
def get_config(): def get_config():
return {'database': {'dsn': dsn_factory(RISOTTO_DB_NAME, RISOTTO_DB_USER, RISOTTO_DB_PASSWORD), return _config
'tiramisu_dsn': dsn_factory(TIRAMISU_DB_NAME, TIRAMISU_DB_USER, TIRAMISU_DB_PASSWORD),
},
'http_server': {'port': 8080,
'default_user': DEFAULT_USER},
'global': {'message_root_path': MESSAGE_PATH,
'configurations_dir': CONFIGURATION_DIR,
'debug': True,
'internal_user': 'internal',
'check_role': True,
'admin_user': DEFAULT_USER,
'sql_dir': SQL_DIR},
'source': {'root_path': SRV_SEED_PATH},
'cache': {'root_path': CACHE_ROOT_PATH},
'servermodel': {'internal_source': 'internal',
'internal_distribution': 'last',
'internal_release_name': 'none'},
'submodule': {'allow_insecure_https': False,
'pki': '192.168.56.112'},
'provider': {'factory_configuration_dir': PROVIDER_FACTORY_CONFIG_DIR,
'factory_configuration_filename': 'infra.json'},
}

View File

@ -1,3 +1,13 @@
class Context: class Context:
def __init__(self): def __init__(self):
self.paths = [] self.paths = []
self.context_id = None
self.start_id = None
def copy(self):
context = Context()
for key, value in self.__dict__.items():
if key.startswith('__'):
continue
setattr(context, key, value)
return context

View File

@ -1,59 +1,342 @@
from os import listdir, makedirs
from os.path import join, isdir, isfile
from shutil import rmtree
from traceback import print_exc
from typing import Dict
from rougail import RougailConvert, RougailConfig, RougailUpgrade
try:
from tiramisu3 import Storage, Config
except:
from tiramisu import Storage, Config
from .config import get_config from .config import get_config
from .dispatcher import dispatcher from .utils import _, tiramisu_display_name
from .logger import log
from .dispatcher import get_dispatcher
from .context import Context from .context import Context
from .remote import remote
from . import services
from .utils import _ RougailConfig['variable_namespace'] = 'configuration'
class Controller: class Controller:
"""Common controller used to add a service in Risotto """Common controller used to add a service in Risotto
""" """
def __init__(self, def __init__(self,
test: bool): test: bool,
self.risotto_modules = services.get_services_list() ) -> None:
self.dispatcher = get_dispatcher()
async def call(self, async def call(self,
uri: str, uri: str,
risotto_context: Context, risotto_context: Context,
*args, *args,
**kwargs): **kwargs,
):
""" a wrapper to dispatcher's call""" """ a wrapper to dispatcher's call"""
version, module, message = uri.split('.', 2)
uri = module + '.' + message
if args: if args:
raise ValueError(_(f'the URI "{uri}" can only be called with keyword arguments')) raise ValueError(_(f'the URI "{uri}" can only be called with keyword arguments'))
if module not in self.risotto_modules: current_uri = risotto_context.paths[-1]
return await remote.remove_call(module, current_module = risotto_context.module
version, version, message = uri.split('.', 1)
message, module = message.split('.', 1)[0]
kwargs) if current_module != module:
return await dispatcher.call(version, raise ValueError(_(f'cannot call to external module ("{module}") to the URI "{uri}" from "{current_module}"'))
uri, return await self.dispatcher.call(version,
risotto_context, message,
**kwargs) risotto_context,
**kwargs,
)
async def publish(self, async def publish(self,
uri: str, uri: str,
risotto_context: Context, risotto_context: Context,
*args, *args,
**kwargs): **kwargs,
):
""" a wrapper to dispatcher's publish""" """ a wrapper to dispatcher's publish"""
version, module, submessage = uri.split('.', 2)
version, message = uri.split('.', 1)
if args: if args:
raise ValueError(_(f'the URI "{uri}" can only be published with keyword arguments')) raise ValueError(_(f'the URI "{uri}" can only be published with keyword arguments'))
if module not in self.risotto_modules: version, message = uri.split('.', 1)
await remote.remove_call(module, await self.dispatcher.publish(version,
version, message,
submessage, risotto_context,
kwargs) **kwargs,
else: )
await dispatcher.publish(version,
message, async def check_role(self,
risotto_context, uri: str,
**kwargs) username: str,
**kwargs: dict,
) -> None:
# create a new config
async with await Config(self.dispatcher.option) as config:
await config.property.read_write()
await config.option('message').value.set(uri)
subconfig = config.option(uri)
for key, value in kwargs.items():
try:
await subconfig.option(key).value.set(value)
except AttributeError:
if get_config()['global']['debug']:
print_exc()
raise ValueError(_(f'unknown parameter in "{uri}": "{key}"'))
except ValueOptionError as err:
raise ValueError(_(f'invalid parameter in "{uri}": {err}'))
await self.dispatcher.check_role(subconfig,
username,
uri,
)
async def on_join(self, async def on_join(self,
risotto_context): risotto_context,
):
pass pass
class TiramisuController(Controller):
def __init__(self,
test: bool,
) -> None:
self.source_imported = None
if not 'dataset_name' in vars(self):
raise Exception(f'please specify "dataset_name" to "{self.__class__.__name__}"')
self.tiramisu_cache_root_path = join(get_config()['cache']['root_path'], self.dataset_name)
super().__init__(test)
self.internal_source_name = get_config()['servermodel']['internal_source']
if not test:
db_conf = get_config()['database']['tiramisu_dsn']
self.save_storage = Storage(engine='postgres')
self.save_storage.setting(dsn=db_conf)
if self.dataset_name != 'servermodel':
self.optiondescription = None
self.dispatcher.set_function('v1.setting.dataset.updated',
None,
TiramisuController.dataset_updated,
self.__class__.__module__,
)
async def on_join(self,
risotto_context: Context,
) -> None:
if isdir(self.tiramisu_cache_root_path):
await self.load_datas(risotto_context)
async def dataset_updated(self,
risotto_context: Context,
) -> Dict:
await self.gen_dictionaries(risotto_context)
await self.load_datas(risotto_context)
async def gen_dictionaries(self,
risotto_context: Context,
) -> None:
sources = await self.get_sources(risotto_context)
source_imported = sources != [self.internal_source_name]
if source_imported and self.source_imported is False:
await self.load_datas(risotto_context)
self.source_imported = source_imported
if not self.source_imported:
return
self._aggregate_tiramisu_funcs(sources)
self._convert_dictionaries_to_tiramisu(sources)
async def get_sources(self,
risotto_context: Context,
) -> None:
return await self.call('v1.setting.source.list',
risotto_context,
)
def _aggregate_tiramisu_funcs(self,
sources: list,
) -> None:
dest_file = join(self.tiramisu_cache_root_path, 'funcs.py')
if not isdir(self.tiramisu_cache_root_path):
makedirs(self.tiramisu_cache_root_path)
with open(dest_file, 'wb') as funcs:
funcs.write(b"""try:
from tiramisu3 import valid_network_netmask, valid_ip_netmask, valid_broadcast, valid_in_network, valid_not_equal as valid_differ, valid_not_equal, calc_value
except:
from tiramisu import valid_network_netmask, valid_ip_netmask, valid_broadcast, valid_in_network, valid_not_equal as valid_differ, valid_not_equal, calc_value
""")
for source in sources:
root_path = join(source['source_directory'],
self.dataset_name,
)
if not isdir(root_path):
continue
for service in listdir(root_path):
path = join(root_path,
service,
'funcs',
)
if not isdir(path):
continue
for filename in listdir(path):
if not filename.endswith('.py'):
continue
filename_path = join(path, filename)
with open(filename_path, 'rb') as fh:
funcs.write(f'# {filename_path}\n'.encode())
funcs.write(fh.read())
funcs.write(b'\n')
def _convert_dictionaries_to_tiramisu(self,
sources: list,
) -> None:
funcs_file = join(self.tiramisu_cache_root_path, 'funcs.py')
tiramisu_file = join(self.tiramisu_cache_root_path, 'tiramisu.py')
dictionaries_dir = join(self.tiramisu_cache_root_path, 'dictionaries')
extras_dictionaries_dir = join(self.tiramisu_cache_root_path, 'extra_dictionaries')
if isdir(dictionaries_dir):
rmtree(dictionaries_dir)
makedirs(dictionaries_dir)
if isdir(extras_dictionaries_dir):
rmtree(extras_dictionaries_dir)
makedirs(extras_dictionaries_dir)
extras = []
upgrade = RougailUpgrade()
for source in sources:
root_path = join(source['source_directory'],
self.dataset_name,
)
if not isdir(root_path):
continue
for service in listdir(root_path):
# upgrade dictionaries
path = join(root_path,
service,
'dictionaries',
)
if not isdir(path):
continue
upgrade.load_xml_from_folders(path,
dictionaries_dir,
RougailConfig['variable_namespace'],
)
for service in listdir(root_path):
# upgrade extra dictionaries
path = join(root_path,
service,
'extras',
)
if not isdir(path):
continue
for namespace in listdir(path):
extra_dir = join(path,
namespace,
)
if not isdir(extra_dir):
continue
extra_dictionaries_dir = join(extras_dictionaries_dir,
namespace,
)
if not isdir(extra_dictionaries_dir):
makedirs(extra_dictionaries_dir)
extras.append((namespace, [extra_dictionaries_dir]))
upgrade.load_xml_from_folders(extra_dir,
extra_dictionaries_dir,
namespace,
)
del upgrade
config = RougailConfig.copy()
config['functions_file'] = funcs_file
config['dictionaries_dir'] = [dictionaries_dir]
config['extra_dictionaries'] = {}
for extra in extras:
config['extra_dictionaries'][extra[0]] = extra[1]
eolobj = RougailConvert(rougailconfig=config)
eolobj.save(tiramisu_file)
async def load(self,
risotto_context: Context,
name: str,
to_deploy: bool=False,
) -> Config:
if self.optiondescription is None:
# use file in cache
tiramisu_file = join(self.tiramisu_cache_root_path, 'tiramisu.py')
if not isfile(tiramisu_file):
raise Exception(_(f'unable to load the "{self.dataset_name}" configuration, is dataset loaded?'))
with open(tiramisu_file) as fileio:
tiramisu_locals = {}
try:
exec(fileio.read(), None, tiramisu_locals)
except Exception as err:
raise Exception(_(f'unable to load tiramisu file {tiramisu_file}: {err}'))
self.optiondescription = tiramisu_locals['option_0']
del tiramisu_locals
try:
letter = self.dataset_name[0]
if not to_deploy:
session_id = f'{letter}_{name}'
else:
session_id = f'{letter}td_{name}'
config = await Config(self.optiondescription,
session_id=session_id,
storage=self.save_storage,
display_name=tiramisu_display_name,
)
# change default rights
await config.property.read_only()
await config.permissive.add('basic')
await config.permissive.add('normal')
await config.permissive.add('expert')
# set information and owner
await config.owner.set(session_id)
await config.information.set(f'{self.dataset_name}_name', name)
except Exception as err:
if get_config()['global']['debug']:
print_exc()
msg = _(f'unable to load config for {self.dataset_name} "{name}": {err}')
await log.error_msg(risotto_context,
None,
msg,
)
return config
async def _deploy_configuration(self,
dico: dict,
) -> None:
config_std = dico['config_to_deploy']
config = dico['config']
# when deploy, calculate force_store_value
ro = await config_std.property.getdefault('read_only', 'append')
if 'force_store_value' not in ro:
await config_std.property.read_write()
if self.dataset_name == 'servermodel':
# server_deployed should be hidden
await config_std.forcepermissive.option('configuration.general.server_deployed').value.set(True)
ro = frozenset(list(ro) + ['force_store_value'])
rw = await config_std.property.getdefault('read_write', 'append')
rw = frozenset(list(rw) + ['force_store_value'])
await config_std.property.setdefault(ro, 'read_only', 'append')
await config_std.property.setdefault(rw, 'read_write', 'append')
await config_std.property.read_only()
# copy informations from 'to deploy' configuration to configuration
await config.information.importation(await config_std.information.exportation())
await config.value.importation(await config_std.value.exportation())
await config.permissive.importation(await config_std.permissive.exportation())
await config.property.importation(await config_std.property.exportation())
async def build_configuration(self,
config: Config,
) -> dict:
configuration = {}
for option in await config.option.list('optiondescription'):
name = await option.option.name()
if name == 'services':
continue
if name == RougailConfig['variable_namespace']:
fullpath = False
flatten = True
else:
fullpath = True
flatten = False
configuration.update(await option.value.dict(leader_to_list=True, fullpath=fullpath, flatten=flatten))
return configuration

View File

@ -1,4 +1,10 @@
from tiramisu import Config try:
from tiramisu3 import Config
from tiramisu3.error import ValueOptionError
except:
from tiramisu import Config
from tiramisu.error import ValueOptionError
from asyncio import get_event_loop, ensure_future
from traceback import print_exc from traceback import print_exc
from copy import copy from copy import copy
from typing import Dict, Callable, List, Optional from typing import Dict, Callable, List, Optional
@ -10,8 +16,9 @@ from .logger import log
from .config import get_config from .config import get_config
from .context import Context from .context import Context
from . import register from . import register
from .remote import Remote
import asyncpg
DISPATCHER = None
class CallDispatcher: class CallDispatcher:
@ -26,68 +33,101 @@ class CallDispatcher:
if response.impl_get_information('multi'): if response.impl_get_information('multi'):
if not isinstance(returns, list): if not isinstance(returns, list):
err = _(f'function {module_name}.{function_name} has to return a list') err = _(f'function {module_name}.{function_name} has to return a list')
await log.error_msg(risotto_context, kwargs, err) raise CallError(err)
raise CallError(str(err))
else: else:
if not isinstance(returns, dict): if not isinstance(returns, dict):
await log.error_msg(risotto_context, kwargs, returns)
err = _(f'function {module_name}.{function_name} has to return a dict') err = _(f'function {module_name}.{function_name} has to return a dict')
await log.error_msg(risotto_context, kwargs, err) raise CallError(err)
raise CallError(str(err))
returns = [returns] returns = [returns]
if response is None: if response is None:
raise Exception('hu?') raise Exception('hu?')
else: else:
for ret in returns: for ret in returns:
async with await Config(response, display_name=lambda self, dyn_name: self.impl_getname()) as config: async with await Config(response, display_name=lambda self, dyn_name, suffix: self.impl_getname()) as config:
await config.property.read_write() await config.property.read_write()
key = None
try: try:
for key, value in ret.items(): for key, value in ret.items():
await config.option(key).value.set(value) await config.option(key).value.set(value)
except AttributeError: except AttributeError as err:
err = _(f'function {module_name}.{function_name} return the unknown parameter "{key}"') if key is not None:
await log.error_msg(risotto_context, kwargs, err) err = _(f'function {module_name}.{function_name} return the unknown parameter "{key}" for the uri "{risotto_context.version}.{risotto_context.message}"')
raise CallError(str(err)) else:
except ValueError: err = _(f'function {module_name}.{function_name} return unconsistency data "{err}" for the uri "{risotto_context.version}.{risotto_context.message}"')
err = _(f'function {module_name}.{function_name} return the parameter "{key}" with an unvalid value "{value}"') raise CallError(err)
await log.error_msg(risotto_context, kwargs, err) except ValueError as err:
raise CallError(str(err)) if key is not None:
err = _(f'function {module_name}.{function_name} return the invalid parameter "{key}" for the uri "{risotto_context.version}.{risotto_context.message}": {err}')
else:
err = _(f'function {module_name}.{function_name} return unconsistency error for the uri "{risotto_context.version}.{risotto_context.message}": {err}')
raise CallError(err)
await config.property.read_only() await config.property.read_only()
mandatories = await config.value.mandatory() mandatories = await config.value.mandatory()
if mandatories: if mandatories:
mand = [mand.split('.')[-1] for mand in mandatories] mand = [mand.split('.')[-1] for mand in mandatories]
raise ValueError(_(f'missing parameters in response: {mand} in message "{risotto_context.message}"')) raise ValueError(_(f'missing parameters in response of the uri "{risotto_context.version}.{risotto_context.message}": {mand} in message'))
try: try:
await config.value.dict() await config.value.dict()
except Exception as err: except Exception as err:
err = _(f'function {module_name}.{function_name} return an invalid response {err}') err = _(f'function {module_name}.{function_name} return an invalid response {err} for the uri "{risotto_context.version}.{risotto_context.message}"')
await log.error_msg(risotto_context, kwargs, err) raise CallError(err)
raise CallError(str(err))
async def call(self, async def call(self,
version: str, version: str,
message: str, message: str,
old_risotto_context: Context, old_risotto_context: Context,
check_role: bool=False, check_role: bool=False,
**kwargs): internal: bool=True,
**kwargs,
):
""" execute the function associate with specified uri """ execute the function associate with specified uri
arguments are validate before arguments are validate before
""" """
risotto_context = self.build_new_context(old_risotto_context, risotto_context = self.build_new_context(old_risotto_context.__dict__,
version, version,
message, message,
'rpc') 'rpc',
function_objs = [self.messages[version][message]] )
# do not start a new database connection if version not in self.messages:
raise CallError(_(f'cannot find version of message "{version}"'))
if message not in self.messages[version]:
raise CallError(_(f'cannot find message "{version}.{message}"'))
function_obj = self.messages[version][message]
# log
function_name = function_obj['function'].__name__
info_msg = _(f"call function {function_obj['full_module_name']}.{function_name}")
if hasattr(old_risotto_context, 'connection'): if hasattr(old_risotto_context, 'connection'):
# do not start a new database connection
risotto_context.connection = old_risotto_context.connection risotto_context.connection = old_risotto_context.connection
return await self.launch(version, await log.start(risotto_context,
message, kwargs,
risotto_context, info_msg,
check_role, )
kwargs, await self.check_message_type(risotto_context,
function_objs) kwargs,
)
config_arguments = await self.load_kwargs_to_config(risotto_context,
f'{version}.{message}',
kwargs,
check_role,
internal,
)
try:
ret = await self.launch(risotto_context,
kwargs,
config_arguments,
function_obj,
)
await log.success(risotto_context,
ret,
)
except Exception as err:
await log.failed(risotto_context,
str(err),
)
raise CallError(err) from err
else: else:
error = None
try: try:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
await connection.set_type_codec( await connection.set_type_codec(
@ -98,145 +138,248 @@ class CallDispatcher:
) )
risotto_context.connection = connection risotto_context.connection = connection
async with connection.transaction(): async with connection.transaction():
return await self.launch(version, try:
message, await log.start(risotto_context,
risotto_context, kwargs,
check_role, info_msg,
kwargs, )
function_objs) await self.check_message_type(risotto_context,
kwargs,
)
config_arguments = await self.load_kwargs_to_config(risotto_context,
f'{version}.{message}',
kwargs,
check_role,
internal,
)
ret = await self.launch(risotto_context,
kwargs,
config_arguments,
function_obj,
)
# log the success
await log.success(risotto_context,
ret,
)
if not internal and isinstance(ret, dict):
ret['context_id'] = risotto_context.context_id
except CallError as err:
if get_config()['global']['debug']:
print_exc()
await log.failed(risotto_context,
str(err),
)
raise err from err
except CallError as err: except CallError as err:
raise err error = err
except Exception as err: except Exception as err:
# if there is a problem with arguments, just send an error and do nothing # if there is a problem with arguments, just send an error and do nothing
if get_config()['global']['debug']: if get_config()['global']['debug']:
print_exc() print_exc()
async with self.pool.acquire() as connection: await log.failed(risotto_context,
await connection.set_type_codec( str(err),
'json', )
encoder=dumps, error = err
decoder=loads, if error:
schema='pg_catalog' if not internal:
) err = CallError(str(error))
risotto_context.connection = connection err.context_id = risotto_context.context_id
async with connection.transaction(): else:
await log.error_msg(risotto_context, kwargs, err) err = error
raise err raise err from error
return ret
class PublishDispatcher: class PublishDispatcher:
async def register_remote(self) -> None:
print()
print(_('======== Registered remote event ========'))
self.listened_connection = await self.pool.acquire()
for version, messages in self.messages.items():
for message, message_infos in messages.items():
# event not emit locally
if message_infos['pattern'] == 'event' and 'functions' in message_infos and message_infos['functions']:
uri = f'{version}.{message}'
print(f' - {uri}')
await self.listened_connection.add_listener(uri,
self.to_async_publish,
)
async def publish(self, async def publish(self,
version: str, version: str,
message: str, message: str,
old_risotto_context: Context, risotto_context: Context,
check_role: bool=False, **kwargs,
**kwargs) -> None: ) -> None:
risotto_context = self.build_new_context(old_risotto_context, if version not in self.messages or message not in self.messages[version]:
version, raise ValueError(_(f'cannot find URI "{version}.{message}"'))
message,
'event') # publish to remote
try: remote_kw = dumps({'kwargs': kwargs,
function_objs = self.messages[version][message].get('functions', []) 'context': {'username': risotto_context.username,
except KeyError: 'paths': risotto_context.paths,
raise ValueError(_(f'cannot find message {version}.{message}')) 'context_id': risotto_context.context_id,
# do not start a new database connection }
if hasattr(old_risotto_context, 'connection'): })
risotto_context.connection = old_risotto_context.connection # FIXME should be better :/
return await self.launch(version, remote_kw = remote_kw.replace("'", "''")
message, await risotto_context.connection.execute(f'NOTIFY "{version}.{message}", \'{remote_kw}\'')
risotto_context,
check_role, def to_async_publish(self,
kwargs, con: 'asyncpg.connection.Connection',
function_objs) pid: int,
else: uri: str,
payload: str,
) -> None:
version, message = uri.split('.', 1)
loop = get_event_loop()
remote_kw = loads(payload)
for function_obj in self.messages[version][message]['functions']:
risotto_context = self.build_new_context(remote_kw['context'],
version,
message,
'event',
)
callback = self.get_callback(version, message, function_obj, risotto_context, remote_kw['kwargs'],)
loop.call_soon(callback)
def get_callback(self,
version,
message,
function_obj,
risotto_context,
kwargs,
):
return lambda: ensure_future(self._publish(version,
message,
function_obj,
risotto_context,
**kwargs,
))
async def _publish(self,
version: str,
message: str,
function_obj,
risotto_context: Context,
**kwargs,
) -> None:
config_arguments = await self.load_kwargs_to_config(risotto_context,
f'{version}.{message}',
kwargs,
False,
False,
)
async with self.pool.acquire() as connection:
await connection.set_type_codec(
'json',
encoder=dumps,
decoder=loads,
schema='pg_catalog'
)
risotto_context.connection = connection
function_name = function_obj['function'].__name__
info_msg = _(f"call function {function_obj['full_module_name']}.{function_name}")
try: try:
async with self.pool.acquire() as connection: async with connection.transaction():
await connection.set_type_codec( try:
'json', await log.start(risotto_context,
encoder=dumps, kwargs,
decoder=loads, info_msg,
schema='pg_catalog' )
) await self.check_message_type(risotto_context,
risotto_context.connection = connection kwargs,
async with connection.transaction(): )
return await self.launch(version, await self.launch(risotto_context,
message, kwargs,
risotto_context, config_arguments,
check_role, function_obj,
kwargs, )
function_objs) # log the success
except CallError as err: await log.success(risotto_context)
raise err except CallError as err:
if get_config()['global']['debug']:
print_exc()
await log.failed(risotto_context,
str(err),
)
except CallError:
pass
except Exception as err: except Exception as err:
# if there is a problem with arguments, just send an error and do nothing # if there is a problem with arguments, log and do nothing
if get_config()['global']['debug']: if get_config()['global']['debug']:
print_exc() print_exc()
async with self.pool.acquire() as connection: await log.failed(risotto_context,
await connection.set_type_codec( str(err),
'json', )
encoder=dumps,
decoder=loads,
schema='pg_catalog'
)
risotto_context.connection = connection
async with connection.transaction():
await log.error_msg(risotto_context, kwargs, err)
raise err
class Dispatcher(register.RegisterDispatcher, class Dispatcher(register.RegisterDispatcher,
Remote,
CallDispatcher, CallDispatcher,
PublishDispatcher): PublishDispatcher,
):
""" Manage message (call or publish) """ Manage message (call or publish)
so launch a function when a message is called so launch a function when a message is called
""" """
def build_new_context(self, def build_new_context(self,
old_risotto_context: Context, context: dict,
version: str, version: str,
message: str, message: str,
type: str): type: str,
) -> Context:
""" This is a new call or a new publish, so create a new context """ This is a new call or a new publish, so create a new context
""" """
uri = version + '.' + message uri = version + '.' + message
risotto_context = Context() risotto_context = Context()
risotto_context.username = old_risotto_context.username risotto_context.username = context['username']
risotto_context.paths = copy(old_risotto_context.paths) risotto_context.paths = copy(context['paths'])
risotto_context.context_id = context['context_id']
risotto_context.paths.append(uri) risotto_context.paths.append(uri)
risotto_context.uri = uri risotto_context.uri = uri
risotto_context.type = type risotto_context.type = type
risotto_context.message = message risotto_context.message = message
risotto_context.version = version risotto_context.version = version
risotto_context.pool = self.pool
return risotto_context return risotto_context
async def check_message_type(self, async def check_message_type(self,
risotto_context: Context, risotto_context: Context,
kwargs: Dict): kwargs: Dict,
) -> None:
if self.messages[risotto_context.version][risotto_context.message]['pattern'] != risotto_context.type: if self.messages[risotto_context.version][risotto_context.message]['pattern'] != risotto_context.type:
msg = _(f'{risotto_context.uri} is not a {risotto_context.type} message') msg = _(f'{risotto_context.uri} is not a {risotto_context.type} message')
await log.error_msg(risotto_context, kwargs, msg)
raise CallError(msg) raise CallError(msg)
async def load_kwargs_to_config(self, async def load_kwargs_to_config(self,
risotto_context: Context, risotto_context: Context,
uri: str, uri: str,
kwargs: Dict, kwargs: Dict,
check_role: bool): check_role: bool,
internal: bool,
):
""" create a new Config et set values to it """ create a new Config et set values to it
""" """
# create a new config # create a new config
async with await Config(self.option) as config: async with await Config(self.option) as config:
await config.property.read_write() await config.property.read_write()
# set message's option # set message's option
await config.option('message').value.set(risotto_context.message) await config.option('message').value.set(uri)
# store values # store values
subconfig = config.option(risotto_context.message) subconfig = config.option(uri)
extra_parameters = {}
for key, value in kwargs.items(): for key, value in kwargs.items():
try: if not internal or not key.startswith('_'):
await subconfig.option(key).value.set(value) try:
except AttributeError: await subconfig.option(key).value.set(value)
if get_config()['global']['debug']: except AttributeError:
print_exc() if get_config()['global']['debug']:
raise ValueError(_(f'unknown parameter in "{uri}": "{key}"')) print_exc()
raise ValueError(_(f'unknown parameter in "{uri}": "{key}"'))
except ValueOptionError as err:
raise ValueError(_(f'invalid parameter in "{uri}": {err}'))
else:
extra_parameters[key] = value
# check mandatories options # check mandatories options
if check_role and get_config().get('global').get('check_role'): if check_role and get_config().get('global').get('check_role'):
await self.check_role(subconfig, await self.check_role(subconfig,
@ -248,7 +391,10 @@ class Dispatcher(register.RegisterDispatcher,
mand = [mand.split('.')[-1] for mand in mandatories] mand = [mand.split('.')[-1] for mand in mandatories]
raise ValueError(_(f'missing parameters in "{uri}": {mand}')) raise ValueError(_(f'missing parameters in "{uri}": {mand}'))
# return complete an validated kwargs # return complete an validated kwargs
return await subconfig.value.dict() parameters = await subconfig.value.dict()
if extra_parameters:
parameters.update(extra_parameters)
return parameters
def get_service(self, def get_service(self,
name: str): name: str):
@ -257,14 +403,15 @@ class Dispatcher(register.RegisterDispatcher,
async def check_role(self, async def check_role(self,
config: Config, config: Config,
user_login: str, user_login: str,
uri: str) -> None: uri: str,
) -> None:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
async with connection.transaction(): async with connection.transaction():
# Verify if user exists and get ID # Verify if user exists and get ID
sql = ''' sql = '''
SELECT UserId SELECT UserId
FROM RisottoUser FROM UserUser
WHERE UserLogin = $1 WHERE Login = $1
''' '''
user_id = await connection.fetchval(sql, user_id = await connection.fetchval(sql,
user_login) user_login)
@ -281,8 +428,8 @@ class Dispatcher(register.RegisterDispatcher,
# Check role # Check role
select_role_uri = ''' select_role_uri = '''
SELECT RoleName SELECT RoleName
FROM URI, RoleURI FROM UserURI, UserRoleURI
WHERE URI.URIName = $1 AND RoleURI.URIId = URI.URIId WHERE UserURI.URIName = $1 AND UserRoleURI.URIId = UserURI.URIId
''' '''
select_role_user = ''' select_role_user = '''
SELECT RoleAttribute, RoleAttributeValue SELECT RoleAttribute, RoleAttributeValue
@ -302,63 +449,55 @@ class Dispatcher(register.RegisterDispatcher,
raise NotAllowedError(_(f'You ({user_login}) don\'t have any authorisation to access to "{uri}"')) raise NotAllowedError(_(f'You ({user_login}) don\'t have any authorisation to access to "{uri}"'))
async def launch(self, async def launch(self,
version: str,
message: str,
risotto_context: Context, risotto_context: Context,
check_role: bool,
kwargs: Dict, kwargs: Dict,
function_objs: List) -> Optional[Dict]: config_arguments: dict,
await self.check_message_type(risotto_context, function_obj: Callable,
kwargs) ) -> Optional[Dict]:
config_arguments = await self.load_kwargs_to_config(risotto_context, # so send the message
f'{version}.{message}', function = function_obj['function']
kwargs, risotto_context.module = function_obj['module'].split('.', 1)[0]
check_role) # build argument for this function
# config is ok, so send the message if risotto_context.type == 'rpc':
for function_obj in function_objs: kw = config_arguments
function = function_obj['function'] else:
module_name = function.__module__.split('.')[-2] kw = {}
function_name = function.__name__ for key, value in config_arguments.items():
info_msg = _(f'in module {module_name}.{function_name}') if key in function_obj['arguments']:
# build argument for this function kw[key] = value
if risotto_context.type == 'rpc':
kw = config_arguments kw['risotto_context'] = risotto_context
# launch
returns = await function(self.get_service(function_obj['module']), **kw)
if risotto_context.type == 'rpc':
# valid returns
await self.valid_call_returns(risotto_context,
function,
returns,
kwargs,
)
# notification
if function_obj.get('notification'):
if returns is None:
raise Exception(_(f'function "{function_obj["full_module_name"]}.{function_obj["function"].__name__}" must returns something for {function_obj["notification"]}!'))
notif_version, notif_message = function_obj['notification'].split('.', 1)
if not isinstance(returns, list):
send_returns = [returns]
else: else:
kw = {} send_returns = returns
for key, value in config_arguments.items(): for ret in send_returns:
if key in function_obj['arguments']: await self.publish(notif_version,
kw[key] = value notif_message,
risotto_context,
**ret,
)
if risotto_context.type == 'rpc':
return returns
def get_dispatcher():
kw['risotto_context'] = risotto_context global DISPATCHER
returns = await function(self.injected_self[function_obj['module']], **kw) if DISPATCHER is None:
if risotto_context.type == 'rpc': DISPATCHER = Dispatcher()
# valid returns register.dispatcher = DISPATCHER
await self.valid_call_returns(risotto_context, return DISPATCHER
function,
returns,
kwargs)
# log the success
await log.info_msg(risotto_context,
{'arguments': kwargs,
'returns': returns},
info_msg)
# notification
if function_obj.get('notification'):
notif_version, notif_message = function_obj['notification'].split('.', 1)
if not isinstance(returns, list):
send_returns = [returns]
else:
send_returns = returns
for ret in send_returns:
await self.publish(notif_version,
notif_message,
risotto_context,
**ret)
if risotto_context.type == 'rpc':
return returns
dispatcher = Dispatcher()
register.dispatcher = dispatcher

View File

@ -1,46 +1,64 @@
from aiohttp.web import Application, Response, get, post, HTTPBadRequest, HTTPInternalServerError, HTTPNotFound from aiohttp.web import Application, Response, get, post, HTTPBadRequest, HTTPInternalServerError, HTTPNotFound, static
from json import dumps from json import dumps
from traceback import print_exc from traceback import print_exc
from tiramisu import Config, default_storage try:
from tiramisu3 import Config, default_storage
except:
from tiramisu import Config, default_storage
from .dispatcher import dispatcher from .dispatcher import get_dispatcher
from .utils import _ from .utils import _
from .context import Context from .context import Context
from .error import CallError, NotAllowedError, RegistrationError from .error import CallError, NotAllowedError, RegistrationError
from .message import get_messages from .message import get_messages
from .logger import log #from .logger import log
from .config import get_config from .config import get_config
from . import services from . import services
extra_routes = {} extra_routes = {}
extra_statics = {}
RISOTTO_MODULES = services.get_services_list()
def create_context(request): def create_context(request):
risotto_context = Context() risotto_context = Context()
risotto_context.username = request.match_info.get('username', if 'username' in dict(request.match_info):
get_config()['http_server']['default_user']) username = request.match_info['username']
elif 'username' in request.headers:
username = request.headers['username']
else:
username = get_config()['http_server']['default_user']
risotto_context.username = username
return risotto_context return risotto_context
def register(version: str, def register(version: str,
path: str): path: str,
):
""" Decorator to register function to the http route """ Decorator to register function to the http route
""" """
def decorator(function): def decorator(function):
if path in extra_routes: if path in extra_routes:
raise RegistrationError(f'the route {path} is already registered') raise RegistrationError(f'the route "{path}" is already registered')
extra_routes[path] = {'function': function, extra_routes[path] = {'function': function,
'version': version} 'version': version,
}
return decorator return decorator
def register_static(path: str,
directory: str,
) -> None:
if path in extra_statics:
raise RegistrationError(f'the static path "{path}" is already registered')
extra_statics[path] = directory
class extra_route_handler: class extra_route_handler:
async def __new__(cls, request): async def __new__(cls,
request,
):
kwargs = dict(request.match_info) kwargs = dict(request.match_info)
kwargs['request'] = request kwargs['request'] = request
kwargs['risotto_context'] = create_context(request) kwargs['risotto_context'] = create_context(request)
@ -50,8 +68,10 @@ class extra_route_handler:
function_name = cls.function.__module__ function_name = cls.function.__module__
# if not 'api' function # if not 'api' function
if function_name != 'risotto.http': if function_name != 'risotto.http':
module_name = function_name.split('.')[-2] risotto_module_name, submodule_name = function_name.split('.', 2)[:-1]
kwargs['self'] = dispatcher.injected_self[module_name] module_name = risotto_module_name.split('_')[-1]
dispatcher = get_dispatcher()
kwargs['self'] = dispatcher.injected_self[module_name + '.' + submodule_name]
try: try:
returns = await cls.function(**kwargs) returns = await cls.function(**kwargs)
except NotAllowedError as err: except NotAllowedError as err:
@ -65,7 +85,8 @@ class extra_route_handler:
# await log.info_msg(kwargs['risotto_context'], # await log.info_msg(kwargs['risotto_context'],
# dict(request.match_info)) # dict(request.match_info))
return Response(text=dumps(returns), return Response(text=dumps(returns),
content_type='application/json') content_type='application/json',
)
async def handle(request): async def handle(request):
@ -73,6 +94,7 @@ async def handle(request):
risotto_context = create_context(request) risotto_context = create_context(request)
kwargs = await request.json() kwargs = await request.json()
try: try:
dispatcher = get_dispatcher()
pattern = dispatcher.messages[version][message]['pattern'] pattern = dispatcher.messages[version][message]['pattern']
if pattern == 'rpc': if pattern == 'rpc':
method = dispatcher.call method = dispatcher.call
@ -82,48 +104,73 @@ async def handle(request):
message, message,
risotto_context, risotto_context,
check_role=True, check_role=True,
**kwargs) internal=False,
except NotAllowedError as err: **kwargs,
raise HTTPNotFound(reason=str(err)) )
except CallError as err:
raise HTTPBadRequest(reason=str(err).replace('\n', ' '))
except Exception as err: except Exception as err:
if get_config()['global']['debug']: context_id = None
print_exc() if isinstance(err, NotAllowedError):
raise HTTPInternalServerError(reason=str(err)) error_type = HTTPNotFound
return Response(text=dumps({'response': text}), elif isinstance(err, CallError):
content_type='application/json') error_type = HTTPBadRequest
context_id = err.context_id
else:
if get_config()['global']['debug']:
print_exc()
error_type = HTTPInternalServerError
response = {'type': 'error',
'reason': str(err).replace('\n', ' '),
}
if context_id is not None:
response['context_id'] = context_id
err = dumps({'response': response,
'type': 'error',
})
raise error_type(text=err,
content_type='application/json',
)
return Response(text=dumps({'response': text,
'type': 'success',
}),
content_type='application/json',
)
async def api(request, async def api(request,
risotto_context): risotto_context,
global tiramisu ):
if not tiramisu: global TIRAMISU
if not TIRAMISU:
# check all URI that have an associated role # check all URI that have an associated role
# all URI without role is concidered has a private URI # all URI without role is concidered has a private URI
uris = [] uris = []
dispatcher = get_dispatcher()
async with dispatcher.pool.acquire() as connection: async with dispatcher.pool.acquire() as connection:
async with connection.transaction(): async with connection.transaction():
# Check role with ACL # Check role with ACL
sql = ''' sql = '''
SELECT URI.URIName SELECT UserURI.URIName
FROM URI, RoleURI FROM UserURI, UserRoleURI
WHERE RoleURI.URIId = URI.URIId WHERE UserRoleURI.URIId = UserURI.URIId
''' '''
uris = [uri['uriname'] for uri in await connection.fetch(sql)] uris = [uri['uriname'] for uri in await connection.fetch(sql)]
async with await Config(get_messages(current_module_names=RISOTTO_MODULES, risotto_modules = services.get_services_list()
async with await Config(get_messages(current_module_names=risotto_modules,
load_shortarg=True, load_shortarg=True,
current_version=risotto_context.version, current_version=risotto_context.version,
uris=uris)[1]) as config: uris=uris,
)[1],
display_name=lambda self, dyn_name, suffix: self.impl_getname()) as config:
await config.property.read_write() await config.property.read_write()
tiramisu = await config.option.dict(remotable='none') TIRAMISU = await config.option.dict(remotable='none')
return tiramisu return TIRAMISU
async def get_app(loop): async def get_app(loop):
""" build all routes """ build all routes
""" """
global extra_routes global extra_routes, extra_statics
dispatcher = get_dispatcher()
services.link_to_dispatcher(dispatcher) services.link_to_dispatcher(dispatcher)
app = Application(loop=loop) app = Application(loop=loop)
routes = [] routes = []
@ -135,9 +182,9 @@ async def get_app(loop):
versions.append(version) versions.append(version)
print() print()
print(_('======== Registered messages ========')) print(_('======== Registered messages ========'))
for message in messages: for message, message_infos in messages.items():
web_message = f'/api/{version}/{message}' web_message = f'/api/{version}/{message}'
pattern = dispatcher.messages[version][message]['pattern'] pattern = message_infos['pattern']
print(f' - {web_message} ({pattern})') print(f' - {web_message} ({pattern})')
routes.append(post(web_message, handle)) routes.append(post(web_message, handle))
print() print()
@ -145,10 +192,14 @@ async def get_app(loop):
for version in versions: for version in versions:
api_route = {'function': api, api_route = {'function': api,
'version': version, 'version': version,
'path': f'/api/{version}'} 'path': f'/api/{version}',
}
extra_handler = type(api_route['path'], (extra_route_handler,), api_route) extra_handler = type(api_route['path'], (extra_route_handler,), api_route)
routes.append(get(api_route['path'], extra_handler)) routes.append(get(api_route['path'], extra_handler))
print(f' - {api_route["path"]} (http_get)') print(f' - {api_route["path"]} (http_get)')
# last version is default version
routes.append(get('/api', extra_handler))
print(f' - /api (http_get)')
print() print()
if extra_routes: if extra_routes:
print(_('======== Registered extra routes ========')) print(_('======== Registered extra routes ========'))
@ -159,11 +210,22 @@ async def get_app(loop):
extra_handler = type(path, (extra_route_handler,), extra) extra_handler = type(path, (extra_route_handler,), extra)
routes.append(get(path, extra_handler)) routes.append(get(path, extra_handler))
print(f' - {path} (http_get)') print(f' - {path} (http_get)')
print() if extra_statics:
if not extra_routes:
print(_('======== Registered static routes ========'))
for path, directory in extra_statics.items():
routes.append(static(path, directory))
print(f' - {path} (static)')
del extra_routes del extra_routes
del extra_statics
app.router.add_routes(routes) app.router.add_routes(routes)
await dispatcher.register_remote()
print()
await dispatcher.on_join() await dispatcher.on_join()
return await loop.create_server(app.make_handler(), '*', get_config()['http_server']['port']) return await loop.create_server(app.make_handler(),
'*',
get_config()['http_server']['port'],
)
tiramisu = None TIRAMISU = None

378
src/risotto/image.py Normal file
View File

@ -0,0 +1,378 @@
from os import listdir, walk, makedirs
from os.path import isfile, isdir, join, dirname
from yaml import load, SafeLoader
from json import load as jload, dump as jdump
from time import time
from shutil import copy2, rmtree, move
from hashlib import sha512
from subprocess import Popen
from rougail import RougailConvert, RougailConfig, RougailUpgrade
try:
from tiramisu3 import Config
except:
from tiramisu import Config
from .utils import _
DATASET_PATH = '/usr/share/risotto/'
TMP_DIRECTORY = '/tmp'
PACKER_TMP_DIRECTORY = join(TMP_DIRECTORY, 'packer')
PACKER_FILE_NAME = 'recipe.json'
IMAGES_DIRECTORY = join(TMP_DIRECTORY, 'images')
FUNCTIONS = b"""try:
from tiramisu3 import valid_network_netmask, valid_ip_netmask, valid_broadcast, valid_in_network, valid_not_equal as valid_differ, valid_not_equal, calc_value
except:
from tiramisu import valid_network_netmask, valid_ip_netmask, valid_broadcast, valid_in_network, valid_not_equal as valid_differ, valid_not_equal, calc_value
# =============================================================
# fork of risotto-setting/src/risotto_setting/config/config.py
def get_password(**kwargs):
return 'password'
def get_ip(**kwargs):
return '1.1.1.1'
def get_chain(**kwargs):
return 'chain'
def get_certificates(**kwargs):
return []
def get_certificate(**kwargs):
return 'certificate'
def get_private_key(**kwargs):
return 'private_key'
def get_linked_configuration(**kwargs):
if 'test' in kwargs and kwargs['test']:
return kwargs['test'][0]
return 'configuration'
def zone_information(**kwargs):
return 'zone'
# =============================================================
"""
class Images:
def __init__(self,
image_dir: str=None,
tmp_dir: str=None,
):
if image_dir is None:
image_dir = IMAGES_DIRECTORY
self.image_dir = image_dir
if isdir(self.image_dir):
rmtree(self.image_dir)
if tmp_dir is None:
tmp_dir = PACKER_TMP_DIRECTORY
self.tmp_dir = tmp_dir
self.load_applications()
def load_applications(self) -> None:
self.build_images = []
self.applications = {}
for distrib in listdir(join(DATASET_PATH, 'seed')):
distrib_dir = join(DATASET_PATH, 'seed', distrib, 'applicationservice')
if not isdir(distrib_dir):
continue
for release in listdir(distrib_dir):
release_dir = join(distrib_dir, release)
if not isdir(release_dir):
continue
for applicationservice in listdir(release_dir):
applicationservice_dir = join(release_dir, applicationservice)
if not isdir(applicationservice_dir):
continue
if applicationservice in self.applications:
raise Exception('multi applicationservice')
with open(join(applicationservice_dir, 'applicationservice.yml')) as yaml:
app = load(yaml, Loader=SafeLoader)
self.applications[applicationservice] = {'path': applicationservice_dir,
'yml': app,
}
if 'service' in app and app['service']:
self.build_images.append(applicationservice)
def calc_depends(self,
dependencies: list,
appname,
key_is_name=False,
):
app = self.applications[appname]['yml']
if not 'depends' in app or not app['depends']:
return
for dependency in app['depends']:
if key_is_name:
key = appname
else:
key = self.applications[dependency]['path']
if key not in dependencies:
dependencies.insert(0, key)
self.calc_depends(dependencies, dependency, key_is_name)
def list_oses(self):
oses = set()
for build in self.build_images:
dependencies = [build]
self.calc_depends(dependencies, build, True)
for dependency in dependencies:
if isdir(join(self.applications[dependency]['path'], 'packer', 'os')):
oses.add(dependency)
break
for os in oses:
dependencies = [self.applications[os]['path']]
self.calc_depends(dependencies, os)
yield os, dependencies
def list_images(self):
for build in self.build_images:
dependencies = [self.applications[build]['path']]
self.calc_depends(dependencies, build)
yield build, dependencies
async def build(self) -> None:
if isdir(self.tmp_dir):
rmtree(self.tmp_dir)
image = Image(self.image_dir,
self.tmp_dir,
)
print(_('Build OSes'))
if not isdir(join(self.image_dir, 'os')):
makedirs(join(self.image_dir, 'os'))
for application, dependencies_path in self.list_oses():
print(_(f'Build OS {application}'))
await image.build_os(application,
dependencies_path,
)
print(_('Build images'))
for application, dependencies_path in self.list_images():
print(_(f'Build image {application}'))
await image.build_image(application,
dependencies_path,
)
class Image:
def __init__(self,
image_dir: str,
tmp_dir: str,
):
self.image_dir = image_dir
self.tmp_dir = tmp_dir
@staticmethod
def copy_files(dependencies_path: list,
dst_path: str,
element: str,
) -> None:
for dependency_path in dependencies_path:
src_path = join(dependency_path,
'packer',
element,
)
root_len = len(src_path) + 1
for dir_name, subdir_names, filenames in walk(src_path):
subdir = join(dst_path, dir_name[root_len:])
if not isdir(subdir):
makedirs(subdir)
for filename in filenames:
path = join(dir_name, filename)
sub_dst_path = join(subdir, filename)
if isfile(sub_dst_path):
raise Exception(_(f'Try to copy {sub_dst_path} which is already exists'))
copy2(path, sub_dst_path)
async def load_configuration(self,
dependencies_path: list,
packer_tmp_directory: str,
) -> dict:
config = RougailConfig.copy()
dictionaries = [join(dependency_path, 'dictionaries') for dependency_path in dependencies_path if isdir(join(dependency_path, 'dictionaries'))]
upgrade = RougailUpgrade()
dest_dictionaries = join(packer_tmp_directory, 'dictionaries')
makedirs(dest_dictionaries)
dest_dictionaries_extras = join(packer_tmp_directory, 'dictionaries_extras')
makedirs(dest_dictionaries_extras)
for dependency_path in dependencies_path:
dictionaries_dir = join(dependency_path, 'dictionaries')
if isdir(dictionaries_dir):
upgrade.load_xml_from_folders(dictionaries_dir,
dest_dictionaries,
RougailConfig['variable_namespace'],
)
extra_dir = join(dependency_path, 'extras', 'packer')
if isdir(extra_dir):
upgrade.load_xml_from_folders(extra_dir,
dest_dictionaries_extras,
'packer',
)
config['dictionaries_dir'] = [dest_dictionaries]
config['extra_dictionaries'] = {'packer': [dest_dictionaries_extras]}
self.merge_funcs(config, dependencies_path, packer_tmp_directory)
packer_configuration = await self.get_packer_information(config, packer_tmp_directory)
return packer_configuration
@staticmethod
def merge_funcs(config: RougailConfig,
dependencies_path: list,
packer_tmp_directory: str,
):
functions = FUNCTIONS
for dependency_path in dependencies_path:
funcs_dir = join(dependency_path, 'funcs')
if not isdir(funcs_dir):
continue
for func in listdir(funcs_dir):
with open(join(funcs_dir, func), 'rb') as fh:
functions += fh.read()
func_name = join(packer_tmp_directory, 'func.py')
with open(func_name, 'wb') as fh:
fh.write(functions)
config['functions_file'] = func_name
@staticmethod
async def get_packer_information(config: RougailConfig,
packer_tmp_directory: str,
) -> dict:
eolobj = RougailConvert(config)
xml = eolobj.save(join(packer_tmp_directory, 'tiramisu.py'))
optiondescription = {}
exec(xml, None, optiondescription)
config = await Config(optiondescription['option_0'])
return await config.option('packer').value.dict(leader_to_list=True, flatten=True)
@staticmethod
def do_recipe_checksum(path: str,
) -> str:
files = []
root_len = len(path) + 1
for dir_name, subdir_names, filenames in walk(path):
subpath = dir_name[root_len:]
for filename in filenames:
with open(join(dir_name, filename), 'rb') as fh:
ctl_sum = sha512(fh.read()).hexdigest()
abs_path = join(subpath, filename)
files.append(f'{abs_path}/{ctl_sum}')
files.sort()
print(files, sha512('\n'.join(files).encode()).hexdigest())
return sha512('\n'.join(files).encode()).hexdigest()
def get_tmp_directory(self,
application: str,
) -> str:
return join(self.tmp_dir,
application + '_' + str(time()),
)
def get_os_filename(self,
packer_configuration: dict,
) -> str:
return join(self.image_dir,
'os',
packer_configuration['os_name'] + '_' + packer_configuration['os_version'] + '.img',
)
def get_image_filename(self,
recipe_checksum: str,
) -> str:
return join(self.image_dir,
f'{recipe_checksum}.img',
)
async def build_os(self,
application: str,
dependencies_path: list,
) -> None:
packer_tmp_directory = self.get_tmp_directory(application)
packer_configuration = await self.load_configuration(dependencies_path, packer_tmp_directory)
packer_dst_os_filename = self.get_os_filename(packer_configuration)
self.copy_files(dependencies_path,
packer_tmp_directory,
'os',
)
packer_configuration['tmp_directory'] = packer_tmp_directory
recipe = {'variables': packer_configuration}
self.build(packer_dst_os_filename,
packer_tmp_directory,
recipe,
)
async def build_image(self,
application: str,
dependencies_path: list,
) -> None:
packer_tmp_directory = self.get_tmp_directory(application)
makedirs(packer_tmp_directory)
self.copy_files(dependencies_path,
packer_tmp_directory,
'image',
)
recipe_checksum = self.do_recipe_checksum(packer_tmp_directory)
packer_dst_filename = self.get_image_filename(recipe_checksum)
packer_configuration = await self.load_configuration(dependencies_path, packer_tmp_directory)
packer_dst_os_filename = join(self.image_dir,
'os',
packer_configuration['os_name'] + '_' + packer_configuration['os_version'] + '.img',
)
packer_configuration['tmp_directory'] = packer_tmp_directory
recipe = {'variables': packer_configuration}
recipe['variables']['iso_url'] = packer_dst_os_filename
self.build(packer_dst_filename,
packer_tmp_directory,
recipe,
f'{packer_dst_os_filename}.sha256',
)
@staticmethod
def build(packer_dst_filename: str,
tmp_directory: str,
recipe: dict,
sha_file: str=None,
) -> None:
packer_filename = join(tmp_directory, PACKER_FILE_NAME)
if sha_file is not None:
with open(sha_file, 'r') as fh:
sha256 = fh.read().split(' ', 1)[0]
recipe['variables']['iso_checksum'] = sha256
with open(packer_filename, 'r') as recipe_fd:
for key, value in jload(recipe_fd).items():
recipe[key] = value
with open(packer_filename, 'w') as recipe_fd:
jdump(recipe, recipe_fd, indent=2)
preprocessors = join(tmp_directory, 'preprocessors')
if isfile(preprocessors):
proc = Popen([preprocessors],
#stdout=PIPE,
#stderr=PIPE,
cwd=tmp_directory,
)
proc.wait()
if proc.returncode:
raise Exception(_(f'error when executing {preprocessors}'))
proc = Popen(['packer', 'build', packer_filename],
#stdout=PIPE,
#stderr=PIPE,
cwd=tmp_directory,
)
proc.wait()
if proc.returncode:
raise Exception(_(f'cannot build {packer_dst_filename} with {packer_filename}'))
if not isdir(dirname(packer_dst_filename)):
makedirs(dirname(packer_dst_filename))
move(join(tmp_directory, 'image.img'), packer_dst_filename)
move(join(tmp_directory, 'image.sha256'), f'{packer_dst_filename}.sha256')
print(_(f'Image {packer_dst_filename} created'))
rmtree(tmp_directory)

View File

@ -1,38 +1,87 @@
from typing import Dict, Any from typing import Dict, Any, Optional
from json import dumps from json import dumps, loads
from asyncpg.exceptions import UndefinedTableError from asyncpg.exceptions import UndefinedTableError
from datetime import datetime
from asyncio import Lock
from .context import Context from .context import Context
from .utils import _ from .utils import _
from .config import get_config from .config import get_config
database_lock = Lock()
LEVELS = ['Error', 'Info', 'Success', 'Started', 'Failure']
class Logger: class Logger:
""" An object to manager log """ An object to manager log
""" """
def __init__(self) -> None:
self.log_connection = None
async def get_connection(self,
risotto_context: Context,
):
if not self.log_connection:
self.log_connection = await risotto_context.pool.acquire()
await self.log_connection.set_type_codec(
'json',
encoder=dumps,
decoder=loads,
schema='pg_catalog'
)
return self.log_connection
async def insert(self, async def insert(self,
msg: str, msg: str,
path: str, risotto_context: Context,
risotto_context: str,
level: str, level: str,
data: Any= None) -> None: kwargs: Any=None,
insert = 'INSERT INTO log(Msg, Path, Username, Level' start: bool=False,
values = 'VALUES($1,$2,$3,$4' ) -> None:
args = [msg, path, risotto_context.username, level] uri = self._get_last_uri(risotto_context)
if data: uris = " ".join(risotto_context.paths)
insert += ', Data' insert = 'INSERT INTO RisottoLog(Msg, URI, URIS, UserLogin, Status'
values += ',$5' values = 'VALUES($1,$2,$3,$4,$5'
args.append(dumps(data)) args = [msg, uri, uris, risotto_context.username, LEVELS.index(level)]
if kwargs:
insert += ', Kwargs'
values += ',$6'
args.append(dumps(kwargs))
context_id = risotto_context.context_id
if context_id is not None:
insert += ', ContextId'
if kwargs:
values += ',$7'
else:
values += ',$6'
args.append(context_id)
sql = insert + ') ' + values + ')' sql = insert + ') ' + values + ') RETURNING LogId'
try: try:
await risotto_context.connection.fetch(sql, *args) async with database_lock:
connection = await self.get_connection(risotto_context)
log_id = await connection.fetchval(sql, *args)
if context_id is None and start:
risotto_context.context_id = log_id
if start:
risotto_context.start_id = log_id
except UndefinedTableError as err: except UndefinedTableError as err:
raise Exception(_(f'cannot access to database ({err}), was the database really created?')) raise Exception(_(f'cannot access to database ({err}), was the database really created?'))
def _get_last_uri(self,
risotto_context: Context,
) -> str:
if risotto_context.paths:
return risotto_context.paths[-1]
return ''
def _get_message_paths(self, def _get_message_paths(self,
risotto_context: Context): risotto_context: Context,
) -> str:
if not risotto_context.paths:
return ''
paths = risotto_context.paths paths = risotto_context.paths
if risotto_context.type: if risotto_context.type:
paths_msg = f' {risotto_context.type} ' paths_msg = f' {risotto_context.type} '
@ -49,44 +98,114 @@ class Logger:
risotto_context: Context, risotto_context: Context,
arguments, arguments,
error: str, error: str,
msg: str=''): msg: str='',
):
""" send message when an error append """ send message when an error append
""" """
paths_msg = self._get_message_paths(risotto_context) paths_msg = self._get_message_paths(risotto_context)
print(_(f'{risotto_context.username}: ERROR: {error} ({paths_msg} with arguments "{arguments}": {msg})')) print(_(f'{risotto_context.username}: ERROR: {error} ({paths_msg} with arguments "{arguments}": {msg})'))
await self.insert(msg, await self.insert(msg,
paths_msg,
risotto_context, risotto_context,
'Error', 'Error',
arguments) arguments,
)
async def info_msg(self, async def info_msg(self,
risotto_context: Context, risotto_context: Context,
arguments: Dict, arguments: Dict,
msg: str=''): msg: str='',
) -> None:
""" send message with common information """ send message with common information
""" """
if risotto_context.paths: paths_msg = self._get_message_paths(risotto_context)
paths_msg = self._get_message_paths(risotto_context)
else:
paths_msg = ''
if get_config()['global']['debug']: if get_config()['global']['debug']:
print(_(f'{risotto_context.username}: INFO:{paths_msg}: {msg}')) print(_(f'{risotto_context.username}: INFO:{paths_msg}: {msg}'))
await self.insert(msg, await self.insert(msg,
paths_msg,
risotto_context, risotto_context,
'Info', 'Info',
arguments) arguments,
)
async def start(self,
risotto_context: Context,
arguments: dict,
msg: str,
) -> None:
paths_msg = self._get_message_paths(risotto_context)
if get_config()['global']['debug']:
if risotto_context.context_id != None:
context = f'({risotto_context.context_id})'
else:
context = ''
print(_(f'{risotto_context.username}: START{context}:{paths_msg}: {msg}'))
await self.insert(msg,
risotto_context,
'Started',
arguments,
start=True,
)
async def success(self,
risotto_context: Context,
returns: Optional[dict]=None,
) -> None:
if get_config()['global']['debug']:
paths_msg = self._get_message_paths(risotto_context)
print(_(f'{risotto_context.username}: SUCCESS({risotto_context.context_id}):{paths_msg}'))
sql = """UPDATE RisottoLog
SET StopDate = $2,
Status = $3
"""
args = [datetime.now(), LEVELS.index('Success')]
if returns:
sql += """, Returns = $4
"""
args.append(dumps(returns))
sql += """WHERE LogId = $1
"""
async with database_lock:
connection = await self.get_connection(risotto_context)
await connection.execute(sql,
risotto_context.start_id,
*args,
)
async def failed(self,
risotto_context: Context,
err: str,
) -> None:
if get_config()['global']['debug']:
paths_msg = self._get_message_paths(risotto_context)
if risotto_context.context_id != None:
context = f'({risotto_context.context_id})'
else:
context = ''
print(_(f'{risotto_context.username}: FAILED({risotto_context.context_id}):{paths_msg}: {err}'))
sql = """UPDATE RisottoLog
SET StopDate = $2,
Status = $4,
Msg = $3
WHERE LogId = $1
"""
async with database_lock:
connection = await self.get_connection(risotto_context)
await connection.execute(sql,
risotto_context.start_id,
datetime.now(),
err[:254],
LEVELS.index('Failure'),
)
async def info(self, async def info(self,
risotto_context, risotto_context,
msg): msg,
):
if get_config()['global']['debug']: if get_config()['global']['debug']:
print(msg) print(msg)
await self.insert(msg, await self.insert(msg,
None,
risotto_context, risotto_context,
'Info') 'Info',
)
log = Logger() log = Logger()

View File

@ -2,9 +2,14 @@ from os import listdir
from os.path import join, basename, dirname, isfile from os.path import join, basename, dirname, isfile
from glob import glob from glob import glob
from gettext import translation from gettext import translation
from tiramisu import StrOption, IntOption, BoolOption, ChoiceOption, OptionDescription, SymLinkOption, FloatOption, \ try:
Calculation, Params, ParamOption, ParamValue, calc_value, calc_value_property_help, \ from tiramisu3 import StrOption, IntOption, BoolOption, ChoiceOption, OptionDescription, \
groups, Option SymLinkOption, FloatOption, Calculation, Params, ParamOption, \
ParamValue, calc_value, calc_value_property_help, groups, Option
except:
from tiramisu import StrOption, IntOption, BoolOption, ChoiceOption, OptionDescription, \
SymLinkOption, FloatOption, Calculation, Params, ParamOption, \
ParamValue, calc_value, calc_value_property_help, groups, Option
from yaml import load, SafeLoader from yaml import load, SafeLoader
@ -14,8 +19,8 @@ from .utils import _
MESSAGE_ROOT_PATH = get_config()['global']['message_root_path'] MESSAGE_ROOT_PATH = get_config()['global']['message_root_path']
groups.addgroup('message') groups.addgroup('message')
MESSAGE_TRANSLATION = translation('risotto-message', join(MESSAGE_ROOT_PATH, '..', 'locale')).gettext CUSTOMTYPES = None
MESSAGE_TRANSLATION = None
class DictOption(Option): class DictOption(Option):
@ -243,7 +248,8 @@ def get_message_file_path(version,
def list_messages(uris, def list_messages(uris,
current_module_names, current_module_names,
current_version): current_version,
):
def get_module_paths(current_module_names): def get_module_paths(current_module_names):
if current_module_names is None: if current_module_names is None:
current_module_names = listdir(join(MESSAGE_ROOT_PATH, version)) current_module_names = listdir(join(MESSAGE_ROOT_PATH, version))
@ -307,6 +313,7 @@ class CustomParam:
'string': 'String', 'string': 'String',
'number': 'Number', 'number': 'Number',
'object': 'Dict', 'object': 'Dict',
'any': 'Any',
'array': 'Array', 'array': 'Array',
'file': 'File', 'file': 'File',
'float': 'Float'} 'float': 'Float'}
@ -407,7 +414,7 @@ def load_customtypes() -> None:
custom_type = CustomType(load(message_file, Loader=SafeLoader)) custom_type = CustomType(load(message_file, Loader=SafeLoader))
ret[version][custom_type.getname()] = custom_type ret[version][custom_type.getname()] = custom_type
except Exception as err: except Exception as err:
raise Exception(_(f'enable to load type {err}: {message}')) raise Exception(_(f'enable to load type "{message}": {err}'))
return ret return ret
@ -426,9 +433,9 @@ def _get_description(description,
def _get_option(name, def _get_option(name,
arg, arg,
file_path, uri,
select_option, select_option,
optiondescription): ):
"""generate option """generate option
""" """
props = [] props = []
@ -438,10 +445,11 @@ def _get_option(name,
props.append(Calculation(calc_value, props.append(Calculation(calc_value,
Params(ParamValue('disabled'), Params(ParamValue('disabled'),
kwargs={'condition': ParamOption(select_option, todict=True), kwargs={'condition': ParamOption(select_option, todict=True),
'expected': ParamValue(optiondescription), 'expected': ParamValue(uri),
'reverse_condition': ParamValue(True)}), 'reverse_condition': ParamValue(True)}),
calc_value_property_help)) calc_value_property_help))
props.append('notunique')
description = arg.description.strip().rstrip() description = arg.description.strip().rstrip()
kwargs = {'name': name, kwargs = {'name': name,
'doc': _get_description(description, name), 'doc': _get_description(description, name),
@ -467,25 +475,25 @@ def _get_option(name,
elif type_ == 'Float': elif type_ == 'Float':
obj = FloatOption(**kwargs) obj = FloatOption(**kwargs)
else: else:
raise Exception('unsupported type {} in {}'.format(type_, file_path)) raise Exception('unsupported type {} in {}'.format(type_, uri))
obj.impl_set_information('ref', arg.ref) obj.impl_set_information('ref', arg.ref)
return obj return obj
def get_options(message_def, def get_options(message_def,
file_path, uri,
select_option, select_option,
optiondescription, load_shortarg,
load_shortarg): ):
"""build option with args/kwargs """build option with args/kwargs
""" """
options =[] options =[]
for name, arg in message_def.parameters.items(): for name, arg in message_def.parameters.items():
current_opt = _get_option(name, current_opt = _get_option(name,
arg, arg,
file_path, uri,
select_option, select_option,
optiondescription) )
options.append(current_opt) options.append(current_opt)
if hasattr(arg, 'shortarg') and arg.shortarg and load_shortarg: if hasattr(arg, 'shortarg') and arg.shortarg and load_shortarg:
options.append(SymLinkOption(arg.shortarg, current_opt)) options.append(SymLinkOption(arg.shortarg, current_opt))
@ -493,17 +501,18 @@ def get_options(message_def,
def _parse_responses(message_def, def _parse_responses(message_def,
file_path): uri,
):
"""build option with returns """build option with returns
""" """
if message_def.response.parameters is None: if message_def.response.parameters is None:
raise Exception('message "{}" did not returned any valid parameters.'.format(message_def.message)) raise Exception(f'message "{message_def.message}" did not returned any valid parameters')
options = [] options = []
names = [] names = []
for name, obj in message_def.response.parameters.items(): for name, obj in message_def.response.parameters.items():
if name in names: if name in names:
raise Exception('multi response with name {} in {}'.format(name, file_path)) raise Exception(f'multi response with name "{name}" in "{uri}"')
names.append(name) names.append(name)
kwargs = {'name': name, kwargs = {'name': name,
@ -516,6 +525,7 @@ def _parse_responses(message_def,
'Number': IntOption, 'Number': IntOption,
'Boolean': BoolOption, 'Boolean': BoolOption,
'Dict': DictOption, 'Dict': DictOption,
'Any': AnyOption,
'Float': FloatOption, 'Float': FloatOption,
# FIXME # FIXME
'File': StrOption}.get(type_) 'File': StrOption}.get(type_)
@ -523,18 +533,21 @@ def _parse_responses(message_def,
raise Exception(f'unknown param type {obj.type} in responses of message {message_def.message}') raise Exception(f'unknown param type {obj.type} in responses of message {message_def.message}')
if hasattr(obj, 'default'): if hasattr(obj, 'default'):
kwargs['default'] = obj.default kwargs['default'] = obj.default
kwargs['properties'] = ('notunique',)
else: else:
kwargs['properties'] = ('mandatory',) kwargs['properties'] = ('mandatory', 'notunique')
options.append(option(**kwargs)) options.append(option(**kwargs))
od = OptionDescription(message_def.message, od = OptionDescription(uri,
message_def.response.description, message_def.response.description,
options) options,
)
od.impl_set_information('multi', message_def.response.multi) od.impl_set_information('multi', message_def.response.multi)
return od return od
def _get_root_option(select_option, def _get_root_option(select_option,
optiondescriptions): optiondescriptions,
):
"""get root option """get root option
""" """
def _get_od(curr_ods): def _get_od(curr_ods):
@ -576,44 +589,51 @@ def _get_root_option(select_option,
def get_messages(current_module_names, def get_messages(current_module_names,
load_shortarg=False, load_shortarg=False,
current_version=None, current_version=None,
uris=None): uris=None,
):
"""generate description from yml files """generate description from yml files
""" """
global MESSAGE_TRANSLATION, CUSTOMTYPES
if MESSAGE_TRANSLATION is None:
MESSAGE_TRANSLATION = translation('risotto-message', join(MESSAGE_ROOT_PATH, '..', 'locale')).gettext
if CUSTOMTYPES is None:
CUSTOMTYPES = load_customtypes()
optiondescriptions = {} optiondescriptions = {}
optiondescriptions_info = {} optiondescriptions_info = {}
messages = list(list_messages(uris, messages = list(list_messages(uris,
current_module_names, current_module_names,
current_version)) current_version,
))
messages.sort() messages.sort()
optiondescriptions_name = [message_name.split('.', 1)[1] for message_name in messages] # optiondescriptions_name = [message_name.split('.', 1)[1] for message_name in messages]
select_option = ChoiceOption('message', select_option = ChoiceOption('message',
'Nom du message.', 'Nom du message.',
tuple(optiondescriptions_name), tuple(messages),
properties=frozenset(['mandatory', 'positional'])) properties=frozenset(['mandatory', 'positional', 'notunique']))
for uri in messages: for uri in messages:
message_def = get_message(uri, message_def = get_message(uri,
current_module_names, current_module_names,
) )
optiondescriptions_info[message_def.message] = {'pattern': message_def.pattern, optiondescriptions_info[message_def.message] = {'pattern': message_def.pattern,
'default_roles': message_def.default_roles, 'default_roles': message_def.default_roles,
'version': message_def.version} 'version': message_def.version,
}
if message_def.pattern == 'rpc': if message_def.pattern == 'rpc':
if not message_def.response: if not message_def.response:
raise Exception(f'rpc without response is not allowed {uri}') raise Exception(f'rpc without response is not allowed {uri}')
optiondescriptions_info[message_def.message]['response'] = _parse_responses(message_def, optiondescriptions_info[message_def.message]['response'] = _parse_responses(message_def,
uri) uri,
)
elif message_def.response: elif message_def.response:
raise Exception(f'response is not allowed for {uri}') raise Exception(f'response is not allowed for {uri}')
message_def.options = get_options(message_def, message_def.options = get_options(message_def,
uri, uri,
select_option, select_option,
message_def.message, load_shortarg,
load_shortarg) )
optiondescriptions[message_def.message] = (message_def.description, message_def.options) optiondescriptions[uri] = (message_def.description, message_def.options)
root = _get_root_option(select_option, root = _get_root_option(select_option,
optiondescriptions) optiondescriptions,
)
return optiondescriptions_info, root return optiondescriptions_info, root
CUSTOMTYPES = load_customtypes()

View File

@ -1,8 +1,13 @@
from tiramisu import Config try:
from tiramisu3 import Config
except:
from tiramisu import Config
from inspect import signature from inspect import signature
from typing import Callable, Optional from typing import Callable, Optional, List
import asyncpg from asyncpg import create_pool
from json import dumps, loads from json import dumps, loads
from pkg_resources import iter_entry_points
from traceback import print_exc
import risotto import risotto
from .utils import _ from .utils import _
from .error import RegistrationError from .error import RegistrationError
@ -10,7 +15,7 @@ from .message import get_messages
from .context import Context from .context import Context
from .config import get_config from .config import get_config
from .logger import log from .logger import log
from pkg_resources import iter_entry_points
class Services(): class Services():
services = {} services = {}
@ -19,25 +24,29 @@ class Services():
def load_services(self): def load_services(self):
for entry_point in iter_entry_points(group='risotto_services'): for entry_point in iter_entry_points(group='risotto_services'):
self.services.setdefault(entry_point.name, []) self.services.setdefault(entry_point.name, {})
self.services_loaded = True self.services_loaded = True
def load_modules(self): def load_modules(self,
limit_services: Optional[List[str]]=None,
) -> None:
for entry_point in iter_entry_points(group='risotto_modules'): for entry_point in iter_entry_points(group='risotto_modules'):
service_name, module_name = entry_point.name.split('.') service_name, module_name = entry_point.name.split('.')
setattr(self, module_name, entry_point.load()) if limit_services is None or service_name in limit_services:
self.services[service_name].append(module_name) self.services[service_name][module_name] = entry_point.load()
self.modules_loaded = True self.modules_loaded = True
#
# def get_services(self):
# if not self.services_loaded:
# self.load_services()
# return [(service, getattr(self, service)) for service in self.services]
def get_services(self): def get_modules(self,
if not self.services_loaded: limit_services: Optional[List[str]]=None,
self.load_services() ) -> List[str]:
return [(s, getattr(self, s)) for s in self.services]
def get_modules(self):
if not self.modules_loaded: if not self.modules_loaded:
self.load_modules() self.load_modules(limit_services=limit_services)
return [(m, getattr(self, m)) for s in self.services.values() for m in s] return [(module + '.' + submodule, entry_point) for module, submodules in self.services.items() for submodule, entry_point in submodules.items()]
def get_services_list(self): def get_services_list(self):
return self.services.keys() return self.services.keys()
@ -49,11 +58,13 @@ class Services():
dispatcher, dispatcher,
validate: bool=True, validate: bool=True,
test: bool=False, test: bool=False,
limit_services: Optional[List[str]]=None,
): ):
for module_name, module in self.get_modules(): for submodule_name, module in self.get_modules(limit_services=limit_services):
dispatcher.set_module(module_name, dispatcher.set_module(submodule_name,
module, module,
test) test,
)
if validate: if validate:
dispatcher.validate() dispatcher.validate()
@ -62,20 +73,26 @@ services = Services()
services.load_services() services.load_services()
setattr(risotto, 'services', services) setattr(risotto, 'services', services)
def register(uris: str, def register(uris: str,
notification: str=None): notification: str=None,
) -> None:
""" Decorator to register function to the dispatcher """ Decorator to register function to the dispatcher
""" """
if not isinstance(uris, list): if not isinstance(uris, list):
uris = [uris] uris = [uris]
def decorator(function): def decorator(function):
for uri in uris: try:
version, message = uri.split('.', 1) for uri in uris:
dispatcher.set_function(version, dispatcher.set_function(uri,
message, notification,
notification, function,
function) function.__module__
)
except NameError:
# if you when register uri, please use get_dispatcher before registered uri
pass
return decorator return decorator
@ -94,6 +111,7 @@ class RegisterDispatcher:
version = obj['version'] version = obj['version']
if version not in self.messages: if version not in self.messages:
self.messages[version] = {} self.messages[version] = {}
obj['message'] = tiramisu_message
self.messages[version][tiramisu_message] = obj self.messages[version][tiramisu_message] = obj
def get_function_args(self, def get_function_args(self,
@ -103,29 +121,38 @@ class RegisterDispatcher:
return {param.name for param in list(signature(function).parameters.values())[first_argument_index:]} return {param.name for param in list(signature(function).parameters.values())[first_argument_index:]}
async def get_message_args(self, async def get_message_args(self,
message: str): message: str,
version: str,
):
# load config # load config
async with await Config(self.option) as config: async with await Config(self.option, display_name=lambda self, dyn_name, suffix: self.impl_getname()) as config:
uri = f'{version}.{message}'
await config.property.read_write() await config.property.read_write()
# set message to the message name # set message to the message name
await config.option('message').value.set(message) await config.option('message').value.set(uri)
# get message argument # get message argument
dico = await config.option(message).value.dict() dico = await config.option(uri).value.dict()
return set(dico.keys()) return set(dico.keys())
async def valid_rpc_params(self, async def valid_rpc_params(self,
version: str, version: str,
message: str, message: str,
function: Callable, function: Callable,
module_name: str): module_name: str,
):
""" parameters function must have strictly all arguments with the correct name """ parameters function must have strictly all arguments with the correct name
""" """
# get message arguments # get message arguments
message_args = await self.get_message_args(message) message_args = await self.get_message_args(message,
version,
)
# get function arguments # get function arguments
function_args = self.get_function_args(function) function_args = self.get_function_args(function)
# compare message arguments with function parameter # compare message arguments with function parameter
# it must not have more or less arguments # it must not have more or less arguments
for arg in function_args - message_args:
if arg.startswith('_'):
message_args.add(arg)
if message_args != function_args: if message_args != function_args:
# raise if arguments are not equal # raise if arguments are not equal
msg = [] msg = []
@ -143,11 +170,14 @@ class RegisterDispatcher:
version: str, version: str,
message: str, message: str,
function: Callable, function: Callable,
module_name: str): module_name: str,
):
""" parameters function validation for event messages """ parameters function validation for event messages
""" """
# get message arguments # get message arguments
message_args = await self.get_message_args(message) message_args = await self.get_message_args(message,
version,
)
# get function arguments # get function arguments
function_args = self.get_function_args(function) function_args = self.get_function_args(function)
# compare message arguments with function parameter # compare message arguments with function parameter
@ -160,33 +190,36 @@ class RegisterDispatcher:
raise RegistrationError(_(f'error with {module_name}.{function_name} arguments: {msg}')) raise RegistrationError(_(f'error with {module_name}.{function_name} arguments: {msg}'))
def set_function(self, def set_function(self,
version: str, uri: str,
message: str,
notification: str, notification: str,
function: Callable): function: Callable,
full_module_name: str,
):
""" register a function to an URI """ register a function to an URI
URI is a message URI is a message
""" """
version, message = uri.split('.', 1)
# check if message exists # check if message exists
if message not in self.messages[version]: if message not in self.messages[version]:
raise RegistrationError(_(f'the message {message} not exists')) raise RegistrationError(_(f'the message {message} not exists'))
# xxx module can only be register with v1.xxxx..... message # xxx submodule can only be register with v1.yyy.xxx..... message
module_name = function.__module__.split('.')[-2] risotto_module_name, submodule_name = full_module_name.split('.')[-3:-1]
message_namespace = message.split('.', 1)[0] module_name = risotto_module_name.split('_')[-1]
message_risotto_module, message_namespace, message_name = message.split('.', 2) message_module, message_submodule, message_name = message.split('.', 2)
if message_risotto_module not in self.risotto_modules: if message_module not in self.risotto_modules:
raise RegistrationError(_(f'cannot registered the "{message}" is not "{self.risotto_modules}"')) raise RegistrationError(_(f'cannot registered the "{message}" is not "{self.risotto_modules}"'))
if self.messages[version][message]['pattern'] == 'rpc' and message_namespace != module_name: if self.messages[version][message]['pattern'] == 'rpc' and \
raise RegistrationError(_(f'cannot registered the "{message}" message in module "{module_name}"')) module_name != message_module and \
message_submodule != submodule_name:
raise RegistrationError(_(f'cannot registered the "{message}" message in submodule "{module_name}.{submodule_name}"'))
# True if first argument is the risotto_context # True if first argument is the risotto_context
function_args = self.get_function_args(function) function_args = self.get_function_args(function)
# check if already register # check if already register
if 'function' in self.messages[version][message]: if 'function' in self.messages[version][message]:
raise RegistrationError(_(f'uri {version}.{message} already registered')) raise RegistrationError(_(f'uri {uri} already registered'))
# register # register
if self.messages[version][message]['pattern'] == 'rpc': if self.messages[version][message]['pattern'] == 'rpc':
@ -195,19 +228,24 @@ class RegisterDispatcher:
register = self.register_event register = self.register_event
register(version, register(version,
message, message,
module_name, f'{module_name}.{submodule_name}',
full_module_name,
function, function,
function_args, function_args,
notification) notification,
)
def register_rpc(self, def register_rpc(self,
version: str, version: str,
message: str, message: str,
module_name: str, module_name: str,
full_module_name: str,
function: Callable, function: Callable,
function_args: list, function_args: list,
notification: Optional[str]): notification: Optional[str],
):
self.messages[version][message]['module'] = module_name self.messages[version][message]['module'] = module_name
self.messages[version][message]['full_module_name'] = full_module_name
self.messages[version][message]['function'] = function self.messages[version][message]['function'] = function
self.messages[version][message]['arguments'] = function_args self.messages[version][message]['arguments'] = function_args
if notification: if notification:
@ -217,26 +255,34 @@ class RegisterDispatcher:
version: str, version: str,
message: str, message: str,
module_name: str, module_name: str,
full_module_name: str,
function: Callable, function: Callable,
function_args: list, function_args: list,
notification: Optional[str]): notification: Optional[str],
):
if 'functions' not in self.messages[version][message]: if 'functions' not in self.messages[version][message]:
self.messages[version][message]['functions'] = [] self.messages[version][message]['functions'] = []
dico = {'module': module_name, dico = {'module': module_name,
'full_module_name': full_module_name,
'function': function, 'function': function,
'arguments': function_args} 'arguments': function_args,
}
if notification and notification: if notification and notification:
dico['notification'] = notification dico['notification'] = notification
self.messages[version][message]['functions'].append(dico) self.messages[version][message]['functions'].append(dico)
def set_module(self, module_name, module, test): def set_module(self,
submodule_name,
module,
test,
):
""" register and instanciate a new module """ register and instanciate a new module
""" """
try: try:
self.injected_self[module_name] = module.Risotto(test) self.injected_self[submodule_name] = module.Risotto(test)
except AttributeError as err: except AttributeError as err:
raise RegistrationError(_(f'unable to register the module {module_name}, this module must have Risotto class')) print(_(f'unable to register the module {submodule_name}, this module must have Risotto class'))
def validate(self): def validate(self):
""" check if all messages have a function """ check if all messages have a function
@ -246,13 +292,15 @@ class RegisterDispatcher:
for message, message_obj in messages.items(): for message, message_obj in messages.items():
if not 'functions' in message_obj and not 'function' in message_obj: if not 'functions' in message_obj and not 'function' in message_obj:
if message_obj['pattern'] == 'event': if message_obj['pattern'] == 'event':
print(f'{message} prêche dans le désert') print(f'{version}.{message} prêche dans le désert')
else: else:
missing_messages.append(message) missing_messages.append(f'{version}.{message}')
if missing_messages: if missing_messages:
raise RegistrationError(_(f'no matching function for uri {missing_messages}')) raise RegistrationError(_(f'no matching function for uri {missing_messages}'))
async def on_join(self): async def on_join(self,
truncate: bool=False,
) -> None:
internal_user = get_config()['global']['internal_user'] internal_user = get_config()['global']['internal_user']
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
await connection.set_type_codec( await connection.set_type_codec(
@ -261,34 +309,47 @@ class RegisterDispatcher:
decoder=loads, decoder=loads,
schema='pg_catalog' schema='pg_catalog'
) )
if truncate:
async with connection.transaction():
await connection.execute('TRUNCATE InfraServer, InfraSite, InfraZone, Log, ProviderDeployment, ProviderFactoryCluster, ProviderFactoryClusterNode, SettingApplicationservice, SettingApplicationServiceDependency, SettingRelease, SettingServer, SettingServermodel, SettingSource, UserRole, UserRoleURI, UserURI, UserUser, InfraServermodel, ProviderZone, ProviderServer, ProviderSource, ProviderApplicationservice, ProviderServermodel')
async with connection.transaction(): async with connection.transaction():
for module_name, module in self.injected_self.items(): for submodule_name, module in self.injected_self.items():
risotto_context = Context() risotto_context = Context()
risotto_context.username = internal_user risotto_context.username = internal_user
risotto_context.paths.append(f'{module_name}.on_join') risotto_context.paths.append(f'internal.{submodule_name}.on_join')
risotto_context.type = None risotto_context.type = None
risotto_context.pool = self.pool
risotto_context.connection = connection risotto_context.connection = connection
info_msg = _(f'in module {module_name}.on_join') risotto_context.module = submodule_name.split('.', 1)[0]
info_msg = _(f'in function risotto_{submodule_name}.on_join')
await log.info_msg(risotto_context, await log.info_msg(risotto_context,
None, None,
info_msg) info_msg)
await module.on_join(risotto_context) try:
await module.on_join(risotto_context)
except Exception as err:
if get_config()['global']['debug']:
print_exc()
msg = _(f'on_join returns an error in module {submodule_name}: {err}')
await log.error_msg(risotto_context, {}, msg)
async def load(self): async def load(self):
# valid function's arguments # valid function's arguments
db_conf = get_config()['database']['dsn'] db_conf = get_config()['database']['dsn']
self.pool = await asyncpg.create_pool(db_conf) self.pool = await create_pool(db_conf)
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
async with connection.transaction(): async with connection.transaction():
for version, messages in self.messages.items(): for version, messages in self.messages.items():
for message, message_infos in messages.items(): for message, message_infos in messages.items():
if message_infos['pattern'] == 'rpc': if message_infos['pattern'] == 'rpc':
module_name = message_infos['module'] # module not available during test
function = message_infos['function'] if 'module' in message_infos:
await self.valid_rpc_params(version, module_name = message_infos['module']
message, function = message_infos['function']
function, await self.valid_rpc_params(version,
module_name) message,
function,
module_name)
elif 'functions' in message_infos: elif 'functions' in message_infos:
# event with functions # event with functions
for function_infos in message_infos['functions']: for function_infos in message_infos['functions']:

View File

@ -1,61 +0,0 @@
from aiohttp import ClientSession
from requests import get, post
from json import dumps
#from tiramisu_api import Config
from .config import get_config
from .utils import _
#
#
# ALLOW_INSECURE_HTTPS = get_config()['module']['allow_insecure_https']
class Remote:
submodules = {}
async def _get_config(self,
module: str,
url: str) -> None:
if module not in self.submodules:
session = ClientSession()
async with session.get(url) as resp:
if resp.status != 200:
try:
json = await resp.json()
err = json['error']['kwargs']['reason']
except:
err = await resp.text()
raise Exception(err)
json = await resp.json()
self.submodules[module] = json
return Config(self.submodules[module])
async def remove_call(self,
module: str,
version: str,
submessage: str,
payload) -> dict:
try:
domain_name = get_config()['module'][module]
except KeyError:
raise ValueError(_(f'cannot find information of remote module "{module}" to access to "{version}.{module}.{submessage}"'))
remote_url = f'http://{domain_name}:8080/api/{version}'
message_url = f'{remote_url}/{submessage}'
config = await self._get_config(module,
remote_url)
for key, value in payload.items():
path = submessage + '.' + key
config.option(path).value.set(value)
session = ClientSession()
async with session.post(message_url, data=dumps(payload)) as resp:
response = await resp.json()
if 'error' in response:
if 'reason' in response['error']['kwargs']:
raise Exception("{}".format(response['error']['kwargs']['reason']))
raise Exception('erreur inconnue')
return response['response']
remote = Remote()

View File

@ -1,9 +1,27 @@
class Undefined: class Undefined:
pass pass
undefined = Undefined()
def _(s): def _(s):
return s return s
undefined = Undefined() def tiramisu_display_name(kls,
dyn_name: 'Base'=None,
suffix: str=None,
) -> str:
if dyn_name is not None:
name = dyn_name
else:
name = kls.impl_getname()
doc = kls.impl_get_information('doc', None)
if doc:
doc = str(doc)
if doc.endswith('.'):
doc = doc[:-1]
if suffix:
doc += suffix
if name != doc:
name += f'" "{doc}'
return name

0
tests/__init__.py Normal file
View File

View File

@ -1,5 +1,15 @@
from tiramisu import Storage try:
from risotto.config import DATABASE_DIR from tiramisu3 import Storage
except:
from tiramisu import Storage
from os.path import isfile as _isfile
import os as _os
_envfile = '/etc/risotto/risotto.conf'
if _isfile(_envfile):
with open(_envfile, 'r') as fh_env:
for line in fh_env.readlines():
key, value = line.strip().split('=')
_os.environ[key] = value
STORAGE = Storage(engine='sqlite3', dir_database=DATABASE_DIR, name='test') STORAGE = Storage(engine='sqlite3')

View File

@ -1,20 +1,29 @@
from importlib import import_module from importlib import import_module
import pytest import pytest
from tiramisu import list_sessions, delete_session try:
from tiramisu3 import list_sessions, delete_session as _delete_session
except:
from tiramisu import list_sessions, delete_session as _delete_session
from .storage import STORAGE from .storage import STORAGE
from risotto import services
from risotto.context import Context from risotto.context import Context
from risotto.services import load_services #from risotto.services import load_services
from risotto.dispatcher import dispatcher from risotto.dispatcher import dispatcher
SOURCE_NAME = 'test'
SERVERMODEL_NAME = 'sm1'
def setup_module(module): def setup_module(module):
load_services(['config'], # load_services(['config'],
validate=False) # validate=False)
services.link_to_dispatcher(dispatcher, limit_services=['setting'], validate=False)
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
config_module.save_storage = STORAGE config_module.save_storage = STORAGE
dispatcher.set_module('server', import_module(f'.server', 'fake_services'), True) #dispatcher.set_module('server', import_module(f'.server', 'fake_services'), True)
dispatcher.set_module('servermodel', import_module(f'.servermodel', 'fake_services'), True) #dispatcher.set_module('servermodel', import_module(f'.servermodel', 'fake_services'), True)
def setup_function(function): def setup_function(function):
@ -23,11 +32,11 @@ def setup_function(function):
config_module.servermodel = {} config_module.servermodel = {}
def teardown_function(function): async def delete_session():
# delete all sessions # delete all sessions
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
for session in list_sessions(storage=config_module.save_storage): for session in await list_sessions(storage=config_module.save_storage):
delete_session(storage=config_module.save_storage, session_id=session) await _delete_session(storage=config_module.save_storage, session_id=session)
def get_fake_context(module_name): def get_fake_context(module_name):
@ -38,127 +47,166 @@ def get_fake_context(module_name):
return risotto_context return risotto_context
@pytest.mark.asyncio async def onjoin(source=True):
async def test_on_join():
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
assert config_module.servermodel == {} assert config_module.servermodel == {}
assert config_module.server == {} assert config_module.server == {}
await delete_session()
# #
#config_module.cache_root_path = 'tests/data'
await dispatcher.load()
await dispatcher.on_join(truncate=True)
if source:
fake_context = get_fake_context('config')
await dispatcher.call('v1',
'setting.source.create',
fake_context,
source_name=SOURCE_NAME,
source_directory='tests/data',
)
INTERNAL_SOURCE = {'source_name': 'internal', 'source_directory': '/srv/risotto/seed/internal'}
TEST_SOURCE = {'source_name': 'test', 'source_directory': 'tests/data'}
##############################################################################################################################
# Source / Release
##############################################################################################################################
@pytest.mark.asyncio
async def test_source_on_join():
# onjoin must create internal source
sources = [INTERNAL_SOURCE]
await onjoin(False)
fake_context = get_fake_context('config') fake_context = get_fake_context('config')
config_module.cache_root_path = 'tests/data' assert await dispatcher.call('v1',
await config_module.on_join(fake_context) 'setting.source.list',
assert list(config_module.servermodel.keys()) == [1, 2] fake_context,
assert list(config_module.server) == [3] ) == sources
assert set(config_module.server[3]) == {'server', 'server_to_deploy', 'funcs_file'} await delete_session()
assert config_module.server[3]['funcs_file'] == 'tests/data/1/funcs.py'
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_server_created(): async def test_source_create():
sources = [INTERNAL_SOURCE, TEST_SOURCE]
await onjoin()
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
assert list(config_module.servermodel.keys()) == ['last_base']
assert list(config_module.server) == []
fake_context = get_fake_context('config') fake_context = get_fake_context('config')
config_module.cache_root_path = 'tests/data' assert await dispatcher.call('v1',
await config_module.on_join(fake_context) 'setting.source.list',
# fake_context,
assert list(config_module.server) == [3] ) == sources
await dispatcher.publish('v1', await delete_session()
'server.created',
fake_context,
server_id=4,
server_name='name3',
server_description='description3',
server_servermodel_id=2)
assert list(config_module.server) == [3, 4]
assert set(config_module.server[4]) == {'server', 'server_to_deploy', 'funcs_file'}
assert config_module.server[4]['funcs_file'] == 'tests/data/2/funcs.py'
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_server_deleted(): async def test_source_describe():
config_module = dispatcher.get_service('config') await onjoin()
fake_context = get_fake_context('config') fake_context = get_fake_context('config')
config_module.cache_root_path = 'tests/data' assert await dispatcher.call('v1',
await config_module.on_join(fake_context) 'setting.source.describe',
# fake_context,
assert list(config_module.server) == [3] source_name='internal',
await dispatcher.publish('v1', ) == INTERNAL_SOURCE
'server.created', assert await dispatcher.call('v1',
fake_context, 'setting.source.describe',
server_id=4, fake_context,
server_name='name4', source_name=SOURCE_NAME,
server_description='description4', ) == TEST_SOURCE
server_servermodel_id=2) await delete_session()
assert list(config_module.server) == [3, 4]
await dispatcher.publish('v1',
'server.deleted', @pytest.mark.asyncio
fake_context, async def test_release_internal_list():
server_id=4) releases = [{'release_distribution': 'last',
assert list(config_module.server) == [3] 'release_name': 'none',
'source_name': 'internal'}]
await onjoin()
fake_context = get_fake_context('config')
assert await dispatcher.call('v1',
'setting.source.release.list',
fake_context,
source_name='internal',
) == releases
await delete_session()
@pytest.mark.asyncio
async def test_release_list():
releases = [{'release_distribution': 'last',
'release_name': '1',
'source_name': 'test'}]
await onjoin()
fake_context = get_fake_context('config')
assert await dispatcher.call('v1',
'setting.source.release.list',
fake_context,
source_name='test',
) == releases
await delete_session()
@pytest.mark.asyncio
async def test_release_describe():
await onjoin()
fake_context = get_fake_context('config')
assert await dispatcher.call('v1',
'setting.source.release.describe',
fake_context,
source_name='internal',
release_distribution='last',
) == {'release_distribution': 'last',
'release_name': 'none',
'source_name': 'internal'}
assert await dispatcher.call('v1',
'setting.source.release.describe',
fake_context,
source_name='test',
release_distribution='last',
) == {'release_distribution': 'last',
'release_name': '1',
'source_name': 'test'}
await delete_session()
##############################################################################################################################
# Servermodel
##############################################################################################################################
async def create_servermodel(name=SERVERMODEL_NAME,
parents_name=['base'],
):
fake_context = get_fake_context('config')
await dispatcher.call('v1',
'setting.servermodel.create',
fake_context,
servermodel_name=name,
servermodel_description='servermodel 1',
parents_name=parents_name,
source_name=SOURCE_NAME,
release_distribution='last',
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_servermodel_created(): async def test_servermodel_created():
await onjoin()
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
fake_context = get_fake_context('config')
config_module.cache_root_path = 'tests/data'
await config_module.on_join(fake_context)
# #
assert list(config_module.servermodel) == [1, 2] assert list(config_module.servermodel) == ['last_base']
servermodel = {'servermodeid': 3, await create_servermodel()
'servermodelname': 'name3'} assert list(config_module.servermodel) == ['last_base', 'last_sm1']
await dispatcher.publish('v1', assert not list(await config_module.servermodel['last_base'].config.parents())
'servermodel.created', assert len(list(await config_module.servermodel['last_sm1'].config.parents())) == 1
fake_context, await delete_session()
servermodel_id=3, #
servermodel_description='name3', #
release_id=1,
servermodel_name='name3')
assert list(config_module.servermodel) == [1, 2, 3]
assert not list(await config_module.servermodel[3].config.parents())
@pytest.mark.asyncio
async def test_servermodel_herited_created():
config_module = dispatcher.get_service('config')
fake_context = get_fake_context('config')
config_module.cache_root_path = 'tests/data'
await config_module.on_join(fake_context)
#
assert list(config_module.servermodel) == [1, 2]
await dispatcher.publish('v1',
'servermodel.created',
fake_context,
servermodel_id=3,
servermodel_name='name3',
release_id=1,
servermodel_description='name3',
servermodel_parents_id=[1])
assert list(config_module.servermodel) == [1, 2, 3]
assert len(list(await config_module.servermodel[3].config.parents())) == 1
@pytest.mark.asyncio
async def test_servermodel_multi_herited_created():
config_module = dispatcher.get_service('config')
fake_context = get_fake_context('config')
config_module.cache_root_path = 'tests/data'
await config_module.on_join(fake_context)
#
assert list(config_module.servermodel) == [1, 2]
await dispatcher.publish('v1',
'servermodel.created',
fake_context,
servermodel_id=3,
servermodel_name='name3',
release_id=1,
servermodel_description='name3',
servermodel_parents_id=[1, 2])
assert list(config_module.servermodel) == [1, 2, 3]
assert len(list(await config_module.servermodel[3].config.parents())) == 2
#@pytest.mark.asyncio #@pytest.mark.asyncio
#async def test_servermodel_updated_not_exists(): #async def test_servermodel_herited_created():
# config_module = dispatcher.get_service('config') # config_module = dispatcher.get_service('config')
# fake_context = get_fake_context('config') # fake_context = get_fake_context('config')
# config_module.cache_root_path = 'tests/data' # config_module.cache_root_path = 'tests/data'
@ -166,7 +214,28 @@ async def test_servermodel_multi_herited_created():
# # # #
# assert list(config_module.servermodel) == [1, 2] # assert list(config_module.servermodel) == [1, 2]
# await dispatcher.publish('v1', # await dispatcher.publish('v1',
# 'servermodel.updated', # 'servermodel.created',
# fake_context,
# servermodel_id=3,
# servermodel_name='name3',
# release_id=1,
# servermodel_description='name3',
# servermodel_parents_id=[1])
# assert list(config_module.servermodel) == [1, 2, 3]
# assert len(list(await config_module.servermodel[3].config.parents())) == 1
# await delete_session()
#
#
#@pytest.mark.asyncio
#async def test_servermodel_multi_herited_created():
# config_module = dispatcher.get_service('config')
# fake_context = get_fake_context('config')
# config_module.cache_root_path = 'tests/data'
# await config_module.on_join(fake_context)
# #
# assert list(config_module.servermodel) == [1, 2]
# await dispatcher.publish('v1',
# 'servermodel.created',
# fake_context, # fake_context,
# servermodel_id=3, # servermodel_id=3,
# servermodel_name='name3', # servermodel_name='name3',
@ -175,164 +244,311 @@ async def test_servermodel_multi_herited_created():
# servermodel_parents_id=[1, 2]) # servermodel_parents_id=[1, 2])
# assert list(config_module.servermodel) == [1, 2, 3] # assert list(config_module.servermodel) == [1, 2, 3]
# assert len(list(await config_module.servermodel[3].config.parents())) == 2 # assert len(list(await config_module.servermodel[3].config.parents())) == 2
# await delete_session()
# #
# #
# @pytest.mark.asyncio ##@pytest.mark.asyncio
# async def test_servermodel_updated1(): ##async def test_servermodel_updated_not_exists():
# config_module = dispatcher.get_service('config') ## config_module = dispatcher.get_service('config')
# fake_context = get_fake_context('config') ## fake_context = get_fake_context('config')
# config_module.cache_root_path = 'tests/data' ## config_module.cache_root_path = 'tests/data'
# await config_module.on_join(fake_context) ## await config_module.on_join(fake_context)
# # ## #
# assert list(config_module.servermodel) == [1, 2] ## assert list(config_module.servermodel) == [1, 2]
# metaconfig1 = config_module.servermodel[1] ## await dispatcher.publish('v1',
# metaconfig2 = config_module.servermodel[2] ## 'servermodel.updated',
# mixconfig1 = (await metaconfig1.config.list())[0] ## fake_context,
# mixconfig2 = (await metaconfig2.config.list())[0] ## servermodel_id=3,
# assert len(list(await metaconfig1.config.parents())) == 0 ## servermodel_name='name3',
# assert len(list(await metaconfig2.config.parents())) == 1 ## release_id=1,
# assert len(list(await mixconfig1.config.list())) == 1 ## servermodel_description='name3',
# assert len(list(await mixconfig2.config.list())) == 0 ## servermodel_parents_id=[1, 2])
# # ## assert list(config_module.servermodel) == [1, 2, 3]
# await dispatcher.publish('v1', ## assert len(list(await config_module.servermodel[3].config.parents())) == 2
# 'servermodel.updated', ## await delete_session()
# fake_context, ##
# servermodel_id=1, ##
# servermodel_name='name1-1', ## @pytest.mark.asyncio
# release_id=1, ## async def test_servermodel_updated1():
# servermodel_description='name1-1') ## config_module = dispatcher.get_service('config')
# assert set(config_module.servermodel) == {1, 2} ## fake_context = get_fake_context('config')
# assert config_module.servermodel[1].information.get('servermodel_name') == 'name1-1' ## config_module.cache_root_path = 'tests/data'
# assert metaconfig1 != config_module.servermodel[1] ## await config_module.on_join(fake_context)
# assert metaconfig2 == config_module.servermodel[2] ## #
# metaconfig1 = config_module.servermodel[1] ## assert list(config_module.servermodel) == [1, 2]
# assert mixconfig1 != next(metaconfig1.config.list()) ## metaconfig1 = config_module.servermodel[1]
# mixconfig1 = next(metaconfig1.config.list()) ## metaconfig2 = config_module.servermodel[2]
# # ## mixconfig1 = (await metaconfig1.config.list())[0]
# assert len(list(await metaconfig1.config.parents())) == 0 ## mixconfig2 = (await metaconfig2.config.list())[0]
# assert len(list(await metaconfig2.config.parents())) == 1 ## assert len(list(await metaconfig1.config.parents())) == 0
# assert len(list(await mixconfig1.config.list())) == 1 ## assert len(list(await metaconfig2.config.parents())) == 1
# assert len(list(await mixconfig2.config.list())) == 0 ## assert len(list(await mixconfig1.config.list())) == 1
# ## assert len(list(await mixconfig2.config.list())) == 0
# ## #
# @pytest.mark.asyncio ## await dispatcher.publish('v1',
# async def test_servermodel_updated2(): ## 'servermodel.updated',
# config_module = dispatcher.get_service('config') ## fake_context,
# fake_context = get_fake_context('config') ## servermodel_id=1,
# config_module.cache_root_path = 'tests/data' ## servermodel_name='name1-1',
# await config_module.on_join(fake_context) ## release_id=1,
# # create a new servermodel ## servermodel_description='name1-1')
# assert list(config_module.servermodel) == [1, 2] ## assert set(config_module.servermodel) == {1, 2}
# mixconfig1 = next(config_module.servermodel[1].config.list()) ## assert config_module.servermodel[1].information.get('servermodel_name') == 'name1-1'
# mixconfig2 = next(config_module.servermodel[2].config.list()) ## assert metaconfig1 != config_module.servermodel[1]
# assert len(list(mixconfig1.config.list())) == 1 ## assert metaconfig2 == config_module.servermodel[2]
# assert len(list(mixconfig2.config.list())) == 0 ## metaconfig1 = config_module.servermodel[1]
# await dispatcher.publish('v1', ## assert mixconfig1 != next(metaconfig1.config.list())
# 'servermodel.created', ## mixconfig1 = next(metaconfig1.config.list())
# fake_context, ## #
# servermodel_id=3, ## assert len(list(await metaconfig1.config.parents())) == 0
# servermodel_name='name3', ## assert len(list(await metaconfig2.config.parents())) == 1
# release_id=1, ## assert len(list(await mixconfig1.config.list())) == 1
# servermodel_description='name3', ## assert len(list(await mixconfig2.config.list())) == 0
# servermodel_parents_id=[1]) ## await delete_session()
# assert list(config_module.servermodel) == [1, 2, 3] ##
# assert len(list(await config_module.servermodel[3].config.parents())) == 1 ##
# assert await config_module.servermodel[3].information.get('servermodel_name') == 'name3' ## @pytest.mark.asyncio
# assert len(list(await mixconfig1.config.list())) == 2 ## async def test_servermodel_updated2():
# assert len(list(await mixconfig2.config.list())) == 0 ## config_module = dispatcher.get_service('config')
# # ## fake_context = get_fake_context('config')
# await dispatcher.publish('v1', ## config_module.cache_root_path = 'tests/data'
# 'servermodel.updated', ## await config_module.on_join(fake_context)
# fake_context, ## # create a new servermodel
# servermodel_id=3, ## assert list(config_module.servermodel) == [1, 2]
# servermodel_name='name3-1', ## mixconfig1 = next(config_module.servermodel[1].config.list())
# release_id=1, ## mixconfig2 = next(config_module.servermodel[2].config.list())
# servermodel_description='name3-1', ## assert len(list(mixconfig1.config.list())) == 1
# servermodel_parents_id=[1, 2]) ## assert len(list(mixconfig2.config.list())) == 0
# assert list(config_module.servermodel) == [1, 2, 3] ## await dispatcher.publish('v1',
# assert config_module.servermodel[3].information.get('servermodel_name') == 'name3-1' ## 'servermodel.created',
# assert len(list(mixconfig1.config.list())) == 2 ## fake_context,
# assert len(list(mixconfig2.config.list())) == 1 ## servermodel_id=3,
# ## servermodel_name='name3',
# ## release_id=1,
# @pytest.mark.asyncio ## servermodel_description='name3',
# async def test_servermodel_updated_config(): ## servermodel_parents_id=[1])
# config_module = dispatcher.get_service('config') ## assert list(config_module.servermodel) == [1, 2, 3]
# fake_context = get_fake_context('config') ## assert len(list(await config_module.servermodel[3].config.parents())) == 1
# config_module.cache_root_path = 'tests/data' ## assert await config_module.servermodel[3].information.get('servermodel_name') == 'name3'
# await config_module.on_join(fake_context) ## assert len(list(await mixconfig1.config.list())) == 2
# # ## assert len(list(await mixconfig2.config.list())) == 0
# config_module.servermodel[1].property.read_write() ## #
# assert config_module.servermodel[1].option('creole.general.mode_conteneur_actif').value.get() == 'non' ## await dispatcher.publish('v1',
# config_module.servermodel[1].option('creole.general.mode_conteneur_actif').value.set('oui') ## 'servermodel.updated',
# assert config_module.servermodel[1].option('creole.general.mode_conteneur_actif').value.get() == 'oui' ## fake_context,
# # ## servermodel_id=3,
# await dispatcher.publish('v1', ## servermodel_name='name3-1',
# 'servermodel.updated', ## release_id=1,
# fake_context, ## servermodel_description='name3-1',
# servermodel_id=1, ## servermodel_parents_id=[1, 2])
# servermodel_name='name1-1', ## assert list(config_module.servermodel) == [1, 2, 3]
# release_id=1, ## assert config_module.servermodel[3].information.get('servermodel_name') == 'name3-1'
# servermodel_description='name1-1') ## assert len(list(mixconfig1.config.list())) == 2
# assert config_module.servermodel[1].option('creole.general.mode_conteneur_actif').value.get() == 'oui' ## assert len(list(mixconfig2.config.list())) == 1
## await delete_session()
##
##
## @pytest.mark.asyncio
## async def test_servermodel_updated_config():
## config_module = dispatcher.get_service('config')
## fake_context = get_fake_context('config')
## config_module.cache_root_path = 'tests/data'
## await config_module.on_join(fake_context)
## #
## config_module.servermodel[1].property.read_write()
## assert config_module.servermodel[1].option('configuration.general.mode_conteneur_actif').value.get() == 'non'
## config_module.servermodel[1].option('configuration.general.mode_conteneur_actif').value.set('oui')
## assert config_module.servermodel[1].option('configuration.general.mode_conteneur_actif').value.get() == 'oui'
## #
## await dispatcher.publish('v1',
## 'servermodel.updated',
## fake_context,
## servermodel_id=1,
## servermodel_name='name1-1',
## release_id=1,
## servermodel_description='name1-1')
## assert config_module.servermodel[1].option('configuration.general.mode_conteneur_actif').value.get() == 'oui'
## await delete_session()
##############################################################################################################################
# Server
##############################################################################################################################
@pytest.mark.asyncio
async def test_server_created_base():
await onjoin()
config_module = dispatcher.get_service('config')
fake_context = get_fake_context('config')
#
assert list(config_module.server) == []
await dispatcher.on_join(truncate=True)
server_name = 'dns.test.lan'
await dispatcher.publish('v1',
'infra.server.created',
fake_context,
server_name=server_name,
server_description='description_created',
servermodel_name='base',
release_distribution='last',
site_name='site_1',
zones_name=['zones'],
)
assert list(config_module.server) == [server_name]
assert set(config_module.server[server_name]) == {'server', 'server_to_deploy', 'funcs_file'}
assert config_module.server[server_name]['funcs_file'] == '/var/cache/risotto/servermodel/last/base/funcs.py'
await delete_session()
@pytest.mark.asyncio
async def test_server_created_own_sm():
await onjoin()
config_module = dispatcher.get_service('config')
fake_context = get_fake_context('config')
await create_servermodel()
#
assert list(config_module.server) == []
await dispatcher.on_join(truncate=True)
server_name = 'dns.test.lan'
await dispatcher.publish('v1',
'infra.server.created',
fake_context,
server_name=server_name,
server_description='description_created',
servermodel_name=SERVERMODEL_NAME,
source_name=SOURCE_NAME,
release_distribution='last',
site_name='site_1',
zones_name=['zones'],
)
assert list(config_module.server) == [server_name]
assert set(config_module.server[server_name]) == {'server', 'server_to_deploy', 'funcs_file'}
assert config_module.server[server_name]['funcs_file'] == '/var/cache/risotto/servermodel/last/sm1/funcs.py'
await delete_session()
#@pytest.mark.asyncio
#async def test_server_deleted():
# config_module = dispatcher.get_service('config')
# config_module.cache_root_path = 'tests/data'
# await config_module.on_join(fake_context)
# #
# assert list(config_module.server) == [3]
# await dispatcher.publish('v1',
# 'server.created',
# fake_context,
# server_id=4,
# server_name='name4',
# server_description='description4',
# server_servermodel_id=2)
# assert list(config_module.server) == [3, 4]
# await dispatcher.publish('v1',
# 'server.deleted',
# fake_context,
# server_id=4)
# assert list(config_module.server) == [3]
# await delete_session()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_server_configuration_get(): async def test_server_configuration_get():
await onjoin()
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
fake_context = get_fake_context('config') fake_context = get_fake_context('config')
config_module.cache_root_path = 'tests/data' await create_servermodel()
await config_module.on_join(fake_context) await dispatcher.on_join(truncate=True)
server_name = 'dns.test.lan'
await dispatcher.publish('v1',
'infra.server.created',
fake_context,
server_name=server_name,
server_description='description_created',
servermodel_name=SERVERMODEL_NAME,
source_name=SOURCE_NAME,
release_distribution='last',
site_name='site_1',
zones_name=['zones'],
)
# #
await config_module.server[3]['server_to_deploy'].property.read_write() await config_module.server[server_name]['server'].property.read_write()
assert await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.get() == 'non' assert await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.get() == 1
await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.set('oui') await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.set(2)
assert await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.get() == 'oui' assert await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.get() == 2
assert await config_module.server[3]['server'].option('creole.general.mode_conteneur_actif').value.get() == 'non' assert await config_module.server[server_name]['server_to_deploy'].option('configuration.general.number_of_interfaces').value.get() == 1
# #
configuration = {'server_name': server_name,
'deployed': False,
'configuration': {'configuration.general.number_of_interfaces': 1,
'configuration.general.interfaces_list': [0],
'configuration.interface_0.domain_name_eth0': 'dns.test.lan'
}
}
values = await dispatcher.call('v1', values = await dispatcher.call('v1',
'config.configuration.server.get', 'setting.config.configuration.server.get',
fake_context, fake_context,
server_id=3) server_name=server_name,
configuration = {'configuration': deployed=False,
{'creole.general.mode_conteneur_actif': 'non', )
'creole.general.master.master': [],
'creole.general.master.slave1': [],
'creole.general.master.slave2': [],
'containers.container0.files.file0.mkdir': False,
'containers.container0.files.file0.name': '/etc/mailname',
'containers.container0.files.file0.rm': False,
'containers.container0.files.file0.source': 'mailname',
'containers.container0.files.file0.activate': True},
'server_id': 3,
'deployed': True}
assert values == configuration assert values == configuration
# #
values = await dispatcher.call('v1', await delete_session()
'config.configuration.server.get',
fake_context,
server_id=3,
deployed=False)
configuration['configuration']['creole.general.mode_conteneur_actif'] = 'oui'
configuration['deployed'] = False
assert values == configuration
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_config_deployed(): async def test_server_configuration_deployed():
await onjoin()
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
fake_context = get_fake_context('config') fake_context = get_fake_context('config')
config_module.cache_root_path = 'tests/data' await create_servermodel()
await config_module.on_join(fake_context) await dispatcher.on_join(truncate=True)
server_name = 'dns.test.lan'
await dispatcher.publish('v1',
'infra.server.created',
fake_context,
server_name=server_name,
server_description='description_created',
servermodel_name=SERVERMODEL_NAME,
source_name=SOURCE_NAME,
release_distribution='last',
site_name='site_1',
zones_name=['zones'],
)
# #
await config_module.server[3]['server_to_deploy'].property.read_write() await config_module.server[server_name]['server'].property.read_write()
assert await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.get() == 'non' assert await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.get() == 1
await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.set('oui') await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.set(2)
assert await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.get() == 'oui' assert await config_module.server[server_name]['server'].option('configuration.general.number_of_interfaces').value.get() == 2
assert await config_module.server[3]['server'].option('creole.general.mode_conteneur_actif').value.get() == 'non' assert await config_module.server[server_name]['server_to_deploy'].option('configuration.general.number_of_interfaces').value.get() == 1
values = await dispatcher.publish('v1', #
'config.configuration.server.deploy', configuration = {'server_name': server_name,
fake_context, 'deployed': False,
server_id=3) 'configuration': {'configuration.general.number_of_interfaces': 1,
assert await config_module.server[3]['server_to_deploy'].option('creole.general.mode_conteneur_actif').value.get() == 'oui' 'configuration.general.interfaces_list': [0],
assert await config_module.server[3]['server'].option('creole.general.mode_conteneur_actif').value.get() == 'oui' 'configuration.interface_0.domain_name_eth0': 'dns.test.lan'
}
}
try:
await dispatcher.call('v1',
'setting.config.configuration.server.get',
fake_context,
server_name=server_name,
)
except:
pass
else:
raise Exception('should raise propertyerror')
values = await dispatcher.call('v1',
'setting.config.configuration.server.deploy',
fake_context,
server_name=server_name,
)
assert values == {'server_name': 'dns.test.lan', 'deployed': True}
await dispatcher.call('v1',
'setting.config.configuration.server.get',
fake_context,
server_name=server_name,
)
#
await delete_session()

View File

@ -2,7 +2,7 @@ from importlib import import_module
import pytest import pytest
from .storage import STORAGE from .storage import STORAGE
from risotto.context import Context from risotto.context import Context
from risotto.services import load_services #from risotto.services import load_services
from risotto.dispatcher import dispatcher from risotto.dispatcher import dispatcher
from risotto.services.session.storage import storage_server, storage_servermodel from risotto.services.session.storage import storage_server, storage_servermodel
@ -16,9 +16,9 @@ def get_fake_context(module_name):
def setup_module(module): def setup_module(module):
load_services(['config', 'session'], #load_services(['config', 'session'],
validate=False, # validate=False,
test=True) # test=True)
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
config_module.save_storage = STORAGE config_module.save_storage = STORAGE
dispatcher.set_module('server', import_module(f'.server', 'fake_services'), True) dispatcher.set_module('server', import_module(f'.server', 'fake_services'), True)