add remote support

This commit is contained in:
Emmanuel Garette 2020-09-12 16:05:17 +02:00
parent 3823eedd02
commit e664dd6174
6 changed files with 242 additions and 177 deletions

View File

@ -1,21 +1,81 @@
from os import environ from os import environ
from os.path import isfile
from configobj import ConfigObj
CONFIGURATION_DIR = environ.get('CONFIGURATION_DIR', '/srv/risotto/configurations') CONFIG_FILE = environ.get('CONFIG_FILE', '/etc/risotto/risotto.conf')
PROVIDER_FACTORY_CONFIG_DIR = environ.get('PROVIDER_FACTORY_CONFIG_DIR', '/srv/factory')
TMP_DIR = '/tmp'
DEFAULT_USER = environ.get('DEFAULT_USER', 'Anonymous') if isfile(CONFIG_FILE):
RISOTTO_DB_NAME = environ.get('RISOTTO_DB_NAME', 'risotto') config = ConfigObj(CONFIG_FILE)
RISOTTO_DB_PASSWORD = environ.get('RISOTTO_DB_PASSWORD', 'risotto') else:
RISOTTO_DB_USER = environ.get('RISOTTO_DB_USER', 'risotto') config = {}
TIRAMISU_DB_NAME = environ.get('TIRAMISU_DB_NAME', 'tiramisu')
TIRAMISU_DB_PASSWORD = environ.get('TIRAMISU_DB_PASSWORD', 'tiramisu')
TIRAMISU_DB_USER = environ.get('TIRAMISU_DB_USER', 'tiramisu') if 'RISOTTO_PORT' in environ:
DB_ADDRESS = environ.get('DB_ADDRESS', 'localhost') RISOTTO_PORT = environ['RISOTTO_PORT']
MESSAGE_PATH = environ.get('MESSAGE_PATH', '/root/risotto-message/messages') else:
SQL_DIR = environ.get('SQL_DIR', './sql') RISOTTO_PORT = config.get('RISOTTO_PORT', 8080)
CACHE_ROOT_PATH = environ.get('CACHE_ROOT_PATH', '/var/cache/risotto') if 'CONFIGURATION_DIR' in environ:
SRV_SEED_PATH = environ.get('SRV_SEED_PATH', '/srv/seed') CONFIGURATION_DIR = environ['CONFIGURATION_DIR']
else:
CONFIGURATION_DIR = config.get('CONFIGURATION_DIR', '/srv/risotto/configurations')
if 'PROVIDER_FACTORY_CONFIG_DIR' in environ:
PROVIDER_FACTORY_CONFIG_DIR = environ['PROVIDER_FACTORY_CONFIG_DIR']
else:
PROVIDER_FACTORY_CONFIG_DIR = config.get('PROVIDER_FACTORY_CONFIG_DIR', '/srv/factory')
if 'DEFAULT_USER' in environ:
DEFAULT_USER = environ['DEFAULT_USER']
else:
DEFAULT_USER = config.get('DEFAULT_USER', 'Anonymous')
if 'RISOTTO_DB_NAME' in environ:
RISOTTO_DB_NAME = environ['RISOTTO_DB_NAME']
else:
RISOTTO_DB_NAME = config.get('RISOTTO_DB_NAME', 'risotto')
if 'RISOTTO_DB_PASSWORD' in environ:
RISOTTO_DB_PASSWORD = environ['RISOTTO_DB_PASSWORD']
else:
RISOTTO_DB_PASSWORD = config.get('RISOTTO_DB_PASSWORD', 'risotto')
if 'RISOTTO_DB_USER' in environ:
RISOTTO_DB_USER = environ['RISOTTO_DB_USER']
else:
RISOTTO_DB_USER = config.get('RISOTTO_DB_USER', 'risotto')
if 'TIRAMISU_DB_NAME' in environ:
TIRAMISU_DB_NAME = environ['TIRAMISU_DB_NAME']
else:
TIRAMISU_DB_NAME = config.get('TIRAMISU_DB_NAME', 'tiramisu')
if 'TIRAMISU_DB_PASSWORD' in environ:
TIRAMISU_DB_PASSWORD = environ['TIRAMISU_DB_PASSWORD']
else:
TIRAMISU_DB_PASSWORD = config.get('TIRAMISU_DB_PASSWORD', 'tiramisu')
if 'TIRAMISU_DB_USER' in environ:
TIRAMISU_DB_USER = environ['TIRAMISU_DB_USER']
else:
TIRAMISU_DB_USER = config.get('TIRAMISU_DB_USER', 'tiramisu')
if 'DB_ADDRESS' in environ:
DB_ADDRESS = environ['DB_ADDRESS']
else:
DB_ADDRESS = config.get('DB_ADDRESS', 'localhost')
if 'MESSAGE_PATH' in environ:
MESSAGE_PATH = environ['MESSAGE_PATH']
else:
MESSAGE_PATH = config.get('MESSAGE_PATH', '/root/risotto-message/messages')
if 'SQL_DIR' in environ:
SQL_DIR = environ['SQL_DIR']
else:
SQL_DIR = config.get('SQL_DIR', './sql')
if 'CACHE_ROOT_PATH' in environ:
CACHE_ROOT_PATH = environ['CACHE_ROOT_PATH']
else:
CACHE_ROOT_PATH = config.get('CACHE_ROOT_PATH', '/var/cache/risotto')
if 'SRV_SEED_PATH' in environ:
SRV_SEED_PATH = environ['SRV_SEED_PATH']
else:
SRV_SEED_PATH = config.get('SRV_SEED_PATH', '/srv/seed')
if 'TMP_DIR' in environ:
TMP_DIR = environ['TMP_DIR']
else:
TMP_DIR = config.get('TMP_DIR', '/tmp')
def dsn_factory(database, user, password, address=DB_ADDRESS): def dsn_factory(database, user, password, address=DB_ADDRESS):
@ -23,19 +83,20 @@ def dsn_factory(database, user, password, address=DB_ADDRESS):
return f'postgres:///{database}?host={mangled_address}/&user={user}&password={password}' return f'postgres:///{database}?host={mangled_address}/&user={user}&password={password}'
def get_config(): _config = {'database': {'dsn': dsn_factory(RISOTTO_DB_NAME, RISOTTO_DB_USER, RISOTTO_DB_PASSWORD),
return {'database': {'dsn': dsn_factory(RISOTTO_DB_NAME, RISOTTO_DB_USER, RISOTTO_DB_PASSWORD),
'tiramisu_dsn': dsn_factory(TIRAMISU_DB_NAME, TIRAMISU_DB_USER, TIRAMISU_DB_PASSWORD), 'tiramisu_dsn': dsn_factory(TIRAMISU_DB_NAME, TIRAMISU_DB_USER, TIRAMISU_DB_PASSWORD),
}, },
'http_server': {'port': 8080, 'http_server': {'port': RISOTTO_PORT,
'default_user': DEFAULT_USER}, 'default_user': DEFAULT_USER},
'global': {'message_root_path': MESSAGE_PATH, 'global': {'message_root_path': MESSAGE_PATH,
'configurations_dir': CONFIGURATION_DIR, 'configurations_dir': CONFIGURATION_DIR,
'debug': True, 'debug': True,
'internal_user': 'internal', 'internal_user': '_internal',
'check_role': True, 'check_role': True,
'admin_user': DEFAULT_USER, 'admin_user': DEFAULT_USER,
'sql_dir': SQL_DIR}, 'sql_dir': SQL_DIR,
'tmp_dir': TMP_DIR,
},
'cache': {'root_path': CACHE_ROOT_PATH}, 'cache': {'root_path': CACHE_ROOT_PATH},
'servermodel': {'internal_source_path': SRV_SEED_PATH, 'servermodel': {'internal_source_path': SRV_SEED_PATH,
'internal_source': 'internal'}, 'internal_source': 'internal'},
@ -44,3 +105,7 @@ def get_config():
'provider': {'factory_configuration_dir': PROVIDER_FACTORY_CONFIG_DIR, 'provider': {'factory_configuration_dir': PROVIDER_FACTORY_CONFIG_DIR,
'factory_configuration_filename': 'infra.json'}, 'factory_configuration_filename': 'infra.json'},
} }
def get_config():
return _config

View File

@ -1,8 +1,5 @@
from .config import get_config
from .dispatcher import dispatcher from .dispatcher import dispatcher
from .context import Context from .context import Context
from .remote import remote
from . import services
from .utils import _ from .utils import _
@ -10,50 +7,48 @@ class Controller:
"""Common controller used to add a service in Risotto """Common controller used to add a service in Risotto
""" """
def __init__(self, def __init__(self,
test: bool): test: bool,
self.risotto_modules = services.get_services_list() ):
pass
async def call(self, async def call(self,
uri: str, uri: str,
risotto_context: Context, risotto_context: Context,
*args, *args,
**kwargs): **kwargs,
):
""" a wrapper to dispatcher's call""" """ a wrapper to dispatcher's call"""
version, module, message = uri.split('.', 2)
uri = module + '.' + message
if args: if args:
raise ValueError(_(f'the URI "{uri}" can only be called with keyword arguments')) raise ValueError(_(f'the URI "{uri}" can only be called with keyword arguments'))
if module not in self.risotto_modules: current_uri = risotto_context.paths[-1]
return await remote.remote_call(module, current_module = risotto_context.module
version, version, message = uri.split('.', 1)
message, module = message.split('.', 1)[0]
kwargs) if current_module != module:
raise ValueError(_(f'cannot call to external module ("{module}") to the URI "{uri}" from "{current_module}"'))
return await dispatcher.call(version, return await dispatcher.call(version,
uri, message,
risotto_context, risotto_context,
**kwargs) **kwargs,
)
async def publish(self, async def publish(self,
uri: str, uri: str,
risotto_context: Context, risotto_context: Context,
*args, *args,
**kwargs): **kwargs,
):
""" a wrapper to dispatcher's publish""" """ a wrapper to dispatcher's publish"""
version, module, submessage = uri.split('.', 2)
version, message = uri.split('.', 1) version, message = uri.split('.', 1)
if args: if args:
raise ValueError(_(f'the URI "{uri}" can only be published with keyword arguments')) raise ValueError(_(f'the URI "{uri}" can only be published with keyword arguments'))
if module not in self.risotto_modules:
await remote.remote_call(module,
version,
submessage,
kwargs)
else:
await dispatcher.publish(version, await dispatcher.publish(version,
message, message,
risotto_context, risotto_context,
**kwargs) **kwargs,
)
async def on_join(self, async def on_join(self,
risotto_context): risotto_context,
):
pass pass

View File

@ -15,8 +15,7 @@ from .logger import log
from .config import get_config from .config import get_config
from .context import Context from .context import Context
from . import register from . import register
#from .remote import Remote from .remote import Remote
import asyncpg
class CallDispatcher: class CallDispatcher:
@ -83,7 +82,8 @@ class CallDispatcher:
risotto_context = self.build_new_context(old_risotto_context, risotto_context = self.build_new_context(old_risotto_context,
version, version,
message, message,
'rpc') 'rpc',
)
if version not in self.messages: if version not in self.messages:
raise CallError(_(f'cannot find version of message "{version}"')) raise CallError(_(f'cannot find version of message "{version}"'))
if message not in self.messages[version]: if message not in self.messages[version]:
@ -150,14 +150,20 @@ class PublishDispatcher:
risotto_context = self.build_new_context(old_risotto_context, risotto_context = self.build_new_context(old_risotto_context,
version, version,
message, message,
'event') 'event',
)
try: try:
function_objs = self.messages[version][message].get('functions', []) function_objs = self.messages[version][message].get('functions', [])
except KeyError: except KeyError:
raise ValueError(_(f'cannot find message {version}.{message}')) raise ValueError(_(f'cannot find message {version}.{message}'))
# do not start a new database connection # do not start a new database connection
if hasattr(old_risotto_context, 'connection'): if hasattr(old_risotto_context, 'connection'):
# publish to remove
remote_kw = dumps({'kwargs': kwargs,
'context': risotto_context.__dict__,
})
risotto_context.connection = old_risotto_context.connection risotto_context.connection = old_risotto_context.connection
await risotto_context.connection.execute(f'NOTIFY "{version}.{message}", \'{remote_kw}\'')
return await self.launch(version, return await self.launch(version,
message, message,
risotto_context, risotto_context,
@ -166,8 +172,8 @@ class PublishDispatcher:
function_objs, function_objs,
internal, internal,
) )
try:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
try:
await connection.set_type_codec( await connection.set_type_codec(
'json', 'json',
encoder=dumps, encoder=dumps,
@ -185,9 +191,9 @@ class PublishDispatcher:
internal, internal,
) )
except CallError as err: except CallError as err:
raise err pass
except Exception as err: except Exception as err:
# if there is a problem with arguments, just send an error and do nothing # if there is a problem with arguments, log and do nothing
if get_config()['global']['debug']: if get_config()['global']['debug']:
print_exc() print_exc()
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
@ -200,11 +206,10 @@ class PublishDispatcher:
risotto_context.connection = connection risotto_context.connection = connection
async with connection.transaction(): async with connection.transaction():
await log.error_msg(risotto_context, kwargs, err) await log.error_msg(risotto_context, kwargs, err)
raise err
class Dispatcher(register.RegisterDispatcher, class Dispatcher(register.RegisterDispatcher,
# Remote, Remote,
CallDispatcher, CallDispatcher,
PublishDispatcher): PublishDispatcher):
""" Manage message (call or publish) """ Manage message (call or publish)
@ -214,7 +219,8 @@ class Dispatcher(register.RegisterDispatcher,
old_risotto_context: Context, old_risotto_context: Context,
version: str, version: str,
message: str, message: str,
type: str): type: str,
) -> Context:
""" This is a new call or a new publish, so create a new context """ This is a new call or a new publish, so create a new context
""" """
uri = version + '.' + message uri = version + '.' + message
@ -230,7 +236,8 @@ class Dispatcher(register.RegisterDispatcher,
async def check_message_type(self, async def check_message_type(self,
risotto_context: Context, risotto_context: Context,
kwargs: Dict): kwargs: Dict,
) -> None:
if self.messages[risotto_context.version][risotto_context.message]['pattern'] != risotto_context.type: if self.messages[risotto_context.version][risotto_context.message]['pattern'] != risotto_context.type:
msg = _(f'{risotto_context.uri} is not a {risotto_context.type} message') msg = _(f'{risotto_context.uri} is not a {risotto_context.type} message')
await log.error_msg(risotto_context, kwargs, msg) await log.error_msg(risotto_context, kwargs, msg)
@ -352,9 +359,10 @@ class Dispatcher(register.RegisterDispatcher,
# config is ok, so send the message # config is ok, so send the message
for function_obj in function_objs: for function_obj in function_objs:
function = function_obj['function'] function = function_obj['function']
module_name = function.__module__.split('.')[-2] submodule_name = function_obj['module']
function_name = function.__name__ function_name = function.__name__
info_msg = _(f'in module {module_name}.{function_name}') risotto_context.module = submodule_name.split('.', 1)[0]
info_msg = _(f'in module {submodule_name}.{function_name}')
# build argument for this function # build argument for this function
if risotto_context.type == 'rpc': if risotto_context.type == 'rpc':
kw = config_arguments kw = config_arguments

View File

@ -51,8 +51,9 @@ class extra_route_handler:
function_name = cls.function.__module__ function_name = cls.function.__module__
# if not 'api' function # if not 'api' function
if function_name != 'risotto.http': if function_name != 'risotto.http':
module_name = function_name.split('.')[-2] risotto_module_name, submodule_name = function_name.split('.', 2)[:-1]
kwargs['self'] = dispatcher.injected_self[module_name] module_name = risotto_module_name.split('_')[-1]
kwargs['self'] = dispatcher.injected_self[module_name + '.' + submodule_name]
try: try:
returns = await cls.function(**kwargs) returns = await cls.function(**kwargs)
except NotAllowedError as err: except NotAllowedError as err:
@ -141,9 +142,9 @@ async def get_app(loop):
versions.append(version) versions.append(version)
print() print()
print(_('======== Registered messages ========')) print(_('======== Registered messages ========'))
for message in messages: for message, message_infos in messages.items():
web_message = f'/api/{version}/{message}' web_message = f'/api/{version}/{message}'
pattern = dispatcher.messages[version][message]['pattern'] pattern = message_infos['pattern']
print(f' - {web_message} ({pattern})') print(f' - {web_message} ({pattern})')
routes.append(post(web_message, handle)) routes.append(post(web_message, handle))
print() print()
@ -168,9 +169,10 @@ async def get_app(loop):
extra_handler = type(path, (extra_route_handler,), extra) extra_handler = type(path, (extra_route_handler,), extra)
routes.append(get(path, extra_handler)) routes.append(get(path, extra_handler))
print(f' - {path} (http_get)') print(f' - {path} (http_get)')
print()
del extra_routes del extra_routes
app.router.add_routes(routes) app.router.add_routes(routes)
await dispatcher.register_remote()
print()
await dispatcher.on_join() await dispatcher.on_join()
return await loop.create_server(app.make_handler(), '*', get_config()['http_server']['port']) return await loop.create_server(app.make_handler(), '*', get_config()['http_server']['port'])

View File

@ -4,8 +4,9 @@ except:
from tiramisu import Config from tiramisu import Config
from inspect import signature from inspect import signature
from typing import Callable, Optional, List from typing import Callable, Optional, List
import asyncpg from asyncpg import create_pool
from json import dumps, loads from json import dumps, loads
from pkg_resources import iter_entry_points
import risotto import risotto
from .utils import _ from .utils import _
from .error import RegistrationError from .error import RegistrationError
@ -13,7 +14,7 @@ from .message import get_messages
from .context import Context from .context import Context
from .config import get_config from .config import get_config
from .logger import log from .logger import log
from pkg_resources import iter_entry_points
class Services(): class Services():
services = {} services = {}
@ -45,7 +46,7 @@ class Services():
) -> List[str]: ) -> List[str]:
if not self.modules_loaded: if not self.modules_loaded:
self.load_modules(limit_services=limit_services) self.load_modules(limit_services=limit_services)
return [(m, getattr(self, m)) for s in self.services.values() for m in s] return [(module + '.' + submodule, getattr(self, submodule)) for module, submodules in self.services.items() for submodule in submodules]
def get_services_list(self): def get_services_list(self):
return self.services.keys() return self.services.keys()
@ -59,10 +60,11 @@ class Services():
test: bool=False, test: bool=False,
limit_services: Optional[List[str]]=None, limit_services: Optional[List[str]]=None,
): ):
for module_name, module in self.get_modules(limit_services=limit_services): for submodule_name, module in self.get_modules(limit_services=limit_services):
dispatcher.set_module(module_name, dispatcher.set_module(submodule_name,
module, module,
test) test,
)
if validate: if validate:
dispatcher.validate() dispatcher.validate()
@ -73,7 +75,8 @@ setattr(risotto, 'services', services)
def register(uris: str, def register(uris: str,
notification: str=None): notification: str=None,
) -> None:
""" Decorator to register function to the dispatcher """ Decorator to register function to the dispatcher
""" """
if not isinstance(uris, list): if not isinstance(uris, list):
@ -185,7 +188,8 @@ class RegisterDispatcher:
version: str, version: str,
message: str, message: str,
notification: str, notification: str,
function: Callable): function: Callable,
):
""" register a function to an URI """ register a function to an URI
URI is a message URI is a message
""" """
@ -194,14 +198,16 @@ class RegisterDispatcher:
if message not in self.messages[version]: if message not in self.messages[version]:
raise RegistrationError(_(f'the message {message} not exists')) raise RegistrationError(_(f'the message {message} not exists'))
# xxx module can only be register with v1.xxxx..... message # xxx submodule can only be register with v1.yyy.xxx..... message
module_name = function.__module__.split('.')[-2] risotto_module_name, submodule_name = function.__module__.split('.')[-3:-1]
message_namespace = message.split('.', 1)[0] module_name = risotto_module_name.split('_')[-1]
message_risotto_module, message_namespace, message_name = message.split('.', 2) message_module, message_submodule, message_name = message.split('.', 2)
if message_risotto_module not in self.risotto_modules: if message_module not in self.risotto_modules:
raise RegistrationError(_(f'cannot registered the "{message}" is not "{self.risotto_modules}"')) raise RegistrationError(_(f'cannot registered the "{message}" is not "{self.risotto_modules}"'))
if self.messages[version][message]['pattern'] == 'rpc' and message_namespace != module_name: if self.messages[version][message]['pattern'] == 'rpc' and \
raise RegistrationError(_(f'cannot registered the "{message}" message in module "{module_name}"')) module_name != message_module and \
message_submodule != submodule_name:
raise RegistrationError(_(f'cannot registered the "{message}" message in submodule "{module_name}.{submodule_name}"'))
# True if first argument is the risotto_context # True if first argument is the risotto_context
function_args = self.get_function_args(function) function_args = self.get_function_args(function)
@ -217,10 +223,11 @@ class RegisterDispatcher:
register = self.register_event register = self.register_event
register(version, register(version,
message, message,
module_name, f'{module_name}.{submodule_name}',
function, function,
function_args, function_args,
notification) notification,
)
def register_rpc(self, def register_rpc(self,
version: str, version: str,
@ -228,7 +235,8 @@ class RegisterDispatcher:
module_name: str, module_name: str,
function: Callable, function: Callable,
function_args: list, function_args: list,
notification: Optional[str]): notification: Optional[str],
):
self.messages[version][message]['module'] = module_name self.messages[version][message]['module'] = module_name
self.messages[version][message]['function'] = function self.messages[version][message]['function'] = function
self.messages[version][message]['arguments'] = function_args self.messages[version][message]['arguments'] = function_args
@ -241,7 +249,8 @@ class RegisterDispatcher:
module_name: str, module_name: str,
function: Callable, function: Callable,
function_args: list, function_args: list,
notification: Optional[str]): notification: Optional[str],
):
if 'functions' not in self.messages[version][message]: if 'functions' not in self.messages[version][message]:
self.messages[version][message]['functions'] = [] self.messages[version][message]['functions'] = []
@ -252,13 +261,17 @@ class RegisterDispatcher:
dico['notification'] = notification dico['notification'] = notification
self.messages[version][message]['functions'].append(dico) self.messages[version][message]['functions'].append(dico)
def set_module(self, module_name, module, test): def set_module(self,
submodule_name,
module,
test,
):
""" register and instanciate a new module """ register and instanciate a new module
""" """
try: try:
self.injected_self[module_name] = module.Risotto(test) self.injected_self[submodule_name] = module.Risotto(test)
except AttributeError as err: except AttributeError as err:
raise RegistrationError(_(f'unable to register the module {module_name}, this module must have Risotto class')) raise RegistrationError(_(f'unable to register the module {submodule_name}, this module must have Risotto class'))
def validate(self): def validate(self):
""" check if all messages have a function """ check if all messages have a function
@ -287,15 +300,16 @@ class RegisterDispatcher:
) )
if truncate: if truncate:
async with connection.transaction(): async with connection.transaction():
await connection.execute('TRUNCATE applicationservicedependency, deployment, factoryclusternode, factorycluster, log, release, userrole, risottouser, roleuri, infraserver, settingserver, servermodel, site, source, uri, userrole, zone, applicationservice') await connection.execute('TRUNCATE InfraServer, InfraSite, InfraZone, Log, ProviderDeployment, ProviderFactoryCluster, ProviderFactoryClusterNode, SettingApplicationservice, SettingApplicationServiceDependency, SettingRelease, SettingServer, SettingServermodel, SettingSource, UserRole, UserRoleURI, UserURI, UserUser, InfraServermodel, ProviderZone, ProviderServer, ProviderServermodel')
async with connection.transaction(): async with connection.transaction():
for module_name, module in self.injected_self.items(): for submodule_name, module in self.injected_self.items():
risotto_context = Context() risotto_context = Context()
risotto_context.username = internal_user risotto_context.username = internal_user
risotto_context.paths.append(f'{module_name}.on_join') risotto_context.paths.append(f'internal.{submodule_name}.on_join')
risotto_context.type = None risotto_context.type = None
risotto_context.connection = connection risotto_context.connection = connection
info_msg = _(f'in module {module_name}.on_join') risotto_context.module = submodule_name.split('.', 1)[0]
info_msg = _(f'in module risotto_{submodule_name}.on_join')
await log.info_msg(risotto_context, await log.info_msg(risotto_context,
None, None,
info_msg) info_msg)
@ -304,7 +318,7 @@ class RegisterDispatcher:
async def load(self): async def load(self):
# valid function's arguments # valid function's arguments
db_conf = get_config()['database']['dsn'] db_conf = get_config()['database']['dsn']
self.pool = await asyncpg.create_pool(db_conf) self.pool = await create_pool(db_conf)
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
async with connection.transaction(): async with connection.transaction():
for version, messages in self.messages.items(): for version, messages in self.messages.items():

View File

@ -1,61 +1,42 @@
from aiohttp import ClientSession from asyncio import get_event_loop, ensure_future
from requests import get, post from json import loads
from json import dumps
#from tiramisu_api import Config
from .context import Context
from .config import get_config from .config import get_config
from .utils import _ from .utils import _
#
#
# ALLOW_INSECURE_HTTPS = get_config()['module']['allow_insecure_https']
class Remote: class Remote:
submodules = {} async def register_remote(self) -> None:
print()
print(_('======== Registered remote event ========'))
self.listened_connection = await self.pool.acquire()
for version, messages in self.messages.items():
for message, message_infos in messages.items():
# event not emit locally
if message_infos['pattern'] == 'event':
module, submodule, submessage = message.split('.', 2)
if f'{module}.{submodule}' not in self.injected_self:
uri = f'{version}.{message}'
print(f' - {uri}')
await self.listened_connection.add_listener(uri, self.to_async_publish)
async def _get_config(self, def to_async_publish(self,
module: str, con: 'asyncpg.connection.Connection',
url: str) -> None: pid: int,
if module not in self.submodules: uri: str,
session = ClientSession() payload: str,
async with session.get(url) as resp: ) -> None:
if resp.status != 200: version, message = uri.split('.', 1)
try: loop = get_event_loop()
json = await resp.json() remote_kw = loads(payload)
err = json['error']['kwargs']['reason'] context = Context()
except: for key, value in remote_kw['context'].items():
err = await resp.text() setattr(context, key, value)
raise Exception(err) callback = lambda: ensure_future(self.publish(version,
json = await resp.json() message,
self.submodules[module] = json context,
return Config(self.submodules[module]) **remote_kw['kwargs'],
))
async def remote_call(self, loop.call_soon(callback)
module: str,
version: str,
submessage: str,
payload) -> dict:
try:
domain_name = get_config()['module'][module]
except KeyError:
raise ValueError(_(f'cannot find information of remote module "{module}" to access to "{version}.{module}.{submessage}"'))
remote_url = f'http://{domain_name}:8080/api/{version}'
message_url = f'{remote_url}/{submessage}'
config = await self._get_config(module,
remote_url)
for key, value in payload.items():
path = submessage + '.' + key
config.option(path).value.set(value)
session = ClientSession()
async with session.post(message_url, data=dumps(payload)) as resp:
response = await resp.json()
if 'error' in response:
if 'reason' in response['error']['kwargs']:
raise Exception("{}".format(response['error']['kwargs']['reason']))
raise Exception('erreur inconnue')
return response['response']
remote = Remote()