can have multi domain locally
This commit is contained in:
parent
892e052969
commit
cc6dd3efe3
|
@ -2,6 +2,7 @@ from .config import get_config
|
|||
from .dispatcher import dispatcher
|
||||
from .context import Context
|
||||
from .remote import remote
|
||||
from .services import list_modules
|
||||
|
||||
|
||||
class Controller:
|
||||
|
@ -9,17 +10,17 @@ class Controller:
|
|||
"""
|
||||
def __init__(self,
|
||||
test: bool):
|
||||
self.submodule = get_config()['global']['module_name']
|
||||
self.risotto_modules = list_modules()
|
||||
|
||||
async def call(self,
|
||||
uri: str,
|
||||
risotto_context: Context,
|
||||
**kwargs):
|
||||
""" a wrapper to dispatcher's call"""
|
||||
version, submodule, message = uri.split('.', 2)
|
||||
uri = submodule + '.' + message
|
||||
if submodule != self.submodule:
|
||||
return await remote.call_or_publish(submodule,
|
||||
version, module, message = uri.split('.', 2)
|
||||
uri = module + '.' + message
|
||||
if module not in self.risotto_modules:
|
||||
return await remote.call_or_publish(module,
|
||||
version,
|
||||
message,
|
||||
kwargs)
|
||||
|
@ -33,14 +34,15 @@ class Controller:
|
|||
risotto_context: Context,
|
||||
**kwargs):
|
||||
""" a wrapper to dispatcher's publish"""
|
||||
version, submodule, uri = uri.split('.', 2)
|
||||
if submodule != self.submodule:
|
||||
await remote.call_or_publish(submodule,
|
||||
version, module, submessage = uri.split('.', 2)
|
||||
version, message = uri.split('.', 1)
|
||||
if module not in self.risotto_modules:
|
||||
await remote.call_or_publish(module,
|
||||
version,
|
||||
message,
|
||||
submessage,
|
||||
kwargs)
|
||||
await dispatcher.publish(version,
|
||||
uri,
|
||||
message,
|
||||
risotto_context,
|
||||
**kwargs)
|
||||
|
||||
|
|
|
@ -11,12 +11,15 @@ from .error import CallError, NotAllowedError, RegistrationError
|
|||
from .message import get_messages
|
||||
from .logger import log
|
||||
from .config import get_config
|
||||
from .services import load_services
|
||||
from .services import list_modules, load_submodules
|
||||
|
||||
|
||||
extra_routes = {}
|
||||
|
||||
|
||||
RISOTTO_MODULES = list_modules()
|
||||
|
||||
|
||||
def create_context(request):
|
||||
risotto_context = Context()
|
||||
risotto_context.username = request.match_info.get('username',
|
||||
|
@ -88,8 +91,8 @@ async def handle(request):
|
|||
if get_config()['global']['debug']:
|
||||
print_exc()
|
||||
raise HTTPInternalServerError(reason=str(err))
|
||||
return Response(text=dumps({'response': text},
|
||||
content_type='application/json'))
|
||||
return Response(text=dumps({'response': text}),
|
||||
content_type='application/json')
|
||||
|
||||
|
||||
async def api(request,
|
||||
|
@ -108,7 +111,8 @@ async def api(request,
|
|||
WHERE RoleURI.URIId = URI.URIId
|
||||
'''
|
||||
uris = [uri['uriname'] for uri in await connection.fetch(sql)]
|
||||
async with await Config(get_messages(load_shortarg=True,
|
||||
async with await Config(get_messages(current_module_names=RISOTTO_MODULES,
|
||||
load_shortarg=True,
|
||||
current_version=risotto_context.version,
|
||||
uris=uris)[1]) as config:
|
||||
await config.property.read_write()
|
||||
|
@ -120,7 +124,7 @@ async def get_app(loop):
|
|||
""" build all routes
|
||||
"""
|
||||
global extra_routes
|
||||
load_services()
|
||||
load_submodules(dispatcher)
|
||||
app = Application(loop=loop)
|
||||
routes = []
|
||||
default_storage.engine('dictionary')
|
||||
|
|
|
@ -13,7 +13,6 @@ from ..config import get_config
|
|||
from ..utils import _
|
||||
|
||||
MESSAGE_ROOT_PATH = get_config()['global']['message_root_path']
|
||||
MODULE_NAME = get_config()['global']['module_name']
|
||||
CUSTOMTYPES = {}
|
||||
|
||||
groups.addgroup('message')
|
||||
|
@ -229,12 +228,12 @@ def _parse_parameters(raw_defs,
|
|||
|
||||
|
||||
def get_message(uri: str,
|
||||
current_module_name: str):
|
||||
current_module_names: str):
|
||||
try:
|
||||
version, message = uri.split('.', 1)
|
||||
path = get_message_file_path(version,
|
||||
message,
|
||||
current_module_name)
|
||||
current_module_names)
|
||||
with open(path, "r") as message_file:
|
||||
return MessageDefinition(load(message_file.read(), Loader=SafeLoader),
|
||||
version,
|
||||
|
@ -247,21 +246,20 @@ def get_message(uri: str,
|
|||
|
||||
def get_message_file_path(version,
|
||||
message,
|
||||
current_module_name):
|
||||
current_module_names):
|
||||
module_name, filename = message.split('.', 1)
|
||||
if current_module_name and module_name != current_module_name:
|
||||
raise Exception(f'should only load message for {current_module_name}, not {message}')
|
||||
if current_module_names and module_name not in current_module_names:
|
||||
raise Exception(f'should only load message for {current_module_names}, not {message}')
|
||||
return join(MESSAGE_ROOT_PATH, version, module_name, 'messages', filename + '.yml')
|
||||
|
||||
|
||||
def list_messages(uris,
|
||||
current_module_name,
|
||||
current_module_names,
|
||||
current_version):
|
||||
def get_module_paths():
|
||||
if current_module_name is not None:
|
||||
yield current_module_name, join(MESSAGE_ROOT_PATH, version, current_module_name, 'messages')
|
||||
else:
|
||||
for module_name in listdir(join(MESSAGE_ROOT_PATH, version)):
|
||||
def get_module_paths(current_module_names):
|
||||
if current_module_names is None:
|
||||
current_module_names = listdir(join(MESSAGE_ROOT_PATH, version))
|
||||
for module_name in current_module_names:
|
||||
yield module_name, join(MESSAGE_ROOT_PATH, version, module_name, 'messages')
|
||||
|
||||
if current_version:
|
||||
|
@ -270,7 +268,7 @@ def list_messages(uris,
|
|||
versions = listdir(join(MESSAGE_ROOT_PATH))
|
||||
versions.sort()
|
||||
for version in versions:
|
||||
for module_name, message_path in get_module_paths():
|
||||
for module_name, message_path in get_module_paths(current_module_names):
|
||||
for message in listdir(message_path):
|
||||
if message.endswith('.yml'):
|
||||
uri = version + '.' + module_name + '.' + message.rsplit('.', 1)[0]
|
||||
|
@ -390,30 +388,13 @@ class CustomType:
|
|||
return self.title
|
||||
|
||||
|
||||
def load_customtypes(current_module_name: str) -> None:
|
||||
def load_customtypes(current_module_names: str) -> None:
|
||||
versions = listdir(MESSAGE_ROOT_PATH)
|
||||
versions.sort()
|
||||
def convert_properties(customtype: str,
|
||||
version: str) -> None:
|
||||
""" if properties include an other customtype, replace it
|
||||
"""
|
||||
properties = {}
|
||||
for key, value in customtype.properties.items():
|
||||
type_ = value.type
|
||||
if type_.startswith('[]'):
|
||||
if type_ in CUSTOMTYPES[version]:
|
||||
raise Exception(_('cannot have []CustomType'))
|
||||
properties[key] = value
|
||||
else:
|
||||
if type_ in CUSTOMTYPES[version]:
|
||||
print('====== ca existe')
|
||||
properties[key] = CUSTOMTYPES[version][ttype_]
|
||||
else:
|
||||
properties[key] = value
|
||||
customtype.properties = properties
|
||||
for version in versions:
|
||||
if version not in CUSTOMTYPES:
|
||||
CUSTOMTYPES[version] = {}
|
||||
for current_module_name in current_module_names:
|
||||
types_path = join(MESSAGE_ROOT_PATH,
|
||||
version,
|
||||
current_module_name,
|
||||
|
@ -426,8 +407,6 @@ def load_customtypes(current_module_name: str) -> None:
|
|||
with open(path, "r") as message_file:
|
||||
try:
|
||||
custom_type = CustomType(load(message_file, Loader=SafeLoader))
|
||||
convert_properties(custom_type,
|
||||
version)
|
||||
CUSTOMTYPES[version][custom_type.getname()] = custom_type
|
||||
except Exception as err:
|
||||
raise Exception(_(f'enable to load type {err}: {message}'))
|
||||
|
@ -612,10 +591,10 @@ def _get_root_option(select_option, optiondescriptions):
|
|||
return OptionDescription('root', 'root', options_obj)
|
||||
|
||||
|
||||
def get_messages(load_shortarg=False,
|
||||
def get_messages(current_module_names,
|
||||
load_shortarg=False,
|
||||
current_version=None,
|
||||
uris=None,
|
||||
current_module_name=MODULE_NAME):
|
||||
uris=None):
|
||||
"""generate description from yml files
|
||||
"""
|
||||
global CUSTOMTYPES
|
||||
|
@ -623,7 +602,7 @@ def get_messages(load_shortarg=False,
|
|||
optiondescriptions_info = {}
|
||||
needs = {}
|
||||
messages = list(list_messages(uris,
|
||||
current_module_name,
|
||||
current_module_names,
|
||||
current_version))
|
||||
messages.sort()
|
||||
optiondescriptions_name = [message_name.split('.', 1)[1] for message_name in messages]
|
||||
|
@ -631,18 +610,17 @@ def get_messages(load_shortarg=False,
|
|||
'Nom du message.',
|
||||
tuple(optiondescriptions_name),
|
||||
properties=frozenset(['mandatory', 'positional']))
|
||||
if current_module_name is None:
|
||||
if current_module_names is None:
|
||||
CUSTOMTYPES = {}
|
||||
if not CUSTOMTYPES:
|
||||
if current_module_name is None:
|
||||
if current_module_names is None:
|
||||
for version in listdir(MESSAGE_ROOT_PATH):
|
||||
for module_name in listdir(join(MESSAGE_ROOT_PATH, version)):
|
||||
load_customtypes(module_name)
|
||||
load_customtypes(listdir(join(MESSAGE_ROOT_PATH, version)))
|
||||
else:
|
||||
load_customtypes(current_module_name)
|
||||
load_customtypes(current_module_names)
|
||||
for message_name in messages:
|
||||
message_def = get_message(message_name,
|
||||
current_module_name)
|
||||
current_module_names)
|
||||
optiondescriptions_info[message_def.uri] = {'pattern': message_def.pattern,
|
||||
'default_roles': message_def.default_roles,
|
||||
'version': message_name.split('.')[0]}
|
||||
|
@ -659,6 +637,6 @@ def get_messages(load_shortarg=False,
|
|||
load_shortarg)
|
||||
|
||||
root = _get_root_option(select_option, optiondescriptions)
|
||||
if current_module_name is None:
|
||||
if current_module_names is None:
|
||||
CUSTOMTYPES = {}
|
||||
return optiondescriptions_info, root
|
||||
|
|
|
@ -10,6 +10,7 @@ from .message import get_messages
|
|||
from .context import Context
|
||||
from .config import get_config
|
||||
from .logger import log
|
||||
from .services import list_modules
|
||||
|
||||
|
||||
def register(uris: str,
|
||||
|
@ -36,22 +37,32 @@ class RegisterDispatcher:
|
|||
# postgresql pool
|
||||
self.pool = None
|
||||
# load tiramisu objects
|
||||
messages, self.option = get_messages()
|
||||
self.risotto_modules = list_modules()
|
||||
messages, self.option = get_messages(self.risotto_modules)
|
||||
# list of uris with informations: {"v1": {"module_name.xxxxx": yyyyyy}}
|
||||
version = 'v1'
|
||||
self.messages = {}
|
||||
for tiramisu_message, obj in messages.items():
|
||||
version = obj['version']
|
||||
if version not in self.messages:
|
||||
self.messages[version] = {}
|
||||
self.messages[version][tiramisu_message] = obj
|
||||
self.risotto_module = get_config()['global']['module_name']
|
||||
|
||||
def get_function_args(self,
|
||||
function: Callable):
|
||||
# remove self
|
||||
first_argument_index = 1
|
||||
return [param.name for param in list(signature(function).parameters.values())[first_argument_index:]]
|
||||
# remove self and risotto_context
|
||||
first_argument_index = 2
|
||||
return {param.name for param in list(signature(function).parameters.values())[first_argument_index:]}
|
||||
|
||||
async def get_message_args(self,
|
||||
message: str):
|
||||
# load config
|
||||
async with await Config(self.option) as config:
|
||||
await config.property.read_write()
|
||||
# set message to the message name
|
||||
await config.option('message').value.set(message)
|
||||
# get message argument
|
||||
dico = await config.option(message).value.dict()
|
||||
return set(dico.keys())
|
||||
|
||||
async def valid_rpc_params(self,
|
||||
version: str,
|
||||
|
@ -60,26 +71,10 @@ class RegisterDispatcher:
|
|||
module_name: str):
|
||||
""" parameters function must have strictly all arguments with the correct name
|
||||
"""
|
||||
async def get_message_args():
|
||||
# load config
|
||||
async with await Config(self.option) as config:
|
||||
await config.property.read_write()
|
||||
# set message to the uri name
|
||||
await config.option('message').value.set(message)
|
||||
# get message argument
|
||||
dico = await config.option(message).value.dict()
|
||||
return set(dico.keys())
|
||||
|
||||
def get_function_args():
|
||||
function_args = self.get_function_args(function)
|
||||
# risotto_context is a special argument, remove it
|
||||
function_args = function_args[1:]
|
||||
return set(function_args)
|
||||
|
||||
# get message arguments
|
||||
message_args = await get_message_args()
|
||||
message_args = await self.get_message_args(message)
|
||||
# get function arguments
|
||||
function_args = get_function_args()
|
||||
function_args = self.get_function_args(function)
|
||||
# compare message arguments with function parameter
|
||||
# it must not have more or less arguments
|
||||
if message_args != function_args:
|
||||
|
@ -102,26 +97,10 @@ class RegisterDispatcher:
|
|||
module_name: str):
|
||||
""" parameters function validation for event messages
|
||||
"""
|
||||
async def get_message_args():
|
||||
# load config
|
||||
async with await Config(self.option) as config:
|
||||
await config.property.read_write()
|
||||
# set message to the message name
|
||||
await config.option('message').value.set(message)
|
||||
# get message argument
|
||||
dico = await config.option(message).value.dict()
|
||||
return set(dico.keys())
|
||||
|
||||
def get_function_args():
|
||||
function_args = self.get_function_args(function)
|
||||
# risotto_context is a special argument, remove it
|
||||
function_args = function_args[1:]
|
||||
return set(function_args)
|
||||
|
||||
# get message arguments
|
||||
message_args = await get_message_args()
|
||||
message_args = await self.get_message_args(message)
|
||||
# get function arguments
|
||||
function_args = get_function_args()
|
||||
function_args = self.get_function_args(function)
|
||||
# compare message arguments with function parameter
|
||||
# it can have less arguments but not more
|
||||
extra_function_args = function_args - message_args
|
||||
|
@ -148,14 +127,13 @@ class RegisterDispatcher:
|
|||
module_name = function.__module__.split('.')[-2]
|
||||
message_namespace = message.split('.', 1)[0]
|
||||
message_risotto_module, message_namespace, message_name = message.split('.', 2)
|
||||
if message_risotto_module != self.risotto_module:
|
||||
raise RegistrationError(_(f'cannot registered the "{message}" is not "{self.risotto_module}"'))
|
||||
if message_risotto_module not in self.risotto_modules:
|
||||
raise RegistrationError(_(f'cannot registered the "{message}" is not "{self.risotto_modules}"'))
|
||||
if self.messages[version][message]['pattern'] == 'rpc' and message_namespace != module_name:
|
||||
raise RegistrationError(_(f'cannot registered the "{message}" message in module "{module_name}"'))
|
||||
|
||||
# True if first argument is the risotto_context
|
||||
function_args = self.get_function_args(function)
|
||||
function_args.pop(0)
|
||||
|
||||
# check if already register
|
||||
if 'function' in self.messages[version][message]:
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
from os import listdir
|
||||
from os.path import isdir, isfile, dirname, abspath, basename, join
|
||||
from importlib import import_module
|
||||
# from ..dispatcher import dispatcher
|
||||
|
||||
|
||||
def list_modules():
|
||||
abs_here = dirname(abspath(__file__))
|
||||
here = basename(abs_here)
|
||||
return [name for name in listdir(abs_here) if not name.startswith('__') and isdir(join(abs_here, name))]
|
||||
|
||||
|
||||
def load_submodules(dispatcher,
|
||||
modules=None,
|
||||
validate: bool=True,
|
||||
test: bool=False):
|
||||
abs_here = dirname(abspath(__file__))
|
||||
here = basename(abs_here)
|
||||
module = basename(dirname(abs_here))
|
||||
if not modules:
|
||||
modules = listdir(abs_here)
|
||||
for module in modules:
|
||||
absmodule = join(abs_here, module)
|
||||
if isdir(absmodule):
|
||||
for submodule in listdir(absmodule):
|
||||
absfilename = join(absmodule, submodule)
|
||||
if isdir(absfilename) and isfile(join(absfilename, '__init__.py')):
|
||||
dispatcher.set_module(submodule,
|
||||
import_module(f'.{here}.{module}.{submodule}',
|
||||
f'risotto'),
|
||||
test)
|
||||
if validate:
|
||||
dispatcher.validate()
|
Loading…
Reference in New Issue