Test d’un service utilisant une base de données.

This commit is contained in:
2019-12-02 10:48:24 +01:00
parent 8c91e01a2b
commit dcaf7da3bc
7 changed files with 99 additions and 35 deletions

View File

@ -1,21 +0,0 @@
import psycopg2
KEY_DATABASE_NAME = 'dbname'
KEY_DATABASE_USER = 'user'
KEY_DATABASE_PASSWORD = 'password'
KEY_DATABASE_HOST = 'host'
def connect(conf, database_options=None):
if database_options is None:
option = conf.option['database']
database_options = {
'host': option[KEY_DATABASE_HOST],
'dbname': option[KEY_DATABASE_NAME],
'user': option[KEY_DATABASE_USER],
'password': option[KEY_DATABASE_PASSWORD]
}
if not database_options['host']:
raise Exception('cannot find postgresql')
return psycopg2.connect(**database_options)

View File

@ -10,9 +10,8 @@ from .config import DEBUG
from .config import get_config
from .context import Context
from . import register
import asyncpg
def connect(db_conf):
return psycopg2.connect(**db_conf)
class CallDispatcher:
def valid_public_function(self,
@ -38,6 +37,7 @@ class CallDispatcher:
raise CallError(str(err))
else:
if not isinstance(returns, dict):
log.error_msg(risotto_context, kwargs, returns)
err = _(f'function {module_name}.{function_name} has to return a dict')
log.error_msg(risotto_context, kwargs, err)
raise CallError(str(err))
@ -97,10 +97,15 @@ class CallDispatcher:
risotto_context.function = obj['function']
if obj['risotto_context']:
kw['risotto_context'] = risotto_context
if obj['database']:
db_conf = get_config.get('database')
risotto_context.db_cursor = await connect(db_conf).cursor()
returns = await risotto_context.function(self.injected_self[obj['module']], **kw)
if 'database' in obj and obj['database']:
db_conf = get_config().get('database')
pool = await asyncpg.create_pool(database=db_conf.get('dbname'), user=db_conf.get('user'))
async with pool.acquire() as connection:
risotto_context.connection = connection
async with connection.transaction():
returns = await risotto_context.function(self.injected_self[obj['module']], **kw)
else:
returns = await risotto_context.function(self.injected_self[obj['module']], **kw)
except CallError as err:
raise err
except Exception as err:

View File

@ -9,7 +9,8 @@ from .config import INTERNAL_USER
def register(uris: str,
notification: str=undefined):
notification: str=undefined,
database: bool=False):
""" Decorator to register function to the dispatcher
"""
if not isinstance(uris, list):
@ -21,6 +22,7 @@ def register(uris: str,
dispatcher.set_function(version,
message,
notification,
database,
function)
return decorator
@ -129,6 +131,7 @@ class RegisterDispatcher:
version: str,
message: str,
notification: str,
database: bool,
function: Callable):
""" register a function to an URI
URI is a message
@ -180,6 +183,7 @@ class RegisterDispatcher:
function,
function_args,
inject_risotto_context,
database,
notification)
def register_rpc(self,
@ -189,11 +193,13 @@ class RegisterDispatcher:
function: Callable,
function_args: list,
inject_risotto_context: bool,
database: bool,
notification: Optional[str]):
self.messages[version][message]['module'] = module_name
self.messages[version][message]['function'] = function
self.messages[version][message]['arguments'] = function_args
self.messages[version][message]['risotto_context'] = inject_risotto_context
self.messages[version][message]['database'] = database
if notification:
self.messages[version][message]['notification'] = notification
@ -204,6 +210,7 @@ class RegisterDispatcher:
function: Callable,
function_args: list,
inject_risotto_context: bool,
database: bool,
notification: Optional[str]):
if 'functions' not in self.messages[version][message]:
self.messages[version][message]['functions'] = []

View File

@ -1,14 +1,35 @@
from ...controller import Controller
from ...register import register
sql_init = """
-- Création de la table ServerModel
CREATE TABLE ServerModel (
ServerModelId SERIAL PRIMARY KEY,
ServerModelName VARCHAR(255) NOT NULL,
ServerModelDescription VARCHAR(255) NOT NULL,
ServerModelSourceId INTEGER NOT NULL,
ServerModelParentId INTEGER,
ServerModelSubReleaseId INTEGER NOT NULL,
ServerModelSubReleaseName VARCHAR(255) NOT NULL,
UNIQUE (ServerModelName, ServerModelSubReleaseId)
);
"""
class Risotto(Controller):
@register('v1.servermodel.list', None)
async def servermodel_list(self, sourceid):
return [{'servermodelid': 1,
'servermodelname': 'name',
'subreleasename': 'name',
'sourceid': 1,
'servermodeldescription': 'description'}]
@register('v1.servermodel.init', None, database=True)
async def servermodel_init(self, risotto_context):
result = await risotto_context.connection.execute(sql_init)
return {'retcode': 0, 'return': result}
@register('v1.servermodel.list', None, database=True)
async def servermodel_list(self, risotto_context, sourceid):
sql = '''
SELECT * FROM ServerModel
'''
servermodels = await risotto_context.connection.fetch(sql)
return [dict(r) for r in servermodels]
@register('v1.servermodel.describe', None)
async def servermodel_describe(self, inheritance, creolefuncs, servermodelid, schema, conffiles, resolvdepends, probes):