diff --git a/src/risotto/__init__.py b/src/risotto/__init__.py index 5756fda..71a5903 100644 --- a/src/risotto/__init__.py +++ b/src/risotto/__init__.py @@ -1,13 +1,3 @@ -from pkg_resources import iter_entry_points +from .http import get_app, services -class Service: - pass - -services = Service() -for ep in iter_entry_points(group='risotto_services'): - setattr(services, ep.name, ep.load()) - -def list_modules(): - return services -from .http import get_app __ALL__ = ('get_app', 'services') diff --git a/src/risotto/controller.py b/src/risotto/controller.py index e85524b..459751b 100644 --- a/src/risotto/controller.py +++ b/src/risotto/controller.py @@ -2,7 +2,7 @@ from .config import get_config from .dispatcher import dispatcher from .context import Context from .remote import remote -from .services import list_modules +from . import services from .utils import _ @@ -11,7 +11,7 @@ class Controller: """ def __init__(self, test: bool): - self.risotto_modules = list_modules() + self.risotto_modules = services.get_list() async def call(self, uri: str, diff --git a/src/risotto/http.py b/src/risotto/http.py index 5a366ff..30f79a5 100644 --- a/src/risotto/http.py +++ b/src/risotto/http.py @@ -11,13 +11,13 @@ from .error import CallError, NotAllowedError, RegistrationError from .message import get_messages from .logger import log from .config import get_config -from .services import list_modules, load_submodules +from . import services extra_routes = {} -RISOTTO_MODULES = list_modules() +RISOTTO_MODULES = services.get_list() def create_context(request): @@ -124,7 +124,7 @@ async def get_app(loop): """ build all routes """ global extra_routes - load_submodules(dispatcher) + services.link_to_dispatcher(dispatcher) app = Application(loop=loop) routes = [] default_storage.engine('dictionary') diff --git a/src/risotto/register.py b/src/risotto/register.py index 80bf713..fc055ca 100644 --- a/src/risotto/register.py +++ b/src/risotto/register.py @@ -3,15 +3,52 @@ from inspect import signature from typing import Callable, Optional import asyncpg from json import dumps, loads - +import risotto from .utils import _ from .error import RegistrationError from .message import get_messages from .context import Context from .config import get_config from .logger import log -from .services import list_modules +from pkg_resources import iter_entry_points +class Services(): + modules_list = [] + modules_loaded = False + + def load_modules(self): + for entry_point in iter_entry_points(group='risotto_services'): + setattr(self, entry_point.name, entry_point.load()) + self.modules_loaded = True + + def list_modules(self): + for entry_point in iter_entry_points(group='risotto_services'): + self.modules_list.append(entry_point.name) + + def get_modules(self): + if not self.modules_loaded: + self.load_modules() + return [(m, getattr(self, m)) for m in self.modules_list] + + def get_list(self): + return self.modules_list + + def link_to_dispatcher(self, + dispatcher, + validate: bool=True, + test: bool=False, + ): + for module_name, module in self.get_modules(): + dispatcher.set_module(module_name, + module, + test) + if validate: + dispatcher.validate() + + +services = Services() +services.list_modules() +setattr(risotto, 'services', services) def register(uris: str, notification: str=None): @@ -37,7 +74,7 @@ class RegisterDispatcher: # postgresql pool self.pool = None # load tiramisu objects - self.risotto_modules = list_modules() + self.risotto_modules = services.get_list() messages, self.option = get_messages(self.risotto_modules) # list of uris with informations: {"v1": {"module_name.xxxxx": yyyyyy}} self.messages = {}