super-graph/psql/mutate.go

603 lines
12 KiB
Go
Raw Normal View History

//nolint:errcheck
2019-09-05 06:09:56 +02:00
package psql
import (
2019-12-25 07:24:30 +01:00
"encoding/json"
2019-09-05 06:09:56 +02:00
"errors"
2019-09-06 06:34:23 +02:00
"fmt"
2019-09-05 06:09:56 +02:00
"io"
"github.com/dosco/super-graph/jsn"
"github.com/dosco/super-graph/qcode"
2019-12-25 07:24:30 +01:00
"github.com/dosco/super-graph/util"
2019-09-05 06:09:56 +02:00
)
2019-12-25 07:24:30 +01:00
type itemType int
const (
itemInsert itemType = iota + 1
itemUpdate
itemConnect
itemDisconnect
itemUnion
)
var insertTypes = map[string]itemType{
"connect": itemConnect,
2019-12-25 07:24:30 +01:00
}
var updateTypes = map[string]itemType{
"connect": itemConnect,
"disconnect": itemDisconnect,
2019-12-25 07:24:30 +01:00
}
2019-10-14 08:51:36 +02:00
var noLimit = qcode.Paging{NoLimit: true}
2019-10-03 09:08:01 +02:00
func (co *Compiler) compileMutation(qc *qcode.QCode, w io.Writer, vars Variables) (uint32, error) {
2019-09-05 06:09:56 +02:00
if len(qc.Selects) == 0 {
return 0, errors.New("empty query")
}
c := &compilerContext{w, qc.Selects, co}
root := &qc.Selects[0]
ti, err := c.schema.GetTable(root.Name)
2019-10-05 08:17:08 +02:00
if err != nil {
return 0, err
}
2019-10-14 08:51:36 +02:00
switch qc.Type {
case qcode.QTInsert:
2019-10-05 08:17:08 +02:00
if _, err := c.renderInsert(qc, w, vars, ti); err != nil {
2019-09-06 06:34:23 +02:00
return 0, err
}
2019-10-14 08:51:36 +02:00
case qcode.QTUpdate:
2019-10-05 08:17:08 +02:00
if _, err := c.renderUpdate(qc, w, vars, ti); err != nil {
return 0, err
}
2019-10-14 08:51:36 +02:00
case qcode.QTUpsert:
2019-10-05 08:17:08 +02:00
if _, err := c.renderUpsert(qc, w, vars, ti); err != nil {
2019-09-06 06:34:23 +02:00
return 0, err
}
2019-10-14 08:51:36 +02:00
case qcode.QTDelete:
2019-10-05 08:17:08 +02:00
if _, err := c.renderDelete(qc, w, vars, ti); err != nil {
2019-09-06 07:17:45 +02:00
return 0, err
}
2019-09-06 06:34:23 +02:00
default:
return 0, errors.New("valid mutations are 'insert', 'update', 'upsert' and 'delete'")
2019-09-05 06:09:56 +02:00
}
2019-10-14 08:51:36 +02:00
root.Paging = noLimit
2019-10-03 09:08:01 +02:00
root.DistinctOn = root.DistinctOn[:]
root.OrderBy = root.OrderBy[:]
root.Where = nil
root.Args = nil
2019-09-05 06:09:56 +02:00
return c.compileQuery(qc, w)
}
2019-12-25 07:24:30 +01:00
type kvitem struct {
id int32
_type itemType
_ctype int
key string
path []string
val json.RawMessage
data map[string]json.RawMessage
array bool
ti *DBTableInfo
relCP *DBRel
relPC *DBRel
items []kvitem
2019-12-25 07:24:30 +01:00
}
2019-09-05 06:09:56 +02:00
2019-12-25 07:24:30 +01:00
type renitem struct {
kvitem
array bool
data map[string]json.RawMessage
}
2019-09-06 06:34:23 +02:00
2019-12-25 07:24:30 +01:00
func (c *compilerContext) handleKVItem(st *util.Stack, item kvitem) error {
var data map[string]json.RawMessage
var array bool
var err error
if item.data == nil {
data, array, err = jsn.Tree(item.val)
if err != nil {
return err
}
} else {
data, array = item.data, item.array
2019-09-05 06:09:56 +02:00
}
2019-12-25 07:24:30 +01:00
var unionize bool
id := item.id + 1
item.items = make([]kvitem, 0, len(data))
2019-09-05 06:09:56 +02:00
2019-12-25 07:24:30 +01:00
for k, v := range data {
if v[0] != '{' && v[0] != '[' {
continue
}
if _, ok := item.ti.ColMap[k]; ok {
continue
}
// Get child-to-parent relationship
relCP, err := c.schema.GetRel(k, item.key)
if err != nil {
var ty itemType
var ok bool
switch item._type {
case itemInsert:
ty, ok = insertTypes[k]
case itemUpdate:
ty, ok = updateTypes[k]
}
if ok {
unionize = true
item1 := item
item1._type = ty
item1.id = id
item1.val = v
2019-09-05 06:09:56 +02:00
2019-12-25 07:24:30 +01:00
item.items = append(item.items, item1)
id++
}
} else {
ti, err := c.schema.GetTable(k)
if err != nil {
return err
}
// Get parent-to-child relationship
relPC, err := c.schema.GetRel(item.key, k)
if err != nil {
return err
}
item1 := kvitem{
2019-12-25 07:24:30 +01:00
id: id,
_type: item._type,
key: k,
val: v,
path: append(item.path, k),
ti: ti,
relCP: relCP,
relPC: relPC,
}
if v[0] == '{' {
item1.data, item1.array, err = jsn.Tree(v)
if err != nil {
return err
}
if v1, ok := item1.data["connect"]; ok && (v1[0] == '{' || v1[0] == '[') {
item1._ctype |= (1 << itemConnect)
}
if v1, ok := item1.data["disconnect"]; ok && (v1[0] == '{' || v1[0] == '[') {
item1._ctype |= (1 << itemDisconnect)
}
}
item.items = append(item.items, item1)
2019-12-25 07:24:30 +01:00
id++
}
2019-09-05 06:09:56 +02:00
}
2019-12-25 07:24:30 +01:00
if unionize {
item._type = itemUnion
}
// For inserts order the children according to
// the creation order required by the parent-to-child
// relationships. For example users need to be created
// before the products they own.
// For updates the order defined in the query must be
// the order used.
switch item._type {
case itemInsert:
for _, v := range item.items {
if v.relPC.Type == RelOneToMany {
st.Push(v)
}
}
st.Push(renitem{kvitem: item, array: array, data: data})
for _, v := range item.items {
if v.relPC.Type == RelOneToOne {
st.Push(v)
}
}
2019-09-05 06:09:56 +02:00
case itemUpdate:
for _, v := range item.items {
if !(v._ctype > 0 && v.relPC.Type == RelOneToOne) {
st.Push(v)
}
}
st.Push(renitem{kvitem: item, array: array, data: data})
for _, v := range item.items {
if v._ctype > 0 && v.relPC.Type == RelOneToOne {
st.Push(v)
}
}
2019-12-25 07:24:30 +01:00
case itemUnion:
st.Push(renitem{kvitem: item, array: array, data: data})
for _, v := range item.items {
st.Push(v)
}
2019-12-25 07:24:30 +01:00
default:
for _, v := range item.items {
st.Push(v)
}
st.Push(renitem{kvitem: item, array: array, data: data})
}
2019-12-25 07:24:30 +01:00
return nil
2019-09-05 06:09:56 +02:00
}
func (c *compilerContext) renderUnionStmt(w io.Writer, item renitem) error {
var connect, disconnect bool
// Render only for parent-to-child relationship of one-to-many
if item.relPC.Type != RelOneToMany {
return nil
}
for _, v := range item.items {
if v._type == itemConnect {
connect = true
} else if v._type == itemDisconnect {
disconnect = true
}
if connect && disconnect {
break
}
}
if connect {
io.WriteString(w, `, `)
if connect && disconnect {
renderCteNameWithSuffix(w, item.kvitem, "c")
} else {
quoted(w, item.ti.Name)
}
io.WriteString(w, ` AS ( UPDATE `)
quoted(w, item.ti.Name)
io.WriteString(w, ` SET `)
quoted(w, item.relPC.Right.Col)
io.WriteString(w, ` = `)
colWithTable(w, item.relPC.Left.Table, item.relPC.Left.Col)
io.WriteString(w, `FROM `)
quoted(w, item.relPC.Left.Table)
io.WriteString(w, ` WHERE`)
i := 0
for _, v := range item.items {
if v._type == itemConnect {
if i != 0 {
io.WriteString(w, ` OR (`)
} else {
io.WriteString(w, ` (`)
}
if err := renderKVItemWhere(w, v); err != nil {
return err
}
io.WriteString(w, `)`)
i++
}
}
io.WriteString(w, ` RETURNING `)
quoted(w, item.ti.Name)
io.WriteString(w, `.*)`)
}
if disconnect {
io.WriteString(w, `, `)
if connect && disconnect {
renderCteNameWithSuffix(w, item.kvitem, "d")
} else {
quoted(w, item.ti.Name)
}
io.WriteString(w, ` AS ( UPDATE `)
quoted(w, item.ti.Name)
io.WriteString(w, ` SET `)
quoted(w, item.relPC.Right.Col)
io.WriteString(w, ` = NULL`)
io.WriteString(w, ` FROM `)
quoted(w, item.relPC.Left.Table)
io.WriteString(w, ` WHERE`)
i := 0
for _, v := range item.items {
if v._type == itemDisconnect {
if i != 0 {
io.WriteString(w, ` OR (`)
} else {
io.WriteString(w, ` (`)
}
if err := renderKVItemWhere(w, v); err != nil {
return err
}
io.WriteString(w, `)`)
i++
}
}
io.WriteString(w, ` RETURNING `)
quoted(w, item.ti.Name)
io.WriteString(w, `.*), `)
}
if connect && disconnect {
quoted(w, item.ti.Name)
io.WriteString(w, ` AS (`)
io.WriteString(w, `SELECT * FROM `)
renderCteNameWithSuffix(w, item.kvitem, "c")
io.WriteString(w, ` UNION ALL `)
io.WriteString(w, `SELECT * FROM `)
renderCteNameWithSuffix(w, item.kvitem, "d")
io.WriteString(w, `)`)
}
return nil
}
2019-12-25 07:24:30 +01:00
func renderInsertUpdateColumns(w io.Writer,
qc *qcode.QCode,
jt map[string]json.RawMessage,
ti *DBTableInfo,
skipcols map[string]struct{},
values bool) (uint32, error) {
2019-10-14 08:51:36 +02:00
root := &qc.Selects[0]
2019-09-06 06:34:23 +02:00
2019-12-25 07:24:30 +01:00
n := 0
for _, cn := range ti.Columns {
2019-12-25 07:24:30 +01:00
if _, ok := skipcols[cn.Name]; ok {
continue
}
if _, ok := jt[cn.Key]; !ok {
2019-09-06 06:34:23 +02:00
continue
}
if _, ok := root.PresetMap[cn.Key]; ok {
continue
}
2019-10-14 08:51:36 +02:00
if len(root.Allowed) != 0 {
if _, ok := root.Allowed[cn.Key]; !ok {
2019-10-14 08:51:36 +02:00
continue
}
}
2019-12-25 07:24:30 +01:00
if n != 0 {
io.WriteString(w, `, `)
2019-09-06 06:34:23 +02:00
}
2019-12-25 07:24:30 +01:00
if values {
colWithTable(w, "t", cn.Name)
} else {
quoted(w, cn.Name)
}
n++
}
for i := range root.PresetList {
2019-11-07 08:37:24 +01:00
cn := root.PresetList[i]
col, ok := ti.ColMap[cn]
2019-11-07 08:37:24 +01:00
if !ok {
continue
}
2019-12-25 07:24:30 +01:00
if _, ok := skipcols[col.Name]; ok {
continue
}
if i != 0 || n != 0 {
io.WriteString(w, `, `)
}
2019-12-25 07:24:30 +01:00
if values {
2019-12-25 07:24:30 +01:00
io.WriteString(w, `'`)
io.WriteString(w, root.PresetMap[cn])
io.WriteString(w, `' :: `)
io.WriteString(w, col.Type)
2019-11-07 08:37:24 +01:00
} else {
2019-12-25 07:24:30 +01:00
quoted(w, cn)
}
}
2019-09-06 06:34:23 +02:00
return 0, nil
}
func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
2019-10-05 08:17:08 +02:00
vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
2019-09-06 06:34:23 +02:00
2019-10-14 08:51:36 +02:00
upsert, ok := vars[qc.ActionVar]
2019-10-05 08:17:08 +02:00
if !ok {
2019-10-14 08:51:36 +02:00
return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
2019-10-05 08:17:08 +02:00
}
if ti.PrimaryCol == nil {
return 0, fmt.Errorf("no primary key column found")
}
2019-10-05 08:17:08 +02:00
jt, _, err := jsn.Tree(upsert)
2019-09-06 07:17:45 +02:00
if err != nil {
return 0, err
2019-09-05 06:09:56 +02:00
}
2019-10-05 08:17:08 +02:00
if _, err := c.renderInsert(qc, w, vars, ti); err != nil {
return 0, err
}
io.WriteString(c.w, ` ON CONFLICT (`)
2019-10-05 08:17:08 +02:00
i := 0
2019-09-06 07:17:45 +02:00
for _, cn := range ti.Columns {
if _, ok := jt[cn.Key]; !ok {
2019-10-05 08:17:08 +02:00
continue
}
if col, ok := ti.ColMap[cn.Key]; !ok || !(col.UniqueKey || col.PrimaryKey) {
2019-10-05 08:17:08 +02:00
continue
}
if i != 0 {
io.WriteString(c.w, `, `)
}
io.WriteString(c.w, cn.Name)
2019-10-05 08:17:08 +02:00
i++
}
if i == 0 {
io.WriteString(c.w, ti.PrimaryCol.Name)
2019-09-06 07:17:45 +02:00
}
io.WriteString(c.w, `)`)
if root.Where != nil {
io.WriteString(c.w, ` WHERE `)
if err := c.renderWhere(root, ti); err != nil {
return 0, err
}
}
2019-09-06 07:17:45 +02:00
io.WriteString(c.w, ` DO UPDATE SET `)
2019-10-05 08:17:08 +02:00
i = 0
for _, cn := range ti.Columns {
if _, ok := jt[cn.Key]; !ok {
2019-10-05 08:17:08 +02:00
continue
}
if i != 0 {
io.WriteString(c.w, `, `)
}
io.WriteString(c.w, cn.Name)
2019-10-05 08:17:08 +02:00
io.WriteString(c.w, ` = EXCLUDED.`)
io.WriteString(c.w, cn.Name)
2019-10-05 08:17:08 +02:00
i++
}
2019-09-06 07:17:45 +02:00
2019-12-25 07:24:30 +01:00
io.WriteString(c.w, ` RETURNING *) `)
2019-09-05 06:09:56 +02:00
return 0, nil
}
2019-09-20 06:19:11 +02:00
2019-12-25 07:24:30 +01:00
func (c *compilerContext) renderConnectStmt(qc *qcode.QCode, w io.Writer,
item renitem) error {
rel := item.relPC
// Render only for parent-to-child relationship of one-to-one
if rel.Type != RelOneToOne {
return nil
}
2019-12-25 07:24:30 +01:00
io.WriteString(w, `, `)
quoted(w, item.ti.Name)
io.WriteString(c.w, ` AS (`)
2019-12-25 07:24:30 +01:00
io.WriteString(c.w, `SELECT * FROM `)
quoted(c.w, item.ti.Name)
io.WriteString(c.w, ` WHERE `)
if err := renderKVItemWhere(c.w, item.kvitem); err != nil {
return err
2019-12-25 07:24:30 +01:00
}
io.WriteString(c.w, ` LIMIT 1)`)
2019-12-25 07:24:30 +01:00
return nil
}
func (c *compilerContext) renderDisconnectStmt(qc *qcode.QCode, w io.Writer,
item renitem) error {
rel := item.relPC
2019-12-25 07:24:30 +01:00
// Render only for parent-to-child relationship of one-to-one
if rel.Type != RelOneToOne {
return nil
2019-12-25 07:24:30 +01:00
}
io.WriteString(w, `, `)
quoted(w, item.ti.Name)
io.WriteString(c.w, ` AS (`)
2019-12-25 07:24:30 +01:00
io.WriteString(c.w, `SELECT * FROM (VALUES(NULL::`)
io.WriteString(w, rel.Right.col.Type)
io.WriteString(c.w, `)) AS LOOKUP(`)
quoted(w, rel.Right.Col)
io.WriteString(c.w, `))`)
2019-12-25 07:24:30 +01:00
return nil
}
func renderKVItemWhere(w io.Writer, item kvitem) error {
return renderWhereFromJSON(w, item.ti.Name, item.val)
2019-12-25 07:24:30 +01:00
}
func renderWhereFromJSON(w io.Writer, table string, val []byte) error {
2019-12-25 07:24:30 +01:00
var kv map[string]json.RawMessage
if err := json.Unmarshal(val, &kv); err != nil {
return err
}
i := 0
for k, v := range kv {
if i != 0 {
io.WriteString(w, ` AND `)
}
colWithTable(w, table, k)
2019-12-25 07:24:30 +01:00
io.WriteString(w, ` = '`)
switch v[0] {
case '"':
w.Write(v[1 : len(v)-1])
default:
w.Write(v)
}
io.WriteString(w, `'`)
i++
}
return nil
}
func renderCteName(w io.Writer, item kvitem) error {
io.WriteString(w, `"`)
io.WriteString(w, item.ti.Name)
if item._type == itemConnect || item._type == itemDisconnect {
io.WriteString(w, `_`)
int2string(w, item.id)
}
io.WriteString(w, `"`)
return nil
}
func renderCteNameWithSuffix(w io.Writer, item kvitem, suffix string) error {
io.WriteString(w, `"`)
io.WriteString(w, item.ti.Name)
io.WriteString(w, `_`)
io.WriteString(w, suffix)
io.WriteString(w, `"`)
return nil
}
func quoted(w io.Writer, identifier string) {
io.WriteString(w, `"`)
io.WriteString(w, identifier)
io.WriteString(w, `"`)
}
func joinPath(w io.Writer, path []string) {
for i := range path {
if i != 0 {
io.WriteString(w, `->`)
}
io.WriteString(w, `'`)
io.WriteString(w, path[i])
io.WriteString(w, `'`)
}
}