Compare commits

..

No commits in common. "c63170be1d047325b32a36bde3fbd8ad48832652" and "2a985757900209181fd5a4de3934c7986c535ce3" have entirely different histories.

4 changed files with 176 additions and 191 deletions

View File

@ -3,11 +3,3 @@ class Context:
self.paths = [] self.paths = []
self.context_id = None self.context_id = None
self.start_id = None self.start_id = None
def copy(self):
context = Context()
for key, value in self.__dict__.items():
if key.startswith('__'):
continue
setattr(context, key, value)
return context

View File

@ -89,6 +89,7 @@ class CallDispatcher:
if hasattr(old_risotto_context, 'connection'): if hasattr(old_risotto_context, 'connection'):
# do not start a new database connection # do not start a new database connection
risotto_context.connection = old_risotto_context.connection risotto_context.connection = old_risotto_context.connection
risotto_context.log_connection = old_risotto_context.log_connection
await log.start(risotto_context, await log.start(risotto_context,
kwargs, kwargs,
info_msg, info_msg,
@ -118,58 +119,66 @@ class CallDispatcher:
raise CallError(err) from err raise CallError(err) from err
else: else:
error = None error = None
try: async with self.pool.acquire() as log_connection:
async with self.pool.acquire() as connection: await log_connection.set_type_codec(
await connection.set_type_codec( 'json',
'json', encoder=dumps,
encoder=dumps, decoder=loads,
decoder=loads, schema='pg_catalog'
schema='pg_catalog' )
) risotto_context.log_connection = log_connection
risotto_context.connection = connection try:
async with connection.transaction(): async with self.pool.acquire() as connection:
try: await connection.set_type_codec(
await log.start(risotto_context, 'json',
kwargs, encoder=dumps,
info_msg, decoder=loads,
) schema='pg_catalog'
await self.check_message_type(risotto_context, )
kwargs, risotto_context.connection = connection
) async with connection.transaction():
config_arguments = await self.load_kwargs_to_config(risotto_context, try:
f'{version}.{message}', await log.start(risotto_context,
kwargs, kwargs,
check_role, info_msg,
internal, )
) await self.check_message_type(risotto_context,
ret = await self.launch(risotto_context, kwargs,
kwargs, )
config_arguments, config_arguments = await self.load_kwargs_to_config(risotto_context,
function_obj, f'{version}.{message}',
) kwargs,
# log the success check_role,
await log.success(risotto_context, internal,
ret, )
) ret = await self.launch(risotto_context,
if not internal and isinstance(ret, dict): kwargs,
ret['context_id'] = risotto_context.context_id config_arguments,
except CallError as err: function_obj,
if get_config()['global']['debug']: )
print_exc() # log the success
await log.failed(risotto_context, await log.success(risotto_context,
str(err), ret,
) )
raise err from err if not internal and isinstance(ret, dict):
except CallError as err: ret['context_id'] = risotto_context.context_id
error = err except CallError as err:
except Exception as err: if get_config()['global']['debug']:
# if there is a problem with arguments, just send an error and do nothing print_exc()
if get_config()['global']['debug']: await log.failed(risotto_context,
print_exc() str(err),
await log.failed(risotto_context, )
str(err), raise err from err
) except CallError as err:
error = err error = err
except Exception as err:
# if there is a problem with arguments, just send an error and do nothing
if get_config()['global']['debug']:
print_exc()
await log.failed(risotto_context,
str(err),
)
error = err
if error: if error:
if not internal: if not internal:
err = CallError(str(error)) err = CallError(str(error))
@ -189,6 +198,8 @@ class PublishDispatcher:
for message, message_infos in messages.items(): for message, message_infos in messages.items():
# event not emit locally # event not emit locally
if message_infos['pattern'] == 'event' and 'functions' in message_infos and message_infos['functions']: if message_infos['pattern'] == 'event' and 'functions' in message_infos and message_infos['functions']:
# module, submodule, submessage = message.split('.', 2)
# if f'{module}.{submodule}' not in self.injected_self:
uri = f'{version}.{message}' uri = f'{version}.{message}'
print(f' - {uri}') print(f' - {uri}')
await self.listened_connection.add_listener(uri, await self.listened_connection.add_listener(uri,
@ -224,34 +235,21 @@ class PublishDispatcher:
version, message = uri.split('.', 1) version, message = uri.split('.', 1)
loop = get_event_loop() loop = get_event_loop()
remote_kw = loads(payload) remote_kw = loads(payload)
for function_obj in self.messages[version][message]['functions']: risotto_context = self.build_new_context(remote_kw['context'],
risotto_context = self.build_new_context(remote_kw['context'], version,
version, message,
message, 'event',
'event', )
) callback = lambda: ensure_future(self._publish(version,
callback = self.get_callback(version, message, function_obj, risotto_context, remote_kw['kwargs'],) message,
loop.call_soon(callback) risotto_context,
**remote_kw['kwargs'],
def get_callback(self, ))
version, loop.call_soon(callback)
message,
function_obj,
risotto_context,
kwargs,
):
return lambda: ensure_future(self._publish(version,
message,
function_obj,
risotto_context,
**kwargs,
))
async def _publish(self, async def _publish(self,
version: str, version: str,
message: str, message: str,
function_obj,
risotto_context: Context, risotto_context: Context,
**kwargs, **kwargs,
) -> None: ) -> None:
@ -261,48 +259,66 @@ class PublishDispatcher:
False, False,
False, False,
) )
async with self.pool.acquire() as connection: for function_obj in self.messages[version][message]['functions']:
await connection.set_type_codec( async with self.pool.acquire() as log_connection:
'json', await log_connection.set_type_codec(
encoder=dumps, 'json',
decoder=loads, encoder=dumps,
schema='pg_catalog' decoder=loads,
) schema='pg_catalog'
risotto_context.connection = connection )
function_name = function_obj['function'].__name__ risotto_context.log_connection = log_connection
info_msg = _(f"call function {function_obj['full_module_name']}.{function_name}") async with self.pool.acquire() as connection:
try: await connection.set_type_codec(
async with connection.transaction(): 'json',
encoder=dumps,
decoder=loads,
schema='pg_catalog'
)
risotto_context.connection = connection
function_name = function_obj['function'].__name__
info_msg = _(f"call function {function_obj['full_module_name']}.{function_name}")
try: try:
await log.start(risotto_context, async with connection.transaction():
kwargs, try:
info_msg, await log.start(risotto_context,
) kwargs,
await self.check_message_type(risotto_context, info_msg,
kwargs, )
) await self.check_message_type(risotto_context,
await self.launch(risotto_context, kwargs,
kwargs, )
config_arguments, await self.launch(risotto_context,
function_obj, kwargs,
) config_arguments,
# log the success function_obj,
await log.success(risotto_context) )
except CallError as err: # log the success
await log.success(risotto_context)
except CallError as err:
if get_config()['global']['debug']:
print_exc()
await log.failed(risotto_context,
str(err),
)
except CallError:
pass
except Exception as err:
# if there is a problem with arguments, log and do nothing
if get_config()['global']['debug']: if get_config()['global']['debug']:
print_exc() print_exc()
await log.failed(risotto_context, async with self.pool.acquire() as connection:
str(err), await connection.set_type_codec(
) 'json',
except CallError: encoder=dumps,
pass decoder=loads,
except Exception as err: schema='pg_catalog'
# if there is a problem with arguments, log and do nothing )
if get_config()['global']['debug']: risotto_context.connection = connection
print_exc() async with connection.transaction():
await log.failed(risotto_context, await log.failed(risotto_context,
str(err), str(err),
) )
class Dispatcher(register.RegisterDispatcher, class Dispatcher(register.RegisterDispatcher,
@ -330,7 +346,6 @@ class Dispatcher(register.RegisterDispatcher,
risotto_context.type = type risotto_context.type = type
risotto_context.message = message risotto_context.message = message
risotto_context.version = version risotto_context.version = version
risotto_context.pool = self.pool
return risotto_context return risotto_context
async def check_message_type(self, async def check_message_type(self,

View File

@ -2,34 +2,15 @@ from typing import Dict, Any, Optional
from json import dumps, loads from json import dumps, loads
from asyncpg.exceptions import UndefinedTableError from asyncpg.exceptions import UndefinedTableError
from datetime import datetime from datetime import datetime
from asyncio import Lock
from .context import Context from .context import Context
from .utils import _ from .utils import _
from .config import get_config from .config import get_config
database_lock = Lock()
class Logger: class Logger:
""" An object to manager log """ An object to manager log
""" """
def __init__(self) -> None:
self.log_connection = None
async def get_connection(self,
risotto_context: Context,
):
if not self.log_connection:
self.log_connection = await risotto_context.pool.acquire()
await self.log_connection.set_type_codec(
'json',
encoder=dumps,
decoder=loads,
schema='pg_catalog'
)
return self.log_connection
async def insert(self, async def insert(self,
msg: str, msg: str,
risotto_context: Context, risotto_context: Context,
@ -57,9 +38,8 @@ class Logger:
sql = insert + ') ' + values + ') RETURNING LogId' sql = insert + ') ' + values + ') RETURNING LogId'
try: try:
async with database_lock: async with risotto_context.log_connection.transaction():
connection = await self.get_connection(risotto_context) log_id = await risotto_context.log_connection.fetchval(sql, *args)
log_id = await connection.fetchval(sql, *args)
if context_id is None and start: if context_id is None and start:
risotto_context.context_id = log_id risotto_context.context_id = log_id
if start: if start:
@ -81,9 +61,8 @@ class Logger:
sql += ' AND URI = $3' sql += ' AND URI = $3'
args.append(uri) args.append(uri)
ret = [] ret = []
async with database_lock: async with risotto_context.log_connection.transaction():
connection = await self.get_connection(risotto_context) for row in await risotto_context.log_connection.fetch(*args):
for row in await connection.fetch(*args):
d = {} d = {}
for key, value in row.items(): for key, value in row.items():
if key in ['kwargs', 'returns']: if key in ['kwargs', 'returns']:
@ -194,12 +173,11 @@ class Logger:
args.append(dumps(returns)) args.append(dumps(returns))
sql += """WHERE LogId = $1 sql += """WHERE LogId = $1
""" """
async with database_lock: async with risotto_context.log_connection.transaction():
connection = await self.get_connection(risotto_context) await risotto_context.log_connection.execute(sql,
await connection.execute(sql, risotto_context.start_id,
risotto_context.start_id, *args,
*args, )
)
async def failed(self, async def failed(self,
risotto_context: Context, risotto_context: Context,
@ -218,13 +196,12 @@ class Logger:
Msg = $3 Msg = $3
WHERE LogId = $1 WHERE LogId = $1
""" """
async with database_lock: async with risotto_context.log_connection.transaction():
connection = await self.get_connection(risotto_context) await risotto_context.log_connection.execute(sql,
await connection.execute(sql, risotto_context.start_id,
risotto_context.start_id, datetime.now(),
datetime.now(), err,
err, )
)
async def info(self, async def info(self,
risotto_context, risotto_context,

View File

@ -297,36 +297,37 @@ class RegisterDispatcher:
truncate: bool=False, truncate: bool=False,
) -> None: ) -> None:
internal_user = get_config()['global']['internal_user'] internal_user = get_config()['global']['internal_user']
async with self.pool.acquire() as connection: async with self.pool.acquire() as log_connection:
await connection.set_type_codec( async with self.pool.acquire() as connection:
'json', await connection.set_type_codec(
encoder=dumps, 'json',
decoder=loads, encoder=dumps,
schema='pg_catalog' decoder=loads,
) schema='pg_catalog'
if truncate: )
if truncate:
async with connection.transaction():
await connection.execute('TRUNCATE InfraServer, InfraSite, InfraZone, Log, ProviderDeployment, ProviderFactoryCluster, ProviderFactoryClusterNode, SettingApplicationservice, SettingApplicationServiceDependency, SettingRelease, SettingServer, SettingServermodel, SettingSource, UserRole, UserRoleURI, UserURI, UserUser, InfraServermodel, ProviderZone, ProviderServer, ProviderSource, ProviderApplicationservice, ProviderServermodel')
async with connection.transaction(): async with connection.transaction():
await connection.execute('TRUNCATE InfraServer, InfraSite, InfraZone, Log, ProviderDeployment, ProviderFactoryCluster, ProviderFactoryClusterNode, SettingApplicationservice, SettingApplicationServiceDependency, SettingRelease, SettingServer, SettingServermodel, SettingSource, UserRole, UserRoleURI, UserURI, UserUser, InfraServermodel, ProviderZone, ProviderServer, ProviderSource, ProviderApplicationservice, ProviderServermodel') for submodule_name, module in self.injected_self.items():
async with connection.transaction(): risotto_context = Context()
for submodule_name, module in self.injected_self.items(): risotto_context.username = internal_user
risotto_context = Context() risotto_context.paths.append(f'internal.{submodule_name}.on_join')
risotto_context.username = internal_user risotto_context.type = None
risotto_context.paths.append(f'internal.{submodule_name}.on_join') risotto_context.log_connection = log_connection
risotto_context.type = None risotto_context.connection = connection
risotto_context.pool = self.pool risotto_context.module = submodule_name.split('.', 1)[0]
risotto_context.connection = connection info_msg = _(f'in function risotto_{submodule_name}.on_join')
risotto_context.module = submodule_name.split('.', 1)[0] await log.info_msg(risotto_context,
info_msg = _(f'in function risotto_{submodule_name}.on_join') None,
await log.info_msg(risotto_context, info_msg)
None, try:
info_msg) await module.on_join(risotto_context)
try: except Exception as err:
await module.on_join(risotto_context) if get_config()['global']['debug']:
except Exception as err: print_exc()
if get_config()['global']['debug']: msg = _(f'on_join returns an error in module {submodule_name}: {err}')
print_exc() await log.error_msg(risotto_context, {}, msg)
msg = _(f'on_join returns an error in module {submodule_name}: {err}')
await log.error_msg(risotto_context, {}, msg)
async def load(self): async def load(self):
# valid function's arguments # valid function's arguments