Compare commits

..

11 Commits

11 changed files with 214 additions and 310 deletions

5
debian/changelog vendored
View File

@ -1,5 +0,0 @@
risotto (0.1) unstable; urgency=low
* first version
-- Cadoles <contact@cadoles.com> Fri, 20 Mar 2020 15:18:25 +0100

10
debian/control vendored
View File

@ -2,13 +2,19 @@ Source: risotto
Section: admin Section: admin
Priority: extra Priority: extra
Maintainer: Cadoles <contact@cadoles.com> Maintainer: Cadoles <contact@cadoles.com>
Build-depends: debhelper (>=11), python3-all, python3-setuptools Build-depends: debhelper (>=11), python3-all, python3-setuptools, dh-python
Standards-Version: 3.9.4 Standards-Version: 3.9.4
Homepage: https://forge.cadoles.com/Infra/risotto Homepage: https://forge.cadoles.com/Infra/risotto
Package: python3-risotto
Architecture: any
Pre-Depends: dpkg, python3, ${misc:Pre-Depends}
Depends: ${python:Depends}, ${misc:Depends}, python3-asyncpg, python3-rougail, python3-aiohttp
Description: configuration manager libraries
Package: risotto Package: risotto
Architecture: any Architecture: any
Pre-Depends: dpkg, python3, ${misc:Pre-Depends} Pre-Depends: dpkg, python3, ${misc:Pre-Depends}
Depends: ${python:Depends}, ${misc:Depends} Depends: ${python:Depends}, ${misc:Depends}, python3-risotto
Description: configuration manager Description: configuration manager

2
debian/risotto.install vendored Normal file
View File

@ -0,0 +1,2 @@
script/risotto-server usr/bin/
sql/risotto.sql usr/share/eole/db/eole-risotto/gen/

5
script/risotto-server Normal file → Executable file
View File

@ -1,16 +1,13 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from sdnotify import SystemdNotifier
from asyncio import get_event_loop from asyncio import get_event_loop
from risotto import get_app from risotto import get_app
if __name__ == '__main__': if __name__ == '__main__':
notifier = SystemdNotifier()
loop = get_event_loop() loop = get_event_loop()
loop.run_until_complete(get_app(loop)) loop.run_until_complete(get_app(loop))
print('HTTP server ready')
notifier.notify("READY=1")
try: try:
print('HTTP server ready')
loop.run_forever() loop.run_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass

View File

@ -1,7 +1,6 @@
from os import environ from os import environ
from os.path import isfile from os.path import isfile
from configobj import ConfigObj from configobj import ConfigObj
from uuid import uuid4
CONFIG_FILE = environ.get('CONFIG_FILE', '/etc/risotto/risotto.conf') CONFIG_FILE = environ.get('CONFIG_FILE', '/etc/risotto/risotto.conf')
@ -21,6 +20,10 @@ if 'CONFIGURATION_DIR' in environ:
CONFIGURATION_DIR = environ['CONFIGURATION_DIR'] CONFIGURATION_DIR = environ['CONFIGURATION_DIR']
else: else:
CONFIGURATION_DIR = config.get('CONFIGURATION_DIR', '/srv/risotto/configurations') CONFIGURATION_DIR = config.get('CONFIGURATION_DIR', '/srv/risotto/configurations')
if 'PROVIDER_FACTORY_CONFIG_DIR' in environ:
PROVIDER_FACTORY_CONFIG_DIR = environ['PROVIDER_FACTORY_CONFIG_DIR']
else:
PROVIDER_FACTORY_CONFIG_DIR = config.get('PROVIDER_FACTORY_CONFIG_DIR', '/srv/factory')
if 'DEFAULT_USER' in environ: if 'DEFAULT_USER' in environ:
DEFAULT_USER = environ['DEFAULT_USER'] DEFAULT_USER = environ['DEFAULT_USER']
else: else:
@ -49,18 +52,6 @@ if 'TIRAMISU_DB_USER' in environ:
TIRAMISU_DB_USER = environ['TIRAMISU_DB_USER'] TIRAMISU_DB_USER = environ['TIRAMISU_DB_USER']
else: else:
TIRAMISU_DB_USER = config.get('TIRAMISU_DB_USER', 'tiramisu') TIRAMISU_DB_USER = config.get('TIRAMISU_DB_USER', 'tiramisu')
if 'CELERYRISOTTO_DB_NAME' in environ:
CELERYRISOTTO_DB_NAME = environ['CELERYRISOTTO_DB_NAME']
else:
CELERYRISOTTO_DB_NAME = config.get('CELERYRISOTTO_DB_NAME', None)
if 'CELERYRISOTTO_DB_PASSWORD' in environ:
CELERYRISOTTO_DB_PASSWORD = environ['CELERYRISOTTO_DB_PASSWORD']
else:
CELERYRISOTTO_DB_PASSWORD = config.get('CELERYRISOTTO_DB_PASSWORD', None)
if 'CELERYRISOTTO_DB_USER' in environ:
CELERYRISOTTO_DB_USER = environ['CELERYRISOTTO_DB_USER']
else:
CELERYRISOTTO_DB_USER = config.get('CELERYRISOTTO_DB_USER', None)
if 'DB_ADDRESS' in environ: if 'DB_ADDRESS' in environ:
DB_ADDRESS = environ['DB_ADDRESS'] DB_ADDRESS = environ['DB_ADDRESS']
else: else:
@ -85,44 +76,6 @@ if 'TMP_DIR' in environ:
TMP_DIR = environ['TMP_DIR'] TMP_DIR = environ['TMP_DIR']
else: else:
TMP_DIR = config.get('TMP_DIR', '/tmp') TMP_DIR = config.get('TMP_DIR', '/tmp')
if 'IMAGE_PATH' in environ:
IMAGE_PATH = environ['IMAGE_PATH']
else:
IMAGE_PATH = config.get('IMAGE_PATH', '/tmp')
if 'PASSWORD_ADMIN_USERNAME' in environ:
PASSWORD_ADMIN_USERNAME = environ['PASSWORD_ADMIN_USERNAME']
else:
PASSWORD_ADMIN_USERNAME = config.get('PASSWORD_ADMIN_USERNAME', 'risotto')
if 'PASSWORD_ADMIN_EMAIL' in environ:
PASSWORD_ADMIN_EMAIL = environ['PASSWORD_ADMIN_EMAIL']
else:
# this parameter is mandatory
PASSWORD_ADMIN_EMAIL = config['PASSWORD_ADMIN_EMAIL']
if 'PASSWORD_ADMIN_PASSWORD' in environ:
PASSWORD_ADMIN_PASSWORD = environ['PASSWORD_ADMIN_PASSWORD']
else:
# this parameter is mandatory
PASSWORD_ADMIN_PASSWORD = config['PASSWORD_ADMIN_PASSWORD']
if 'PASSWORD_DEVICE_IDENTIFIER' in environ:
PASSWORD_DEVICE_IDENTIFIER = environ['PASSWORD_DEVICE_IDENTIFIER']
else:
PASSWORD_DEVICE_IDENTIFIER = config.get('PASSWORD_DEVICE_IDENTIFIER', uuid4())
if 'PASSWORD_URL' in environ:
PASSWORD_URL = environ['PASSWORD_URL']
else:
PASSWORD_URL = config.get('PASSWORD_URL', 'https://localhost:8001/')
if 'PKI_ADMIN_PASSWORD' in environ:
PKI_ADMIN_PASSWORD = environ['PKI_ADMIN_PASSWORD']
else:
PKI_ADMIN_PASSWORD = config['PKI_ADMIN_PASSWORD']
if 'PKI_ADMIN_EMAIL' in environ:
PKI_ADMIN_EMAIL = environ['PKI_ADMIN_EMAIL']
else:
PKI_ADMIN_EMAIL = config['PKI_ADMIN_EMAIL']
if 'PKI_URL' in environ:
PKI_URL = environ['PKI_URL']
else:
PKI_URL = config.get('PKI_URL', 'http://localhost:8002')
def dsn_factory(database, user, password, address=DB_ADDRESS): def dsn_factory(database, user, password, address=DB_ADDRESS):
@ -132,7 +85,6 @@ def dsn_factory(database, user, password, address=DB_ADDRESS):
_config = {'database': {'dsn': dsn_factory(RISOTTO_DB_NAME, RISOTTO_DB_USER, RISOTTO_DB_PASSWORD), _config = {'database': {'dsn': dsn_factory(RISOTTO_DB_NAME, RISOTTO_DB_USER, RISOTTO_DB_PASSWORD),
'tiramisu_dsn': dsn_factory(TIRAMISU_DB_NAME, TIRAMISU_DB_USER, TIRAMISU_DB_PASSWORD), 'tiramisu_dsn': dsn_factory(TIRAMISU_DB_NAME, TIRAMISU_DB_USER, TIRAMISU_DB_PASSWORD),
'celery_dsn': dsn_factory(CELERYRISOTTO_DB_NAME, CELERYRISOTTO_DB_USER, CELERYRISOTTO_DB_PASSWORD)
}, },
'http_server': {'port': RISOTTO_PORT, 'http_server': {'port': RISOTTO_PORT,
'default_user': DEFAULT_USER}, 'default_user': DEFAULT_USER},
@ -145,24 +97,13 @@ _config = {'database': {'dsn': dsn_factory(RISOTTO_DB_NAME, RISOTTO_DB_USER, RIS
'sql_dir': SQL_DIR, 'sql_dir': SQL_DIR,
'tmp_dir': TMP_DIR, 'tmp_dir': TMP_DIR,
}, },
'password': {'admin_username': PASSWORD_ADMIN_USERNAME,
'admin_email': PASSWORD_ADMIN_EMAIL,
'admin_password': PASSWORD_ADMIN_PASSWORD,
'device_identifier': PASSWORD_DEVICE_IDENTIFIER,
'service_url': PASSWORD_URL,
},
'pki': {'admin_password': PKI_ADMIN_PASSWORD,
'owner': PKI_ADMIN_EMAIL,
'url': PKI_URL,
},
'cache': {'root_path': CACHE_ROOT_PATH}, 'cache': {'root_path': CACHE_ROOT_PATH},
'servermodel': {'internal_source_path': SRV_SEED_PATH, 'servermodel': {'internal_source_path': SRV_SEED_PATH,
'internal_source': 'internal'}, 'internal_source': 'internal'},
'submodule': {'allow_insecure_https': False, 'submodule': {'allow_insecure_https': False,
'pki': '192.168.56.112'}, 'pki': '192.168.56.112'},
'provider': {'factory_configuration_filename': 'infra.json', 'provider': {'factory_configuration_dir': PROVIDER_FACTORY_CONFIG_DIR,
'packer_filename': 'recipe.json', 'factory_configuration_filename': 'infra.json'},
'risotto_images_dir': IMAGE_PATH},
} }

View File

@ -39,40 +39,15 @@ class Controller:
**kwargs, **kwargs,
): ):
""" a wrapper to dispatcher's publish""" """ a wrapper to dispatcher's publish"""
version, message = uri.split('.', 1)
if args: if args:
raise ValueError(_(f'the URI "{uri}" can only be published with keyword arguments')) raise ValueError(_(f'the URI "{uri}" can only be published with keyword arguments'))
version, message = uri.split('.', 1)
await dispatcher.publish(version, await dispatcher.publish(version,
message, message,
risotto_context, risotto_context,
**kwargs, **kwargs,
) )
@staticmethod
async def check_role(self,
uri: str,
username: str,
**kwargs: dict,
) -> None:
# create a new config
async with await Config(dispatcher.option) as config:
await config.property.read_write()
await config.option('message').value.set(uri)
subconfig = config.option(uri)
for key, value in kwargs.items():
try:
await subconfig.option(key).value.set(value)
except AttributeError:
if get_config()['global']['debug']:
print_exc()
raise ValueError(_(f'unknown parameter in "{uri}": "{key}"'))
except ValueOptionError as err:
raise ValueError(_(f'invalid parameter in "{uri}": {err}'))
await dispatcher.check_role(subconfig,
username,
uri,
)
async def on_join(self, async def on_join(self,
risotto_context, risotto_context,
): ):

View File

@ -4,7 +4,6 @@ try:
except: except:
from tiramisu import Config from tiramisu import Config
from tiramisu.error import ValueOptionError from tiramisu.error import ValueOptionError
from asyncio import get_event_loop, ensure_future
from traceback import print_exc from traceback import print_exc
from copy import copy from copy import copy
from typing import Dict, Callable, List, Optional from typing import Dict, Callable, List, Optional
@ -16,6 +15,7 @@ from .logger import log
from .config import get_config from .config import get_config
from .context import Context from .context import Context
from . import register from . import register
from .remote import Remote
class CallDispatcher: class CallDispatcher:
@ -79,7 +79,7 @@ class CallDispatcher:
""" execute the function associate with specified uri """ execute the function associate with specified uri
arguments are validate before arguments are validate before
""" """
risotto_context = self.build_new_context(old_risotto_context.__dict__, risotto_context = self.build_new_context(old_risotto_context,
version, version,
message, message,
'rpc', 'rpc',
@ -88,35 +88,20 @@ class CallDispatcher:
raise CallError(_(f'cannot find version of message "{version}"')) raise CallError(_(f'cannot find version of message "{version}"'))
if message not in self.messages[version]: if message not in self.messages[version]:
raise CallError(_(f'cannot find message "{version}.{message}"')) raise CallError(_(f'cannot find message "{version}.{message}"'))
function_obj = self.messages[version][message] function_objs = [self.messages[version][message]]
# do not start a new database connection # do not start a new database connection
if hasattr(old_risotto_context, 'connection'): if hasattr(old_risotto_context, 'connection'):
risotto_context.connection = old_risotto_context.connection risotto_context.connection = old_risotto_context.connection
await self.check_message_type(risotto_context, return await self.launch(version,
kwargs, message,
) risotto_context,
config_arguments = await self.load_kwargs_to_config(risotto_context, check_role,
f'{version}.{message}',
kwargs,
check_role,
internal,
)
return await self.launch(risotto_context,
kwargs, kwargs,
config_arguments, function_objs,
function_obj, internal,
) )
else: else:
try: try:
await self.check_message_type(risotto_context,
kwargs,
)
config_arguments = await self.load_kwargs_to_config(risotto_context,
f'{version}.{message}',
kwargs,
check_role,
internal,
)
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
await connection.set_type_codec( await connection.set_type_codec(
'json', 'json',
@ -126,10 +111,13 @@ class CallDispatcher:
) )
risotto_context.connection = connection risotto_context.connection = connection
async with connection.transaction(): async with connection.transaction():
return await self.launch(risotto_context, return await self.launch(version,
message,
risotto_context,
check_role,
kwargs, kwargs,
config_arguments, function_objs,
function_obj, internal,
) )
except CallError as err: except CallError as err:
raise err raise err
@ -151,80 +139,66 @@ class CallDispatcher:
class PublishDispatcher: class PublishDispatcher:
async def register_remote(self) -> None:
print()
print(_('======== Registered remote event ========'))
self.listened_connection = await self.pool.acquire()
for version, messages in self.messages.items():
for message, message_infos in messages.items():
# event not emit locally
if message_infos['pattern'] == 'event' and 'functions' in message_infos and message_infos['functions']:
# module, submodule, submessage = message.split('.', 2)
# if f'{module}.{submodule}' not in self.injected_self:
uri = f'{version}.{message}'
print(f' - {uri}')
await self.listened_connection.add_listener(uri,
self.to_async_publish,
)
async def publish(self, async def publish(self,
version: str, version: str,
message: str, message: str,
risotto_context: Context, old_risotto_context: Context,
check_role: bool=False,
internal: bool=True,
**kwargs, **kwargs,
) -> None: ) -> None:
if version not in self.messages or message not in self.messages[version]: risotto_context = self.build_new_context(old_risotto_context,
raise ValueError(_(f'cannot find URI "{version}.{message}"'))
# publish to remote
remote_kw = dumps({'kwargs': kwargs,
'context': {'username': risotto_context.username,
'paths': risotto_context.paths,
}
})
# FIXME should be better :/
remote_kw = remote_kw.replace("'", "''")
await risotto_context.connection.execute(f'NOTIFY "{version}.{message}", \'{remote_kw}\'')
def to_async_publish(self,
con: 'asyncpg.connection.Connection',
pid: int,
uri: str,
payload: str,
) -> None:
version, message = uri.split('.', 1)
loop = get_event_loop()
remote_kw = loads(payload)
risotto_context = self.build_new_context(remote_kw['context'],
version, version,
message, message,
'event', 'event',
) )
callback = lambda: ensure_future(self._publish(version, try:
message, function_objs = self.messages[version][message].get('functions', [])
risotto_context, except KeyError:
**remote_kw['kwargs'], raise ValueError(_(f'cannot find message {version}.{message}'))
)) # do not start a new database connection
loop.call_soon(callback) if hasattr(old_risotto_context, 'connection'):
# publish to remove
async def _publish(self, remote_kw = dumps({'kwargs': kwargs,
version: str, 'context': risotto_context.__dict__,
message: str, })
risotto_context: Context, risotto_context.connection = old_risotto_context.connection
**kwargs, # FIXME should be better :/
) -> None: remote_kw = remote_kw.replace("'", "''")
config_arguments = await self.load_kwargs_to_config(risotto_context, await risotto_context.connection.execute(f'NOTIFY "{version}.{message}", \'{remote_kw}\'')
f'{version}.{message}', return await self.launch(version,
kwargs, message,
False, risotto_context,
False, check_role,
) kwargs,
for function_obj in self.messages[version][message]['functions']: function_objs,
async with self.pool.acquire() as connection: internal,
try: )
await self.check_message_type(risotto_context, async with self.pool.acquire() as connection:
kwargs, try:
) await connection.set_type_codec(
'json',
encoder=dumps,
decoder=loads,
schema='pg_catalog'
)
risotto_context.connection = connection
async with connection.transaction():
return await self.launch(version,
message,
risotto_context,
check_role,
kwargs,
function_objs,
internal,
)
except CallError as err:
pass
except Exception as err:
# if there is a problem with arguments, log and do nothing
if get_config()['global']['debug']:
print_exc()
async with self.pool.acquire() as connection:
await connection.set_type_codec( await connection.set_type_codec(
'json', 'json',
encoder=dumps, encoder=dumps,
@ -233,37 +207,18 @@ class PublishDispatcher:
) )
risotto_context.connection = connection risotto_context.connection = connection
async with connection.transaction(): async with connection.transaction():
await self.launch(risotto_context, await log.error_msg(risotto_context, kwargs, err)
kwargs,
config_arguments,
function_obj,
)
except CallError as err:
pass
except Exception as err:
# if there is a problem with arguments, log and do nothing
if get_config()['global']['debug']:
print_exc()
async with self.pool.acquire() as connection:
await connection.set_type_codec(
'json',
encoder=dumps,
decoder=loads,
schema='pg_catalog'
)
risotto_context.connection = connection
async with connection.transaction():
await log.error_msg(risotto_context, kwargs, err)
class Dispatcher(register.RegisterDispatcher, class Dispatcher(register.RegisterDispatcher,
Remote,
CallDispatcher, CallDispatcher,
PublishDispatcher): PublishDispatcher):
""" Manage message (call or publish) """ Manage message (call or publish)
so launch a function when a message is called so launch a function when a message is called
""" """
def build_new_context(self, def build_new_context(self,
context: dict, old_risotto_context: Context,
version: str, version: str,
message: str, message: str,
type: str, type: str,
@ -272,8 +227,8 @@ class Dispatcher(register.RegisterDispatcher,
""" """
uri = version + '.' + message uri = version + '.' + message
risotto_context = Context() risotto_context = Context()
risotto_context.username = context['username'] risotto_context.username = old_risotto_context.username
risotto_context.paths = copy(context['paths']) risotto_context.paths = copy(old_risotto_context.paths)
risotto_context.paths.append(uri) risotto_context.paths.append(uri)
risotto_context.uri = uri risotto_context.uri = uri
risotto_context.type = type risotto_context.type = type
@ -333,7 +288,7 @@ class Dispatcher(register.RegisterDispatcher,
parameters = await subconfig.value.dict() parameters = await subconfig.value.dict()
if extra_parameters: if extra_parameters:
parameters.update(extra_parameters) parameters.update(extra_parameters)
return parameters return parameters
def get_service(self, def get_service(self,
name: str): name: str):
@ -342,15 +297,14 @@ class Dispatcher(register.RegisterDispatcher,
async def check_role(self, async def check_role(self,
config: Config, config: Config,
user_login: str, user_login: str,
uri: str, uri: str) -> None:
) -> None:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
async with connection.transaction(): async with connection.transaction():
# Verify if user exists and get ID # Verify if user exists and get ID
sql = ''' sql = '''
SELECT UserId SELECT UserId
FROM UserUser FROM UserUser
WHERE Login = $1 WHERE UserLogin = $1
''' '''
user_id = await connection.fetchval(sql, user_id = await connection.fetchval(sql,
user_login) user_login)
@ -388,55 +342,65 @@ class Dispatcher(register.RegisterDispatcher,
raise NotAllowedError(_(f'You ({user_login}) don\'t have any authorisation to access to "{uri}"')) raise NotAllowedError(_(f'You ({user_login}) don\'t have any authorisation to access to "{uri}"'))
async def launch(self, async def launch(self,
version: str,
message: str,
risotto_context: Context, risotto_context: Context,
check_role: bool,
kwargs: Dict, kwargs: Dict,
config_arguments: dict, function_objs: List,
function_obj: Callable, internal: bool,
) -> Optional[Dict]: ) -> Optional[Dict]:
# so send the message await self.check_message_type(risotto_context,
function = function_obj['function'] kwargs)
risotto_context.module = function_obj['module'].split('.', 1)[0] config_arguments = await self.load_kwargs_to_config(risotto_context,
function_name = function.__name__ f'{version}.{message}',
info_msg = _(f"in function {function_obj['full_module_name']}.{function_name}") kwargs,
# build argument for this function check_role,
if risotto_context.type == 'rpc': internal,
kw = config_arguments )
else: # config is ok, so send the message
kw = {} for function_obj in function_objs:
for key, value in config_arguments.items(): function = function_obj['function']
if key in function_obj['arguments']: submodule_name = function_obj['module']
kw[key] = value function_name = function.__name__
risotto_context.module = submodule_name.split('.', 1)[0]
kw['risotto_context'] = risotto_context info_msg = _(f'in module {submodule_name}.{function_name}')
returns = await function(self.get_service(function_obj['module']), **kw) # build argument for this function
if risotto_context.type == 'rpc': if risotto_context.type == 'rpc':
# valid returns kw = config_arguments
await self.valid_call_returns(risotto_context,
function,
returns,
kwargs,
)
# log the success
await log.info_msg(risotto_context,
{'arguments': kwargs,
'returns': returns},
info_msg,
)
# notification
if function_obj.get('notification'):
notif_version, notif_message = function_obj['notification'].split('.', 1)
if not isinstance(returns, list):
send_returns = [returns]
else: else:
send_returns = returns kw = {}
for ret in send_returns: for key, value in config_arguments.items():
await self.publish(notif_version, if key in function_obj['arguments']:
notif_message, kw[key] = value
risotto_context,
**ret, kw['risotto_context'] = risotto_context
) returns = await function(self.injected_self[function_obj['module']], **kw)
if risotto_context.type == 'rpc': if risotto_context.type == 'rpc':
return returns # valid returns
await self.valid_call_returns(risotto_context,
function,
returns,
kwargs)
# log the success
await log.info_msg(risotto_context,
{'arguments': kwargs,
'returns': returns},
info_msg)
# notification
if function_obj.get('notification'):
notif_version, notif_message = function_obj['notification'].split('.', 1)
if not isinstance(returns, list):
send_returns = [returns]
else:
send_returns = returns
for ret in send_returns:
await self.publish(notif_version,
notif_message,
risotto_context,
**ret)
if risotto_context.type == 'rpc':
return returns
dispatcher = Dispatcher() dispatcher = Dispatcher()

View File

@ -29,8 +29,7 @@ def create_context(request):
def register(version: str, def register(version: str,
path: str, path: str):
):
""" Decorator to register function to the http route """ Decorator to register function to the http route
""" """
def decorator(function): def decorator(function):
@ -42,9 +41,7 @@ def register(version: str,
class extra_route_handler: class extra_route_handler:
async def __new__(cls, async def __new__(cls, request):
request,
):
kwargs = dict(request.match_info) kwargs = dict(request.match_info)
kwargs['request'] = request kwargs['request'] = request
kwargs['risotto_context'] = create_context(request) kwargs['risotto_context'] = create_context(request)
@ -99,13 +96,11 @@ async def handle(request):
print_exc() print_exc()
raise HTTPInternalServerError(reason=str(err)) raise HTTPInternalServerError(reason=str(err))
return Response(text=dumps({'response': text}), return Response(text=dumps({'response': text}),
content_type='application/json', content_type='application/json')
)
async def api(request, async def api(request,
risotto_context, risotto_context):
):
global TIRAMISU global TIRAMISU
if not TIRAMISU: if not TIRAMISU:
# check all URI that have an associated role # check all URI that have an associated role
@ -157,8 +152,7 @@ async def get_app(loop):
for version in versions: for version in versions:
api_route = {'function': api, api_route = {'function': api,
'version': version, 'version': version,
'path': f'/api/{version}', 'path': f'/api/{version}'}
}
extra_handler = type(api_route['path'], (extra_route_handler,), api_route) extra_handler = type(api_route['path'], (extra_route_handler,), api_route)
routes.append(get(api_route['path'], extra_handler)) routes.append(get(api_route['path'], extra_handler))
print(f' - {api_route["path"]} (http_get)') print(f' - {api_route["path"]} (http_get)')
@ -180,10 +174,7 @@ async def get_app(loop):
await dispatcher.register_remote() await dispatcher.register_remote()
print() print()
await dispatcher.on_join() await dispatcher.on_join()
return await loop.create_server(app.make_handler(), return await loop.create_server(app.make_handler(), '*', get_config()['http_server']['port'])
'*',
get_config()['http_server']['port'],
)
TIRAMISU = None TIRAMISU = None

View File

@ -7,7 +7,6 @@ from typing import Callable, Optional, List
from asyncpg import create_pool from asyncpg import create_pool
from json import dumps, loads from json import dumps, loads
from pkg_resources import iter_entry_points from pkg_resources import iter_entry_points
from traceback import print_exc
import risotto import risotto
from .utils import _ from .utils import _
from .error import RegistrationError from .error import RegistrationError
@ -24,7 +23,7 @@ class Services():
def load_services(self): def load_services(self):
for entry_point in iter_entry_points(group='risotto_services'): for entry_point in iter_entry_points(group='risotto_services'):
self.services.setdefault(entry_point.name, {}) self.services.setdefault(entry_point.name, [])
self.services_loaded = True self.services_loaded = True
def load_modules(self, def load_modules(self,
@ -33,20 +32,21 @@ class Services():
for entry_point in iter_entry_points(group='risotto_modules'): for entry_point in iter_entry_points(group='risotto_modules'):
service_name, module_name = entry_point.name.split('.') service_name, module_name = entry_point.name.split('.')
if limit_services is None or service_name in limit_services: if limit_services is None or service_name in limit_services:
self.services[service_name][module_name] = entry_point.load() setattr(self, module_name, entry_point.load())
self.services[service_name].append(module_name)
self.modules_loaded = True self.modules_loaded = True
#
# def get_services(self): def get_services(self):
# if not self.services_loaded: if not self.services_loaded:
# self.load_services() self.load_services()
# return [(service, getattr(self, service)) for service in self.services] return [(s, getattr(self, s)) for s in self.services]
def get_modules(self, def get_modules(self,
limit_services: Optional[List[str]]=None, limit_services: Optional[List[str]]=None,
) -> List[str]: ) -> List[str]:
if not self.modules_loaded: if not self.modules_loaded:
self.load_modules(limit_services=limit_services) self.load_modules(limit_services=limit_services)
return [(module + '.' + submodule, entry_point) for module, submodules in self.services.items() for submodule, entry_point in submodules.items()] return [(module + '.' + submodule, getattr(self, submodule)) for module, submodules in self.services.items() for submodule in submodules]
def get_services_list(self): def get_services_list(self):
return self.services.keys() return self.services.keys()
@ -199,8 +199,7 @@ class RegisterDispatcher:
raise RegistrationError(_(f'the message {message} not exists')) raise RegistrationError(_(f'the message {message} not exists'))
# xxx submodule can only be register with v1.yyy.xxx..... message # xxx submodule can only be register with v1.yyy.xxx..... message
full_module_name = function.__module__ risotto_module_name, submodule_name = function.__module__.split('.')[-3:-1]
risotto_module_name, submodule_name = full_module_name.split('.')[-3:-1]
module_name = risotto_module_name.split('_')[-1] module_name = risotto_module_name.split('_')[-1]
message_module, message_submodule, message_name = message.split('.', 2) message_module, message_submodule, message_name = message.split('.', 2)
if message_module not in self.risotto_modules: if message_module not in self.risotto_modules:
@ -225,7 +224,6 @@ class RegisterDispatcher:
register(version, register(version,
message, message,
f'{module_name}.{submodule_name}', f'{module_name}.{submodule_name}',
full_module_name,
function, function,
function_args, function_args,
notification, notification,
@ -235,13 +233,11 @@ class RegisterDispatcher:
version: str, version: str,
message: str, message: str,
module_name: str, module_name: str,
full_module_name: str,
function: Callable, function: Callable,
function_args: list, function_args: list,
notification: Optional[str], notification: Optional[str],
): ):
self.messages[version][message]['module'] = module_name self.messages[version][message]['module'] = module_name
self.messages[version][message]['full_module_name'] = full_module_name
self.messages[version][message]['function'] = function self.messages[version][message]['function'] = function
self.messages[version][message]['arguments'] = function_args self.messages[version][message]['arguments'] = function_args
if notification: if notification:
@ -251,7 +247,6 @@ class RegisterDispatcher:
version: str, version: str,
message: str, message: str,
module_name: str, module_name: str,
full_module_name: str,
function: Callable, function: Callable,
function_args: list, function_args: list,
notification: Optional[str], notification: Optional[str],
@ -260,10 +255,8 @@ class RegisterDispatcher:
self.messages[version][message]['functions'] = [] self.messages[version][message]['functions'] = []
dico = {'module': module_name, dico = {'module': module_name,
'full_module_name': full_module_name,
'function': function, 'function': function,
'arguments': function_args, 'arguments': function_args}
}
if notification and notification: if notification and notification:
dico['notification'] = notification dico['notification'] = notification
self.messages[version][message]['functions'].append(dico) self.messages[version][message]['functions'].append(dico)
@ -278,7 +271,7 @@ class RegisterDispatcher:
try: try:
self.injected_self[submodule_name] = module.Risotto(test) self.injected_self[submodule_name] = module.Risotto(test)
except AttributeError as err: except AttributeError as err:
print(_(f'unable to register the module {submodule_name}, this module must have Risotto class')) raise RegistrationError(_(f'unable to register the module {submodule_name}, this module must have Risotto class'))
def validate(self): def validate(self):
""" check if all messages have a function """ check if all messages have a function
@ -307,7 +300,7 @@ class RegisterDispatcher:
) )
if truncate: if truncate:
async with connection.transaction(): async with connection.transaction():
await connection.execute('TRUNCATE InfraServer, InfraSite, InfraZone, Log, ProviderDeployment, ProviderFactoryCluster, ProviderFactoryClusterNode, SettingApplicationservice, SettingApplicationServiceDependency, SettingRelease, SettingServer, SettingServermodel, SettingSource, UserRole, UserRoleURI, UserURI, UserUser, InfraServermodel, ProviderZone, ProviderServer, ProviderSource, ProviderApplicationservice, ProviderServermodel') await connection.execute('TRUNCATE InfraServer, InfraSite, InfraZone, Log, ProviderDeployment, ProviderFactoryCluster, ProviderFactoryClusterNode, SettingApplicationservice, SettingApplicationServiceDependency, SettingRelease, SettingServer, SettingServermodel, SettingSource, UserRole, UserRoleURI, UserURI, UserUser, InfraServermodel, ProviderZone, ProviderServer, ProviderSource, ProviderApplicationservice ProviderServermodel')
async with connection.transaction(): async with connection.transaction():
for submodule_name, module in self.injected_self.items(): for submodule_name, module in self.injected_self.items():
risotto_context = Context() risotto_context = Context()
@ -316,17 +309,11 @@ class RegisterDispatcher:
risotto_context.type = None risotto_context.type = None
risotto_context.connection = connection risotto_context.connection = connection
risotto_context.module = submodule_name.split('.', 1)[0] risotto_context.module = submodule_name.split('.', 1)[0]
info_msg = _(f'in function risotto_{submodule_name}.on_join') info_msg = _(f'in module risotto_{submodule_name}.on_join')
await log.info_msg(risotto_context, await log.info_msg(risotto_context,
None, None,
info_msg) info_msg)
try: await module.on_join(risotto_context)
await module.on_join(risotto_context)
except Exception as err:
if get_config()['global']['debug']:
print_exc()
msg = _(f'on_join returns an error in module {submodule_name}: {err}')
await log.error_msg(risotto_context, {}, msg)
async def load(self): async def load(self):
# valid function's arguments # valid function's arguments

42
src/risotto/remote.py Normal file
View File

@ -0,0 +1,42 @@
from asyncio import get_event_loop, ensure_future
from json import loads
from .context import Context
from .config import get_config
from .utils import _
class Remote:
async def register_remote(self) -> None:
print()
print(_('======== Registered remote event ========'))
self.listened_connection = await self.pool.acquire()
for version, messages in self.messages.items():
for message, message_infos in messages.items():
# event not emit locally
if message_infos['pattern'] == 'event':
module, submodule, submessage = message.split('.', 2)
if f'{module}.{submodule}' not in self.injected_self:
uri = f'{version}.{message}'
print(f' - {uri}')
await self.listened_connection.add_listener(uri, self.to_async_publish)
def to_async_publish(self,
con: 'asyncpg.connection.Connection',
pid: int,
uri: str,
payload: str,
) -> None:
version, message = uri.split('.', 1)
loop = get_event_loop()
remote_kw = loads(payload)
context = Context()
for key, value in remote_kw['context'].items():
setattr(context, key, value)
callback = lambda: ensure_future(self.publish(version,
message,
context,
**remote_kw['kwargs'],
))
loop.call_soon(callback)

View File

@ -392,6 +392,7 @@ async def test_server_created_base():
release_distribution='last', release_distribution='last',
site_name='site_1', site_name='site_1',
zones_name=['zones'], zones_name=['zones'],
zones_ip=['1.1.1.1'],
) )
assert list(config_module.server) == [server_name] assert list(config_module.server) == [server_name]
assert set(config_module.server[server_name]) == {'server', 'server_to_deploy', 'funcs_file'} assert set(config_module.server[server_name]) == {'server', 'server_to_deploy', 'funcs_file'}
@ -419,6 +420,7 @@ async def test_server_created_own_sm():
release_distribution='last', release_distribution='last',
site_name='site_1', site_name='site_1',
zones_name=['zones'], zones_name=['zones'],
zones_ip=['1.1.1.1'],
) )
assert list(config_module.server) == [server_name] assert list(config_module.server) == [server_name]
assert set(config_module.server[server_name]) == {'server', 'server_to_deploy', 'funcs_file'} assert set(config_module.server[server_name]) == {'server', 'server_to_deploy', 'funcs_file'}
@ -467,6 +469,7 @@ async def test_server_configuration_get():
release_distribution='last', release_distribution='last',
site_name='site_1', site_name='site_1',
zones_name=['zones'], zones_name=['zones'],
zones_ip=['1.1.1.1'],
) )
# #
await config_module.server[server_name]['server'].property.read_write() await config_module.server[server_name]['server'].property.read_write()
@ -512,6 +515,7 @@ async def test_server_configuration_deployed():
release_distribution='last', release_distribution='last',
site_name='site_1', site_name='site_1',
zones_name=['zones'], zones_name=['zones'],
zones_ip=['1.1.1.1'],
) )
# #
await config_module.server[server_name]['server'].property.read_write() await config_module.server[server_name]['server'].property.read_write()