Compare commits

...

6 Commits

3 changed files with 37 additions and 4 deletions

View File

@ -48,6 +48,31 @@ class Controller:
**kwargs, **kwargs,
) )
@staticmethod
async def check_role(self,
uri: str,
username: str,
**kwargs: dict,
) -> None:
# create a new config
async with await Config(dispatcher.option) as config:
await config.property.read_write()
await config.option('message').value.set(uri)
subconfig = config.option(uri)
for key, value in kwargs.items():
try:
await subconfig.option(key).value.set(value)
except AttributeError:
if get_config()['global']['debug']:
print_exc()
raise ValueError(_(f'unknown parameter in "{uri}": "{key}"'))
except ValueOptionError as err:
raise ValueError(_(f'invalid parameter in "{uri}": {err}'))
await dispatcher.check_role(subconfig,
username,
uri,
)
async def on_join(self, async def on_join(self,
risotto_context, risotto_context,
): ):

View File

@ -333,7 +333,7 @@ class Dispatcher(register.RegisterDispatcher,
parameters = await subconfig.value.dict() parameters = await subconfig.value.dict()
if extra_parameters: if extra_parameters:
parameters.update(extra_parameters) parameters.update(extra_parameters)
return parameters return parameters
def get_service(self, def get_service(self,
name: str): name: str):
@ -342,7 +342,8 @@ class Dispatcher(register.RegisterDispatcher,
async def check_role(self, async def check_role(self,
config: Config, config: Config,
user_login: str, user_login: str,
uri: str) -> None: uri: str,
) -> None:
async with self.pool.acquire() as connection: async with self.pool.acquire() as connection:
async with connection.transaction(): async with connection.transaction():
# Verify if user exists and get ID # Verify if user exists and get ID

View File

@ -7,6 +7,7 @@ from typing import Callable, Optional, List
from asyncpg import create_pool from asyncpg import create_pool
from json import dumps, loads from json import dumps, loads
from pkg_resources import iter_entry_points from pkg_resources import iter_entry_points
from traceback import print_exc
import risotto import risotto
from .utils import _ from .utils import _
from .error import RegistrationError from .error import RegistrationError
@ -277,7 +278,7 @@ class RegisterDispatcher:
try: try:
self.injected_self[submodule_name] = module.Risotto(test) self.injected_self[submodule_name] = module.Risotto(test)
except AttributeError as err: except AttributeError as err:
raise RegistrationError(_(f'unable to register the module {submodule_name}, this module must have Risotto class')) print(_(f'unable to register the module {submodule_name}, this module must have Risotto class'))
def validate(self): def validate(self):
""" check if all messages have a function """ check if all messages have a function
@ -319,7 +320,13 @@ class RegisterDispatcher:
await log.info_msg(risotto_context, await log.info_msg(risotto_context,
None, None,
info_msg) info_msg)
await module.on_join(risotto_context) try:
await module.on_join(risotto_context)
except Exception as err:
if get_config()['global']['debug']:
print_exc()
msg = _(f'on_join returns an error in module {submodule_name}: {err}')
await log.error_msg(risotto_context, {}, msg)
async def load(self): async def load(self):
# valid function's arguments # valid function's arguments