add session tests

This commit is contained in:
Emmanuel Garette 2019-12-07 16:21:20 +01:00
parent 3c5285a7d2
commit 3b31f092bd
25 changed files with 705 additions and 193 deletions

View File

@ -18,9 +18,10 @@ parameters:
ref: Server.ServerId ref: Server.ServerId
description: | description: |
Identifiant du serveur. Identifiant du serveur.
deploy: deployed:
type: Boolean type: Boolean
description: Configuration de type déployée. description: Configuration de type déployée.
default: true
response: response:
type: ConfigConfiguration type: ConfigConfiguration

View File

@ -17,6 +17,6 @@ parameters:
type: Number type: Number
description: | description: |
Identifiant du serveur. Identifiant du serveur.
deploy: deployed:
type: Boolean type: Boolean
description: Configuration de type déployée. description: Configuration de type déployée.

View File

@ -10,6 +10,5 @@ public: false
domain: servermodel-domain domain: servermodel-domain
parameters: parameters:
servermodels: type: Servermodel
type: '[]Servermodel' description: Informations sur les modèles de serveur créés.
description: Informations sur les modèles de serveur créés.

View File

@ -10,6 +10,5 @@ public: false
domain: servermodel-domain domain: servermodel-domain
parameters: parameters:
servermodels: type: 'Servermodel'
type: '[]Servermodel' description: Informations sur les modèles de serveur modifiés.
description: Informations sur les modèles de serveur modifiés.

View File

@ -2,7 +2,7 @@
uri: session.server.get uri: session.server.get
description: | description: |
Configure le server. Récupérer la configuration du server.
pattern: rpc pattern: rpc
@ -16,6 +16,11 @@ parameters:
ref: Config.SessionId ref: Config.SessionId
shortarg: s shortarg: s
description: Identifiant de la configuration. description: Identifiant de la configuration.
name:
type: String
shortarg: n
description: Nom de la variable.
default: null
response: response:
type: Session type: Session

View File

@ -18,6 +18,6 @@ parameters:
description: Identifiant de la session. description: Identifiant de la session.
response: response:
type: SessionConfigurationStatus type: Session
description: Statut de la configuration. description: Statut de la configuration.

View File

@ -16,6 +16,11 @@ parameters:
ref: Config.SessionId ref: Config.SessionId
shortarg: s shortarg: s
description: Identifiant de la configuration. description: Identifiant de la configuration.
name:
type: String
shortarg: n
description: Nom de la variable.
default: null
response: response:
type: Session type: Session

View File

@ -18,6 +18,6 @@ parameters:
description: Identifiant de la session. description: Identifiant de la session.
response: response:
type: SessionConfigurationStatus type: Session
description: Statut de la configuration. description: Statut de la configuration.

View File

@ -3,8 +3,15 @@ title: ConfigConfiguration
type: object type: object
description: Description de la configuration. description: Description de la configuration.
properties: properties:
server_id:
type: number
description: Identifiant du serveur.
ref: Server.ServerId
deployed:
type: boolean
description: La configuration est déployée.
configuration: configuration:
type: File type: object
description: Détail de la configuration au format JSON. description: Détail de la configuration au format JSON.
required: required:
- configuration - configuration

View File

@ -37,16 +37,16 @@ properties:
type: object type: object
description: Liste des services applicatifs déclarés pour ce modèle de serveur. description: Liste des services applicatifs déclarés pour ce modèle de serveur.
schema: schema:
type: File type: string
description: Contenu du schema. description: Contenu du schema.
probes: probes:
type: File type: string
description: Informations sur les sondes. description: Informations sur les sondes.
creolefuncs: creolefuncs:
type: File type: string
description: Fonctions Creole. description: Fonctions Creole.
conffiles: conffiles:
type: File type: string
description: Fichiers creole au format tar encodé base64 description: Fichiers creole au format tar encodé base64
required: required:
- servermodelid - servermodelid

View File

@ -1,24 +0,0 @@
---
title: SessionConfigurationStatus
type: object
description: Statut de la configuration.
properties:
session_id:
type: string
description: ID de la session.
ref: Config.SessionId
status:
type: string
description: Statut de la configuration (peut être ok, error, incomplete)
message:
type: string
description: Message d'erreur si la configuration a le statut error.
mandatories:
type: array
items:
type: string
description: Liste des variables obligatoires non renseignées si la configuration a le statut incomplete.
required:
- session_id
- status

View File

@ -12,9 +12,6 @@ properties:
index: index:
type: number type: number
description: Index de la variable a modifier. description: Index de la variable a modifier.
status:
type: string
description: Status de la modification.
message: message:
type: string type: string
description: Message d'erreur. description: Message d'erreur.

View File

@ -27,7 +27,7 @@ properties:
type: boolean type: boolean
description: La configuration est en mode debug. description: La configuration est en mode debug.
content: content:
type: file type: object
description: Contenu de la configuration. description: Contenu de la configuration.
required: required:
- session_id - session_id

View File

@ -1,6 +1,4 @@
from .http import get_app from .http import get_app
# just to register every route
from . import services as _services
__ALL__ = ('get_app',) __ALL__ = ('get_app',)

View File

@ -1,7 +1,7 @@
HTTP_PORT = 8080 HTTP_PORT = 8080
MESSAGE_ROOT_PATH = 'messages' MESSAGE_ROOT_PATH = 'messages'
ROOT_CACHE_DIR = 'cache' ROOT_CACHE_DIR = 'cache'
DEBUG = True DEBUG = False
DATABASE_DIR = 'database' DATABASE_DIR = 'database'
INTERNAL_USER = 'internal' INTERNAL_USER = 'internal'
CONFIGURATION_DIR = 'configurations' CONFIGURATION_DIR = 'configurations'

View File

@ -149,7 +149,7 @@ class PublishDispatcher:
return return
# config is ok, so publish the message # config is ok, so publish the message
for function_obj in self.messages[version][message]['functions']: for function_obj in self.messages[version][message].get('functions', []):
function = function_obj['function'] function = function_obj['function']
module_name = function.__module__.split('.')[-2] module_name = function.__module__.split('.')[-2]
function_name = function.__name__ function_name = function.__name__
@ -163,21 +163,21 @@ class PublishDispatcher:
if function_obj['risotto_context']: if function_obj['risotto_context']:
kw['risotto_context'] = risotto_context kw['risotto_context'] = risotto_context
# send event # send event
await function(self.injected_self[function_obj['module']], **kw) returns = await function(self.injected_self[function_obj['module']], **kw)
except Exception as err: except Exception as err:
if DEBUG: if DEBUG:
print_exc() print_exc()
log.error_msg(risotto_context, kwargs, err, info_msg) log.error_msg(risotto_context, kwargs, err, info_msg)
continue
else: else:
log.info_msg(risotto_context, kwargs, info_msg) log.info_msg(risotto_context, kwargs, info_msg)
# notification
# notification if function_obj.get('notification'):
if obj.get('notification'): notif_version, notif_message = function_obj['notification'].split('.', 1)
notif_version, notif_message = obj['notification'].split('.', 1) await self.publish(notif_version,
await self.publish(notif_version, notif_message,
notif_message, risotto_context,
risotto_context, **returns)
**returns)
class Dispatcher(register.RegisterDispatcher, CallDispatcher, PublishDispatcher): class Dispatcher(register.RegisterDispatcher, CallDispatcher, PublishDispatcher):

View File

@ -1,6 +1,9 @@
from aiohttp.web import Application, Response, get, post, HTTPBadRequest, HTTPInternalServerError, HTTPNotFound from aiohttp.web import Application, Response, get, post, HTTPBadRequest, HTTPInternalServerError, HTTPNotFound
from tiramisu import Config
from json import dumps from json import dumps
from traceback import print_exc
from tiramisu import Config
from .dispatcher import dispatcher from .dispatcher import dispatcher
from .utils import _ from .utils import _
from .context import Context from .context import Context
@ -8,7 +11,7 @@ from .error import CallError, NotAllowedError, RegistrationError
from .message import get_messages from .message import get_messages
from .logger import log from .logger import log
from .config import DEBUG, HTTP_PORT from .config import DEBUG, HTTP_PORT
from traceback import print_exc from .services import load_services
def create_context(request): def create_context(request):
@ -96,6 +99,7 @@ async def get_app(loop):
""" build all routes """ build all routes
""" """
global extra_routes global extra_routes
load_services()
app = Application(loop=loop) app = Application(loop=loop)
routes = [] routes = []
for version, messages in dispatcher.messages.items(): for version, messages in dispatcher.messages.items():

View File

@ -1,6 +1,7 @@
from typing import Dict from typing import Dict
from .context import Context from .context import Context
from .utils import _ from .utils import _
from .config import DEBUG
class Logger: class Logger:
@ -30,6 +31,7 @@ class Logger:
""" send message when an error append """ send message when an error append
""" """
paths_msg = self._get_message_paths(risotto_context) paths_msg = self._get_message_paths(risotto_context)
# if DEBUG:
print(_(f'{risotto_context.username}: ERROR: {error} ({paths_msg} with arguments "{arguments}": {msg})')) print(_(f'{risotto_context.username}: ERROR: {error} ({paths_msg} with arguments "{arguments}": {msg})'))
def info_msg(self, def info_msg(self,
@ -48,7 +50,8 @@ class Logger:
if msg: if msg:
tmsg += f' {msg}' tmsg += f' {msg}'
print(tmsg) if DEBUG:
print(tmsg)
log = Logger() log = Logger()

View File

@ -209,10 +209,10 @@ class RegisterDispatcher:
self.messages[version][message]['functions'] = [] self.messages[version][message]['functions'] = []
dico = {'module': module_name, dico = {'module': module_name,
'functions': function, 'function': function,
'arguments': function_args, 'arguments': function_args,
'risotto_context': inject_risotto_context} 'risotto_context': inject_risotto_context}
if notification: if notification and notification is not undefined:
dico['notification'] = notification dico['notification'] = notification
self.messages[version][message]['functions'].append(dico) self.messages[version][message]['functions'].append(dico)

View File

@ -4,15 +4,16 @@ from importlib import import_module
from ..dispatcher import dispatcher from ..dispatcher import dispatcher
def list_import(): def load_services(modules=None,
validate: bool=True):
abs_here = dirname(abspath(__file__)) abs_here = dirname(abspath(__file__))
here = basename(abs_here) here = basename(abs_here)
module = basename(dirname(abs_here)) module = basename(dirname(abs_here))
for filename in listdir(abs_here): if not modules:
modules = listdir(abs_here)
for filename in modules:
absfilename = join(abs_here, filename) absfilename = join(abs_here, filename)
if isdir(absfilename) and isfile(join(absfilename, '__init__.py')): if isdir(absfilename) and isfile(join(absfilename, '__init__.py')):
dispatcher.set_module(filename, import_module(f'.{here}.{filename}', module)) dispatcher.set_module(filename, import_module(f'.{here}.{filename}', module))
dispatcher.validate() if validate:
dispatcher.validate()
list_import()

View File

@ -1,9 +1,9 @@
from lxml.etree import parse from lxml.etree import parse
from io import BytesIO from io import BytesIO
from os import unlink
from os.path import isdir, isfile, join from os.path import isdir, isfile, join
from traceback import print_exc from traceback import print_exc
from json import dumps from typing import Dict, List
from typing import Dict
from tiramisu import Storage, delete_session, MetaConfig, MixConfig from tiramisu import Storage, delete_session, MetaConfig, MixConfig
from rougail import load as rougail_load from rougail import load as rougail_load
@ -22,7 +22,7 @@ class Risotto(Controller):
server = {} server = {}
def __init__(self) -> None: def __init__(self) -> None:
for dirname in [ROOT_CACHE_DIR, DATABASE_DIR, ROUGAIL_DTD_PATH]: for dirname in [ROOT_CACHE_DIR, DATABASE_DIR]:
if not isdir(dirname): if not isdir(dirname):
raise RegistrationError(_(f'unable to find the cache dir "{dirname}"')) raise RegistrationError(_(f'unable to find the cache dir "{dirname}"'))
self.save_storage = Storage(engine='sqlite3', dir_database=DATABASE_DIR) self.save_storage = Storage(engine='sqlite3', dir_database=DATABASE_DIR)
@ -58,7 +58,8 @@ class Risotto(Controller):
for servermodel in servermodels: for servermodel in servermodels:
if 'servermodelparentsid' in servermodel: if 'servermodelparentsid' in servermodel:
for servermodelparentid in servermodel['servermodelparentsid']: for servermodelparentid in servermodel['servermodelparentsid']:
self.servermodel_legacy(servermodel['servermodelname'], self.servermodel_legacy(risotto_context,
servermodel['servermodelname'],
servermodel['servermodelid'], servermodel['servermodelid'],
servermodelparentid) servermodelparentid)
@ -156,6 +157,7 @@ class Risotto(Controller):
return metaconfig return metaconfig
def servermodel_legacy(self, def servermodel_legacy(self,
risotto_context: Context,
servermodel_name: str, servermodel_name: str,
servermodel_id: int, servermodel_id: int,
servermodel_parent_id: int) -> None: servermodel_parent_id: int) -> None:
@ -284,99 +286,106 @@ class Risotto(Controller):
async def server_deleted(self, async def server_deleted(self,
server_id: int) -> None: server_id: int) -> None:
# delete config to it's parents # delete config to it's parents
for config in self.server[server_id].values(): for server_type in ['server', 'server_to_deploy']:
config = self.server[server_id]['server']
for parent in config.config.parents(): for parent in config.config.parents():
parent.config.pop(config.config.name()) parent.config.pop(config.config.name())
delete_session(config.config.name()) delete_session(storage=self.save_storage,
session_id=config.config.name())
# delete metaconfig # delete metaconfig
del self.server[server_id] del self.server[server_id]
@register('v1.servermodel.created') @register('v1.servermodel.created')
async def servermodel_created(self, async def servermodel_created(self,
servermodels) -> None: risotto_context: Context,
servermodelid: int,
servermodelname: str,
servermodelparentsid: List[int]) -> None:
""" when servermodels are created, load it and do link """ when servermodels are created, load it and do link
""" """
for servermodel in servermodels: await self.load_and_link_servermodel(risotto_context,
await self.load_servermodel(servermodel['servermodelid'], servermodel['servermodelname']) servermodelid,
for servermodel in servermodels: servermodelname,
if 'servermodelparentsid' in servermodel: servermodelparentsid)
for servermodelparentid in servermodel['servermodelparentsid']:
self.servermodel_legacy(servermodel['servermodelname'], servermodel['servermodelid'], servermodelparentid)
async def load_and_link_servermodel(self,
risotto_context: Context,
servermodelid: int,
servermodelname: str,
servermodelparentsid: List[int]) -> None:
await self.load_servermodel(risotto_context,
servermodelid,
servermodelname)
if servermodelparentsid is not None:
for servermodelparentid in servermodelparentsid:
self.servermodel_legacy(risotto_context,
servermodelname,
servermodelid,
servermodelparentid)
def servermodel_delete(self,
servermodelid: int) -> List[MetaConfig]:
metaconfig = self.servermodel.pop(servermodelid)
mixconfig = next(metaconfig.config.list())
children = []
for child in mixconfig.config.list():
children.append(child)
mixconfig.config.pop(child.config.name())
metaconfig.config.pop(mixconfig.config.name())
delete_session(storage=self.save_storage,
session_id=mixconfig.config.name())
del mixconfig
for parent in metaconfig.config.parents():
parent.config.pop(metaconfig.config.name())
delete_session(storage=self.save_storage,
session_id=metaconfig.config.name())
return children
@register('v1.servermodel.updated') @register('v1.servermodel.updated')
async def servermodel_updated(self, async def servermodel_updated(self,
risotto_context: Context, risotto_context: Context,
servermodels) -> None: servermodelid: int,
for servermodel in servermodels: servermodelname: str,
servermodelid = servermodel['servermodelid'] servermodelparentsid: List[int]) -> None:
servermodelname = servermodel['servermodelname'] log.info_msg(risotto_context,
servermodelparentsid = servermodel.get('servermodelparentsid') None,
log.info_msg(risotto_context, f'Reload servermodel {servermodelname} ({servermodelid})')
None, # unlink cache to force download new aggregated file
f'Reload servermodel {servermodelname} ({servermodelid})') cache_file = join(ROOT_CACHE_DIR, str(servermodelid)+".xml")
# unlink cache to force download new aggregated file if isfile(cache_file):
cache_file = join(ROOT_CACHE_DIR, str(servermodelid)+".xml") unlink(cache_file)
if isfile(cache_file):
unlink(cache_file)
# get current servermodel # store all informations
old_servermodel = self.servermodel[servermodelid] if servermodelid in self.servermodel:
old_values = self.servermodel[servermodelid].value.exportation()
old_permissives = self.servermodel[servermodelid].permissive.exportation()
old_properties = self.servermodel[servermodelid].property.exportation()
children = self.servermodel_delete(servermodelid)
else:
old_values = None
# create new one # create new one
await self.load_servermodel(servermodelid, servermodelname) await self.load_and_link_servermodel(risotto_context,
servermodelid,
servermodelname,
servermodelparentsid)
# migrate all informations # migrates informations
self.servermodel[servermodelid].value.importation(old_servermodel.value.exportation()) if old_values is not None:
self.servermodel[servermodelid].permissive.importation(old_servermodel.permissive.exportation()) self.servermodel[servermodelid].value.importation(old_values)
self.servermodel[servermodelid].property.importation(old_servermodel.property.exportation()) self.servermodel[servermodelid].permissive.importation(old_permissives)
self.servermodel[servermodelid].property.importation(old_properties)
# remove link to legacy for child in children:
if servermodelparentsid: self.servermodel_legacy(risotto_context,
for servermodelparentid in servermodelparentsid: child.information.get('servermodel_name'),
mix = self.servermodel[servermodelparentid].config.get('m_v_' + str(servermodelparentid)) child.information.get('servermodel_id'),
try: servermodelid)
mix.config.pop(old_servermodel.config.name())
except:
# if mix config is reloaded too
pass
# add new link
self.servermodel_legacy(servermodelname, servermodelid, servermodelparentid)
# reload servers or servermodels in servermodel
for subconfig in old_servermodel.config.list():
if not isinstance(subconfig, MixConfig):
# a server
name = subconfig.config.name()
if name.startswith('str_'):
continue
server_id = subconfig.information.get('server_id')
server_name = subconfig.information.get('server_name')
try:
old_servermodel.config.pop(name)
old_servermodel.config.pop(f'std_{server_id}')
except:
pass
del self.server[server_id]
self.load_server(risotto_context,
server_id,
server_name,
servermodelid)
else:
# a servermodel
for subsubconfig in subconfig.config.list():
name = subsubconfig.config.name()
try:
subconfig.config.pop(name)
except:
pass
self.servermodel_legacy(subsubconfig.information.get('servermodel_name'),
subsubconfig.information.get('servermodel_id'),
servermodelid)
@register('v1.config.configuration.server.get', None) @register('v1.config.configuration.server.get', None)
async def get_configuration(self, async def get_configuration(self,
server_id: int, server_id: int,
deploy: bool) -> bytes: deployed: bool) -> bytes:
if server_id not in self.server: if server_id not in self.server:
msg = _(f'cannot find server with id {server_id}') msg = _(f'cannot find server with id {server_id}')
log.error_msg(risotto_context, log.error_msg(risotto_context,
@ -384,16 +393,16 @@ class Risotto(Controller):
msg) msg)
raise CallError(msg) raise CallError(msg)
if deploy: if deployed:
server = self.server[server_id]['server'] server = self.server[server_id]['server']
else: else:
server = self.server[server_id]['server_to_deploy'] server = self.server[server_id]['server_to_deploy']
server.property.read_only() server.property.read_only()
try: try:
dico = server.value.dict(fullpath=True) configuration = server.value.dict(fullpath=True)
except: except:
if deploy: if deployed:
msg = _(f'No configuration available for server {server_id}') msg = _(f'No configuration available for server {server_id}')
else: else:
msg = _(f'No undeployed configuration available for server {server_id}') msg = _(f'No undeployed configuration available for server {server_id}')
@ -401,8 +410,9 @@ class Risotto(Controller):
None, None,
msg) msg)
raise CallError(msg) raise CallError(msg)
return dumps(dico).encode() return {'server_id': server_id,
'deployed': deployed,
'configuration': configuration}
@register('v1.config.configuration.server.deploy', 'v1.config.configuration.server.updated') @register('v1.config.configuration.server.deploy', 'v1.config.configuration.server.updated')
async def deploy_configuration(self, async def deploy_configuration(self,
@ -428,4 +438,4 @@ class Risotto(Controller):
config.property.importation(config_std.property.exportation()) config.property.importation(config_std.property.exportation())
return {'server_id': server_id, return {'server_id': server_id,
'deploy': True} 'deployed': True}

View File

@ -1,7 +1,6 @@
from os import urandom # , unlink from os import urandom # , unlink
from binascii import hexlify from binascii import hexlify
from traceback import print_exc from traceback import print_exc
from json import dumps
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
from tiramisu import Storage from tiramisu import Storage
@ -28,6 +27,7 @@ class Risotto(Controller):
return storage_servermodel return storage_servermodel
def get_session(self, def get_session(self,
risotto_context: Context,
session_id: str, session_id: str,
type: str) -> Dict: type: str) -> Dict:
""" Get session information from storage """ Get session information from storage
@ -36,7 +36,8 @@ class Risotto(Controller):
storage = storage_server storage = storage_server
else: else:
storage = storage_servermodel storage = storage_servermodel
return storage.get_session(session_id) return storage.get_session(session_id,
risotto_context.username)
def get_session_informations(self, def get_session_informations(self,
risotto_context: Context, risotto_context: Context,
@ -44,9 +45,9 @@ class Risotto(Controller):
type: str) -> Dict: type: str) -> Dict:
""" format session with a session ID name """ format session with a session ID name
""" """
session = self.get_session(session_id, session = self.get_session(risotto_context,
type, session_id,
risotto_context.username) type)
return self.format_session(session_id, return self.format_session(session_id,
session) session)
@ -104,16 +105,16 @@ class Risotto(Controller):
self.modify_storage) self.modify_storage)
# return session's information # return session's information
return self.get_session_informations(session_id, return self.get_session_informations(risotto_context,
session_id,
type) type)
@register(['v1.session.server.list', 'v1.session.servermodel.list'], None) @register(['v1.session.server.list', 'v1.session.servermodel.list'], None)
async def list_session_server(self, async def list_session_server(self,
risotto_context: Context) -> Dict: risotto_context: Context) -> Dict:
type = risotto_context.message.rsplit('.', 2)[-2] type = risotto_context.message.rsplit('.', 2)[-2]
storage = self.get_storage(type, storage = self.get_storage(type)
risotto_context.username) return [self.format_session(session_id, session) for session_id, session in storage.get_sessions().items()]
return [self.format_session(session_id, session) or session_id, session in storage.get_sessions().items()]
@register(['v1.session.server.filter', 'v1.session.servermodel.filter'], None) @register(['v1.session.server.filter', 'v1.session.servermodel.filter'], None)
@ -127,7 +128,7 @@ class Risotto(Controller):
storage = self.get_storage(type) storage = self.get_storage(type)
# to validate the session right # to validate the session right
storage.get_session(session_id, storage.get_session(session_id,
username) risotto_context.username)
if namespace is not None: if namespace is not None:
storage.set_namespace(session_id, storage.set_namespace(session_id,
namespace) namespace)
@ -139,7 +140,8 @@ class Risotto(Controller):
if debug is not None: if debug is not None:
storage.set_config_debug(session_id, storage.set_config_debug(session_id,
debug) debug)
return self.get_session_informations(session_id, return self.get_session_informations(risotto_context,
session_id,
type) type)
@register(['v1.session.server.configure', 'v1.session.servermodel.configure'], None) @register(['v1.session.server.configure', 'v1.session.servermodel.configure'], None)
@ -152,25 +154,25 @@ class Risotto(Controller):
value: Any, value: Any,
value_multi: Optional[List]) -> Dict: value_multi: Optional[List]) -> Dict:
type = risotto_context.message.rsplit('.', 2)[-2] type = risotto_context.message.rsplit('.', 2)[-2]
session = self.get_session(session_id, session = self.get_session(risotto_context,
type, session_id,
risotto_context.username) type)
# if multi and not follower the value is in fact in value_multi # if multi and not follower the value is in fact in value_multi
option = session['config'].option(name).option option = session['option'].option(name).option
if option.ismulti() and not option.isfollower(): if option.ismulti() and not option.isfollower():
value = value_multi value = value_multi
try: namespace = session['namespace']
update = {'name': name, update = {'name': f'{namespace}.{name}',
'action': action, 'action': action,
'value': value} 'value': value}
if index is not None: if index is not None:
update['index'] = index update['index'] = index
updates = {'updates': [update]} updates = {'updates': [update]}
session['option'].updates(updates) ret = session['option'].updates(updates)
except Exception as err: if update['name'] in ret:
if DEBUG: for val in ret[update['name']][index]:
print_exc() if isinstance(val, ValueError):
raise CallError(str(err)) raise CallError(val)
ret = {'session_id': session_id, ret = {'session_id': session_id,
'name': name} 'name': name}
if index is not None: if index is not None:
@ -182,9 +184,9 @@ class Risotto(Controller):
risotto_context: Context, risotto_context: Context,
session_id: str) -> Dict: session_id: str) -> Dict:
type = risotto_context.message.rsplit('.', 2)[-2] type = risotto_context.message.rsplit('.', 2)[-2]
session = self.get_session(session_id, session = self.get_session(risotto_context,
type, session_id,
risotto_context.username) type)
try: try:
session['config'].forcepermissive.option(session['namespace']).value.dict() session['config'].forcepermissive.option(session['namespace']).value.dict()
except Exception as err: except Exception as err:
@ -205,13 +207,17 @@ class Risotto(Controller):
@register(['v1.session.server.get', 'v1.session.servermodel.get'], None) @register(['v1.session.server.get', 'v1.session.servermodel.get'], None)
async def get_session_server(self, async def get_session_server(self,
risotto_context: Context, risotto_context: Context,
session_id: str) -> Dict: session_id: str,
name: Optional[str]) -> Dict:
type = risotto_context.message.rsplit('.', 2)[-2] type = risotto_context.message.rsplit('.', 2)[-2]
session = self.get_session(session_id, session = self.get_session(risotto_context,
type, session_id,
risotto_context.username) type)
info = self.format_session(session_id, session) info = self.format_session(session_id, session)
info['content'] = dumps(session['option'].value.dict(fullpath=True)) if name is not None:
info['content'] = {name: session['option'].option(name).value.get()}
else:
info['content'] = session['option'].value.dict()
return info return info
@register(['v1.session.server.stop', 'v1.session.servermodel.stop'], None) @register(['v1.session.server.stop', 'v1.session.servermodel.stop'], None)

View File

@ -66,7 +66,7 @@ class Storage(object):
mode: str): mode: str):
""" Define which edition mode to select """ Define which edition mode to select
""" """
config = self.session[id]['config'] config = self.sessions[id]['config']
for mode_level in modes.values(): for mode_level in modes.values():
if modes[mode] < mode_level: if modes[mode] < mode_level:
config.property.add(mode_level.name) config.property.add(mode_level.name)
@ -77,7 +77,7 @@ class Storage(object):
def set_config_debug(self, id_, is_debug): def set_config_debug(self, id_, is_debug):
""" Enable/Disable debug mode """ Enable/Disable debug mode
""" """
config = self.session[id_]['config'] config = self.sessions[id_]['config']
if is_debug: if is_debug:
config.property.pop('hidden') config.property.pop('hidden')
else: else:
@ -94,13 +94,14 @@ class Storage(object):
return self.sessions; return self.sessions;
def get_session(self, def get_session(self,
id: int, session_id: int,
username: str) -> Dict: username: str) -> Dict:
if id not in self.sessions: if session_id not in self.sessions:
raise Exception(f'the session {id} not exists') raise Exception(f'the session {id} not exists')
if username != storage.get_session(session_id)['username']: session = self.sessions[session_id]
if username != session['username']:
raise NotAllowedError() raise NotAllowedError()
return self.sessions[id] return session
def del_session(self, def del_session(self,
id: int): id: int):

View File

@ -13,7 +13,6 @@ def setup_module(module):
validate=False) validate=False)
config_module = dispatcher.get_service('config') config_module = dispatcher.get_service('config')
config_module.save_storage = Storage(engine='sqlite3', dir_database=DATABASE_DIR, name='test') config_module.save_storage = Storage(engine='sqlite3', dir_database=DATABASE_DIR, name='test')
config_module.save_persistent = False
dispatcher.set_module('server', import_module(f'.server', 'fake_services')) dispatcher.set_module('server', import_module(f'.server', 'fake_services'))
dispatcher.set_module('servermodel', import_module(f'.servermodel', 'fake_services')) dispatcher.set_module('servermodel', import_module(f'.servermodel', 'fake_services'))

501
tests/test_session.py Normal file
View File

@ -0,0 +1,501 @@
from importlib import import_module
import pytest
from tiramisu import Storage
from risotto.context import Context
from risotto.services import load_services
from risotto.dispatcher import dispatcher
from risotto.config import DATABASE_DIR
from risotto.services.session.storage import storage_server, storage_servermodel
def get_fake_context(module_name):
risotto_context = Context()
risotto_context.username = 'test'
risotto_context.paths.append(f'{module_name}.on_join')
risotto_context.type = None
return risotto_context
def setup_module(module):
load_services(['config', 'session'],
validate=False)
config_module = dispatcher.get_service('config')
config_module.save_storage = Storage(engine='sqlite3', dir_database=DATABASE_DIR, name='test')
dispatcher.set_module('server', import_module(f'.server', 'fake_services'))
dispatcher.set_module('servermodel', import_module(f'.servermodel', 'fake_services'))
def teardown_function(function):
config_module = dispatcher.get_service('session')
storage_server.sessions = {}
storage_servermodel.sessions = {}
@pytest.mark.asyncio
async def test_server_start():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.server:
await config_module.on_join(fake_context)
await dispatcher.call('v1',
'session.server.start',
fake_context,
id=3)
@pytest.mark.asyncio
async def test_server_list():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.server:
await config_module.on_join(fake_context)
assert not await dispatcher.call('v1',
'session.server.list',
fake_context)
await dispatcher.call('v1',
'session.server.start',
fake_context,
id=3)
assert await dispatcher.call('v1',
'session.server.list',
fake_context)
@pytest.mark.asyncio
async def test_server_filter_namespace():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.server:
await config_module.on_join(fake_context)
session = await dispatcher.call('v1',
'session.server.start',
fake_context,
id=3)
session_id = session['session_id']
namespace = 'containers'
await dispatcher.call('v1',
'session.server.filter',
fake_context,
session_id=session_id,
namespace=namespace)
list_result = await dispatcher.call('v1',
'session.server.list',
fake_context)
assert list_result[0]['namespace'] == namespace
@pytest.mark.asyncio
async def test_server_filter_mode():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.server:
await config_module.on_join(fake_context)
session = await dispatcher.call('v1',
'session.server.start',
fake_context,
id=3)
session_id = session['session_id']
assert session['mode'] == 'normal'
mode = 'expert'
await dispatcher.call('v1',
'session.server.filter',
fake_context,
session_id=session_id,
mode=mode)
list_result = await dispatcher.call('v1',
'session.server.list',
fake_context)
assert list_result[0]['mode'] == mode
@pytest.mark.asyncio
async def test_server_filter_debug():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.server:
await config_module.on_join(fake_context)
session = await dispatcher.call('v1',
'session.server.start',
fake_context,
id=3)
session_id = session['session_id']
assert session['debug'] == False
debug = True
await dispatcher.call('v1',
'session.server.filter',
fake_context,
session_id=session_id,
debug=debug)
list_result = await dispatcher.call('v1',
'session.server.list',
fake_context)
assert list_result[0]['debug'] == debug
#FIXME
@pytest.mark.asyncio
async def test_server_filter_get():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.server:
await config_module.on_join(fake_context)
session = await dispatcher.call('v1',
'session.server.start',
fake_context,
id=3)
session_id = session['session_id']
values = await dispatcher.call('v1',
'session.server.get',
fake_context,
session_id=session_id)
assert values == {'content': {"general.mode_conteneur_actif": "non",
"general.master.master": [],
"general.master.slave1": [],
"general.master.slave2": []},
'debug': False,
'id': 3,
'mode': 'normal',
'namespace': 'creole',
'session_id': session_id,
'timestamp': values['timestamp'],
'username': 'test'}
@pytest.mark.asyncio
async def test_server_filter_get_one_value():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.server:
await config_module.on_join(fake_context)
session = await dispatcher.call('v1',
'session.server.start',
fake_context,
id=3)
session_id = session['session_id']
values = await dispatcher.call('v1',
'session.server.get',
fake_context,
session_id=session_id,
name="general.mode_conteneur_actif")
assert values == {'content': {"general.mode_conteneur_actif": "non"},
'debug': False,
'id': 3,
'mode': 'normal',
'namespace': 'creole',
'session_id': session_id,
'timestamp': values['timestamp'],
'username': 'test'}
@pytest.mark.asyncio
async def test_server_filter_configure():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.server:
await config_module.on_join(fake_context)
session = await dispatcher.call('v1',
'session.server.start',
fake_context,
id=3)
session_id = session['session_id']
await dispatcher.call('v1',
'session.server.configure',
fake_context,
session_id=session_id,
action='modify',
name='general.mode_conteneur_actif',
value='oui')
list_result = await dispatcher.call('v1',
'session.server.list',
fake_context)
values = await dispatcher.call('v1',
'session.server.get',
fake_context,
session_id=session_id,
name="general.mode_conteneur_actif")
assert values == {'content': {"general.mode_conteneur_actif": "oui"},
'debug': False,
'id': 3,
'mode': 'normal',
'namespace': 'creole',
'session_id': session_id,
'timestamp': values['timestamp'],
'username': 'test'}
@pytest.mark.asyncio
async def test_server_filter_validate():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.server:
await config_module.on_join(fake_context)
session = await dispatcher.call('v1',
'session.server.start',
fake_context,
id=3)
session_id = session['session_id']
await dispatcher.call('v1',
'session.server.validate',
fake_context,
session_id=session_id)
@pytest.mark.asyncio
async def test_server_stop():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.server:
await config_module.on_join(fake_context)
assert not await dispatcher.call('v1',
'session.server.list',
fake_context)
start = await dispatcher.call('v1',
'session.server.start',
fake_context,
id=3)
session_id = start['session_id']
assert await dispatcher.call('v1',
'session.server.list',
fake_context)
await dispatcher.call('v1',
'session.server.stop',
fake_context,
session_id=session_id)
assert not await dispatcher.call('v1',
'session.server.list',
fake_context)
# servermodel
@pytest.mark.asyncio
async def test_servermodel_start():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.servermodel:
await config_module.on_join(fake_context)
await dispatcher.call('v1',
'session.servermodel.start',
fake_context,
id=1)
@pytest.mark.asyncio
async def test_servermodel_list():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.servermodel:
await config_module.on_join(fake_context)
assert not await dispatcher.call('v1',
'session.servermodel.list',
fake_context)
await dispatcher.call('v1',
'session.servermodel.start',
fake_context,
id=1)
assert await dispatcher.call('v1',
'session.servermodel.list',
fake_context)
@pytest.mark.asyncio
async def test_servermodel_filter_namespace():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.servermodel:
await config_module.on_join(fake_context)
session = await dispatcher.call('v1',
'session.servermodel.start',
fake_context,
id=1)
session_id = session['session_id']
namespace = 'containers'
await dispatcher.call('v1',
'session.servermodel.filter',
fake_context,
session_id=session_id,
namespace=namespace)
list_result = await dispatcher.call('v1',
'session.servermodel.list',
fake_context)
assert list_result[0]['namespace'] == namespace
@pytest.mark.asyncio
async def test_servermodel_filter_mode():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.servermodel:
await config_module.on_join(fake_context)
session = await dispatcher.call('v1',
'session.servermodel.start',
fake_context,
id=1)
session_id = session['session_id']
assert session['mode'] == 'normal'
mode = 'expert'
await dispatcher.call('v1',
'session.servermodel.filter',
fake_context,
session_id=session_id,
mode=mode)
list_result = await dispatcher.call('v1',
'session.servermodel.list',
fake_context)
assert list_result[0]['mode'] == mode
@pytest.mark.asyncio
async def test_servermodel_filter_debug():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.servermodel:
await config_module.on_join(fake_context)
session = await dispatcher.call('v1',
'session.servermodel.start',
fake_context,
id=1)
session_id = session['session_id']
assert session['debug'] == False
debug = True
await dispatcher.call('v1',
'session.servermodel.filter',
fake_context,
session_id=session_id,
debug=debug)
list_result = await dispatcher.call('v1',
'session.servermodel.list',
fake_context)
assert list_result[0]['debug'] == debug
@pytest.mark.asyncio
async def test_servermodel_filter_get():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.servermodel:
await config_module.on_join(fake_context)
session = await dispatcher.call('v1',
'session.servermodel.start',
fake_context,
id=1)
session_id = session['session_id']
values = await dispatcher.call('v1',
'session.servermodel.get',
fake_context,
session_id=session_id)
assert values == {'content': {"general.mode_conteneur_actif": "non",
"general.master.master": [],
"general.master.slave1": [],
"general.master.slave2": []},
'debug': False,
'id': 1,
'mode': 'normal',
'namespace': 'creole',
'session_id': session_id,
'timestamp': values['timestamp'],
'username': 'test'}
@pytest.mark.asyncio
async def test_servermodel_filter_get_one_value():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.servermodel:
await config_module.on_join(fake_context)
session = await dispatcher.call('v1',
'session.servermodel.start',
fake_context,
id=1)
session_id = session['session_id']
values = await dispatcher.call('v1',
'session.servermodel.get',
fake_context,
session_id=session_id,
name="general.mode_conteneur_actif")
assert values == {'content': {"general.mode_conteneur_actif": "non"},
'debug': False,
'id': 1,
'mode': 'normal',
'namespace': 'creole',
'session_id': session_id,
'timestamp': values['timestamp'],
'username': 'test'}
@pytest.mark.asyncio
async def test_servermodel_filter_configure():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.servermodel:
await config_module.on_join(fake_context)
session = await dispatcher.call('v1',
'session.servermodel.start',
fake_context,
id=1)
session_id = session['session_id']
await dispatcher.call('v1',
'session.servermodel.configure',
fake_context,
session_id=session_id,
action='modify',
name='general.mode_conteneur_actif',
value='oui')
list_result = await dispatcher.call('v1',
'session.servermodel.list',
fake_context)
values = await dispatcher.call('v1',
'session.servermodel.get',
fake_context,
session_id=session_id,
name="general.mode_conteneur_actif")
assert values == {'content': {"general.mode_conteneur_actif": "oui"},
'debug': False,
'id': 1,
'mode': 'normal',
'namespace': 'creole',
'session_id': session_id,
'timestamp': values['timestamp'],
'username': 'test'}
@pytest.mark.asyncio
async def test_servermodel_filter_validate():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.servermodel:
await config_module.on_join(fake_context)
session = await dispatcher.call('v1',
'session.servermodel.start',
fake_context,
id=1)
session_id = session['session_id']
await dispatcher.call('v1',
'session.servermodel.validate',
fake_context,
session_id=session_id)
@pytest.mark.asyncio
async def test_servermodel_stop():
fake_context = get_fake_context('session')
config_module = dispatcher.get_service('config')
if not config_module.servermodel:
await config_module.on_join(fake_context)
assert not await dispatcher.call('v1',
'session.servermodel.list',
fake_context)
start = await dispatcher.call('v1',
'session.servermodel.start',
fake_context,
id=1)
session_id = start['session_id']
assert await dispatcher.call('v1',
'session.servermodel.list',
fake_context)
await dispatcher.call('v1',
'session.servermodel.stop',
fake_context,
session_id=session_id)
assert not await dispatcher.call('v1',
'session.servermodel.list',
fake_context)