risotto/src/risotto/http.py

232 lines
8.4 KiB
Python

from aiohttp.web import Application, Response, get, post, HTTPBadRequest, HTTPInternalServerError, HTTPNotFound, static
from json import dumps
from traceback import print_exc
try:
from tiramisu3 import Config, default_storage
except:
from tiramisu import Config, default_storage
from .dispatcher import get_dispatcher
from .utils import _
from .context import Context
from .error import CallError, NotAllowedError, RegistrationError
from .message import get_messages
#from .logger import log
from .config import get_config
from . import services
extra_routes = {}
extra_statics = {}
def create_context(request):
risotto_context = Context()
if 'username' in dict(request.match_info):
username = request.match_info['username']
elif 'username' in request.headers:
username = request.headers['username']
else:
username = get_config()['http_server']['default_user']
risotto_context.username = username
return risotto_context
def register(version: str,
path: str,
):
""" Decorator to register function to the http route
"""
def decorator(function):
if path in extra_routes:
raise RegistrationError(f'the route "{path}" is already registered')
extra_routes[path] = {'function': function,
'version': version,
}
return decorator
def register_static(path: str,
directory: str,
) -> None:
if path in extra_statics:
raise RegistrationError(f'the static path "{path}" is already registered')
extra_statics[path] = directory
class extra_route_handler:
async def __new__(cls,
request,
):
kwargs = dict(request.match_info)
kwargs['request'] = request
kwargs['risotto_context'] = create_context(request)
kwargs['risotto_context'].version = cls.version
kwargs['risotto_context'].paths.append(cls.path)
kwargs['risotto_context'].type = 'http_get'
function_name = cls.function.__module__
# if not 'api' function
if function_name != 'risotto.http':
risotto_module_name, submodule_name = function_name.split('.', 2)[:-1]
module_name = risotto_module_name.split('_')[-1]
dispatcher = get_dispatcher()
kwargs['self'] = dispatcher.injected_self[module_name + '.' + submodule_name]
try:
returns = await cls.function(**kwargs)
except NotAllowedError as err:
raise HTTPNotFound(reason=str(err))
except CallError as err:
raise HTTPBadRequest(reason=str(err))
except Exception as err:
if get_config()['global']['debug']:
print_exc()
raise HTTPInternalServerError(reason=str(err))
# await log.info_msg(kwargs['risotto_context'],
# dict(request.match_info))
return Response(text=dumps(returns),
content_type='application/json',
)
async def handle(request):
version, message = request.match_info.get_info()['path'].rsplit('/', 2)[-2:]
risotto_context = create_context(request)
kwargs = await request.json()
try:
dispatcher = get_dispatcher()
pattern = dispatcher.messages[version][message]['pattern']
if pattern == 'rpc':
method = dispatcher.call
else:
method = dispatcher.publish
text = await method(version,
message,
risotto_context,
check_role=True,
internal=False,
**kwargs,
)
except Exception as err:
context_id = None
if isinstance(err, NotAllowedError):
error_type = HTTPNotFound
elif isinstance(err, CallError):
error_type = HTTPBadRequest
context_id = err.context_id
else:
if get_config()['global']['debug']:
print_exc()
error_type = HTTPInternalServerError
response = {'type': 'error',
'reason': str(err).replace('\n', ' '),
}
if context_id is not None:
response['context_id'] = context_id
err = dumps({'response': response,
'type': 'error',
})
raise error_type(text=err,
content_type='application/json',
)
return Response(text=dumps({'response': text,
'type': 'success',
}),
content_type='application/json',
)
async def api(request,
risotto_context,
):
global TIRAMISU
if not TIRAMISU:
# check all URI that have an associated role
# all URI without role is concidered has a private URI
uris = []
dispatcher = get_dispatcher()
async with dispatcher.pool.acquire() as connection:
async with connection.transaction():
# Check role with ACL
sql = '''
SELECT UserURI.URIName
FROM UserURI, UserRoleURI
WHERE UserRoleURI.URIId = UserURI.URIId
'''
uris = [uri['uriname'] for uri in await connection.fetch(sql)]
risotto_modules = services.get_services_list()
async with await Config(get_messages(current_module_names=risotto_modules,
load_shortarg=True,
current_version=risotto_context.version,
uris=uris,
)[1],
display_name=lambda self, dyn_name, suffix: self.impl_getname()) as config:
await config.property.read_write()
TIRAMISU = await config.option.dict(remotable='none')
return TIRAMISU
async def get_app(loop):
""" build all routes
"""
global extra_routes, extra_statics
dispatcher = get_dispatcher()
services.link_to_dispatcher(dispatcher)
app = Application(loop=loop)
routes = []
default_storage.engine('dictionary')
await dispatcher.load()
versions = []
for version, messages in dispatcher.messages.items():
if version not in versions:
versions.append(version)
print()
print(_('======== Registered messages ========'))
for message, message_infos in messages.items():
web_message = f'/api/{version}/{message}'
pattern = message_infos['pattern']
print(f' - {web_message} ({pattern})')
routes.append(post(web_message, handle))
print()
print(_('======== Registered api routes ========'))
for version in versions:
api_route = {'function': api,
'version': version,
'path': f'/api/{version}',
}
extra_handler = type(api_route['path'], (extra_route_handler,), api_route)
routes.append(get(api_route['path'], extra_handler))
print(f' - {api_route["path"]} (http_get)')
# last version is default version
routes.append(get('/api', extra_handler))
print(f' - /api (http_get)')
print()
if extra_routes:
print(_('======== Registered extra routes ========'))
for path, extra in extra_routes.items():
version = extra['version']
path = f'/api/{version}{path}'
extra['path'] = path
extra_handler = type(path, (extra_route_handler,), extra)
routes.append(get(path, extra_handler))
print(f' - {path} (http_get)')
if extra_statics:
if not extra_routes:
print(_('======== Registered static routes ========'))
for path, directory in extra_statics.items():
routes.append(static(path, directory))
print(f' - {path} (static)')
del extra_routes
del extra_statics
app.router.add_routes(routes)
await dispatcher.register_remote()
print()
await dispatcher.on_join()
return await loop.create_server(app.make_handler(),
'*',
get_config()['http_server']['port'],
)
TIRAMISU = None