Compare commits

...

3 Commits

Author SHA1 Message Date
Emmanuel Garette cc6dd3efe3 can have multi domain locally 2020-03-10 14:03:51 +01:00
Emmanuel Garette 892e052969 sql_filename => sql_dir 2020-03-10 14:01:45 +01:00
Emmanuel Garette 5103b7bd28 update vocabulary 2020-03-10 14:00:39 +01:00
8 changed files with 134 additions and 132 deletions

View File

@ -1,6 +1,10 @@
Message
=======
message: config.session.server.start
version: v1
uri: v1.config.session.server.start
version: v1
module: config
submodule: session
message: config.session.server.start
submessage: session.server.start
subsubmessage: server.start

View File

@ -1,6 +1,7 @@
import asyncpg
import asyncio
from os.path import isfile
from os import listdir
from os.path import isdir, join
from sys import exit
@ -8,16 +9,20 @@ from risotto.config import get_config
async def main():
sql_filename = get_config()['global']['sql_filename']
if not isfile(sql_filename):
sql_dir = get_config()['global']['sql_dir']
if not isdir(sql_dir):
print('no sql file to import')
exit()
db_conf = get_config()['database']['dsn']
pool = await asyncpg.create_pool(db_conf)
async with pool.acquire() as connection:
async with connection.transaction():
with open(sql_filename, 'r') as sql:
await connection.execute(sql.read())
for filename in listdir(sql_dir):
if filename.endswith('.sql'):
sql_filename = join(sql_dir, filename)
with open(sql_filename, 'r') as sql:
await connection.execute(sql.read())
if __name__ == '__main__':
loop = asyncio.get_event_loop()

View File

@ -7,8 +7,7 @@ DEFAULT_USER = environ.get('DEFAULT_USER', 'Anonymous')
DEFAULT_DSN = environ.get('RISOTTO_DSN', 'postgres:///risotto?host=/var/run/postgresql/&user=risotto')
DEFAULT_TIRAMISU_DSN = environ.get('DEFAULT_TIRAMISU_DSN', 'postgres:///tiramisu?host=/var/run/postgresql/&user=tiramisu')
MESSAGE_PATH = environ.get('MESSAGE_PATH', '/root/risotto-message/messages')
MODULE_NAME = environ.get('MODULE_NAME', 'test')
SQL_FILENAME = f'/root/risotto-{MODULE_NAME}/sql/init.sql'
SQL_DIR = environ.get('SQL_DIR', './sql')
def get_config():
@ -22,8 +21,7 @@ def get_config():
'internal_user': 'internal',
'check_role': True,
'admin_user': DEFAULT_USER,
'module_name': MODULE_NAME,
'sql_filename': SQL_FILENAME},
'sql_dir': SQL_DIR},
'source': {'root_path': '/srv/seed'},
'cache': {'root_path': '/var/cache/risotto'},
'servermodel': {'internal_source': 'internal',

View File

@ -2,6 +2,7 @@ from .config import get_config
from .dispatcher import dispatcher
from .context import Context
from .remote import remote
from .services import list_modules
class Controller:
@ -9,17 +10,17 @@ class Controller:
"""
def __init__(self,
test: bool):
self.submodule = get_config()['global']['module_name']
self.risotto_modules = list_modules()
async def call(self,
uri: str,
risotto_context: Context,
**kwargs):
""" a wrapper to dispatcher's call"""
version, submodule, message = uri.split('.', 2)
uri = submodule + '.' + message
if submodule != self.submodule:
return await remote.call_or_publish(submodule,
version, module, message = uri.split('.', 2)
uri = module + '.' + message
if module not in self.risotto_modules:
return await remote.call_or_publish(module,
version,
message,
kwargs)
@ -33,14 +34,15 @@ class Controller:
risotto_context: Context,
**kwargs):
""" a wrapper to dispatcher's publish"""
version, submodule, uri = uri.split('.', 2)
if submodule != self.submodule:
await remote.call_or_publish(submodule,
version, module, submessage = uri.split('.', 2)
version, message = uri.split('.', 1)
if module not in self.risotto_modules:
await remote.call_or_publish(module,
version,
message,
submessage,
kwargs)
await dispatcher.publish(version,
uri,
message,
risotto_context,
**kwargs)

View File

@ -11,12 +11,15 @@ from .error import CallError, NotAllowedError, RegistrationError
from .message import get_messages
from .logger import log
from .config import get_config
from .services import load_services
from .services import list_modules, load_submodules
extra_routes = {}
RISOTTO_MODULES = list_modules()
def create_context(request):
risotto_context = Context()
risotto_context.username = request.match_info.get('username',
@ -88,8 +91,8 @@ async def handle(request):
if get_config()['global']['debug']:
print_exc()
raise HTTPInternalServerError(reason=str(err))
return Response(text=dumps({'response': text},
content_type='application/json'))
return Response(text=dumps({'response': text}),
content_type='application/json')
async def api(request,
@ -108,7 +111,8 @@ async def api(request,
WHERE RoleURI.URIId = URI.URIId
'''
uris = [uri['uriname'] for uri in await connection.fetch(sql)]
async with await Config(get_messages(load_shortarg=True,
async with await Config(get_messages(current_module_names=RISOTTO_MODULES,
load_shortarg=True,
current_version=risotto_context.version,
uris=uris)[1]) as config:
await config.property.read_write()
@ -120,7 +124,7 @@ async def get_app(loop):
""" build all routes
"""
global extra_routes
load_services()
load_submodules(dispatcher)
app = Application(loop=loop)
routes = []
default_storage.engine('dictionary')

View File

@ -13,7 +13,6 @@ from ..config import get_config
from ..utils import _
MESSAGE_ROOT_PATH = get_config()['global']['message_root_path']
MODULE_NAME = get_config()['global']['module_name']
CUSTOMTYPES = {}
groups.addgroup('message')
@ -229,12 +228,12 @@ def _parse_parameters(raw_defs,
def get_message(uri: str,
current_module_name: str):
current_module_names: str):
try:
version, message = uri.split('.', 1)
path = get_message_file_path(version,
message,
current_module_name)
current_module_names)
with open(path, "r") as message_file:
return MessageDefinition(load(message_file.read(), Loader=SafeLoader),
version,
@ -247,22 +246,21 @@ def get_message(uri: str,
def get_message_file_path(version,
message,
current_module_name):
current_module_names):
module_name, filename = message.split('.', 1)
if current_module_name and module_name != current_module_name:
raise Exception(f'should only load message for {current_module_name}, not {message}')
if current_module_names and module_name not in current_module_names:
raise Exception(f'should only load message for {current_module_names}, not {message}')
return join(MESSAGE_ROOT_PATH, version, module_name, 'messages', filename + '.yml')
def list_messages(uris,
current_module_name,
current_module_names,
current_version):
def get_module_paths():
if current_module_name is not None:
yield current_module_name, join(MESSAGE_ROOT_PATH, version, current_module_name, 'messages')
else:
for module_name in listdir(join(MESSAGE_ROOT_PATH, version)):
yield module_name, join(MESSAGE_ROOT_PATH, version, module_name, 'messages')
def get_module_paths(current_module_names):
if current_module_names is None:
current_module_names = listdir(join(MESSAGE_ROOT_PATH, version))
for module_name in current_module_names:
yield module_name, join(MESSAGE_ROOT_PATH, version, module_name, 'messages')
if current_version:
versions = [current_version]
@ -270,7 +268,7 @@ def list_messages(uris,
versions = listdir(join(MESSAGE_ROOT_PATH))
versions.sort()
for version in versions:
for module_name, message_path in get_module_paths():
for module_name, message_path in get_module_paths(current_module_names):
for message in listdir(message_path):
if message.endswith('.yml'):
uri = version + '.' + module_name + '.' + message.rsplit('.', 1)[0]
@ -390,47 +388,28 @@ class CustomType:
return self.title
def load_customtypes(current_module_name: str) -> None:
def load_customtypes(current_module_names: str) -> None:
versions = listdir(MESSAGE_ROOT_PATH)
versions.sort()
def convert_properties(customtype: str,
version: str) -> None:
""" if properties include an other customtype, replace it
"""
properties = {}
for key, value in customtype.properties.items():
type_ = value.type
if type_.startswith('[]'):
if type_ in CUSTOMTYPES[version]:
raise Exception(_('cannot have []CustomType'))
properties[key] = value
else:
if type_ in CUSTOMTYPES[version]:
print('====== ca existe')
properties[key] = CUSTOMTYPES[version][ttype_]
else:
properties[key] = value
customtype.properties = properties
for version in versions:
if version not in CUSTOMTYPES:
CUSTOMTYPES[version] = {}
types_path = join(MESSAGE_ROOT_PATH,
version,
current_module_name,
'types')
for message in listdir(types_path):
if message.endswith('.yml'):
path = join(types_path, message)
# remove extension
message = message.rsplit('.', 1)[0]
with open(path, "r") as message_file:
try:
custom_type = CustomType(load(message_file, Loader=SafeLoader))
convert_properties(custom_type,
version)
CUSTOMTYPES[version][custom_type.getname()] = custom_type
except Exception as err:
raise Exception(_(f'enable to load type {err}: {message}'))
for current_module_name in current_module_names:
types_path = join(MESSAGE_ROOT_PATH,
version,
current_module_name,
'types')
for message in listdir(types_path):
if message.endswith('.yml'):
path = join(types_path, message)
# remove extension
message = message.rsplit('.', 1)[0]
with open(path, "r") as message_file:
try:
custom_type = CustomType(load(message_file, Loader=SafeLoader))
CUSTOMTYPES[version][custom_type.getname()] = custom_type
except Exception as err:
raise Exception(_(f'enable to load type {err}: {message}'))
@ -612,10 +591,10 @@ def _get_root_option(select_option, optiondescriptions):
return OptionDescription('root', 'root', options_obj)
def get_messages(load_shortarg=False,
def get_messages(current_module_names,
load_shortarg=False,
current_version=None,
uris=None,
current_module_name=MODULE_NAME):
uris=None):
"""generate description from yml files
"""
global CUSTOMTYPES
@ -623,7 +602,7 @@ def get_messages(load_shortarg=False,
optiondescriptions_info = {}
needs = {}
messages = list(list_messages(uris,
current_module_name,
current_module_names,
current_version))
messages.sort()
optiondescriptions_name = [message_name.split('.', 1)[1] for message_name in messages]
@ -631,18 +610,17 @@ def get_messages(load_shortarg=False,
'Nom du message.',
tuple(optiondescriptions_name),
properties=frozenset(['mandatory', 'positional']))
if current_module_name is None:
if current_module_names is None:
CUSTOMTYPES = {}
if not CUSTOMTYPES:
if current_module_name is None:
if current_module_names is None:
for version in listdir(MESSAGE_ROOT_PATH):
for module_name in listdir(join(MESSAGE_ROOT_PATH, version)):
load_customtypes(module_name)
load_customtypes(listdir(join(MESSAGE_ROOT_PATH, version)))
else:
load_customtypes(current_module_name)
load_customtypes(current_module_names)
for message_name in messages:
message_def = get_message(message_name,
current_module_name)
current_module_names)
optiondescriptions_info[message_def.uri] = {'pattern': message_def.pattern,
'default_roles': message_def.default_roles,
'version': message_name.split('.')[0]}
@ -659,6 +637,6 @@ def get_messages(load_shortarg=False,
load_shortarg)
root = _get_root_option(select_option, optiondescriptions)
if current_module_name is None:
if current_module_names is None:
CUSTOMTYPES = {}
return optiondescriptions_info, root

View File

@ -10,6 +10,7 @@ from .message import get_messages
from .context import Context
from .config import get_config
from .logger import log
from .services import list_modules
def register(uris: str,
@ -36,22 +37,32 @@ class RegisterDispatcher:
# postgresql pool
self.pool = None
# load tiramisu objects
messages, self.option = get_messages()
self.risotto_modules = list_modules()
messages, self.option = get_messages(self.risotto_modules)
# list of uris with informations: {"v1": {"module_name.xxxxx": yyyyyy}}
version = 'v1'
self.messages = {}
for tiramisu_message, obj in messages.items():
version = obj['version']
if version not in self.messages:
self.messages[version] = {}
self.messages[version][tiramisu_message] = obj
self.risotto_module = get_config()['global']['module_name']
def get_function_args(self,
function: Callable):
# remove self
first_argument_index = 1
return [param.name for param in list(signature(function).parameters.values())[first_argument_index:]]
# remove self and risotto_context
first_argument_index = 2
return {param.name for param in list(signature(function).parameters.values())[first_argument_index:]}
async def get_message_args(self,
message: str):
# load config
async with await Config(self.option) as config:
await config.property.read_write()
# set message to the message name
await config.option('message').value.set(message)
# get message argument
dico = await config.option(message).value.dict()
return set(dico.keys())
async def valid_rpc_params(self,
version: str,
@ -60,26 +71,10 @@ class RegisterDispatcher:
module_name: str):
""" parameters function must have strictly all arguments with the correct name
"""
async def get_message_args():
# load config
async with await Config(self.option) as config:
await config.property.read_write()
# set message to the uri name
await config.option('message').value.set(message)
# get message argument
dico = await config.option(message).value.dict()
return set(dico.keys())
def get_function_args():
function_args = self.get_function_args(function)
# risotto_context is a special argument, remove it
function_args = function_args[1:]
return set(function_args)
# get message arguments
message_args = await get_message_args()
message_args = await self.get_message_args(message)
# get function arguments
function_args = get_function_args()
function_args = self.get_function_args(function)
# compare message arguments with function parameter
# it must not have more or less arguments
if message_args != function_args:
@ -102,26 +97,10 @@ class RegisterDispatcher:
module_name: str):
""" parameters function validation for event messages
"""
async def get_message_args():
# load config
async with await Config(self.option) as config:
await config.property.read_write()
# set message to the message name
await config.option('message').value.set(message)
# get message argument
dico = await config.option(message).value.dict()
return set(dico.keys())
def get_function_args():
function_args = self.get_function_args(function)
# risotto_context is a special argument, remove it
function_args = function_args[1:]
return set(function_args)
# get message arguments
message_args = await get_message_args()
message_args = await self.get_message_args(message)
# get function arguments
function_args = get_function_args()
function_args = self.get_function_args(function)
# compare message arguments with function parameter
# it can have less arguments but not more
extra_function_args = function_args - message_args
@ -148,14 +127,13 @@ class RegisterDispatcher:
module_name = function.__module__.split('.')[-2]
message_namespace = message.split('.', 1)[0]
message_risotto_module, message_namespace, message_name = message.split('.', 2)
if message_risotto_module != self.risotto_module:
raise RegistrationError(_(f'cannot registered the "{message}" is not "{self.risotto_module}"'))
if message_risotto_module not in self.risotto_modules:
raise RegistrationError(_(f'cannot registered the "{message}" is not "{self.risotto_modules}"'))
if self.messages[version][message]['pattern'] == 'rpc' and message_namespace != module_name:
raise RegistrationError(_(f'cannot registered the "{message}" message in module "{module_name}"'))
# True if first argument is the risotto_context
function_args = self.get_function_args(function)
function_args.pop(0)
# check if already register
if 'function' in self.messages[version][message]:

View File

@ -0,0 +1,33 @@
from os import listdir
from os.path import isdir, isfile, dirname, abspath, basename, join
from importlib import import_module
# from ..dispatcher import dispatcher
def list_modules():
abs_here = dirname(abspath(__file__))
here = basename(abs_here)
return [name for name in listdir(abs_here) if not name.startswith('__') and isdir(join(abs_here, name))]
def load_submodules(dispatcher,
modules=None,
validate: bool=True,
test: bool=False):
abs_here = dirname(abspath(__file__))
here = basename(abs_here)
module = basename(dirname(abs_here))
if not modules:
modules = listdir(abs_here)
for module in modules:
absmodule = join(abs_here, module)
if isdir(absmodule):
for submodule in listdir(absmodule):
absfilename = join(absmodule, submodule)
if isdir(absfilename) and isfile(join(absfilename, '__init__.py')):
dispatcher.set_module(submodule,
import_module(f'.{here}.{module}.{submodule}',
f'risotto'),
test)
if validate:
dispatcher.validate()