from aiohttp.web import Application, Response, get, post, HTTPBadRequest, HTTPInternalServerError, HTTPNotFound, HTTPUnauthorized from aiohttp import BasicAuth, RequestInfo from json import dumps from traceback import print_exc from tiramisu import Config import datetime import jwt from .dispatcher import dispatcher from .utils import _ from .context import Context from .error import CallError, NotAllowedError, RegistrationError from .message import get_messages from .logger import log from .config import get_config from .services import load_services def create_context(request): risotto_context = Context() if 'Authorization' in request.headers: token = request.headers['Authorization'] if not token.startswith("Bearer "): raise HTTPBadRequest(reason='Unexpected bearer format') token = token[7:] decoded = verify_token(token) if 'user' in decoded: risotto_context.username = decoded['user'] return risotto_context else: risotto_context.username = request.match_info.get('username', get_config()['http_server']['default_user']) return risotto_context def register(version: str, path: str): """ Decorator to register function to the http route """ def decorator(function): if path in extra_routes: raise RegistrationError(f'the route {path} is already registered') extra_routes[path] = {'function': function, 'version': version} return decorator class extra_route_handler: async def __new__(cls, request): kwargs = dict(request.match_info) kwargs['request'] = request kwargs['risotto_context'] = create_context(request) kwargs['risotto_context'].version = cls.version kwargs['risotto_context'].paths.append(cls.path) kwargs['risotto_context'].type = 'http_get' function_name = cls.function.__module__ # if not 'api' function if function_name != 'risotto.http': module_name = function_name.split('.')[-2] kwargs['self'] = dispatcher.injected_self[module_name] try: returns = await cls.function(**kwargs) except NotAllowedError as err: raise HTTPUnauthorized(reason=str(err)) except CallError as err: raise HTTPBadRequest(reason=str(err)) except Exception as err: if get_config()['global']['debug']: print_exc() raise HTTPInternalServerError(reason=str(err)) # await log.info_msg(kwargs['risotto_context'], # dict(request.match_info)) return Response(text=dumps(returns)) async def handle(request): version, message = request.match_info.get_info()['path'].rsplit('/', 2)[-2:] risotto_context = create_context(request) kwargs = await request.json() try: pattern = dispatcher.messages[version][message]['pattern'] if pattern == 'rpc': method = dispatcher.call else: method = dispatcher.publish text = await method(version, message, risotto_context, check_role=True, **kwargs) except NotAllowedError as err: raise HTTPUnauthorized(reason=str(err)) except CallError as err: raise HTTPBadRequest(reason=str(err).replace('\n', ' ')) except Exception as err: if get_config()['global']['debug']: print_exc() raise HTTPInternalServerError(reason=str(err)) return Response(text=dumps({'response': text})) async def api(request, risotto_context): global tiramisu if not tiramisu: # check all URI that have an associated role # all URI without role is concidered has a private URI uris = [] async with dispatcher.pool.acquire() as connection: async with connection.transaction(): # Check role with ACL sql = ''' SELECT URI.URIName FROM URI, RoleURI WHERE RoleURI.URIId = URI.URIId ''' uris = [uri['uriname'] for uri in await connection.fetch(sql)] config = await Config(get_messages(load_shortarg=True, uris=uris)[1]) await config.property.read_write() tiramisu = await config.option.dict(remotable='none') return tiramisu extra_routes = {'': {'function': api, 'version': 'v1'}} async def get_app(loop): """ build all routes """ global extra_routes load_services() app = Application(loop=loop) routes = [] await dispatcher.load() for version, messages in dispatcher.messages.items(): print() print(_('======== Registered messages ========')) for message in messages: web_message = f'/api/{version}/{message}' pattern = dispatcher.messages[version][message]['pattern'] print(f' - {web_message} ({pattern})') routes.append(post(web_message, handle)) print() print(_('======== Registered extra routes ========')) for path, extra in extra_routes.items(): version = extra['version'] path = f'/api/{version}{path}' extra['path'] = path extra_handler = type(path, (extra_route_handler,), extra) routes.append(get(path, extra_handler)) print(f' - {path} (http_get)') # routes.append(get(f'/api/{version}', api)) print() del extra_routes app.add_routes(routes) app.router.add_post('/auth', auth) app.router.add_post('/access_token', access_token) await dispatcher.on_join() return await loop.create_server(app.make_handler(), '*', get_config()['http_server']['port']) async def auth(request): auth_code = request.headers['Authorization'] if not auth_code.startswith("Basic "): raise HTTPBadRequest(reason='Unexpected bearer format') auth = BasicAuth.decode(auth_code) async with dispatcher.pool.acquire() as connection: async with connection.transaction(): # Check role with ACL sql = ''' SELECT UserName FROM RisottoUser WHERE UserLogin = $1 AND UserPassword = crypt($2, UserPassword); ''' res = await connection.fetch(sql, auth.login, auth.password) if res: res = gen_token(auth) if verify_token(res): return Response(text=str(res.decode('utf-8'))) else: return HTTPInternalServerError(reason='Token could not be verified just after creation') else: raise HTTPUnauthorized(reason='Unauthorized') def gen_token(auth): secret = get_config()['jwt']['secret'] expire = get_config()['jwt']['token_expire'] issuer = get_config()['jwt']['issuer'] audience = get_config()['jwt']['audience'] payload = { 'user': auth.login, 'exp': datetime.datetime.utcnow() + datetime.timedelta(seconds=expire), 'iss': issuer, 'aud': audience } token = jwt.encode(payload, secret, algorithm='HS256') return token def access_token(request): expire = get_config()['jwt']['token_expire'] secret = get_config()['jwt']['secret'] token = request.headers['Authorization'] if not token.startswith("Bearer "): raise HTTPBadRequest(reason='Unexpected bearer format') token = token[7:] decoded = verify_token(token) if decoded: decoded['exp'] = datetime.datetime.utcnow() + datetime.timedelta(seconds=expire) token = jwt.encode(decoded, secret, algorithm='HS256') return Response(text=str(token.decode('utf-8'))) else: return HTTPUnauthorized(reason='Token could not be verified') def verify_token(token): secret = get_config()['jwt']['secret'] issuer = get_config()['jwt']['issuer'] audience = get_config()['jwt']['audience'] try: decoded = jwt.decode(token, secret, issuer=issuer, audience=audience, algorithms=['HS256']) except jwt.ExpiredSignatureError: raise HTTPUnauthorized(reason='Token Expired') except jwt.InvalidIssuerError: raise HTTPUnauthorized(reason='Token could not be verified') except jwt.InvalidAudienceError: raise HTTPUnauthorized(reason='Token audience not match') return decoded tiramisu = None