can have multi domain locally

This commit is contained in:
Emmanuel Garette 2020-03-10 14:03:37 +01:00
parent 892e052969
commit cc6dd3efe3
5 changed files with 116 additions and 121 deletions

View File

@ -2,6 +2,7 @@ from .config import get_config
from .dispatcher import dispatcher from .dispatcher import dispatcher
from .context import Context from .context import Context
from .remote import remote from .remote import remote
from .services import list_modules
class Controller: class Controller:
@ -9,17 +10,17 @@ class Controller:
""" """
def __init__(self, def __init__(self,
test: bool): test: bool):
self.submodule = get_config()['global']['module_name'] self.risotto_modules = list_modules()
async def call(self, async def call(self,
uri: str, uri: str,
risotto_context: Context, risotto_context: Context,
**kwargs): **kwargs):
""" a wrapper to dispatcher's call""" """ a wrapper to dispatcher's call"""
version, submodule, message = uri.split('.', 2) version, module, message = uri.split('.', 2)
uri = submodule + '.' + message uri = module + '.' + message
if submodule != self.submodule: if module not in self.risotto_modules:
return await remote.call_or_publish(submodule, return await remote.call_or_publish(module,
version, version,
message, message,
kwargs) kwargs)
@ -33,14 +34,15 @@ class Controller:
risotto_context: Context, risotto_context: Context,
**kwargs): **kwargs):
""" a wrapper to dispatcher's publish""" """ a wrapper to dispatcher's publish"""
version, submodule, uri = uri.split('.', 2) version, module, submessage = uri.split('.', 2)
if submodule != self.submodule: version, message = uri.split('.', 1)
await remote.call_or_publish(submodule, if module not in self.risotto_modules:
await remote.call_or_publish(module,
version, version,
message, submessage,
kwargs) kwargs)
await dispatcher.publish(version, await dispatcher.publish(version,
uri, message,
risotto_context, risotto_context,
**kwargs) **kwargs)

View File

@ -11,12 +11,15 @@ from .error import CallError, NotAllowedError, RegistrationError
from .message import get_messages from .message import get_messages
from .logger import log from .logger import log
from .config import get_config from .config import get_config
from .services import load_services from .services import list_modules, load_submodules
extra_routes = {} extra_routes = {}
RISOTTO_MODULES = list_modules()
def create_context(request): def create_context(request):
risotto_context = Context() risotto_context = Context()
risotto_context.username = request.match_info.get('username', risotto_context.username = request.match_info.get('username',
@ -88,8 +91,8 @@ async def handle(request):
if get_config()['global']['debug']: if get_config()['global']['debug']:
print_exc() print_exc()
raise HTTPInternalServerError(reason=str(err)) raise HTTPInternalServerError(reason=str(err))
return Response(text=dumps({'response': text}, return Response(text=dumps({'response': text}),
content_type='application/json')) content_type='application/json')
async def api(request, async def api(request,
@ -108,7 +111,8 @@ async def api(request,
WHERE RoleURI.URIId = URI.URIId WHERE RoleURI.URIId = URI.URIId
''' '''
uris = [uri['uriname'] for uri in await connection.fetch(sql)] uris = [uri['uriname'] for uri in await connection.fetch(sql)]
async with await Config(get_messages(load_shortarg=True, async with await Config(get_messages(current_module_names=RISOTTO_MODULES,
load_shortarg=True,
current_version=risotto_context.version, current_version=risotto_context.version,
uris=uris)[1]) as config: uris=uris)[1]) as config:
await config.property.read_write() await config.property.read_write()
@ -120,7 +124,7 @@ async def get_app(loop):
""" build all routes """ build all routes
""" """
global extra_routes global extra_routes
load_services() load_submodules(dispatcher)
app = Application(loop=loop) app = Application(loop=loop)
routes = [] routes = []
default_storage.engine('dictionary') default_storage.engine('dictionary')

View File

@ -13,7 +13,6 @@ from ..config import get_config
from ..utils import _ from ..utils import _
MESSAGE_ROOT_PATH = get_config()['global']['message_root_path'] MESSAGE_ROOT_PATH = get_config()['global']['message_root_path']
MODULE_NAME = get_config()['global']['module_name']
CUSTOMTYPES = {} CUSTOMTYPES = {}
groups.addgroup('message') groups.addgroup('message')
@ -229,12 +228,12 @@ def _parse_parameters(raw_defs,
def get_message(uri: str, def get_message(uri: str,
current_module_name: str): current_module_names: str):
try: try:
version, message = uri.split('.', 1) version, message = uri.split('.', 1)
path = get_message_file_path(version, path = get_message_file_path(version,
message, message,
current_module_name) current_module_names)
with open(path, "r") as message_file: with open(path, "r") as message_file:
return MessageDefinition(load(message_file.read(), Loader=SafeLoader), return MessageDefinition(load(message_file.read(), Loader=SafeLoader),
version, version,
@ -247,21 +246,20 @@ def get_message(uri: str,
def get_message_file_path(version, def get_message_file_path(version,
message, message,
current_module_name): current_module_names):
module_name, filename = message.split('.', 1) module_name, filename = message.split('.', 1)
if current_module_name and module_name != current_module_name: if current_module_names and module_name not in current_module_names:
raise Exception(f'should only load message for {current_module_name}, not {message}') raise Exception(f'should only load message for {current_module_names}, not {message}')
return join(MESSAGE_ROOT_PATH, version, module_name, 'messages', filename + '.yml') return join(MESSAGE_ROOT_PATH, version, module_name, 'messages', filename + '.yml')
def list_messages(uris, def list_messages(uris,
current_module_name, current_module_names,
current_version): current_version):
def get_module_paths(): def get_module_paths(current_module_names):
if current_module_name is not None: if current_module_names is None:
yield current_module_name, join(MESSAGE_ROOT_PATH, version, current_module_name, 'messages') current_module_names = listdir(join(MESSAGE_ROOT_PATH, version))
else: for module_name in current_module_names:
for module_name in listdir(join(MESSAGE_ROOT_PATH, version)):
yield module_name, join(MESSAGE_ROOT_PATH, version, module_name, 'messages') yield module_name, join(MESSAGE_ROOT_PATH, version, module_name, 'messages')
if current_version: if current_version:
@ -270,7 +268,7 @@ def list_messages(uris,
versions = listdir(join(MESSAGE_ROOT_PATH)) versions = listdir(join(MESSAGE_ROOT_PATH))
versions.sort() versions.sort()
for version in versions: for version in versions:
for module_name, message_path in get_module_paths(): for module_name, message_path in get_module_paths(current_module_names):
for message in listdir(message_path): for message in listdir(message_path):
if message.endswith('.yml'): if message.endswith('.yml'):
uri = version + '.' + module_name + '.' + message.rsplit('.', 1)[0] uri = version + '.' + module_name + '.' + message.rsplit('.', 1)[0]
@ -390,30 +388,13 @@ class CustomType:
return self.title return self.title
def load_customtypes(current_module_name: str) -> None: def load_customtypes(current_module_names: str) -> None:
versions = listdir(MESSAGE_ROOT_PATH) versions = listdir(MESSAGE_ROOT_PATH)
versions.sort() versions.sort()
def convert_properties(customtype: str,
version: str) -> None:
""" if properties include an other customtype, replace it
"""
properties = {}
for key, value in customtype.properties.items():
type_ = value.type
if type_.startswith('[]'):
if type_ in CUSTOMTYPES[version]:
raise Exception(_('cannot have []CustomType'))
properties[key] = value
else:
if type_ in CUSTOMTYPES[version]:
print('====== ca existe')
properties[key] = CUSTOMTYPES[version][ttype_]
else:
properties[key] = value
customtype.properties = properties
for version in versions: for version in versions:
if version not in CUSTOMTYPES: if version not in CUSTOMTYPES:
CUSTOMTYPES[version] = {} CUSTOMTYPES[version] = {}
for current_module_name in current_module_names:
types_path = join(MESSAGE_ROOT_PATH, types_path = join(MESSAGE_ROOT_PATH,
version, version,
current_module_name, current_module_name,
@ -426,8 +407,6 @@ def load_customtypes(current_module_name: str) -> None:
with open(path, "r") as message_file: with open(path, "r") as message_file:
try: try:
custom_type = CustomType(load(message_file, Loader=SafeLoader)) custom_type = CustomType(load(message_file, Loader=SafeLoader))
convert_properties(custom_type,
version)
CUSTOMTYPES[version][custom_type.getname()] = custom_type CUSTOMTYPES[version][custom_type.getname()] = custom_type
except Exception as err: except Exception as err:
raise Exception(_(f'enable to load type {err}: {message}')) raise Exception(_(f'enable to load type {err}: {message}'))
@ -612,10 +591,10 @@ def _get_root_option(select_option, optiondescriptions):
return OptionDescription('root', 'root', options_obj) return OptionDescription('root', 'root', options_obj)
def get_messages(load_shortarg=False, def get_messages(current_module_names,
load_shortarg=False,
current_version=None, current_version=None,
uris=None, uris=None):
current_module_name=MODULE_NAME):
"""generate description from yml files """generate description from yml files
""" """
global CUSTOMTYPES global CUSTOMTYPES
@ -623,7 +602,7 @@ def get_messages(load_shortarg=False,
optiondescriptions_info = {} optiondescriptions_info = {}
needs = {} needs = {}
messages = list(list_messages(uris, messages = list(list_messages(uris,
current_module_name, current_module_names,
current_version)) current_version))
messages.sort() messages.sort()
optiondescriptions_name = [message_name.split('.', 1)[1] for message_name in messages] optiondescriptions_name = [message_name.split('.', 1)[1] for message_name in messages]
@ -631,18 +610,17 @@ def get_messages(load_shortarg=False,
'Nom du message.', 'Nom du message.',
tuple(optiondescriptions_name), tuple(optiondescriptions_name),
properties=frozenset(['mandatory', 'positional'])) properties=frozenset(['mandatory', 'positional']))
if current_module_name is None: if current_module_names is None:
CUSTOMTYPES = {} CUSTOMTYPES = {}
if not CUSTOMTYPES: if not CUSTOMTYPES:
if current_module_name is None: if current_module_names is None:
for version in listdir(MESSAGE_ROOT_PATH): for version in listdir(MESSAGE_ROOT_PATH):
for module_name in listdir(join(MESSAGE_ROOT_PATH, version)): load_customtypes(listdir(join(MESSAGE_ROOT_PATH, version)))
load_customtypes(module_name)
else: else:
load_customtypes(current_module_name) load_customtypes(current_module_names)
for message_name in messages: for message_name in messages:
message_def = get_message(message_name, message_def = get_message(message_name,
current_module_name) current_module_names)
optiondescriptions_info[message_def.uri] = {'pattern': message_def.pattern, optiondescriptions_info[message_def.uri] = {'pattern': message_def.pattern,
'default_roles': message_def.default_roles, 'default_roles': message_def.default_roles,
'version': message_name.split('.')[0]} 'version': message_name.split('.')[0]}
@ -659,6 +637,6 @@ def get_messages(load_shortarg=False,
load_shortarg) load_shortarg)
root = _get_root_option(select_option, optiondescriptions) root = _get_root_option(select_option, optiondescriptions)
if current_module_name is None: if current_module_names is None:
CUSTOMTYPES = {} CUSTOMTYPES = {}
return optiondescriptions_info, root return optiondescriptions_info, root

View File

@ -10,6 +10,7 @@ from .message import get_messages
from .context import Context from .context import Context
from .config import get_config from .config import get_config
from .logger import log from .logger import log
from .services import list_modules
def register(uris: str, def register(uris: str,
@ -36,22 +37,32 @@ class RegisterDispatcher:
# postgresql pool # postgresql pool
self.pool = None self.pool = None
# load tiramisu objects # load tiramisu objects
messages, self.option = get_messages() self.risotto_modules = list_modules()
messages, self.option = get_messages(self.risotto_modules)
# list of uris with informations: {"v1": {"module_name.xxxxx": yyyyyy}} # list of uris with informations: {"v1": {"module_name.xxxxx": yyyyyy}}
version = 'v1'
self.messages = {} self.messages = {}
for tiramisu_message, obj in messages.items(): for tiramisu_message, obj in messages.items():
version = obj['version'] version = obj['version']
if version not in self.messages: if version not in self.messages:
self.messages[version] = {} self.messages[version] = {}
self.messages[version][tiramisu_message] = obj self.messages[version][tiramisu_message] = obj
self.risotto_module = get_config()['global']['module_name']
def get_function_args(self, def get_function_args(self,
function: Callable): function: Callable):
# remove self # remove self and risotto_context
first_argument_index = 1 first_argument_index = 2
return [param.name for param in list(signature(function).parameters.values())[first_argument_index:]] return {param.name for param in list(signature(function).parameters.values())[first_argument_index:]}
async def get_message_args(self,
message: str):
# load config
async with await Config(self.option) as config:
await config.property.read_write()
# set message to the message name
await config.option('message').value.set(message)
# get message argument
dico = await config.option(message).value.dict()
return set(dico.keys())
async def valid_rpc_params(self, async def valid_rpc_params(self,
version: str, version: str,
@ -60,26 +71,10 @@ class RegisterDispatcher:
module_name: str): module_name: str):
""" parameters function must have strictly all arguments with the correct name """ parameters function must have strictly all arguments with the correct name
""" """
async def get_message_args():
# load config
async with await Config(self.option) as config:
await config.property.read_write()
# set message to the uri name
await config.option('message').value.set(message)
# get message argument
dico = await config.option(message).value.dict()
return set(dico.keys())
def get_function_args():
function_args = self.get_function_args(function)
# risotto_context is a special argument, remove it
function_args = function_args[1:]
return set(function_args)
# get message arguments # get message arguments
message_args = await get_message_args() message_args = await self.get_message_args(message)
# get function arguments # get function arguments
function_args = get_function_args() function_args = self.get_function_args(function)
# compare message arguments with function parameter # compare message arguments with function parameter
# it must not have more or less arguments # it must not have more or less arguments
if message_args != function_args: if message_args != function_args:
@ -102,26 +97,10 @@ class RegisterDispatcher:
module_name: str): module_name: str):
""" parameters function validation for event messages """ parameters function validation for event messages
""" """
async def get_message_args():
# load config
async with await Config(self.option) as config:
await config.property.read_write()
# set message to the message name
await config.option('message').value.set(message)
# get message argument
dico = await config.option(message).value.dict()
return set(dico.keys())
def get_function_args():
function_args = self.get_function_args(function)
# risotto_context is a special argument, remove it
function_args = function_args[1:]
return set(function_args)
# get message arguments # get message arguments
message_args = await get_message_args() message_args = await self.get_message_args(message)
# get function arguments # get function arguments
function_args = get_function_args() function_args = self.get_function_args(function)
# compare message arguments with function parameter # compare message arguments with function parameter
# it can have less arguments but not more # it can have less arguments but not more
extra_function_args = function_args - message_args extra_function_args = function_args - message_args
@ -148,14 +127,13 @@ class RegisterDispatcher:
module_name = function.__module__.split('.')[-2] module_name = function.__module__.split('.')[-2]
message_namespace = message.split('.', 1)[0] message_namespace = message.split('.', 1)[0]
message_risotto_module, message_namespace, message_name = message.split('.', 2) message_risotto_module, message_namespace, message_name = message.split('.', 2)
if message_risotto_module != self.risotto_module: if message_risotto_module not in self.risotto_modules:
raise RegistrationError(_(f'cannot registered the "{message}" is not "{self.risotto_module}"')) raise RegistrationError(_(f'cannot registered the "{message}" is not "{self.risotto_modules}"'))
if self.messages[version][message]['pattern'] == 'rpc' and message_namespace != module_name: if self.messages[version][message]['pattern'] == 'rpc' and message_namespace != module_name:
raise RegistrationError(_(f'cannot registered the "{message}" message in module "{module_name}"')) raise RegistrationError(_(f'cannot registered the "{message}" message in module "{module_name}"'))
# True if first argument is the risotto_context # True if first argument is the risotto_context
function_args = self.get_function_args(function) function_args = self.get_function_args(function)
function_args.pop(0)
# check if already register # check if already register
if 'function' in self.messages[version][message]: if 'function' in self.messages[version][message]:

View File

@ -0,0 +1,33 @@
from os import listdir
from os.path import isdir, isfile, dirname, abspath, basename, join
from importlib import import_module
# from ..dispatcher import dispatcher
def list_modules():
abs_here = dirname(abspath(__file__))
here = basename(abs_here)
return [name for name in listdir(abs_here) if not name.startswith('__') and isdir(join(abs_here, name))]
def load_submodules(dispatcher,
modules=None,
validate: bool=True,
test: bool=False):
abs_here = dirname(abspath(__file__))
here = basename(abs_here)
module = basename(dirname(abs_here))
if not modules:
modules = listdir(abs_here)
for module in modules:
absmodule = join(abs_here, module)
if isdir(absmodule):
for submodule in listdir(absmodule):
absfilename = join(absmodule, submodule)
if isdir(absfilename) and isfile(join(absfilename, '__init__.py')):
dispatcher.set_module(submodule,
import_module(f'.{here}.{module}.{submodule}',
f'risotto'),
test)
if validate:
dispatcher.validate()