diff --git a/config/allow.list b/config/allow.list index a0befc5..5a989a9 100644 --- a/config/allow.list +++ b/config/allow.list @@ -1,5 +1,68 @@ # http://localhost:8080/ +variables { + "update": { + "name": "Hellooooo", + "description": "World", + "created_at": "now", + "updated_at": "now" + }, + "user": 123 +} + +mutation { + products(update: $update, where: {id: {eq: 134}}) { + id + name + description + } +} + +variables { + "update": { + "name": "Hellooooo", + "description": "World !!!!!" + }, + "user": 123 +} + +mutation { + products(id: 5, update: $update) { + id + name + description + } +} + +variables { + "id": 5 +} + +{ + products(id: $ID) { + id + name + description + } +} + + +variables { + "update": { + "name": "Hellooooo", + "description": "World" + }, + "user": 123 +} + +mutation { + products(update: $update, where: {id: {eq: 134}}) { + id + name + description + } +} + query { me { id @@ -32,94 +95,3 @@ query { } } -query { - products( - limit: 30 - order_by: { price: desc } - distinct: [price] - where: { id: { and: { greater_or_equals: 20, lt: 28 } } } - ) { - id - name - price - user { - id - email - } - } - -variables { - "insert": { - "name": "Hello", - "description": "World", - "created_at": "now", - "updated_at": "now" - }, - "user": 123 -} - -mutation { - products(insert: $insert) { - id - name - description - } -} - -variables { - "insert": { - "name": "Hello", - "description": "World", - "created_at": "now", - "updated_at": "now" - }, - "user": 123 -} - -mutation { - products(insert: $insert) { - id - } -} - -variables { - "insert": { - "description": "World3", - "name": "Hello3", - "created_at": "now", - "updated_at": "now" - }, - "user": 123 -} - -{ - customers { - id - email - payments { - customer_id - amount - billing_details - } - } -} - - -variables { - "insert": { - "description": "World3", - "name": "Hello3", - "created_at": "now", - "updated_at": "now" - }, - "user": 123 -} - -{ - me { - id - full_name - } -} - - diff --git a/psql/insert.go b/psql/insert.go index 9bc9d0c..ca3118e 100644 --- a/psql/insert.go +++ b/psql/insert.go @@ -3,6 +3,7 @@ package psql import ( "bytes" "errors" + "fmt" "io" "github.com/dosco/super-graph/jsn" @@ -21,8 +22,19 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia c.w.WriteString(root.Table) c.w.WriteString(` AS (`) - if _, err := c.renderInsert(qc, w, vars); err != nil { - return 0, err + switch root.Action { + case qcode.ActionInsert: + if _, err := c.renderInsert(qc, w, vars); err != nil { + return 0, err + } + + case qcode.ActionUpdate: + if _, err := c.renderUpdate(qc, w, vars); err != nil { + return 0, err + } + + default: + return 0, errors.New("valid mutations are 'insert' and 'update'") } c.w.WriteString(`) `) @@ -33,9 +45,14 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { root := &qc.Selects[0] - insert, ok := vars["insert"] + insert, ok := vars[root.ActionVar] if !ok { - return 0, errors.New("Variable 'insert' not defined") + return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) + } + + ti, err := c.schema.GetTable(root.Table) + if err != nil { + return 0, err } jt, array, err := jsn.Tree(insert) @@ -45,12 +62,12 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, vars Va c.w.WriteString(`WITH input AS (SELECT {{insert}}::json AS j) INSERT INTO `) c.w.WriteString(root.Table) - io.WriteString(c.w, " (") - c.renderInsertColumns(qc, w, jt) - io.WriteString(c.w, ")") + io.WriteString(c.w, ` (`) + c.renderInsertColumns(qc, w, jt, ti) + io.WriteString(c.w, `)`) c.w.WriteString(` SELECT `) - c.renderInsertColumns(qc, w, jt) + c.renderInsertColumns(qc, w, jt, ti) c.w.WriteString(` FROM input i, `) if array { @@ -67,12 +84,7 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, vars Va } func (c *compilerContext) renderInsertColumns(qc *qcode.QCode, w *bytes.Buffer, - jt map[string]interface{}) (uint32, error) { - - ti, err := c.schema.GetTable(qc.Selects[0].Table) - if err != nil { - return 0, err - } + jt map[string]interface{}, ti *DBTableInfo) (uint32, error) { i := 0 for _, cn := range ti.ColumnNames { @@ -80,7 +92,7 @@ func (c *compilerContext) renderInsertColumns(qc *qcode.QCode, w *bytes.Buffer, continue } if i != 0 { - io.WriteString(c.w, ", ") + io.WriteString(c.w, `, `) } c.w.WriteString(cn) i++ @@ -88,3 +100,72 @@ func (c *compilerContext) renderInsertColumns(qc *qcode.QCode, w *bytes.Buffer, return 0, nil } + +func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { + root := &qc.Selects[0] + + update, ok := vars[root.ActionVar] + if !ok { + return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) + } + + ti, err := c.schema.GetTable(root.Table) + if err != nil { + return 0, err + } + + jt, array, err := jsn.Tree(update) + if err != nil { + return 0, err + } + + c.w.WriteString(`WITH input AS (SELECT {{update}}::json AS j) UPDATE `) + c.w.WriteString(root.Table) + io.WriteString(c.w, ` SET (`) + c.renderInsertColumns(qc, w, jt, ti) + + c.w.WriteString(`) = (SELECT `) + c.renderInsertColumns(qc, w, jt, ti) + c.w.WriteString(` FROM input i, `) + + if array { + c.w.WriteString(`json_populate_recordset`) + } else { + c.w.WriteString(`json_populate_record`) + } + + c.w.WriteString(`(NULL::`) + c.w.WriteString(root.Table) + c.w.WriteString(`, i.j) t)`) + + io.WriteString(c.w, ` WHERE `) + + if err := c.renderWhere(root, ti); err != nil { + return 0, err + } + + io.WriteString(c.w, ` RETURNING *`) + + return 0, nil +} + +func (c *compilerContext) renderUpdateColumns(qc *qcode.QCode, w *bytes.Buffer, + jt map[string]interface{}, ti *DBTableInfo) (uint32, error) { + + i := 0 + for _, cn := range ti.ColumnNames { + if _, ok := jt[cn]; !ok { + continue + } + if i != 0 { + io.WriteString(c.w, `, `) + } + c.w.WriteString(cn) + c.w.WriteString(` = {{`) + c.w.WriteString(cn) + c.w.WriteString(`}}`) + i++ + } + + return 0, nil +} diff --git a/psql/insert_test.go b/psql/insert_test.go index a5d1a25..6542361 100644 --- a/psql/insert_test.go +++ b/psql/insert_test.go @@ -14,7 +14,7 @@ func singleInsert(t *testing.T) { } }` - sql := `test` + sql := `WITH product AS (WITH input AS (SELECT {{insert}}::json AS j) INSERT INTO product (name, description) SELECT name, description FROM input i, json_populate_record(NULL::product, i.j) t RETURNING * ) SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > 0) AND (("product"."price") < 8) AND (("id") = 15)) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";` vars := map[string]json.RawMessage{ "insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), @@ -25,8 +25,6 @@ func singleInsert(t *testing.T) { t.Fatal(err) } - fmt.Println(">", string(resSQL)) - if string(resSQL) != sql { t.Fatal(errNotExpected) } @@ -40,7 +38,7 @@ func bulkInsert(t *testing.T) { } }` - sql := `test` + sql := `WITH product AS (WITH input AS (SELECT {{insert}}::json AS j) INSERT INTO product (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::product, i.j) t RETURNING * ) SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > 0) AND (("product"."price") < 8) AND (("id") = 15)) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";` vars := map[string]json.RawMessage{ "insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), @@ -51,7 +49,31 @@ func bulkInsert(t *testing.T) { t.Fatal(err) } - fmt.Println(">", string(resSQL)) + if string(resSQL) != sql { + t.Fatal(errNotExpected) + } +} + +func singleUpdate(t *testing.T) { + gql := `mutation { + product(id: 15, update: $update, where: { id: { eq: 1 } }) { + id + name + } + }` + + sql := `test` + + vars := map[string]json.RawMessage{ + "update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), + } + + resSQL, err := compileGQLToPSQL(gql, vars) + if err != nil { + t.Fatal(err) + } + + fmt.Println(string(resSQL)) if string(resSQL) != sql { t.Fatal(errNotExpected) @@ -61,5 +83,5 @@ func bulkInsert(t *testing.T) { func TestCompileInsert(t *testing.T) { t.Run("singleInsert", singleInsert) t.Run("bulkInsert", bulkInsert) - + t.Run("singleUpdate", singleUpdate) } diff --git a/psql/select_test.go b/psql/select_test.go index 77b83a8..acbbfed 100644 --- a/psql/select_test.go +++ b/psql/select_test.go @@ -42,7 +42,7 @@ func TestMain(m *testing.M) { "secret", "password", "token", - }, + } }) if err != nil { diff --git a/qcode/qcode.go b/qcode/qcode.go index 306f745..47a1d78 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -11,12 +11,17 @@ import ( ) type QType int +type Action int const ( maxSelectors = 30 QTQuery QType = iota + 1 QTMutation + + ActionInsert Action = iota + 1 + ActionUpdate + ActionDelete ) type QCode struct { @@ -41,6 +46,8 @@ type Select struct { OrderBy []*OrderBy DistinctOn []string Paging Paging + Action Action + ActionVar string Children []int32 } @@ -360,6 +367,15 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error { err = com.compileArgLimit(sel, arg) case "offset": err = com.compileArgOffset(sel, arg) + case "insert": + sel.Action = ActionInsert + err = com.compileArgAction(sel, arg) + case "update": + sel.Action = ActionUpdate + err = com.compileArgAction(sel, arg) + case "delete": + sel.Action = ActionDelete + err = com.compileArgAction(sel, arg) } if err != nil { @@ -647,6 +663,15 @@ func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error { return nil } +func (com *Compiler) compileArgAction(sel *Select, arg *Arg) error { + if arg.Val.Type != nodeVar { + return fmt.Errorf("value for argument '%s' must be a variable", arg.Name) + } + sel.ActionVar = arg.Val.Val + + return nil +} + func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) { name := node.Name if name[0] == '_' {