role support

This commit is contained in:
Emmanuel Garette 2019-12-27 15:09:38 +01:00
parent 50aa8019ab
commit 94168554f2
9 changed files with 164 additions and 59 deletions

View File

@ -52,6 +52,36 @@ CREATE TABLE Server (
ServerServermodelId INTEGER NOT NULL ServerServermodelId INTEGER NOT NULL
); );
-- User, Role and ACL table creation
CREATE TABLE RisottoUser (
UserId SERIAL PRIMARY KEY,
UserLogin VARCHAR(100) NOT NULL UNIQUE,
UserName VARCHAR(100) NOT NULL,
UserSurname VARCHAR(100) NOT NULL
);
CREATE TABLE UserRole (
RoleId SERIAL PRIMARY KEY,
RoleUserId INTEGER NOT NULL,
RoleName VARCHAR(255) NOT NULL,
RoleAttribute VARCHAR(255),
RoleAttributeValue VARCHAR(255),
FOREIGN KEY (RoleUserId) REFERENCES RisottoUser(UserId)
);
CREATE TABLE URI (
URIId SERIAL PRIMARY KEY,
URIName VARCHAR(255) NOT NULL UNIQUE
);
CREATE TABLE RoleURI (
RoleName VARCHAR(255) NOT NULL,
URIId INTEGER NOT NULL,
FOREIGN KEY (URIId) REFERENCES URI(URIId),
PRIMARY KEY (RoleName, URIId)
);
""" """
async def main(): async def main():

View File

@ -25,6 +25,7 @@ def get_config():
'global': {'message_root_path': CURRENT_PATH.parents[2] / 'messages', 'global': {'message_root_path': CURRENT_PATH.parents[2] / 'messages',
'debug': DEBUG, 'debug': DEBUG,
'internal_user': 'internal', 'internal_user': 'internal',
'check_role': False,
'rougail_dtd_path': '../rougail/data/creole.dtd'}, 'rougail_dtd_path': '../rougail/data/creole.dtd'},
'source': {'root_path': '/srv/seed'}, 'source': {'root_path': '/srv/seed'},
'cache': {'root_path': '/var/cache/risotto'} 'cache': {'root_path': '/var/cache/risotto'}

View File

@ -15,15 +15,6 @@ import asyncpg
class CallDispatcher: class CallDispatcher:
def valid_public_function(self,
risotto_context: Context,
kwargs: Dict,
public_only: bool):
if public_only and not self.messages[risotto_context.version][risotto_context.message]['public']:
msg = _(f'the message {risotto_context.message} is private')
log.error_msg(risotto_context, kwargs, msg)
raise NotAllowedError(msg)
async def valid_call_returns(self, async def valid_call_returns(self,
risotto_context: Context, risotto_context: Context,
returns: Dict, returns: Dict,
@ -77,7 +68,7 @@ class CallDispatcher:
version: str, version: str,
message: str, message: str,
old_risotto_context: Context, old_risotto_context: Context,
public_only: bool=False, check_role: bool=False,
**kwargs): **kwargs):
""" execute the function associate with specified uri """ execute the function associate with specified uri
arguments are validate before arguments are validate before
@ -86,16 +77,14 @@ class CallDispatcher:
version, version,
message, message,
'rpc') 'rpc')
self.valid_public_function(risotto_context,
kwargs,
public_only)
self.check_message_type(risotto_context, self.check_message_type(risotto_context,
kwargs) kwargs)
try: try:
tiramisu_config = await self.load_kwargs_to_config(risotto_context, kw = await self.load_kwargs_to_config(risotto_context,
kwargs) f'{version}.{message}',
kwargs,
check_role)
function_obj = self.messages[version][message] function_obj = self.messages[version][message]
kw = await tiramisu_config.option(message).value.dict()
risotto_context.function = function_obj['function'] risotto_context.function = function_obj['function']
if function_obj['risotto_context']: if function_obj['risotto_context']:
kw['risotto_context'] = risotto_context kw['risotto_context'] = risotto_context
@ -150,7 +139,12 @@ class CallDispatcher:
class PublishDispatcher: class PublishDispatcher:
async def publish(self, version, message, old_risotto_context, public_only=False, **kwargs): async def publish(self,
version: str,
message: str,
old_risotto_context: Context,
check_role: bool=False,
**kwargs) -> None:
risotto_context = self.build_new_context(old_risotto_context, risotto_context = self.build_new_context(old_risotto_context,
version, version,
message, message,
@ -158,9 +152,8 @@ class PublishDispatcher:
self.check_message_type(risotto_context, self.check_message_type(risotto_context,
kwargs) kwargs)
try: try:
config = await self.load_kwargs_to_config(risotto_context, config_arguments = await self.load_kwargs_to_config(risotto_context,
kwargs) kwargs)
config_arguments = await config.option(message).value.dict()
except CallError as err: except CallError as err:
return return
except Exception as err: except Exception as err:
@ -250,7 +243,9 @@ class Dispatcher(register.RegisterDispatcher, CallDispatcher, PublishDispatcher)
async def load_kwargs_to_config(self, async def load_kwargs_to_config(self,
risotto_context: Context, risotto_context: Context,
kwargs: Dict): uri: str,
kwargs: Dict,
check_role: bool):
""" create a new Config et set values to it """ create a new Config et set values to it
""" """
# create a new config # create a new config
@ -266,20 +261,73 @@ class Dispatcher(register.RegisterDispatcher, CallDispatcher, PublishDispatcher)
except AttributeError: except AttributeError:
if DEBUG: if DEBUG:
print_exc() print_exc()
raise ValueError(_(f'unknown parameter "{key}"')) raise ValueError(_(f'unknown parameter in "{uri}": "{key}"'))
# check mandatories options # check mandatories options
if check_role and get_config().get('global').get('check_role'):
await self.check_role(subconfig,
risotto_context.username,
uri)
await config.property.read_only() await config.property.read_only()
mandatories = await config.value.mandatory() mandatories = await config.value.mandatory()
if mandatories: if mandatories:
mand = [mand.split('.')[-1] for mand in mandatories] mand = [mand.split('.')[-1] for mand in mandatories]
raise ValueError(_(f'missing parameters: {mand}')) raise ValueError(_(f'missing parameters in "{uri}": {mand}'))
# return the config # return complete an validated kwargs
return config return await subconfig.value.dict()
def get_service(self, def get_service(self,
name: str): name: str):
return self.injected_self[name] return self.injected_self[name]
async def check_role(self,
config: Config,
user_login: str,
uri: str) -> None:
db_conf = get_config().get('database')
pool = await asyncpg.create_pool(database=db_conf.get('dbname'), user=db_conf.get('user'))
async with pool.acquire() as connection:
async with connection.transaction():
# Verify if user exists and get ID
sql = '''
SELECT UserId
FROM RisottoUser
WHERE UserLogin = $1
'''
user_id = await connection.fetchval(sql,
user_login)
if user_id is None:
raise NotAllowedError(_(f"You ({user_login}) don't have any account in this application"))
# Get all references for this message
refs = {}
for option in await config.list('all'):
ref = await option.information.get('ref')
if ref:
refs[ref] = str(await option.value.get())
# Check role
select_role_uri = '''
SELECT RoleName
FROM URI, RoleURI
WHERE URI.URIName = $1 AND RoleURI.URIId = URI.URIId
'''
select_role_user = '''
SELECT RoleAttribute, RoleAttributeValue
FROM UserRole
WHERE RoleUserId = $1 AND RoleName = $2
'''
for uri_role in await connection.fetch(select_role_uri,
uri):
for user_role in await connection.fetch(select_role_user,
user_id,
uri_role['rolename']):
if not user_role['roleattribute']:
return
if user_role['roleattribute'] in refs and \
user_role['roleattributevalue'] == refs[user_role['roleattribute']]:
return
raise NotAllowedError(_(f'You ({user_login}) don\'t have any authorisation to access to "{uri}"'))
dispatcher = Dispatcher() dispatcher = Dispatcher()
register.dispatcher = dispatcher register.dispatcher = dispatcher

View File

@ -73,7 +73,7 @@ async def handle(request):
text = await method(version, text = await method(version,
message, message,
risotto_context, risotto_context,
public_only=True, check_role=True,
**kwargs) **kwargs)
except NotAllowedError as err: except NotAllowedError as err:
raise HTTPNotFound(reason=str(err)) raise HTTPNotFound(reason=str(err))

View File

@ -440,16 +440,19 @@ def _get_option(name,
kwargs['multi'] = True kwargs['multi'] = True
type_ = type_[2:] type_ = type_[2:]
if type_ == 'Dict': if type_ == 'Dict':
return DictOption(**kwargs) obj = DictOption(**kwargs)
elif type_ == 'String': elif type_ == 'String':
return StrOption(**kwargs) obj = StrOption(**kwargs)
elif type_ == 'Any': elif type_ == 'Any':
return AnyOption(**kwargs) obj = AnyOption(**kwargs)
elif 'Number' in type_ or type_ == 'ID' or type_ == 'Integer': elif 'Number' in type_ or type_ == 'ID' or type_ == 'Integer':
return IntOption(**kwargs) obj = IntOption(**kwargs)
elif type_ == 'Boolean': elif type_ == 'Boolean':
return BoolOption(**kwargs) obj = BoolOption(**kwargs)
raise Exception('unsupported type {} in {}'.format(type_, file_path)) else:
raise Exception('unsupported type {} in {}'.format(type_, file_path))
obj.impl_set_information('ref', arg.ref)
return obj
def _parse_args(message_def, def _parse_args(message_def,
@ -463,11 +466,15 @@ def _parse_args(message_def,
""" """
new_options = OrderedDict() new_options = OrderedDict()
for name, arg in message_def.parameters.items(): for name, arg in message_def.parameters.items():
new_options[name] = arg #new_options[name] = arg
if arg.ref: # if arg.ref:
needs.setdefault(message_def.uri, {}).setdefault(arg.ref, []).append(name) # needs.setdefault(message_def.uri, {}).setdefault(arg.ref, []).append(name)
for name, arg in new_options.items(): #for name, arg in new_options.items():
current_opt = _get_option(name, arg, file_path, select_option, optiondescription) current_opt = _get_option(name,
arg,
file_path,
select_option,
optiondescription)
options.append(current_opt) options.append(current_opt)
if hasattr(arg, 'shortarg') and arg.shortarg and load_shortarg: if hasattr(arg, 'shortarg') and arg.shortarg and load_shortarg:
options.append(SymLinkOption(arg.shortarg, current_opt)) options.append(SymLinkOption(arg.shortarg, current_opt))

View File

@ -1,12 +1,13 @@
from tiramisu import Config from tiramisu import Config
from inspect import signature from inspect import signature
from typing import Callable, Optional from typing import Callable, Optional
import asyncpg
from .utils import undefined, _ from .utils import undefined, _
from .error import RegistrationError from .error import RegistrationError
from .message import get_messages from .message import get_messages
from .context import Context from .context import Context
from .config import INTERNAL_USER from .config import INTERNAL_USER, get_config
def register(uris: str, def register(uris: str,
@ -248,23 +249,38 @@ class RegisterDispatcher:
risotto_context.type = None risotto_context.type = None
await module.on_join(risotto_context) await module.on_join(risotto_context)
async def insert_message(self,
connection,
uri):
sql = """INSERT INTO URI(URIName) VALUES ($1)
ON CONFLICT (URIName) DO NOTHING
"""
await connection.fetchval(sql,
uri)
async def load(self): async def load(self):
# valid function's arguments # valid function's arguments
for version, messages in self.messages.items(): db_conf = get_config().get('database')
for message, message_infos in messages.items(): pool = await asyncpg.create_pool(database=db_conf.get('dbname'), user=db_conf.get('user'))
if message_infos['pattern'] == 'rpc': async with pool.acquire() as connection:
module_name = message_infos['module'] async with connection.transaction():
function = message_infos['function'] for version, messages in self.messages.items():
await self.valid_rpc_params(version, for message, message_infos in messages.items():
message, if message_infos['pattern'] == 'rpc':
function, module_name = message_infos['module']
module_name) function = message_infos['function']
else: await self.valid_rpc_params(version,
if 'functions' in message_infos: message,
for function_infos in message_infos['functions']: function,
module_name = function_infos['module'] module_name)
function = function_infos['function'] else:
await self.valid_event_params(version, if 'functions' in message_infos:
message, for function_infos in message_infos['functions']:
function, module_name = function_infos['module']
module_name) function = function_infos['function']
await self.valid_event_params(version,
message,
function,
module_name)
await self.insert_message(connection,
f'{version}.{message}')

View File

@ -12,7 +12,8 @@ from ...context import Context
from ...utils import _ from ...utils import _
class Risotto(Controller): class Risotto(Controller):
def __init__(self): def __init__(self,
test: bool) -> None:
self.source_root_path = get_config().get('source').get('root_path') self.source_root_path = get_config().get('source').get('root_path')
async def _applicationservice_create(self, async def _applicationservice_create(self,

View File

@ -16,7 +16,8 @@ from ...logger import log
class Risotto(Controller): class Risotto(Controller):
def __init__(self): def __init__(self,
test: bool) -> None:
self.source_root_path = get_config().get('source').get('root_path') self.source_root_path = get_config().get('source').get('root_path')
self.cache_root_path = join(get_config().get('cache').get('root_path'), 'servermodel') self.cache_root_path = join(get_config().get('cache').get('root_path'), 'servermodel')
if not isdir(self.cache_root_path): if not isdir(self.cache_root_path):

View File

@ -12,7 +12,8 @@ from ...utils import _
class Risotto(Controller): class Risotto(Controller):
def __init__(self): def __init__(self,
test: bool) -> None:
self.storage = Storage(engine='dictionary') self.storage = Storage(engine='dictionary')
self.cache_root_path = join(get_config().get('cache').get('root_path'), 'servermodel') self.cache_root_path = join(get_config().get('cache').get('root_path'), 'servermodel')