2019-12-13 13:55:30 +01:00
|
|
|
import asyncpg
|
|
|
|
import asyncio
|
2020-03-10 14:01:21 +01:00
|
|
|
from os import listdir
|
|
|
|
from os.path import isdir, join
|
2020-03-06 07:33:20 +01:00
|
|
|
from sys import exit
|
2019-12-13 13:55:30 +01:00
|
|
|
|
|
|
|
|
2020-03-13 12:24:33 +01:00
|
|
|
from risotto.utils import _
|
2020-03-04 15:15:07 +01:00
|
|
|
from risotto.config import get_config
|
2019-12-27 15:09:38 +01:00
|
|
|
|
2019-12-13 13:55:30 +01:00
|
|
|
|
|
|
|
async def main():
|
2020-03-10 14:01:21 +01:00
|
|
|
sql_dir = get_config()['global']['sql_dir']
|
|
|
|
if not isdir(sql_dir):
|
2020-03-04 15:15:07 +01:00
|
|
|
print('no sql file to import')
|
|
|
|
exit()
|
2020-01-30 16:22:06 +01:00
|
|
|
db_conf = get_config()['database']['dsn']
|
|
|
|
pool = await asyncpg.create_pool(db_conf)
|
2019-12-13 13:55:30 +01:00
|
|
|
async with pool.acquire() as connection:
|
|
|
|
async with connection.transaction():
|
2020-03-10 14:01:21 +01:00
|
|
|
for filename in listdir(sql_dir):
|
|
|
|
if filename.endswith('.sql'):
|
|
|
|
sql_filename = join(sql_dir, filename)
|
|
|
|
with open(sql_filename, 'r') as sql:
|
2020-03-13 12:24:33 +01:00
|
|
|
try:
|
|
|
|
await connection.execute(sql.read())
|
|
|
|
except Exception as err:
|
|
|
|
print(_(f'unable to import {filename}: {err}'))
|
|
|
|
exit(1)
|
2020-03-10 14:01:21 +01:00
|
|
|
|
2019-12-13 16:42:10 +01:00
|
|
|
|
2019-12-19 12:25:16 +01:00
|
|
|
if __name__ == '__main__':
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
loop.run_until_complete(main())
|