SQLAlchemyを用いてCSVからデータベースへのインポート・エクスポート

CSVファイルをデータベースにインポートする

pandas の to_sql() がとても便利です。

df = pd.read_csv(csvfile, encoding='utf8', sep=",", dtype=str)
df.to_sql(tablename, con=engine, if_exists='append', index=False)

テーブルをCSVファイルにエクスポートする

  • Flask-SQLAlchemy を使っている。

まずは、モデルクラス名を取得する。

models = {
    idx: mapper.class_.__name__
    for idx, mapper in enumerate(db.Model.registry.mappers)
}

そして、importlib.import_module()getattr() メソッドによって動的にモデルクラスをインポートする。

これで query.all() によってテーブルのすべてのレコードを取得することができる。

module = import_module('app.models')
for name in target_models:
    model = getattr(module, name)
    records = model.query.all()

CLIツールにする

click を利用して、CLI ツールを作った。

CSV のインポートとエクスポートを分けて、サブコマンド(load と export)として扱っている。

インポートの場合は -i オプションで CSV ファイルを指定する、-t オプションでテーブル名を指定する必要がある。

エクスポートの場合は、カレントディレクトリがデフォルトの出力ディレクトリになっている。

import csv
import click
from pathlib import Path
from importlib import import_module

import pandas as pd

from app import db, create_app

app = create_app()
app.app_context().push()


@click.group()
def cli():
    pass


@cli.command()
@click.option('-i', '--input', required=True)
@click.option('-t', '--tablename', required=True)
def load(input, tablename):
    """load a csv file into database."""
    input_csv = Path(input).absolute()
    if input_csv.is_file() and input_csv.suffix == '.csv':
        df = pd.read_csv(input_csv, encoding='utf8', sep=",", dtype=str)
        df.to_sql(tablename, con=db.engine, if_exists='append', index=False)
        click.echo('Data has been saved to DB.')
    else:
        raise ValueError('Please specify a file with extension .csv!')


@cli.command()
@click.option('-o', '--output', default='.')
def export(output):
    """export database table to csv file."""
    out_dir = Path(output).absolute()

    models = {
        idx: mapper.class_.__name__
        for idx, mapper in enumerate(db.Model.registry.mappers)
    }
    click.echo(f'Models: {models}')

    input_models = click.prompt(
        'Please enter the number of models(separated by space), default is',
        default='all')
    input_models = input_models.split(' ')
    if input_models == ['all']:
        target_models = models.values()
    else:
        target_models = [models.get(int(i)) for i in input_models]

    module = import_module('app.models')
    for name in target_models:
        click.echo(f'Dump {name} to csv...')
        # get model
        model = getattr(module, name)
        # get all records of model
        records = model.query.all()
        # get column names
        columns = [col.name for col in model.__mapper__.columns]

        filename = out_dir / f'{name}.csv'
        csvfile = open(filename, 'w', encoding='utf-8')
        outcsv = csv.writer(csvfile)

        # write header row
        outcsv.writerow(columns)
        for record in records:
            outcsv.writerow([getattr(record, column) for column in columns])
        csvfile.close()
    click.echo('Done.')


if __name__ == '__main__':
    cli()