229 lines
8.5 KiB
Python
229 lines
8.5 KiB
Python
from aiohttp.web import Application, Response, get, post, HTTPBadRequest, HTTPInternalServerError, HTTPNotFound, HTTPUnauthorized
|
|
from aiohttp import BasicAuth, RequestInfo
|
|
from json import dumps
|
|
from traceback import print_exc
|
|
from tiramisu import Config
|
|
import datetime
|
|
import jwt
|
|
|
|
from .dispatcher import dispatcher
|
|
from .utils import _
|
|
from .context import Context
|
|
from .error import CallError, NotAllowedError, RegistrationError
|
|
from .message import get_messages
|
|
from .logger import log
|
|
from .config import get_config
|
|
from .services import load_services
|
|
|
|
|
|
def create_context(request):
|
|
risotto_context = Context()
|
|
if 'Authorization' in request.headers:
|
|
token = request.headers['Authorization']
|
|
if not token.startswith("Bearer "):
|
|
raise HTTPBadRequest(reason='Unexpected bearer format')
|
|
token = token[7:]
|
|
decoded = verify_token(token)
|
|
if 'user' in decoded:
|
|
risotto_context.username = decoded['user']
|
|
return risotto_context
|
|
else:
|
|
risotto_context.username = request.match_info.get('username',
|
|
get_config()['http_server']['default_user'])
|
|
return risotto_context
|
|
|
|
|
|
def register(version: str,
|
|
path: str):
|
|
""" Decorator to register function to the http route
|
|
"""
|
|
def decorator(function):
|
|
if path in extra_routes:
|
|
raise RegistrationError(f'the route {path} is already registered')
|
|
extra_routes[path] = {'function': function,
|
|
'version': version}
|
|
return decorator
|
|
|
|
|
|
class extra_route_handler:
|
|
async def __new__(cls, request):
|
|
kwargs = dict(request.match_info)
|
|
kwargs['request'] = request
|
|
kwargs['risotto_context'] = create_context(request)
|
|
kwargs['risotto_context'].version = cls.version
|
|
kwargs['risotto_context'].paths.append(cls.path)
|
|
kwargs['risotto_context'].type = 'http_get'
|
|
function_name = cls.function.__module__
|
|
# if not 'api' function
|
|
if function_name != 'risotto.http':
|
|
module_name = function_name.split('.')[-2]
|
|
kwargs['self'] = dispatcher.injected_self[module_name]
|
|
try:
|
|
returns = await cls.function(**kwargs)
|
|
except NotAllowedError as err:
|
|
raise HTTPUnauthorized(reason=str(err))
|
|
except CallError as err:
|
|
raise HTTPBadRequest(reason=str(err))
|
|
except Exception as err:
|
|
if get_config()['global']['debug']:
|
|
print_exc()
|
|
raise HTTPInternalServerError(reason=str(err))
|
|
# await log.info_msg(kwargs['risotto_context'],
|
|
# dict(request.match_info))
|
|
return Response(text=dumps(returns))
|
|
|
|
|
|
async def handle(request):
|
|
version, message = request.match_info.get_info()['path'].rsplit('/', 2)[-2:]
|
|
risotto_context = create_context(request)
|
|
kwargs = await request.json()
|
|
try:
|
|
pattern = dispatcher.messages[version][message]['pattern']
|
|
if pattern == 'rpc':
|
|
method = dispatcher.call
|
|
else:
|
|
method = dispatcher.publish
|
|
text = await method(version,
|
|
message,
|
|
risotto_context,
|
|
check_role=True,
|
|
**kwargs)
|
|
except NotAllowedError as err:
|
|
raise HTTPUnauthorized(reason=str(err))
|
|
except CallError as err:
|
|
raise HTTPBadRequest(reason=str(err).replace('\n', ' '))
|
|
except Exception as err:
|
|
if get_config()['global']['debug']:
|
|
print_exc()
|
|
raise HTTPInternalServerError(reason=str(err))
|
|
return Response(text=dumps({'response': text}))
|
|
|
|
|
|
async def api(request, risotto_context):
|
|
global tiramisu
|
|
if not tiramisu:
|
|
# check all URI that have an associated role
|
|
# all URI without role is concidered has a private URI
|
|
uris = []
|
|
async with dispatcher.pool.acquire() as connection:
|
|
async with connection.transaction():
|
|
# Check role with ACL
|
|
sql = '''
|
|
SELECT URI.URIName
|
|
FROM URI, RoleURI
|
|
WHERE RoleURI.URIId = URI.URIId
|
|
'''
|
|
uris = [uri['uriname'] for uri in await connection.fetch(sql)]
|
|
config = await Config(get_messages(load_shortarg=True, uris=uris)[1])
|
|
await config.property.read_write()
|
|
tiramisu = await config.option.dict(remotable='none')
|
|
return tiramisu
|
|
|
|
|
|
extra_routes = {'': {'function': api,
|
|
'version': 'v1'}}
|
|
|
|
|
|
async def get_app(loop):
|
|
""" build all routes
|
|
"""
|
|
global extra_routes
|
|
load_services()
|
|
app = Application(loop=loop)
|
|
routes = []
|
|
await dispatcher.load()
|
|
for version, messages in dispatcher.messages.items():
|
|
print()
|
|
print(_('======== Registered messages ========'))
|
|
for message in messages:
|
|
web_message = f'/api/{version}/{message}'
|
|
pattern = dispatcher.messages[version][message]['pattern']
|
|
print(f' - {web_message} ({pattern})')
|
|
routes.append(post(web_message, handle))
|
|
print()
|
|
print(_('======== Registered extra routes ========'))
|
|
for path, extra in extra_routes.items():
|
|
version = extra['version']
|
|
path = f'/api/{version}{path}'
|
|
extra['path'] = path
|
|
extra_handler = type(path, (extra_route_handler,), extra)
|
|
routes.append(get(path, extra_handler))
|
|
print(f' - {path} (http_get)')
|
|
# routes.append(get(f'/api/{version}', api))
|
|
print()
|
|
del extra_routes
|
|
app.add_routes(routes)
|
|
app.router.add_post('/auth', auth)
|
|
app.router.add_post('/access_token', access_token)
|
|
await dispatcher.on_join()
|
|
return await loop.create_server(app.make_handler(), '*', get_config()['http_server']['port'])
|
|
|
|
async def auth(request):
|
|
auth_code = request.headers['Authorization']
|
|
if not auth_code.startswith("Basic "):
|
|
raise HTTPBadRequest(reason='Unexpected bearer format')
|
|
auth = BasicAuth.decode(auth_code)
|
|
async with dispatcher.pool.acquire() as connection:
|
|
async with connection.transaction():
|
|
# Check role with ACL
|
|
sql = '''
|
|
SELECT UserName
|
|
FROM RisottoUser
|
|
WHERE UserLogin = $1
|
|
AND UserPassword = crypt($2, UserPassword);
|
|
'''
|
|
res = await connection.fetch(sql, auth.login, auth.password)
|
|
if res:
|
|
res = gen_token(auth)
|
|
if verify_token(res):
|
|
return Response(text=str(res.decode('utf-8')))
|
|
else:
|
|
return HTTPInternalServerError(reason='Token could not be verified just after creation')
|
|
else:
|
|
raise HTTPUnauthorized(reason='Unauthorized')
|
|
|
|
def gen_token(auth):
|
|
secret = get_config()['jwt']['secret']
|
|
expire = get_config()['jwt']['token_expire']
|
|
issuer = get_config()['jwt']['issuer']
|
|
audience = get_config()['jwt']['audience']
|
|
payload = {
|
|
'user': auth.login,
|
|
'exp': datetime.datetime.utcnow() + datetime.timedelta(seconds=expire),
|
|
'iss': issuer,
|
|
'aud': audience
|
|
}
|
|
token = jwt.encode(payload, secret, algorithm='HS256')
|
|
return token
|
|
|
|
def access_token(request):
|
|
expire = get_config()['jwt']['token_expire']
|
|
secret = get_config()['jwt']['secret']
|
|
token = request.headers['Authorization']
|
|
if not token.startswith("Bearer "):
|
|
raise HTTPBadRequest(reason='Unexpected bearer format')
|
|
token = token[7:]
|
|
decoded = verify_token(token)
|
|
if decoded:
|
|
decoded['exp'] = datetime.datetime.utcnow() + datetime.timedelta(seconds=expire)
|
|
token = jwt.encode(decoded, secret, algorithm='HS256')
|
|
return Response(text=str(token.decode('utf-8')))
|
|
else:
|
|
return HTTPUnauthorized(reason='Token could not be verified')
|
|
|
|
def verify_token(token):
|
|
secret = get_config()['jwt']['secret']
|
|
issuer = get_config()['jwt']['issuer']
|
|
audience = get_config()['jwt']['audience']
|
|
try:
|
|
decoded = jwt.decode(token, secret, issuer=issuer, audience=audience, algorithms=['HS256'])
|
|
except jwt.ExpiredSignatureError:
|
|
raise HTTPUnauthorized(reason='Token Expired')
|
|
except jwt.InvalidIssuerError:
|
|
raise HTTPUnauthorized(reason='Token could not be verified')
|
|
except jwt.InvalidAudienceError:
|
|
raise HTTPUnauthorized(reason='Token audience not match')
|
|
return decoded
|
|
|
|
tiramisu = None
|