package psql

import (
	"database/sql"
	"fmt"
	"strconv"
	"strings"

	"github.com/jackc/pgtype"
)

type DBInfo struct {
	Version   int
	Tables    []DBTable
	Columns   [][]DBColumn
	Functions []DBFunction
	colMap    map[string]map[string]*DBColumn
}

func GetDBInfo(db *sql.DB) (*DBInfo, error) {
	di := &DBInfo{}
	var version string

	err := db.QueryRow(`SHOW server_version_num`).Scan(&version)
	if err != nil {
		return nil, fmt.Errorf("error fetching version: %w", err)
	}

	di.Version, err = strconv.Atoi(version)
	if err != nil {
		return nil, err
	}

	di.Tables, err = GetTables(db)
	if err != nil {
		return nil, err
	}

	for _, t := range di.Tables {
		cols, err := GetColumns(db, "public", t.Name)
		if err != nil {
			return nil, err
		}

		di.Columns = append(di.Columns, cols)
	}

	di.colMap = newColMap(di.Tables, di.Columns)

	di.Functions, err = GetFunctions(db)
	if err != nil {
		return nil, err
	}

	return di, nil
}

func newColMap(tables []DBTable, columns [][]DBColumn) map[string]map[string]*DBColumn {
	cm := make(map[string]map[string]*DBColumn, len(tables))

	for i, t := range tables {
		cols := columns[i]
		cm[t.Key] = make(map[string]*DBColumn, len(cols))

		for n, c := range cols {
			cm[t.Key][c.Key] = &columns[i][n]
		}
	}

	return cm
}

func (di *DBInfo) AddTable(t DBTable, cols []DBColumn) {
	t.ID = di.Tables[len(di.Tables)-1].ID

	di.Tables = append(di.Tables, t)
	di.colMap[t.Key] = make(map[string]*DBColumn, len(cols))

	for i := range cols {
		cols[i].ID = int16(i)
		c := &cols[i]
		di.colMap[t.Key][c.Key] = c
	}
	di.Columns = append(di.Columns, cols)
}

func (di *DBInfo) GetColumn(table, column string) (*DBColumn, bool) {
	v, ok := di.colMap[strings.ToLower(table)][strings.ToLower(column)]
	return v, ok
}

type DBTable struct {
	ID   int
	Name string
	Key  string
	Type string
}

func GetTables(db *sql.DB) ([]DBTable, error) {
	sqlStmt := `
SELECT
	c.relname as "name",
	CASE c.relkind WHEN 'r' THEN 'table'
		WHEN 'v' THEN 'view'
		WHEN 'm' THEN 'materialized view'
		WHEN 'f' THEN 'foreign table' 
	END as "type"
FROM pg_catalog.pg_class c
	LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind IN ('r','v','m','f','')
	AND n.nspname <> ('pg_catalog')
	AND n.nspname <> ('information_schema')
	AND n.nspname !~ ('^pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid);`

	var tables []DBTable

	rows, err := db.Query(sqlStmt)
	if err != nil {
		return nil, fmt.Errorf("Error fetching tables: %s", err)
	}
	defer rows.Close()

	for i := 0; rows.Next(); i++ {
		t := DBTable{ID: i}
		err = rows.Scan(&t.Name, &t.Type)
		if err != nil {
			return nil, err
		}
		t.Key = strings.ToLower(t.Name)
		if t.Key != "schema_migrations" && t.Key != "ar_internal_metadata" {
			tables = append(tables, t)
		}
	}

	return tables, nil
}

type DBColumn struct {
	ID         int16
	Name       string
	Key        string
	Type       string
	Array      bool
	NotNull    bool
	PrimaryKey bool
	UniqueKey  bool
	FKeyTable  string
	FKeyColID  []int16
	fKeyColID  pgtype.Int2Array
}

func GetColumns(db *sql.DB, schema, table string) ([]DBColumn, error) {
	sqlStmt := `
SELECT  
	f.attnum AS id,  
	f.attname AS name,  
	f.attnotnull AS notnull,  
	pg_catalog.format_type(f.atttypid,f.atttypmod) AS type,  
	CASE
	 WHEN f.attndims != 0 THEN true
	 WHEN right(pg_catalog.format_type(f.atttypid,f.atttypmod), 2) = '[]' THEN true
	 ELSE false
	END AS array,
	CASE  
		WHEN p.contype = ('p'::char) THEN true  
		ELSE false 
	END AS primarykey,  
	CASE  
		WHEN p.contype = ('u'::char) THEN true  
		ELSE false
	END AS uniquekey,
	CASE
		WHEN p.contype = ('f'::char) THEN g.relname 
		ELSE ''::text
	END AS foreignkey,
	CASE
		WHEN p.contype = ('f'::char) THEN p.confkey::int2[]
		ELSE ARRAY[]::int2[]
	END AS foreignkey_fieldnum
FROM pg_attribute f
	JOIN pg_class c ON c.oid = f.attrelid  
	LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum  
	LEFT JOIN pg_namespace n ON n.oid = c.relnamespace  
	LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)  
	LEFT JOIN pg_class AS g ON p.confrelid = g.oid  
WHERE c.relkind IN ('r', 'v', 'm', 'f')
	AND n.nspname = $1  -- Replace with Schema name  
	AND c.relname = $2  -- Replace with table name  
	AND f.attnum > 0
	AND f.attisdropped = false
ORDER BY id;`

	rows, err := db.Query(sqlStmt, schema, table)
	if err != nil {
		return nil, fmt.Errorf("error fetching columns: %s", err)
	}
	defer rows.Close()

	cmap := make(map[int16]DBColumn)

	for rows.Next() {
		c := DBColumn{}

		err = rows.Scan(&c.ID, &c.Name, &c.NotNull, &c.Type, &c.Array, &c.PrimaryKey, &c.UniqueKey, &c.FKeyTable, &c.fKeyColID)
		if err != nil {
			return nil, err
		}

		if v, ok := cmap[c.ID]; ok {
			if c.PrimaryKey {
				v.PrimaryKey = true
				v.UniqueKey = true
			}
			if c.NotNull {
				v.NotNull = true
			}
			if c.UniqueKey {
				v.UniqueKey = true
			}
			if c.Array {
				v.Array = true
			}
			if len(c.FKeyTable) != 0 {
				v.FKeyTable = c.FKeyTable
			}
			if c.fKeyColID.Elements != nil {
				v.fKeyColID = c.fKeyColID
				err := v.fKeyColID.AssignTo(&v.FKeyColID)
				if err != nil {
					return nil, err
				}
			}
			cmap[c.ID] = v
		} else {
			err := c.fKeyColID.AssignTo(&c.FKeyColID)
			if err != nil {
				return nil, err
			}
			c.Key = strings.ToLower(c.Name)
			if c.PrimaryKey {
				c.UniqueKey = true
			}
			cmap[c.ID] = c
		}
	}

	cols := make([]DBColumn, 0, len(cmap))
	for i := range cmap {
		cols = append(cols, cmap[i])
	}

	return cols, nil
}

type DBFunction struct {
	Name   string
	Params []DBFuncParam
}

type DBFuncParam struct {
	ID   int
	Name string
	Type string
}

func GetFunctions(db *sql.DB) ([]DBFunction, error) {
	sqlStmt := `
SELECT 
	routines.routine_name, 
	parameters.specific_name,
	parameters.data_type, 
	parameters.parameter_name,
	parameters.ordinal_position	
FROM 
	information_schema.routines
RIGHT JOIN 
	information_schema.parameters 
	ON (routines.specific_name = parameters.specific_name and parameters.ordinal_position IS NOT NULL)	
WHERE 
	routines.specific_schema = 'public'
ORDER BY 
	routines.routine_name, parameters.ordinal_position;`

	rows, err := db.Query(sqlStmt)
	if err != nil {
		return nil, fmt.Errorf("Error fetching functions: %s", err)
	}
	defer rows.Close()

	var funcs []DBFunction
	fm := make(map[string]int)

	for rows.Next() {
		var fn, fid string
		fp := DBFuncParam{}

		err = rows.Scan(&fn, &fid, &fp.Type, &fp.Name, &fp.ID)
		if err != nil {
			return nil, err
		}

		if i, ok := fm[fid]; ok {
			funcs[i].Params = append(funcs[i].Params, fp)
		} else {
			funcs = append(funcs, DBFunction{Name: fn, Params: []DBFuncParam{fp}})
			fm[fid] = len(funcs) - 1
		}
	}

	return funcs, nil
}

// func GetValType(type string) qcode.ValType {
// 	switch {
// 		case "bigint", "integer", "smallint", "numeric", "bigserial":
// 			return qcode.ValInt
// 		case "double precision", "real":
// 			return qcode.ValFloat
// 		case ""
// 	}
// }