Add config driven custom table relationships

This commit is contained in:
Vikram Rangnekar 2019-12-09 01:48:18 -05:00
parent 679dd1fc83
commit 0e16eee93b
18 changed files with 790 additions and 500 deletions

View File

@ -134,6 +134,12 @@ tables:
- name: deals - name: deals
table: products table: products
- name: users
columns:
- name: email
foreign_key: products.name
roles_query: "SELECT * FROM users WHERE id = $user_id" roles_query: "SELECT * FROM users WHERE id = $user_id"
roles: roles:

View File

@ -1173,6 +1173,23 @@ Even tracing data is availble in the Super Graph web UI if tracing is enabled in
![Query Tracing](/tracing.png "Super Graph Web UI Query Tracing") ![Query Tracing](/tracing.png "Super Graph Web UI Query Tracing")
## Configure Database Relationships
In most cases you don't need this configuration, Super Graph will discover and learn
the relationship graph within your database automatically. It does this using `Foreign Key` relationships that you have defined in your database schema.
The below configs are only needed in special cases such as when you don't use foreign keys or when you want to create a relationship between two tables where a foreign key is not defined or cannot be defined.
For example in the sample below a relationship is defined between the `tags` column on the `posts` table with the `slug` column on the `tags` table. This cannot be defined as using foreign keys since the `tags` column is of type array `text[]` and Postgres for one does not allow foreign keys with array columns.
```yaml
tables:
- name: posts
columns:
- name: tags
related_to: tags.slug
```
## Configuration files ## Configuration files

View File

@ -112,15 +112,15 @@ func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w io.Writer
root := &qc.Selects[0] root := &qc.Selects[0]
i := 0 i := 0
for _, cn := range ti.ColumnNames { for _, cn := range ti.Columns {
if _, ok := jt[cn]; !ok { if _, ok := jt[cn.Key]; !ok {
continue continue
} }
if _, ok := root.PresetMap[cn]; ok { if _, ok := root.PresetMap[cn.Key]; ok {
continue continue
} }
if len(root.Allowed) != 0 { if len(root.Allowed) != 0 {
if _, ok := root.Allowed[cn]; !ok { if _, ok := root.Allowed[cn.Key]; !ok {
continue continue
} }
} }
@ -128,7 +128,7 @@ func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w io.Writer
io.WriteString(c.w, `, `) io.WriteString(c.w, `, `)
} }
io.WriteString(c.w, `"`) io.WriteString(c.w, `"`)
io.WriteString(c.w, cn) io.WriteString(c.w, cn.Name)
io.WriteString(c.w, `"`) io.WriteString(c.w, `"`)
i++ i++
} }
@ -139,7 +139,7 @@ func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w io.Writer
for i := range root.PresetList { for i := range root.PresetList {
cn := root.PresetList[i] cn := root.PresetList[i]
col, ok := ti.Columns[cn] col, ok := ti.ColMap[cn]
if !ok { if !ok {
continue continue
} }
@ -229,6 +229,10 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar) return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
} }
if ti.PrimaryCol == nil {
return 0, fmt.Errorf("no primary key column found")
}
jt, _, err := jsn.Tree(upsert) jt, _, err := jsn.Tree(upsert)
if err != nil { if err != nil {
return 0, err return 0, err
@ -241,23 +245,23 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
io.WriteString(c.w, ` ON CONFLICT (`) io.WriteString(c.w, ` ON CONFLICT (`)
i := 0 i := 0
for _, cn := range ti.ColumnNames { for _, cn := range ti.Columns {
if _, ok := jt[cn]; !ok { if _, ok := jt[cn.Key]; !ok {
continue continue
} }
if col, ok := ti.Columns[cn]; !ok || !(col.UniqueKey || col.PrimaryKey) { if col, ok := ti.ColMap[cn.Key]; !ok || !(col.UniqueKey || col.PrimaryKey) {
continue continue
} }
if i != 0 { if i != 0 {
io.WriteString(c.w, `, `) io.WriteString(c.w, `, `)
} }
io.WriteString(c.w, cn) io.WriteString(c.w, cn.Name)
i++ i++
} }
if i == 0 { if i == 0 {
io.WriteString(c.w, ti.PrimaryCol) io.WriteString(c.w, ti.PrimaryCol.Name)
} }
io.WriteString(c.w, `)`) io.WriteString(c.w, `)`)
@ -272,16 +276,16 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
io.WriteString(c.w, ` DO UPDATE SET `) io.WriteString(c.w, ` DO UPDATE SET `)
i = 0 i = 0
for _, cn := range ti.ColumnNames { for _, cn := range ti.Columns {
if _, ok := jt[cn]; !ok { if _, ok := jt[cn.Key]; !ok {
continue continue
} }
if i != 0 { if i != 0 {
io.WriteString(c.w, `, `) io.WriteString(c.w, `, `)
} }
io.WriteString(c.w, cn) io.WriteString(c.w, cn.Name)
io.WriteString(c.w, ` = EXCLUDED.`) io.WriteString(c.w, ` = EXCLUDED.`)
io.WriteString(c.w, cn) io.WriteString(c.w, cn.Name)
i++ i++
} }

View File

@ -3,6 +3,7 @@ package psql
import ( import (
"log" "log"
"os" "os"
"strings"
"testing" "testing"
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/qcode"
@ -127,61 +128,73 @@ func TestMain(m *testing.M) {
log.Fatal(err) log.Fatal(err)
} }
tables := []*DBTable{ tables := []DBTable{
&DBTable{Name: "customers", Type: "table"}, DBTable{Name: "customers", Type: "table"},
&DBTable{Name: "users", Type: "table"}, DBTable{Name: "users", Type: "table"},
&DBTable{Name: "products", Type: "table"}, DBTable{Name: "products", Type: "table"},
&DBTable{Name: "purchases", Type: "table"}, DBTable{Name: "purchases", Type: "table"},
DBTable{Name: "tags", Type: "table"},
} }
columns := [][]*DBColumn{ columns := [][]DBColumn{
[]*DBColumn{ []DBColumn{
&DBColumn{ID: 1, Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 1, Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: false},
&DBColumn{ID: 2, Name: "full_name", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 2, Name: "full_name", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 3, Name: "phone", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 3, Name: "phone", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 4, Name: "email", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 4, Name: "email", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 5, Name: "encrypted_password", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 5, Name: "encrypted_password", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 6, Name: "reset_password_token", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 6, Name: "reset_password_token", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 7, Name: "reset_password_sent_at", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 7, Name: "reset_password_sent_at", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 8, Name: "remember_created_at", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 8, Name: "remember_created_at", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 9, Name: "created_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 9, Name: "created_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 10, Name: "updated_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}}, DBColumn{ID: 10, Name: "updated_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false}},
[]*DBColumn{ []DBColumn{
&DBColumn{ID: 1, Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 1, Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: false},
&DBColumn{ID: 2, Name: "full_name", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 2, Name: "full_name", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 3, Name: "phone", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 3, Name: "phone", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 4, Name: "avatar", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 4, Name: "avatar", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 5, Name: "email", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 5, Name: "email", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 6, Name: "encrypted_password", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 6, Name: "encrypted_password", Type: "character varying", NotNull: true, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 7, Name: "reset_password_token", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 7, Name: "reset_password_token", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 8, Name: "reset_password_sent_at", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 8, Name: "reset_password_sent_at", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 9, Name: "remember_created_at", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 9, Name: "remember_created_at", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 10, Name: "created_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 10, Name: "created_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 11, Name: "updated_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}}, DBColumn{ID: 11, Name: "updated_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false}},
[]*DBColumn{ []DBColumn{
&DBColumn{ID: 1, Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 1, Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: false},
&DBColumn{ID: 2, Name: "name", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 2, Name: "name", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 3, Name: "description", Type: "text", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 3, Name: "description", Type: "text", NotNull: false, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 4, Name: "price", Type: "numeric(7,2)", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 4, Name: "price", Type: "numeric(7,2)", NotNull: false, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 5, Name: "user_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "users", FKeyColID: []int16{1}}, DBColumn{ID: 5, Name: "user_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "users", FKeyColID: []int16{1}},
&DBColumn{ID: 6, Name: "created_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 6, Name: "created_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 7, Name: "updated_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 7, Name: "updated_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 8, Name: "tsv", Type: "tsvector", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}}, DBColumn{ID: 8, Name: "tsv", Type: "tsvector", NotNull: false, PrimaryKey: false, UniqueKey: false},
[]*DBColumn{ DBColumn{ID: 9, Name: "tags", Type: "text[]", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "tags", FKeyColID: []int16{3}, Array: true}},
&DBColumn{ID: 1, Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, []DBColumn{
&DBColumn{ID: 2, Name: "customer_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "customers", FKeyColID: []int16{1}}, DBColumn{ID: 1, Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: false},
&DBColumn{ID: 3, Name: "product_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "products", FKeyColID: []int16{1}}, DBColumn{ID: 2, Name: "customer_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "customers", FKeyColID: []int16{1}},
&DBColumn{ID: 4, Name: "sale_type", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 3, Name: "product_id", Type: "bigint", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "products", FKeyColID: []int16{1}},
&DBColumn{ID: 5, Name: "quantity", Type: "integer", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 4, Name: "sale_type", Type: "character varying", NotNull: false, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 6, Name: "due_date", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}, DBColumn{ID: 5, Name: "quantity", Type: "integer", NotNull: false, PrimaryKey: false, UniqueKey: false},
&DBColumn{ID: 7, Name: "returned", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false, FKeyTable: "", FKeyColID: []int16(nil)}}, DBColumn{ID: 6, Name: "due_date", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false},
DBColumn{ID: 7, Name: "returned", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, UniqueKey: false}},
[]DBColumn{
DBColumn{ID: 1, Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, UniqueKey: false},
DBColumn{ID: 2, Name: "name", Type: "text", NotNull: false, PrimaryKey: false, UniqueKey: false},
DBColumn{ID: 3, Name: "slug", Type: "text", NotNull: false, PrimaryKey: false, UniqueKey: false}},
}
for i := range tables {
tables[i].Key = strings.ToLower(tables[i].Name)
for n := range columns[i] {
columns[i][n].Key = strings.ToLower(columns[i][n].Name)
}
} }
schema := &DBSchema{ schema := &DBSchema{
ver: 110000, ver: 110000,
t: make(map[string]*DBTableInfo), t: make(map[string]*DBTableInfo),
rm: make(map[string]map[string]*DBRel), rm: make(map[string]map[string]*DBRel),
al: make(map[string]struct{}),
} }
aliases := map[string][]string{ aliases := map[string][]string{
@ -189,7 +202,15 @@ func TestMain(m *testing.M) {
} }
for i, t := range tables { for i, t := range tables {
if err := schema.updateSchema(t, columns[i], aliases); err != nil { err := schema.addTable(t, columns[i], aliases)
if err != nil {
log.Fatal(err)
}
}
for i, t := range tables {
err := schema.updateRelationships(t, columns[i])
if err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }

View File

@ -14,7 +14,6 @@ import (
) )
const ( const (
empty = ""
closeBlock = 500 closeBlock = 500
) )
@ -38,13 +37,17 @@ func (c *Compiler) AddRelationship(child, parent string, rel *DBRel) error {
return c.schema.SetRel(child, parent, rel) return c.schema.SetRel(child, parent, rel)
} }
func (c *Compiler) IDColumn(table string) (string, error) { func (c *Compiler) IDColumn(table string) (*DBColumn, error) {
t, err := c.schema.GetTable(table) ti, err := c.schema.GetTable(table)
if err != nil { if err != nil {
return empty, err return nil, err
} }
return t.PrimaryCol, nil if ti.PrimaryCol == nil {
return nil, fmt.Errorf("no primary key column found")
}
return ti.PrimaryCol, nil
} }
type compilerContext struct { type compilerContext struct {
@ -225,18 +228,16 @@ func (c *compilerContext) processChildren(sel *qcode.Select, ti *DBTableInfo) (u
switch rel.Type { switch rel.Type {
case RelOneToMany: case RelOneToMany:
fallthrough if _, ok := colmap[rel.Right.Col]; !ok {
case RelBelongTo: cols = append(cols, &qcode.Column{Table: ti.Name, Name: rel.Right.Col, FieldName: rel.Right.Col})
if _, ok := colmap[rel.Col2]; !ok {
cols = append(cols, &qcode.Column{Table: ti.Name, Name: rel.Col2, FieldName: rel.Col2})
} }
case RelOneToManyThrough: case RelOneToManyThrough:
if _, ok := colmap[rel.Col1]; !ok { if _, ok := colmap[rel.Left.Col]; !ok {
cols = append(cols, &qcode.Column{Table: ti.Name, Name: rel.Col1, FieldName: rel.Col1}) cols = append(cols, &qcode.Column{Table: ti.Name, Name: rel.Left.Col, FieldName: rel.Left.Col})
} }
case RelRemote: case RelRemote:
if _, ok := colmap[rel.Col1]; !ok { if _, ok := colmap[rel.Left.Col]; !ok {
cols = append(cols, &qcode.Column{Table: ti.Name, Name: rel.Col1, FieldName: rel.Col2}) cols = append(cols, &qcode.Column{Table: ti.Name, Name: rel.Left.Col, FieldName: rel.Right.Col})
} }
skipped |= (1 << uint(id)) skipped |= (1 << uint(id))
@ -400,17 +401,13 @@ func (c *compilerContext) renderJoinByName(table, parent string, id int32) error
} }
//fmt.Fprintf(w, ` LEFT OUTER JOIN "%s" ON (("%s"."%s") = ("%s_%d"."%s"))`, //fmt.Fprintf(w, ` LEFT OUTER JOIN "%s" ON (("%s"."%s") = ("%s_%d"."%s"))`,
//rel.Through, rel.Through, rel.ColT, c.parent.Name, c.parent.ID, rel.Col1) //rel.Through, rel.Through, rel.ColT, c.parent.Name, c.parent.ID, rel.Left.Col)
io.WriteString(c.w, ` LEFT OUTER JOIN "`) io.WriteString(c.w, ` LEFT OUTER JOIN "`)
io.WriteString(c.w, rel.Through) io.WriteString(c.w, rel.Through)
io.WriteString(c.w, `" ON ((`) io.WriteString(c.w, `" ON ((`)
colWithTable(c.w, rel.Through, rel.ColT) colWithTable(c.w, rel.Through, rel.ColT)
io.WriteString(c.w, `) = (`) io.WriteString(c.w, `) = (`)
if id != -1 { colWithTableID(c.w, pt.Name, id, rel.Left.Col)
colWithTableID(c.w, pt.Name, id, rel.Col1)
} else {
colWithTable(c.w, pt.Name, rel.Col1)
}
io.WriteString(c.w, `))`) io.WriteString(c.w, `))`)
return nil return nil
@ -461,9 +458,9 @@ func (c *compilerContext) renderRemoteRelColumns(sel *qcode.Select, ti *DBTableI
io.WriteString(c.w, ", ") io.WriteString(c.w, ", ")
} }
//fmt.Fprintf(w, `"%s_%d"."%s" AS "%s"`, //fmt.Fprintf(w, `"%s_%d"."%s" AS "%s"`,
//c.sel.Name, c.sel.ID, rel.Col1, rel.Col2) //c.sel.Name, c.sel.ID, rel.Left.Col, rel.Right.Col)
colWithTableID(c.w, ti.Name, sel.ID, rel.Col1) colWithTableID(c.w, ti.Name, sel.ID, rel.Left.Col)
alias(c.w, rel.Col2) alias(c.w, rel.Right.Col)
i++ i++
} }
} }
@ -514,7 +511,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
for n, col := range sel.Cols { for n, col := range sel.Cols {
cn := col.Name cn := col.Name
_, isRealCol := ti.Columns[cn] _, isRealCol := ti.ColMap[cn]
if !isRealCol { if !isRealCol {
if isSearch { if isSearch {
@ -525,7 +522,10 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
continue continue
} }
} }
cn = ti.TSVCol if ti.TSVCol == nil {
return errors.New("no ts_vector column found")
}
cn = ti.TSVCol.Name
arg := sel.Args["search"] arg := sel.Args["search"]
if i != 0 { if i != 0 {
@ -741,7 +741,13 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf
func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo) error { func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo) error {
parent := c.s[sel.ParentID] parent := c.s[sel.ParentID]
return c.renderRelationshipByName(ti.Name, parent.Name, parent.ID)
pti, err := c.schema.GetTable(parent.Name)
if err != nil {
return err
}
return c.renderRelationshipByName(ti.Name, pti.Name, parent.ID)
} }
func (c *compilerContext) renderRelationshipByName(table, parent string, id int32) error { func (c *compilerContext) renderRelationshipByName(table, parent string, id int32) error {
@ -750,43 +756,54 @@ func (c *compilerContext) renderRelationshipByName(table, parent string, id int3
return err return err
} }
switch rel.Type { io.WriteString(c.w, `((`)
case RelBelongTo:
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
//c.sel.Name, rel.Col1, c.parent.Name, c.parent.ID, rel.Col2)
io.WriteString(c.w, `((`)
colWithTable(c.w, table, rel.Col1)
io.WriteString(c.w, `) = (`)
if id != -1 {
colWithTableID(c.w, parent, id, rel.Col2)
} else {
colWithTable(c.w, parent, rel.Col2)
}
io.WriteString(c.w, `))`)
switch rel.Type {
case RelOneToMany: case RelOneToMany:
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`, //fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
//c.sel.Name, rel.Col1, c.parent.Name, c.parent.ID, rel.Col2) //c.sel.Name, rel.Left.Col, c.parent.Name, c.parent.ID, rel.Right.Col)
io.WriteString(c.w, `((`)
colWithTable(c.w, table, rel.Col1) switch {
io.WriteString(c.w, `) = (`) case !rel.Left.Array && rel.Right.Array:
if id != -1 { colWithTable(c.w, table, rel.Left.Col)
colWithTableID(c.w, parent, id, rel.Col2) io.WriteString(c.w, `) = any (`)
} else { colWithTableID(c.w, parent, id, rel.Right.Col)
colWithTable(c.w, parent, rel.Col2)
case rel.Left.Array && !rel.Right.Array:
colWithTableID(c.w, parent, id, rel.Right.Col)
io.WriteString(c.w, `) = any (`)
colWithTable(c.w, table, rel.Left.Col)
default:
colWithTable(c.w, table, rel.Left.Col)
io.WriteString(c.w, `) = (`)
colWithTableID(c.w, parent, id, rel.Right.Col)
} }
io.WriteString(c.w, `))`)
case RelOneToManyThrough: case RelOneToManyThrough:
// This requires the through table to be joined onto this select // This requires the through table to be joined onto this select
//fmt.Fprintf(w, `(("%s"."%s") = ("%s"."%s"))`, //fmt.Fprintf(w, `(("%s"."%s") = ("%s"."%s"))`,
//c.sel.Name, rel.Col1, rel.Through, rel.Col2) //c.sel.Name, rel.Left.Col, rel.Through, rel.Right.Col)
io.WriteString(c.w, `((`)
colWithTable(c.w, table, rel.Col1) switch {
io.WriteString(c.w, `) = (`) case !rel.Left.Array && rel.Right.Array:
colWithTable(c.w, rel.Through, rel.Col2) colWithTable(c.w, table, rel.Left.Col)
io.WriteString(c.w, `))`) io.WriteString(c.w, `) = any (`)
colWithTable(c.w, rel.Through, rel.Right.Col)
case rel.Left.Array && !rel.Right.Array:
colWithTable(c.w, rel.Through, rel.Right.Col)
io.WriteString(c.w, `) = any (`)
colWithTable(c.w, table, rel.Left.Col)
default:
colWithTable(c.w, table, rel.Left.Col)
io.WriteString(c.w, `) = (`)
colWithTable(c.w, rel.Through, rel.Right.Col)
}
} }
io.WriteString(c.w, `))`)
return nil return nil
} }
@ -908,7 +925,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
} }
if len(ex.Col) != 0 { if len(ex.Col) != 0 {
if col, ok = ti.Columns[ex.Col]; !ok { if col, ok = ti.ColMap[ex.Col]; !ok {
return fmt.Errorf("no column '%s' found ", ex.Col) return fmt.Errorf("no column '%s' found ", ex.Col)
} }
@ -965,28 +982,23 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
return nil return nil
case qcode.OpEqID: case qcode.OpEqID:
if len(ti.PrimaryCol) == 0 { if ti.PrimaryCol == nil {
return fmt.Errorf("no primary key column defined for %s", ti.Name) return fmt.Errorf("no primary key column defined for %s", ti.Name)
} }
if col, ok = ti.Columns[ti.PrimaryCol]; !ok { col = ti.PrimaryCol
return fmt.Errorf("no primary key column '%s' found ", ti.PrimaryCol)
}
//fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol) //fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol)
io.WriteString(c.w, `((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, ti.PrimaryCol) colWithTable(c.w, ti.Name, ti.PrimaryCol.Name)
//io.WriteString(c.w, ti.PrimaryCol) //io.WriteString(c.w, ti.PrimaryCol)
io.WriteString(c.w, `) =`) io.WriteString(c.w, `) =`)
case qcode.OpTsQuery: case qcode.OpTsQuery:
if len(ti.TSVCol) == 0 { if ti.PrimaryCol == nil {
return fmt.Errorf("no tsv column defined for %s", ti.Name) return fmt.Errorf("no tsv column defined for %s", ti.Name)
} }
if _, ok = ti.Columns[ti.TSVCol]; !ok {
return fmt.Errorf("no tsv column '%s' found ", ti.TSVCol)
}
//fmt.Fprintf(w, `(("%s") @@ websearch_to_tsquery('%s'))`, c.ti.TSVCol, val.Val) //fmt.Fprintf(w, `(("%s") @@ websearch_to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
io.WriteString(c.w, `((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, ti.TSVCol) colWithTable(c.w, ti.Name, ti.TSVCol.Name)
if c.schema.ver >= 110000 { if c.schema.ver >= 110000 {
io.WriteString(c.w, `) @@ websearch_to_tsquery('`) io.WriteString(c.w, `) @@ websearch_to_tsquery('`)
} else { } else {
@ -1003,6 +1015,9 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
if ex.Type == qcode.ValList { if ex.Type == qcode.ValList {
c.renderList(ex) c.renderList(ex)
} else { } else {
if col == nil {
return errors.New("no column found for expression value")
}
c.renderVal(ex, c.vars, col) c.renderVal(ex, c.vars, col)
} }
@ -1179,8 +1194,10 @@ func colWithTable(w io.Writer, table, col string) {
func colWithTableID(w io.Writer, table string, id int32, col string) { func colWithTableID(w io.Writer, table string, id int32, col string) {
io.WriteString(w, `"`) io.WriteString(w, `"`)
io.WriteString(w, table) io.WriteString(w, table)
io.WriteString(w, `_`) if id >= 0 {
int2string(w, id) io.WriteString(w, `_`)
int2string(w, id)
}
io.WriteString(w, `"."`) io.WriteString(w, `"."`)
io.WriteString(w, col) io.WriteString(w, col)
io.WriteString(w, `"`) io.WriteString(w, `"`)

View File

@ -185,7 +185,7 @@ func oneToMany(t *testing.T) {
} }
} }
func belongsTo(t *testing.T) { func oneToManyReverse(t *testing.T) {
gql := `query { gql := `query {
products { products {
name name
@ -208,6 +208,37 @@ func belongsTo(t *testing.T) {
} }
} }
func oneToManyArray(t *testing.T) {
gql := `query {
product {
name
price
tags {
id
name
}
}
tags {
name
product {
name
}
}
}
}`
sql := `SELECT row_to_json("json_root") FROM (SELECT "sel_0"."json_0" AS "tags", "sel_2"."json_2" AS "product" FROM (SELECT row_to_json((SELECT "json_row_2" FROM (SELECT "products_2"."name" AS "name", "products_2"."price" AS "price", "tags_3_join"."json_3" AS "tags") AS "json_row_2")) AS "json_2" FROM (SELECT "products"."name", "products"."price", "products"."tags" FROM "products" LIMIT ('1') :: integer) AS "products_2" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("json_3"), '[]') AS "json_3" FROM (SELECT row_to_json((SELECT "json_row_3" FROM (SELECT "tags_3"."id" AS "id", "tags_3"."name" AS "name") AS "json_row_3")) AS "json_3" FROM (SELECT "tags"."id", "tags"."name" FROM "tags" WHERE ((("tags"."slug") = any ("products_2"."tags"))) LIMIT ('20') :: integer) AS "tags_3" LIMIT ('20') :: integer) AS "json_agg_3") AS "tags_3_join" ON ('true') LIMIT ('1') :: integer) AS "sel_2", (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "tags_0"."name" AS "name", "product_1_join"."json_1" AS "product") AS "json_row_0")) AS "json_0" FROM (SELECT "tags"."name", "tags"."slug" FROM "tags" LIMIT ('20') :: integer) AS "tags_0" LEFT OUTER JOIN LATERAL (SELECT row_to_json((SELECT "json_row_1" FROM (SELECT "products_1"."name" AS "name") AS "json_row_1")) AS "json_1" FROM (SELECT "products"."name" FROM "products" WHERE ((("tags_0"."slug") = any ("products"."tags"))) LIMIT ('1') :: integer) AS "products_1" LIMIT ('1') :: integer) AS "product_1_join" ON ('true') LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0") AS "json_root"`
resSQL, err := compileGQLToPSQL(gql, nil, "admin")
if err != nil {
t.Fatal(err)
}
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func manyToMany(t *testing.T) { func manyToMany(t *testing.T) {
gql := `query { gql := `query {
products { products {
@ -480,8 +511,9 @@ func TestCompileQuery(t *testing.T) {
t.Run("withWhereMultiOr", withWhereMultiOr) t.Run("withWhereMultiOr", withWhereMultiOr)
t.Run("fetchByID", fetchByID) t.Run("fetchByID", fetchByID)
t.Run("searchQuery", searchQuery) t.Run("searchQuery", searchQuery)
t.Run("belongsTo", belongsTo)
t.Run("oneToMany", oneToMany) t.Run("oneToMany", oneToMany)
t.Run("oneToManyReverse", oneToManyReverse)
t.Run("oneToManyArray", oneToManyArray)
t.Run("manyToMany", manyToMany) t.Run("manyToMany", manyToMany)
t.Run("manyToManyReverse", manyToManyReverse) t.Run("manyToManyReverse", manyToManyReverse)
t.Run("aggFunction", aggFunction) t.Run("aggFunction", aggFunction)

286
psql/schema.go Normal file
View File

@ -0,0 +1,286 @@
package psql
import (
"fmt"
"strings"
"github.com/gobuffalo/flect"
"github.com/jackc/pgx/v4/pgxpool"
)
type DBSchema struct {
ver int
t map[string]*DBTableInfo
rm map[string]map[string]*DBRel
}
type DBTableInfo struct {
Name string
Singular bool
Columns []DBColumn
PrimaryCol *DBColumn
TSVCol *DBColumn
ColMap map[string]*DBColumn
ColIDMap map[int16]*DBColumn
}
type RelType int
const (
RelOneToMany RelType = iota + 1
RelOneToManyThrough
RelRemote
)
type DBRel struct {
Type RelType
Through string
ColT string
Left struct {
Col string
Array bool
}
Right struct {
Col string
Array bool
}
}
func NewDBSchema(db *pgxpool.Pool,
info *DBInfo, aliases map[string][]string) (*DBSchema, error) {
schema := &DBSchema{
t: make(map[string]*DBTableInfo),
rm: make(map[string]map[string]*DBRel),
}
for i, t := range info.Tables {
err := schema.addTable(t, info.Columns[i], aliases)
if err != nil {
return nil, err
}
}
for i, t := range info.Tables {
err := schema.updateRelationships(t, info.Columns[i])
if err != nil {
return nil, err
}
}
return schema, nil
}
func (s *DBSchema) addTable(
t DBTable, cols []DBColumn, aliases map[string][]string) error {
colmap := make(map[string]*DBColumn, len(cols))
colidmap := make(map[int16]*DBColumn, len(cols))
singular := flect.Singularize(t.Key)
s.t[singular] = &DBTableInfo{
Name: t.Name,
Singular: true,
Columns: cols,
ColMap: colmap,
ColIDMap: colidmap,
}
plural := flect.Pluralize(t.Key)
s.t[plural] = &DBTableInfo{
Name: t.Name,
Singular: false,
Columns: cols,
ColMap: colmap,
ColIDMap: colidmap,
}
if al, ok := aliases[t.Key]; ok {
for i := range al {
k1 := flect.Singularize(al[i])
s.t[k1] = s.t[singular]
k2 := flect.Pluralize(al[i])
s.t[k2] = s.t[plural]
}
}
for i := range cols {
c := &cols[i]
switch {
case c.Type == "tsvector":
s.t[singular].TSVCol = c
s.t[plural].TSVCol = c
case c.PrimaryKey:
s.t[singular].PrimaryCol = c
s.t[plural].PrimaryCol = c
}
colmap[c.Key] = c
colidmap[c.ID] = c
}
return nil
}
func (s *DBSchema) updateRelationships(t DBTable, cols []DBColumn) error {
jcols := make([]DBColumn, 0, len(cols))
ct := t.Key
cti, ok := s.t[ct]
if !ok {
return fmt.Errorf("invalid foreign key table '%s'", ct)
}
for _, c := range cols {
if len(c.FKeyTable) == 0 || len(c.FKeyColID) == 0 {
continue
}
// Foreign key column name
ft := strings.ToLower(c.FKeyTable)
fcid := c.FKeyColID[0]
ti, ok := s.t[ft]
if !ok {
return fmt.Errorf("invalid foreign key table '%s'", ft)
}
fc, ok := ti.ColIDMap[fcid]
if !ok {
return fmt.Errorf("invalid foreign key column id '%d' for table '%s'",
fcid, ti.Name)
}
// One-to-many relation between current table and the
// table in the foreign key
rel1 := &DBRel{Type: RelOneToMany}
rel1.Left.Col = c.Name
rel1.Left.Array = c.Array
rel1.Right.Col = fc.Name
rel1.Right.Array = fc.Array
if err := s.SetRel(ct, ft, rel1); err != nil {
return err
}
// One-to-many reverse relation between the foreign key table and the
// the current table
rel2 := &DBRel{Type: RelOneToMany}
rel2.Left.Col = fc.Name
rel2.Left.Array = fc.Array
rel2.Right.Col = c.Name
rel2.Right.Array = c.Array
if err := s.SetRel(ft, ct, rel2); err != nil {
return err
}
jcols = append(jcols, c)
}
// If table contains multiple foreign key columns it's a possible
// join table for many-to-many relationships or multiple one-to-many
// relations
// Below one-to-many relations use the current table as the
// join table aka through table.
if len(jcols) > 1 {
for i := range jcols {
for n := range jcols {
if n == i {
continue
}
err := s.updateSchemaOTMT(cti, jcols[i], jcols[n])
if err != nil {
return err
}
}
}
}
return nil
}
func (s *DBSchema) updateSchemaOTMT(
ti *DBTableInfo, col1, col2 DBColumn) error {
t1 := strings.ToLower(col1.FKeyTable)
t2 := strings.ToLower(col2.FKeyTable)
fc1, ok := ti.ColIDMap[col1.FKeyColID[0]]
if !ok {
return fmt.Errorf("invalid foreign key column id '%d' for table '%s'",
col1.FKeyColID[0], ti.Name)
}
fc2, ok := ti.ColIDMap[col2.FKeyColID[0]]
if !ok {
return fmt.Errorf("invalid foreign key column id '%d' for table '%s'",
col2.FKeyColID[0], ti.Name)
}
// One-to-many-through relation between 1nd foreign key table and the
// 2nd foreign key table
rel1 := &DBRel{Type: RelOneToManyThrough}
rel1.Through = ti.Name
rel1.ColT = col2.Name
rel1.Left.Col = fc2.Name
rel1.Right.Col = col1.Name
if err := s.SetRel(t1, t2, rel1); err != nil {
return err
}
// One-to-many-through relation between 2nd foreign key table and the
// 1nd foreign key table
rel2 := &DBRel{Type: RelOneToManyThrough}
rel2.Through = ti.Name
rel2.ColT = col1.Name
rel2.Left.Col = fc1.Name
rel2.Right.Col = col2.Name
if err := s.SetRel(t2, t1, rel2); err != nil {
return err
}
return nil
}
func (s *DBSchema) GetTable(table string) (*DBTableInfo, error) {
t, ok := s.t[table]
if !ok {
return nil, fmt.Errorf("unknown table '%s'", table)
}
return t, nil
}
func (s *DBSchema) SetRel(child, parent string, rel *DBRel) error {
sc := strings.ToLower(flect.Singularize(child))
pc := strings.ToLower(flect.Pluralize(child))
if _, ok := s.rm[sc]; !ok {
s.rm[sc] = make(map[string]*DBRel)
}
if _, ok := s.rm[pc]; !ok {
s.rm[pc] = make(map[string]*DBRel)
}
sp := strings.ToLower(flect.Singularize(parent))
pp := strings.ToLower(flect.Pluralize(parent))
s.rm[sc][sp] = rel
s.rm[sc][pp] = rel
s.rm[pc][sp] = rel
s.rm[pc][pp] = rel
return nil
}
func (s *DBSchema) GetRel(child, parent string) (*DBRel, error) {
rel, ok := s.rm[child][parent]
if !ok {
return nil, fmt.Errorf("unknown relationship '%s' -> '%s'",
child, parent)
}
return rel, nil
}

View File

@ -6,17 +6,75 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/gobuffalo/flect"
"github.com/jackc/pgtype" "github.com/jackc/pgtype"
"github.com/jackc/pgx/v4/pgxpool" "github.com/jackc/pgx/v4/pgxpool"
) )
type DBInfo struct {
Version int
Tables []DBTable
Columns [][]DBColumn
colmap map[string]map[string]*DBColumn
}
func GetDBInfo(db *pgxpool.Pool) (*DBInfo, error) {
di := &DBInfo{}
dbc, err := db.Acquire(context.Background())
if err != nil {
return nil, fmt.Errorf("error acquiring connection from pool: %w", err)
}
defer dbc.Release()
var version string
err = dbc.QueryRow(context.Background(), `SHOW server_version_num`).Scan(&version)
if err != nil {
return nil, fmt.Errorf("error fetching version: %w", err)
}
di.Version, err = strconv.Atoi(version)
if err != nil {
return nil, err
}
di.Tables, err = GetTables(dbc)
if err != nil {
return nil, err
}
di.colmap = make(map[string]map[string]*DBColumn, len(di.Tables))
for i, t := range di.Tables {
cols, err := GetColumns(dbc, "public", t.Name)
if err != nil {
return nil, err
}
di.Columns = append(di.Columns, cols)
di.colmap[t.Key] = make(map[string]*DBColumn, len(cols))
for n, c := range di.Columns[i] {
di.colmap[t.Key][c.Key] = &di.Columns[i][n]
}
}
return di, nil
}
func (di *DBInfo) GetColumn(table, column string) (*DBColumn, bool) {
v, ok := di.colmap[strings.ToLower(table)][strings.ToLower(column)]
return v, ok
}
type DBTable struct { type DBTable struct {
ID int
Name string Name string
Key string
Type string Type string
} }
func GetTables(dbc *pgxpool.Conn) ([]*DBTable, error) { func GetTables(dbc *pgxpool.Conn) ([]DBTable, error) {
sqlStmt := ` sqlStmt := `
SELECT SELECT
c.relname as "name", c.relname as "name",
@ -33,7 +91,7 @@ WHERE c.relkind IN ('r','v','m','f','')
AND n.nspname !~ ('^pg_toast') AND n.nspname !~ ('^pg_toast')
AND pg_catalog.pg_table_is_visible(c.oid);` AND pg_catalog.pg_table_is_visible(c.oid);`
var tables []*DBTable var tables []DBTable
rows, err := dbc.Query(context.Background(), sqlStmt) rows, err := dbc.Query(context.Background(), sqlStmt)
if err != nil { if err != nil {
@ -41,13 +99,14 @@ AND pg_catalog.pg_table_is_visible(c.oid);`
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for i := 0; rows.Next(); i++ {
t := DBTable{} t := DBTable{ID: i}
err = rows.Scan(&t.Name, &t.Type) err = rows.Scan(&t.Name, &t.Type)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tables = append(tables, &t) t.Key = strings.ToLower(t.Name)
tables = append(tables, t)
} }
return tables, nil return tables, nil
@ -56,7 +115,9 @@ AND pg_catalog.pg_table_is_visible(c.oid);`
type DBColumn struct { type DBColumn struct {
ID int16 ID int16
Name string Name string
Key string
Type string Type string
Array bool
NotNull bool NotNull bool
PrimaryKey bool PrimaryKey bool
UniqueKey bool UniqueKey bool
@ -65,13 +126,17 @@ type DBColumn struct {
fKeyColID pgtype.Int2Array fKeyColID pgtype.Int2Array
} }
func GetColumns(dbc *pgxpool.Conn, schema, table string) ([]*DBColumn, error) { func GetColumns(dbc *pgxpool.Conn, schema, table string) ([]DBColumn, error) {
sqlStmt := ` sqlStmt := `
SELECT SELECT
f.attnum AS id, f.attnum AS id,
f.attname AS name, f.attname AS name,
f.attnotnull AS notnull, f.attnotnull AS notnull,
pg_catalog.format_type(f.atttypid,f.atttypmod) AS type, pg_catalog.format_type(f.atttypid,f.atttypmod) AS type,
CASE
WHEN f.attndims != 0 THEN true
ELSE false
END AS array,
CASE CASE
WHEN p.contype = ('p'::char) THEN true WHEN p.contype = ('p'::char) THEN true
ELSE false ELSE false
@ -107,12 +172,11 @@ ORDER BY id;`
} }
defer rows.Close() defer rows.Close()
cmap := make(map[int16]*DBColumn) cmap := make(map[int16]DBColumn)
for rows.Next() { for rows.Next() {
c := DBColumn{} c := DBColumn{}
err = rows.Scan(&c.ID, &c.Name, &c.NotNull, &c.Type, &c.PrimaryKey, &c.UniqueKey, err = rows.Scan(&c.ID, &c.Name, &c.NotNull, &c.Type, &c.Array, &c.PrimaryKey, &c.UniqueKey, &c.FKeyTable, &c.fKeyColID)
&c.FKeyTable, &c.fKeyColID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -127,291 +191,23 @@ ORDER BY id;`
if c.UniqueKey { if c.UniqueKey {
v.UniqueKey = true v.UniqueKey = true
} }
if c.Array {
v.Array = true
}
} else { } else {
err := c.fKeyColID.AssignTo(&c.FKeyColID) err := c.fKeyColID.AssignTo(&c.FKeyColID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cmap[c.ID] = &c c.Key = strings.ToLower(c.Name)
cmap[c.ID] = c
} }
} }
cols := make([]*DBColumn, 0, len(cmap)) cols := make([]DBColumn, 0, len(cmap))
for _, v := range cmap { for _, v := range cmap {
cols = append(cols, v) cols = append(cols, v)
} }
return cols, nil return cols, nil
} }
type DBSchema struct {
ver int
t map[string]*DBTableInfo
rm map[string]map[string]*DBRel
al map[string]struct{}
}
type DBTableInfo struct {
Name string
Singular bool
PrimaryCol string
TSVCol string
Columns map[string]*DBColumn
ColumnNames []string
}
type RelType int
const (
RelBelongTo RelType = iota + 1
RelOneToMany
RelOneToManyThrough
RelRemote
)
type DBRel struct {
Type RelType
Through string
ColT string
Col1 string
Col2 string
}
func NewDBSchema(db *pgxpool.Pool, aliases map[string][]string) (*DBSchema, error) {
schema := &DBSchema{
t: make(map[string]*DBTableInfo),
rm: make(map[string]map[string]*DBRel),
al: make(map[string]struct{}),
}
dbc, err := db.Acquire(context.Background())
if err != nil {
return nil, fmt.Errorf("error acquiring connection from pool: %w", err)
}
defer dbc.Release()
var version string
err = dbc.QueryRow(context.Background(), `SHOW server_version_num`).Scan(&version)
if err != nil {
return nil, fmt.Errorf("error fetching version: %w", err)
}
schema.ver, err = strconv.Atoi(version)
if err != nil {
return nil, err
}
tables, err := GetTables(dbc)
if err != nil {
return nil, err
}
for _, t := range tables {
cols, err := GetColumns(dbc, "public", t.Name)
if err != nil {
return nil, err
}
if err := schema.updateSchema(t, cols, aliases); err != nil {
return nil, err
}
}
return schema, nil
}
func (s *DBSchema) updateSchema(
t *DBTable,
cols []*DBColumn,
aliases map[string][]string) error {
// Foreign key columns in current table
colByID := make(map[int16]*DBColumn)
columns := make(map[string]*DBColumn, len(cols))
colNames := make([]string, 0, len(cols))
for i := range cols {
c := cols[i]
name := strings.ToLower(c.Name)
columns[name] = c
colNames = append(colNames, name)
colByID[c.ID] = c
}
singular := strings.ToLower(flect.Singularize(t.Name))
s.t[singular] = &DBTableInfo{
Name: t.Name,
Singular: true,
Columns: columns,
ColumnNames: colNames,
}
plural := strings.ToLower(flect.Pluralize(t.Name))
s.t[plural] = &DBTableInfo{
Name: t.Name,
Singular: false,
Columns: columns,
ColumnNames: colNames,
}
ct := strings.ToLower(t.Name)
if al, ok := aliases[ct]; ok {
for i := range al {
k1 := flect.Singularize(al[i])
s.t[k1] = s.t[singular]
k2 := flect.Pluralize(al[i])
s.t[k2] = s.t[plural]
s.al[k1] = struct{}{}
s.al[k2] = struct{}{}
}
}
jcols := make([]*DBColumn, 0, len(cols))
for _, c := range cols {
switch {
case c.Type == "tsvector":
s.t[singular].TSVCol = c.Name
s.t[plural].TSVCol = c.Name
case c.PrimaryKey:
s.t[singular].PrimaryCol = c.Name
s.t[plural].PrimaryCol = c.Name
case len(c.FKeyTable) != 0:
if len(c.FKeyColID) == 0 {
continue
}
// Foreign key column name
ft := strings.ToLower(c.FKeyTable)
fc, ok := colByID[c.FKeyColID[0]]
if !ok {
continue
}
// Belongs-to relation between current table and the
// table in the foreign key
rel1 := &DBRel{RelBelongTo, "", "", c.Name, fc.Name}
if err := s.SetRel(ct, ft, rel1); err != nil {
return err
}
// One-to-many relation between the foreign key table and the
// the current table
rel2 := &DBRel{RelOneToMany, "", "", fc.Name, c.Name}
if err := s.SetRel(ft, ct, rel2); err != nil {
return err
}
jcols = append(jcols, c)
}
}
// If table contains multiple foreign key columns it's a possible
// join table for many-to-many relationships or multiple one-to-many
// relations
// Below one-to-many relations use the current table as the
// join table aka through table.
if len(jcols) > 1 {
for i := range jcols {
for n := range jcols {
if n == i {
continue
}
err := s.updateSchemaOTMT(ct, jcols[i], jcols[n], colByID)
if err != nil {
return err
}
}
}
}
return nil
}
func (s *DBSchema) updateSchemaOTMT(
ct string,
col1, col2 *DBColumn,
colByID map[int16]*DBColumn) error {
t1 := strings.ToLower(col1.FKeyTable)
t2 := strings.ToLower(col2.FKeyTable)
fc1, ok := colByID[col1.FKeyColID[0]]
if !ok {
return fmt.Errorf("expected column id '%d' not found", col1.FKeyColID[0])
}
fc2, ok := colByID[col2.FKeyColID[0]]
if !ok {
return fmt.Errorf("expected column id '%d' not found", col2.FKeyColID[0])
}
// One-to-many-through relation between 1nd foreign key table and the
// 2nd foreign key table
//rel1 := &DBRel{RelOneToManyThrough, ct, fc1.Name, col1.Name}
rel1 := &DBRel{RelOneToManyThrough, ct, col2.Name, fc2.Name, col1.Name}
if err := s.SetRel(t1, t2, rel1); err != nil {
return err
}
// One-to-many-through relation between 2nd foreign key table and the
// 1nd foreign key table
//rel2 := &DBRel{RelOneToManyThrough, ct, col2.Name, fc2.Name}
rel2 := &DBRel{RelOneToManyThrough, ct, col1.Name, fc1.Name, col2.Name}
if err := s.SetRel(t2, t1, rel2); err != nil {
return err
}
return nil
}
func (s *DBSchema) GetTable(table string) (*DBTableInfo, error) {
t, ok := s.t[table]
if !ok {
return nil, fmt.Errorf("unknown table '%s'", table)
}
return t, nil
}
func (s *DBSchema) SetRel(child, parent string, rel *DBRel) error {
sc := strings.ToLower(flect.Singularize(child))
pc := strings.ToLower(flect.Pluralize(child))
if _, ok := s.rm[sc]; !ok {
s.rm[sc] = make(map[string]*DBRel)
}
if _, ok := s.rm[pc]; !ok {
s.rm[pc] = make(map[string]*DBRel)
}
sp := strings.ToLower(flect.Singularize(parent))
pp := strings.ToLower(flect.Pluralize(parent))
s.rm[sc][sp] = rel
s.rm[sc][pp] = rel
s.rm[pc][sp] = rel
s.rm[pc][pp] = rel
return nil
}
func (s *DBSchema) GetRel(child, parent string) (*DBRel, error) {
rel, ok := s.rm[child][parent]
if !ok {
return nil, fmt.Errorf("unknown relationship '%s' -> '%s'",
child, parent)
}
return rel, nil
}
func (s *DBSchema) IsAlias(name string) bool {
_, ok := s.al[name]
return ok
}

View File

@ -295,6 +295,10 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
continue continue
} }
if field.ParentID == -1 {
parentID = -1
}
trv := com.getRole(role, field.Name) trv := com.getRole(role, field.Name)
selects = append(selects, Select{ selects = append(selects, Select{

View File

@ -96,7 +96,7 @@ func initAllowList(cpath string) {
} }
func (al *allowList) add(req *gqlReq) { func (al *allowList) add(req *gqlReq) {
if len(req.ref) == 0 || len(req.Query) == 0 { if al.saveChan == nil || len(req.ref) == 0 || len(req.Query) == 0 {
return return
} }

View File

@ -36,6 +36,10 @@ func argMap(ctx context.Context, vars []byte) func(w io.Writer, tag string) (int
if len(fields) == 0 { if len(fields) == 0 {
return 0, nil return 0, nil
} }
v := fields[0].Value
if len(v) >= 2 && v[0] == '"' && v[len(v)-1] == '"' {
fields[0].Value = v[1 : len(v)-1]
}
return w.Write(fields[0].Value) return w.Write(fields[0].Value)
} }

View File

@ -81,11 +81,17 @@ type config struct {
roles map[string]*configRole roles map[string]*configRole
} }
type configColumn struct {
Name string
ForeignKey string `mapstructure:"related_to"`
}
type configTable struct { type configTable struct {
Name string Name string
Table string Table string
Blocklist []string Blocklist []string
Remotes []configRemote Remotes []configRemote
Columns []configColumn
} }
type configRemote struct { type configRemote struct {
@ -226,6 +232,7 @@ func (c *config) Init(vi *viper.Viper) error {
if _, ok := c.roles[role.Name]; ok { if _, ok := c.roles[role.Name]; ok {
errlog.Fatal().Msgf("duplicate role '%s' found", role.Name) errlog.Fatal().Msgf("duplicate role '%s' found", role.Name)
} }
role.Name = strings.ToLower(role.Name) role.Name = strings.ToLower(role.Name)
role.Match = sanitize(role.Match) role.Match = sanitize(role.Match)
role.tablesMap = make(map[string]*configRoleTable) role.tablesMap = make(map[string]*configRoleTable)
@ -296,6 +303,28 @@ func (c *config) getAliasMap() map[string][]string {
return m return m
} }
func (c *config) isABCLEnabled() bool {
if len(c.RolesQuery) == 0 {
return false
}
switch len(c.Roles) {
case 0, 1:
return false
case 2:
_, ok1 := c.roles["anon"]
_, ok2 := c.roles["user"]
return !(ok1 && ok2)
}
return true
}
func (c *config) isAnonRoleDefined() bool {
_, ok := c.roles["anon"]
return ok
}
var varRe1 = regexp.MustCompile(`(?mi)\$([a-zA-Z0-9_.]+)`) var varRe1 = regexp.MustCompile(`(?mi)\$([a-zA-Z0-9_.]+)`)
var varRe2 = regexp.MustCompile(`\{\{([a-zA-Z0-9_.]+)\}\}`) var varRe2 = regexp.MustCompile(`\{\{([a-zA-Z0-9_.]+)\}\}`)

112
serv/config_compile.go Normal file
View File

@ -0,0 +1,112 @@
package serv
import (
"fmt"
"strings"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
)
func addForeignKeys(c *config, di *psql.DBInfo) error {
for _, t := range c.Tables {
for _, c := range t.Columns {
if err := addForeignKey(di, c, t); err != nil {
return err
}
}
}
return nil
}
func addForeignKey(di *psql.DBInfo, c configColumn, t configTable) error {
c1, ok := di.GetColumn(t.Name, c.Name)
if !ok {
return fmt.Errorf(
"Invalid table '%s' or column '%s in config",
t.Name, c.Name)
}
v := strings.SplitN(c.ForeignKey, ".", 2)
if len(v) != 2 {
return fmt.Errorf(
"Invalid foreign_key in config for table '%s' and column '%s",
t.Name, c.Name)
}
fkt, fkc := v[0], v[1]
c2, ok := di.GetColumn(fkt, fkc)
if !ok {
return fmt.Errorf(
"Invalid foreign_key in config for table '%s' and column '%s",
t.Name, c.Name)
}
c1.FKeyTable = fkt
c1.FKeyColID = []int16{c2.ID}
return nil
}
func addRoles(c *config, qc *qcode.Compiler) error {
for _, r := range c.Roles {
for _, t := range r.Tables {
if err := addRole(qc, r, t); err != nil {
return err
}
}
}
return nil
}
func addRole(qc *qcode.Compiler, r configRole, t configRoleTable) error {
blockFilter := []string{"false"}
query := qcode.QueryConfig{
Limit: t.Query.Limit,
Filters: t.Query.Filters,
Columns: t.Query.Columns,
DisableFunctions: t.Query.DisableFunctions,
}
if t.Query.Block {
query.Filters = blockFilter
}
insert := qcode.InsertConfig{
Filters: t.Insert.Filters,
Columns: t.Insert.Columns,
Presets: t.Insert.Presets,
}
if t.Query.Block {
insert.Filters = blockFilter
}
update := qcode.UpdateConfig{
Filters: t.Insert.Filters,
Columns: t.Insert.Columns,
Presets: t.Insert.Presets,
}
if t.Query.Block {
update.Filters = blockFilter
}
delete := qcode.DeleteConfig{
Filters: t.Insert.Filters,
Columns: t.Insert.Columns,
}
if t.Query.Block {
delete.Filters = blockFilter
}
return qc.AddRole(r.Name, t.Name, qcode.TRConfig{
Query: query,
Insert: insert,
Update: update,
Delete: delete,
})
}

View File

@ -77,7 +77,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
mutation := (qt == qcode.QTMutation) mutation := (qt == qcode.QTMutation)
anonQuery := (qt == qcode.QTQuery && c.req.role == "anon") anonQuery := (qt == qcode.QTQuery && c.req.role == "anon")
useRoleQuery := len(conf.RolesQuery) != 0 && mutation useRoleQuery := conf.isABCLEnabled() && mutation
useTx := useRoleQuery || conf.DB.SetUserID useTx := useRoleQuery || conf.DB.SetUserID
if useTx { if useTx {
@ -127,7 +127,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
row = db.QueryRow(c.Context, ps.sd.SQL, vars...) row = db.QueryRow(c.Context, ps.sd.SQL, vars...)
} }
if mutation || anonQuery { if mutation || anonQuery || !conf.isABCLEnabled() {
err = row.Scan(&root) err = row.Scan(&root)
} else { } else {
err = row.Scan(&role, &root) err = row.Scan(&role, &root)

View File

@ -24,14 +24,16 @@ func buildStmt(qt qcode.QType, gql, vars []byte, role string) ([]stmt, error) {
return buildRoleStmt(gql, vars, role) return buildRoleStmt(gql, vars, role)
case qcode.QTQuery: case qcode.QTQuery:
switch { if role == "anon" {
case role == "anon": return buildRoleStmt(gql, vars, "anon")
return buildRoleStmt(gql, vars, role) }
default: if conf.isABCLEnabled() {
return buildMultiStmt(gql, vars) return buildMultiStmt(gql, vars)
} }
return buildRoleStmt(gql, vars, "user")
default: default:
return nil, fmt.Errorf("unknown query type '%d'", qt) return nil, fmt.Errorf("unknown query type '%d'", qt)
} }

View File

@ -77,7 +77,15 @@ func prepareStmt(gql string, vars []byte) error {
switch qt { switch qt {
case qcode.QTQuery: case qcode.QTQuery:
stmts1, err := buildMultiStmt(q, vars) var stmts1 []stmt
var err error
if conf.isABCLEnabled() {
stmts1, err = buildMultiStmt(q, vars)
} else {
stmts1, err = buildRoleStmt(q, vars, "user")
}
if err != nil { if err != nil {
return err return err
} }
@ -87,14 +95,16 @@ func prepareStmt(gql string, vars []byte) error {
return err return err
} }
stmts2, err := buildRoleStmt(q, vars, "anon") if conf.isAnonRoleDefined() {
if err != nil { stmts2, err := buildRoleStmt(q, vars, "anon")
return err if err != nil {
} return err
}
err = prepare(tx, &stmts2[0], gqlHash(gql, vars, "anon")) err = prepare(tx, &stmts2[0], gqlHash(gql, vars, "anon"))
if err != nil { if err != nil {
return err return err
}
} }
case qcode.QTMutation: case qcode.QTMutation:
@ -142,7 +152,7 @@ func prepare(tx pgx.Tx, st *stmt, key string) error {
// nolint: errcheck // nolint: errcheck
func prepareRoleStmt(tx pgx.Tx) error { func prepareRoleStmt(tx pgx.Tx) error {
if len(conf.RolesQuery) == 0 { if !conf.isABCLEnabled() {
return nil return nil
} }

View File

@ -36,7 +36,6 @@ func initResolvers() error {
func initRemotes(t configTable) error { func initRemotes(t configTable) error {
h := xxhash.New() h := xxhash.New()
var err error
for _, r := range t.Remotes { for _, r := range t.Remotes {
// defines the table column to be used as an id in the // defines the table column to be used as an id in the
@ -46,21 +45,20 @@ func initRemotes(t configTable) error {
// if no table column specified in the config then // if no table column specified in the config then
// use the primary key of the table as the id // use the primary key of the table as the id
if len(idcol) == 0 { if len(idcol) == 0 {
idcol, err = pcompile.IDColumn(t.Name) pcol, err := pcompile.IDColumn(t.Name)
if err != nil { if err != nil {
return err return err
} }
idcol = pcol.Key
} }
idk := fmt.Sprintf("__%s_%s", t.Name, idcol) idk := fmt.Sprintf("__%s_%s", t.Name, idcol)
// register a relationship between the remote data // register a relationship between the remote data
// and the database table // and the database table
val := &psql.DBRel{ val := &psql.DBRel{Type: psql.RelRemote}
Type: psql.RelRemote, val.Left.Col = idcol
Col1: idcol, val.Right.Col = idk
Col2: idk,
}
err := pcompile.AddRelationship(strings.ToLower(r.Name), t.Name, val) err := pcompile.AddRelationship(strings.ToLower(r.Name), t.Name, val)
if err != nil { if err != nil {

View File

@ -15,77 +15,29 @@ import (
) )
func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
var err error di, err := psql.GetDBInfo(db)
schema, err = psql.NewDBSchema(db, c.getAliasMap())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
conf := qcode.Config{ if err = addForeignKeys(c, di); err != nil {
return nil, nil, err
}
schema, err = psql.NewDBSchema(db, di, c.getAliasMap())
if err != nil {
return nil, nil, err
}
qc, err := qcode.NewCompiler(qcode.Config{
Blocklist: c.DB.Blocklist, Blocklist: c.DB.Blocklist,
} })
qc, err := qcode.NewCompiler(conf)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
blockFilter := []string{"false"} if err := addRoles(c, qc); err != nil {
return nil, nil, err
for _, r := range c.Roles {
for _, t := range r.Tables {
query := qcode.QueryConfig{
Limit: t.Query.Limit,
Filters: t.Query.Filters,
Columns: t.Query.Columns,
DisableFunctions: t.Query.DisableFunctions,
}
if t.Query.Block {
query.Filters = blockFilter
}
insert := qcode.InsertConfig{
Filters: t.Insert.Filters,
Columns: t.Insert.Columns,
Presets: t.Insert.Presets,
}
if t.Query.Block {
insert.Filters = blockFilter
}
update := qcode.UpdateConfig{
Filters: t.Insert.Filters,
Columns: t.Insert.Columns,
Presets: t.Insert.Presets,
}
if t.Query.Block {
update.Filters = blockFilter
}
delete := qcode.DeleteConfig{
Filters: t.Insert.Filters,
Columns: t.Insert.Columns,
}
if t.Query.Block {
delete.Filters = blockFilter
}
err := qc.AddRole(r.Name, t.Name, qcode.TRConfig{
Query: query,
Insert: insert,
Update: update,
Delete: delete,
})
if err != nil {
return nil, nil, err
}
}
} }
pc := psql.NewCompiler(psql.Config{ pc := psql.NewCompiler(psql.Config{