diff --git a/script/database_manager.py b/script/database_manager.py index 8fa8db9..944d96f 100644 --- a/script/database_manager.py +++ b/script/database_manager.py @@ -52,6 +52,36 @@ CREATE TABLE Server ( ServerServermodelId INTEGER NOT NULL ); +-- User, Role and ACL table creation + +CREATE TABLE RisottoUser ( + UserId SERIAL PRIMARY KEY, + UserLogin VARCHAR(100) NOT NULL UNIQUE, + UserName VARCHAR(100) NOT NULL, + UserSurname VARCHAR(100) NOT NULL +); + +CREATE TABLE UserRole ( + RoleId SERIAL PRIMARY KEY, + RoleUserId INTEGER NOT NULL, + RoleName VARCHAR(255) NOT NULL, + RoleAttribute VARCHAR(255), + RoleAttributeValue VARCHAR(255), + FOREIGN KEY (RoleUserId) REFERENCES RisottoUser(UserId) +); + +CREATE TABLE URI ( + URIId SERIAL PRIMARY KEY, + URIName VARCHAR(255) NOT NULL UNIQUE +); + +CREATE TABLE RoleURI ( + RoleName VARCHAR(255) NOT NULL, + URIId INTEGER NOT NULL, + FOREIGN KEY (URIId) REFERENCES URI(URIId), + PRIMARY KEY (RoleName, URIId) +); + """ async def main(): diff --git a/src/risotto/config.py b/src/risotto/config.py index 65ec8ff..75108c8 100644 --- a/src/risotto/config.py +++ b/src/risotto/config.py @@ -25,6 +25,7 @@ def get_config(): 'global': {'message_root_path': CURRENT_PATH.parents[2] / 'messages', 'debug': DEBUG, 'internal_user': 'internal', + 'check_role': False, 'rougail_dtd_path': '../rougail/data/creole.dtd'}, 'source': {'root_path': '/srv/seed'}, 'cache': {'root_path': '/var/cache/risotto'} diff --git a/src/risotto/dispatcher.py b/src/risotto/dispatcher.py index 3414015..532a156 100644 --- a/src/risotto/dispatcher.py +++ b/src/risotto/dispatcher.py @@ -15,15 +15,6 @@ import asyncpg class CallDispatcher: - def valid_public_function(self, - risotto_context: Context, - kwargs: Dict, - public_only: bool): - if public_only and not self.messages[risotto_context.version][risotto_context.message]['public']: - msg = _(f'the message {risotto_context.message} is private') - log.error_msg(risotto_context, kwargs, msg) - raise NotAllowedError(msg) - async def valid_call_returns(self, risotto_context: Context, returns: Dict, @@ -77,7 +68,7 @@ class CallDispatcher: version: str, message: str, old_risotto_context: Context, - public_only: bool=False, + check_role: bool=False, **kwargs): """ execute the function associate with specified uri arguments are validate before @@ -86,16 +77,14 @@ class CallDispatcher: version, message, 'rpc') - self.valid_public_function(risotto_context, - kwargs, - public_only) self.check_message_type(risotto_context, kwargs) try: - tiramisu_config = await self.load_kwargs_to_config(risotto_context, - kwargs) + kw = await self.load_kwargs_to_config(risotto_context, + f'{version}.{message}', + kwargs, + check_role) function_obj = self.messages[version][message] - kw = await tiramisu_config.option(message).value.dict() risotto_context.function = function_obj['function'] if function_obj['risotto_context']: kw['risotto_context'] = risotto_context @@ -150,7 +139,12 @@ class CallDispatcher: class PublishDispatcher: - async def publish(self, version, message, old_risotto_context, public_only=False, **kwargs): + async def publish(self, + version: str, + message: str, + old_risotto_context: Context, + check_role: bool=False, + **kwargs) -> None: risotto_context = self.build_new_context(old_risotto_context, version, message, @@ -158,9 +152,8 @@ class PublishDispatcher: self.check_message_type(risotto_context, kwargs) try: - config = await self.load_kwargs_to_config(risotto_context, - kwargs) - config_arguments = await config.option(message).value.dict() + config_arguments = await self.load_kwargs_to_config(risotto_context, + kwargs) except CallError as err: return except Exception as err: @@ -250,7 +243,9 @@ class Dispatcher(register.RegisterDispatcher, CallDispatcher, PublishDispatcher) async def load_kwargs_to_config(self, risotto_context: Context, - kwargs: Dict): + uri: str, + kwargs: Dict, + check_role: bool): """ create a new Config et set values to it """ # create a new config @@ -266,20 +261,73 @@ class Dispatcher(register.RegisterDispatcher, CallDispatcher, PublishDispatcher) except AttributeError: if DEBUG: print_exc() - raise ValueError(_(f'unknown parameter "{key}"')) + raise ValueError(_(f'unknown parameter in "{uri}": "{key}"')) # check mandatories options + if check_role and get_config().get('global').get('check_role'): + await self.check_role(subconfig, + risotto_context.username, + uri) await config.property.read_only() mandatories = await config.value.mandatory() if mandatories: mand = [mand.split('.')[-1] for mand in mandatories] - raise ValueError(_(f'missing parameters: {mand}')) - # return the config - return config + raise ValueError(_(f'missing parameters in "{uri}": {mand}')) + # return complete an validated kwargs + return await subconfig.value.dict() def get_service(self, name: str): return self.injected_self[name] + async def check_role(self, + config: Config, + user_login: str, + uri: str) -> None: + db_conf = get_config().get('database') + pool = await asyncpg.create_pool(database=db_conf.get('dbname'), user=db_conf.get('user')) + async with pool.acquire() as connection: + async with connection.transaction(): + # Verify if user exists and get ID + sql = ''' + SELECT UserId + FROM RisottoUser + WHERE UserLogin = $1 + ''' + user_id = await connection.fetchval(sql, + user_login) + if user_id is None: + raise NotAllowedError(_(f"You ({user_login}) don't have any account in this application")) + + # Get all references for this message + refs = {} + for option in await config.list('all'): + ref = await option.information.get('ref') + if ref: + refs[ref] = str(await option.value.get()) + + # Check role + select_role_uri = ''' + SELECT RoleName + FROM URI, RoleURI + WHERE URI.URIName = $1 AND RoleURI.URIId = URI.URIId + ''' + select_role_user = ''' + SELECT RoleAttribute, RoleAttributeValue + FROM UserRole + WHERE RoleUserId = $1 AND RoleName = $2 + ''' + for uri_role in await connection.fetch(select_role_uri, + uri): + for user_role in await connection.fetch(select_role_user, + user_id, + uri_role['rolename']): + if not user_role['roleattribute']: + return + if user_role['roleattribute'] in refs and \ + user_role['roleattributevalue'] == refs[user_role['roleattribute']]: + return + raise NotAllowedError(_(f'You ({user_login}) don\'t have any authorisation to access to "{uri}"')) + dispatcher = Dispatcher() register.dispatcher = dispatcher diff --git a/src/risotto/http.py b/src/risotto/http.py index c7a5b54..f94726e 100644 --- a/src/risotto/http.py +++ b/src/risotto/http.py @@ -73,7 +73,7 @@ async def handle(request): text = await method(version, message, risotto_context, - public_only=True, + check_role=True, **kwargs) except NotAllowedError as err: raise HTTPNotFound(reason=str(err)) diff --git a/src/risotto/message/message.py b/src/risotto/message/message.py index 6e523bb..4871179 100644 --- a/src/risotto/message/message.py +++ b/src/risotto/message/message.py @@ -440,16 +440,19 @@ def _get_option(name, kwargs['multi'] = True type_ = type_[2:] if type_ == 'Dict': - return DictOption(**kwargs) + obj = DictOption(**kwargs) elif type_ == 'String': - return StrOption(**kwargs) + obj = StrOption(**kwargs) elif type_ == 'Any': - return AnyOption(**kwargs) + obj = AnyOption(**kwargs) elif 'Number' in type_ or type_ == 'ID' or type_ == 'Integer': - return IntOption(**kwargs) + obj = IntOption(**kwargs) elif type_ == 'Boolean': - return BoolOption(**kwargs) - raise Exception('unsupported type {} in {}'.format(type_, file_path)) + obj = BoolOption(**kwargs) + else: + raise Exception('unsupported type {} in {}'.format(type_, file_path)) + obj.impl_set_information('ref', arg.ref) + return obj def _parse_args(message_def, @@ -463,11 +466,15 @@ def _parse_args(message_def, """ new_options = OrderedDict() for name, arg in message_def.parameters.items(): - new_options[name] = arg - if arg.ref: - needs.setdefault(message_def.uri, {}).setdefault(arg.ref, []).append(name) - for name, arg in new_options.items(): - current_opt = _get_option(name, arg, file_path, select_option, optiondescription) + #new_options[name] = arg + # if arg.ref: + # needs.setdefault(message_def.uri, {}).setdefault(arg.ref, []).append(name) + #for name, arg in new_options.items(): + current_opt = _get_option(name, + arg, + file_path, + select_option, + optiondescription) options.append(current_opt) if hasattr(arg, 'shortarg') and arg.shortarg and load_shortarg: options.append(SymLinkOption(arg.shortarg, current_opt)) diff --git a/src/risotto/register.py b/src/risotto/register.py index 4290691..9089d8d 100644 --- a/src/risotto/register.py +++ b/src/risotto/register.py @@ -1,12 +1,13 @@ from tiramisu import Config from inspect import signature from typing import Callable, Optional +import asyncpg from .utils import undefined, _ from .error import RegistrationError from .message import get_messages from .context import Context -from .config import INTERNAL_USER +from .config import INTERNAL_USER, get_config def register(uris: str, @@ -248,23 +249,38 @@ class RegisterDispatcher: risotto_context.type = None await module.on_join(risotto_context) + async def insert_message(self, + connection, + uri): + sql = """INSERT INTO URI(URIName) VALUES ($1) + ON CONFLICT (URIName) DO NOTHING + """ + await connection.fetchval(sql, + uri) + async def load(self): # valid function's arguments - for version, messages in self.messages.items(): - for message, message_infos in messages.items(): - if message_infos['pattern'] == 'rpc': - module_name = message_infos['module'] - function = message_infos['function'] - await self.valid_rpc_params(version, - message, - function, - module_name) - else: - if 'functions' in message_infos: - for function_infos in message_infos['functions']: - module_name = function_infos['module'] - function = function_infos['function'] - await self.valid_event_params(version, - message, - function, - module_name) + db_conf = get_config().get('database') + pool = await asyncpg.create_pool(database=db_conf.get('dbname'), user=db_conf.get('user')) + async with pool.acquire() as connection: + async with connection.transaction(): + for version, messages in self.messages.items(): + for message, message_infos in messages.items(): + if message_infos['pattern'] == 'rpc': + module_name = message_infos['module'] + function = message_infos['function'] + await self.valid_rpc_params(version, + message, + function, + module_name) + else: + if 'functions' in message_infos: + for function_infos in message_infos['functions']: + module_name = function_infos['module'] + function = function_infos['function'] + await self.valid_event_params(version, + message, + function, + module_name) + await self.insert_message(connection, + f'{version}.{message}') diff --git a/src/risotto/services/applicationservice/applicationservice.py b/src/risotto/services/applicationservice/applicationservice.py index 9d6ad1e..664a08e 100644 --- a/src/risotto/services/applicationservice/applicationservice.py +++ b/src/risotto/services/applicationservice/applicationservice.py @@ -12,7 +12,8 @@ from ...context import Context from ...utils import _ class Risotto(Controller): - def __init__(self): + def __init__(self, + test: bool) -> None: self.source_root_path = get_config().get('source').get('root_path') async def _applicationservice_create(self, diff --git a/src/risotto/services/servermodel/servermodel.py b/src/risotto/services/servermodel/servermodel.py index 844d0c7..4b5b66f 100644 --- a/src/risotto/services/servermodel/servermodel.py +++ b/src/risotto/services/servermodel/servermodel.py @@ -16,7 +16,8 @@ from ...logger import log class Risotto(Controller): - def __init__(self): + def __init__(self, + test: bool) -> None: self.source_root_path = get_config().get('source').get('root_path') self.cache_root_path = join(get_config().get('cache').get('root_path'), 'servermodel') if not isdir(self.cache_root_path): diff --git a/src/risotto/services/template/template.py b/src/risotto/services/template/template.py index 3c61c26..63efaaf 100644 --- a/src/risotto/services/template/template.py +++ b/src/risotto/services/template/template.py @@ -12,7 +12,8 @@ from ...utils import _ class Risotto(Controller): - def __init__(self): + def __init__(self, + test: bool) -> None: self.storage = Storage(engine='dictionary') self.cache_root_path = join(get_config().get('cache').get('root_path'), 'servermodel')