forked from Infra/risotto
change for tests
This commit is contained in:
@ -42,7 +42,7 @@ class CallDispatcher:
|
||||
raise Exception('hu?')
|
||||
else:
|
||||
for ret in returns:
|
||||
async with await Config(response) as config:
|
||||
async with await Config(response, display_name=lambda self, dyn_name, suffix: self.impl_getname()) as config:
|
||||
await config.property.read_write()
|
||||
try:
|
||||
for key, value in ret.items():
|
||||
|
@ -20,9 +20,6 @@ from . import services
|
||||
extra_routes = {}
|
||||
|
||||
|
||||
RISOTTO_MODULES = services.get_services_list()
|
||||
|
||||
|
||||
def create_context(request):
|
||||
risotto_context = Context()
|
||||
risotto_context.username = request.match_info.get('username',
|
||||
@ -100,8 +97,8 @@ async def handle(request):
|
||||
|
||||
async def api(request,
|
||||
risotto_context):
|
||||
global tiramisu
|
||||
if not tiramisu:
|
||||
global TIRAMISU
|
||||
if not TIRAMISU:
|
||||
# check all URI that have an associated role
|
||||
# all URI without role is concidered has a private URI
|
||||
uris = []
|
||||
@ -114,13 +111,14 @@ async def api(request,
|
||||
WHERE RoleURI.URIId = URI.URIId
|
||||
'''
|
||||
uris = [uri['uriname'] for uri in await connection.fetch(sql)]
|
||||
async with await Config(get_messages(current_module_names=RISOTTO_MODULES,
|
||||
risotto_modules = services.get_services_list()
|
||||
async with await Config(get_messages(current_module_names=risotto_modules,
|
||||
load_shortarg=True,
|
||||
current_version=risotto_context.version,
|
||||
uris=uris)[1]) as config:
|
||||
await config.property.read_write()
|
||||
tiramisu = await config.option.dict(remotable='none')
|
||||
return tiramisu
|
||||
TIRAMISU = await config.option.dict(remotable='none')
|
||||
return TIRAMISU
|
||||
|
||||
|
||||
async def get_app(loop):
|
||||
@ -169,4 +167,4 @@ async def get_app(loop):
|
||||
return await loop.create_server(app.make_handler(), '*', get_config()['http_server']['port'])
|
||||
|
||||
|
||||
tiramisu = None
|
||||
TIRAMISU = None
|
||||
|
@ -3,7 +3,7 @@ try:
|
||||
except:
|
||||
from tiramisu import Config
|
||||
from inspect import signature
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, List
|
||||
import asyncpg
|
||||
from json import dumps, loads
|
||||
import risotto
|
||||
@ -25,11 +25,14 @@ class Services():
|
||||
self.services.setdefault(entry_point.name, [])
|
||||
self.services_loaded = True
|
||||
|
||||
def load_modules(self):
|
||||
def load_modules(self,
|
||||
limit_services: Optional[List[str]]=None,
|
||||
) -> None:
|
||||
for entry_point in iter_entry_points(group='risotto_modules'):
|
||||
service_name, module_name = entry_point.name.split('.')
|
||||
setattr(self, module_name, entry_point.load())
|
||||
self.services[service_name].append(module_name)
|
||||
if limit_services is None or service_name in limit_services:
|
||||
setattr(self, module_name, entry_point.load())
|
||||
self.services[service_name].append(module_name)
|
||||
self.modules_loaded = True
|
||||
|
||||
def get_services(self):
|
||||
@ -37,9 +40,11 @@ class Services():
|
||||
self.load_services()
|
||||
return [(s, getattr(self, s)) for s in self.services]
|
||||
|
||||
def get_modules(self):
|
||||
def get_modules(self,
|
||||
limit_services: Optional[List[str]]=None,
|
||||
) -> List[str]:
|
||||
if not self.modules_loaded:
|
||||
self.load_modules()
|
||||
self.load_modules(limit_services=limit_services)
|
||||
return [(m, getattr(self, m)) for s in self.services.values() for m in s]
|
||||
|
||||
def get_services_list(self):
|
||||
@ -52,8 +57,9 @@ class Services():
|
||||
dispatcher,
|
||||
validate: bool=True,
|
||||
test: bool=False,
|
||||
limit_services: Optional[List[str]]=None,
|
||||
):
|
||||
for module_name, module in self.get_modules():
|
||||
for module_name, module in self.get_modules(limit_services=limit_services):
|
||||
dispatcher.set_module(module_name,
|
||||
module,
|
||||
test)
|
||||
@ -65,6 +71,7 @@ services = Services()
|
||||
services.load_services()
|
||||
setattr(risotto, 'services', services)
|
||||
|
||||
|
||||
def register(uris: str,
|
||||
notification: str=None):
|
||||
""" Decorator to register function to the dispatcher
|
||||
@ -255,7 +262,9 @@ class RegisterDispatcher:
|
||||
if missing_messages:
|
||||
raise RegistrationError(_(f'no matching function for uri {missing_messages}'))
|
||||
|
||||
async def on_join(self):
|
||||
async def on_join(self,
|
||||
truncate: bool=False,
|
||||
) -> None:
|
||||
internal_user = get_config()['global']['internal_user']
|
||||
async with self.pool.acquire() as connection:
|
||||
await connection.set_type_codec(
|
||||
@ -264,6 +273,9 @@ class RegisterDispatcher:
|
||||
decoder=loads,
|
||||
schema='pg_catalog'
|
||||
)
|
||||
if truncate:
|
||||
async with connection.transaction():
|
||||
await connection.execute('TRUNCATE applicationservicedependency, deployment, factoryclusternode, factorycluster, log, release, userrole, risottouser, roleuri, infraserver, settingserver, servermodel, site, source, uri, userrole, zone, applicationservice')
|
||||
async with connection.transaction():
|
||||
for module_name, module in self.injected_self.items():
|
||||
risotto_context = Context()
|
||||
@ -286,12 +298,14 @@ class RegisterDispatcher:
|
||||
for version, messages in self.messages.items():
|
||||
for message, message_infos in messages.items():
|
||||
if message_infos['pattern'] == 'rpc':
|
||||
module_name = message_infos['module']
|
||||
function = message_infos['function']
|
||||
await self.valid_rpc_params(version,
|
||||
message,
|
||||
function,
|
||||
module_name)
|
||||
# module not available during test
|
||||
if 'module' in message_infos:
|
||||
module_name = message_infos['module']
|
||||
function = message_infos['function']
|
||||
await self.valid_rpc_params(version,
|
||||
message,
|
||||
function,
|
||||
module_name)
|
||||
elif 'functions' in message_infos:
|
||||
# event with functions
|
||||
for function_infos in message_infos['functions']:
|
||||
|
Reference in New Issue
Block a user