import asyncpg import asyncio from os import listdir from os.path import isdir, join from sys import exit from risotto.utils import _ from risotto.config import get_config async def main(): sql_dir = get_config()['global']['sql_dir'] if not isdir(sql_dir): print('no sql file to import') exit() db_conf = get_config()['database']['dsn'] pool = await asyncpg.create_pool(db_conf) async with pool.acquire() as connection: async with connection.transaction(): for filename in listdir(sql_dir): if filename.endswith('.sql'): sql_filename = join(sql_dir, filename) with open(sql_filename, 'r') as sql: try: await connection.execute(sql.read()) except Exception as err: print(_(f'unable to import {filename}: {err}')) exit(1) if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(main())