sql_filename => sql_dir
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
import asyncpg
|
||||
import asyncio
|
||||
from os.path import isfile
|
||||
from os import listdir
|
||||
from os.path import isdir, join
|
||||
from sys import exit
|
||||
|
||||
|
||||
@ -8,16 +9,20 @@ from risotto.config import get_config
|
||||
|
||||
|
||||
async def main():
|
||||
sql_filename = get_config()['global']['sql_filename']
|
||||
if not isfile(sql_filename):
|
||||
sql_dir = get_config()['global']['sql_dir']
|
||||
if not isdir(sql_dir):
|
||||
print('no sql file to import')
|
||||
exit()
|
||||
db_conf = get_config()['database']['dsn']
|
||||
pool = await asyncpg.create_pool(db_conf)
|
||||
async with pool.acquire() as connection:
|
||||
async with connection.transaction():
|
||||
with open(sql_filename, 'r') as sql:
|
||||
await connection.execute(sql.read())
|
||||
for filename in listdir(sql_dir):
|
||||
if filename.endswith('.sql'):
|
||||
sql_filename = join(sql_dir, filename)
|
||||
with open(sql_filename, 'r') as sql:
|
||||
await connection.execute(sql.read())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
loop = asyncio.get_event_loop()
|
||||
|
Reference in New Issue
Block a user