35 lines
1.0 KiB
Python
35 lines
1.0 KiB
Python
import asyncpg
|
|
import asyncio
|
|
from os import listdir
|
|
from os.path import isdir, join
|
|
from sys import exit
|
|
|
|
|
|
from risotto.utils import _
|
|
from risotto.config import get_config
|
|
|
|
|
|
async def main():
|
|
sql_dir = get_config()['global']['sql_dir']
|
|
if not isdir(sql_dir):
|
|
print('no sql file to import')
|
|
exit()
|
|
db_conf = get_config()['database']['dsn']
|
|
pool = await asyncpg.create_pool(db_conf)
|
|
async with pool.acquire() as connection:
|
|
async with connection.transaction():
|
|
for filename in listdir(sql_dir):
|
|
if filename.endswith('.sql'):
|
|
sql_filename = join(sql_dir, filename)
|
|
with open(sql_filename, 'r') as sql:
|
|
try:
|
|
await connection.execute(sql.read())
|
|
except Exception as err:
|
|
print(_(f'unable to import {filename}: {err}'))
|
|
exit(1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
loop = asyncio.get_event_loop()
|
|
loop.run_until_complete(main())
|