From 4edc15eb9895a5c84f68dade668c1a7137b8b4ed Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Fri, 25 Oct 2019 00:01:22 -0400 Subject: [PATCH] Optimize prepared statement flow for RBAC --- psql/mutate.go | 2 -- psql/query_test.go | 32 ++++++++++++------------- qcode/parse.go | 2 ++ qcode/qcode.go | 2 ++ serv/allow.go | 2 ++ serv/auth.go | 6 ++--- serv/auth_rails.go | 4 ++-- serv/cmd.go | 2 ++ serv/config.go | 26 ++++++++++++++++++++ serv/core.go | 14 ++++++++--- serv/http.go | 8 ++++--- serv/prepare.go | 9 +++++-- serv/utils.go | 14 +++++++++++ serv/utils_test.go | 59 ++++++++++++++++++++++++++++++++++------------ 14 files changed, 136 insertions(+), 46 deletions(-) diff --git a/psql/mutate.go b/psql/mutate.go index 067270e..84bb122 100644 --- a/psql/mutate.go +++ b/psql/mutate.go @@ -61,8 +61,6 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w io.Writer, vars Variables root.Where = nil root.Args = nil - qc.Type = qcode.QTQuery - return c.compileQuery(qc, w) } diff --git a/psql/query_test.go b/psql/query_test.go index ea65645..c24b439 100644 --- a/psql/query_test.go +++ b/psql/query_test.go @@ -201,7 +201,7 @@ func withComplexArgs(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -229,7 +229,7 @@ func withWhereMultiOr(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -255,7 +255,7 @@ func withWhereIsNull(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -281,7 +281,7 @@ func withWhereAndList(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -301,7 +301,7 @@ func fetchByID(t *testing.T) { } }` - sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` + sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -321,7 +321,7 @@ func searchQuery(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -344,7 +344,7 @@ func oneToMany(t *testing.T) { } }` - sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -367,7 +367,7 @@ func belongsTo(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -390,7 +390,7 @@ func manyToMany(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -413,7 +413,7 @@ func manyToManyReverse(t *testing.T) { } }` - sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -433,7 +433,7 @@ func aggFunction(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -453,7 +453,7 @@ func aggFunctionBlockedByCol(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "anon") if err != nil { @@ -473,7 +473,7 @@ func aggFunctionDisabled(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "anon1") if err != nil { @@ -493,7 +493,7 @@ func aggFunctionWithFilter(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -513,7 +513,7 @@ func queryWithVariables(t *testing.T) { } }` - sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` + sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -532,7 +532,7 @@ func syntheticTables(t *testing.T) { } }` - sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";` + sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { diff --git a/qcode/parse.go b/qcode/parse.go index c07ab45..0fe6c34 100644 --- a/qcode/parse.go +++ b/qcode/parse.go @@ -18,7 +18,9 @@ type parserType int32 const ( maxFields = 100 maxArgs = 10 +) +const ( parserError parserType = iota parserEOF opQuery diff --git a/qcode/qcode.go b/qcode/qcode.go index 30bc724..a1d55e3 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -16,7 +16,9 @@ type Action int const ( maxSelectors = 30 +) +const ( QTQuery QType = iota + 1 QTInsert QTUpdate diff --git a/serv/allow.go b/serv/allow.go index 1279170..f960fc0 100644 --- a/serv/allow.go +++ b/serv/allow.go @@ -182,6 +182,8 @@ func (al *allowList) load() { item.vars = varBytes } + //fmt.Println("%%", item.gql, string(item.vars)) + al.list[gqlHash(q, varBytes, "")] = item varBytes = nil diff --git a/serv/auth.go b/serv/auth.go index 77942eb..22ab698 100644 --- a/serv/auth.go +++ b/serv/auth.go @@ -7,9 +7,9 @@ import ( ) var ( - userIDProviderKey = struct{}{} - userIDKey = struct{}{} - userRoleKey = struct{}{} + userIDProviderKey = "user_id_provider" + userIDKey = "user_id" + userRoleKey = "user_role" ) func headerAuth(next http.HandlerFunc) http.HandlerFunc { diff --git a/serv/auth_rails.go b/serv/auth_rails.go index cd0b327..7f78da0 100644 --- a/serv/auth_rails.go +++ b/serv/auth_rails.go @@ -122,14 +122,14 @@ func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ck, err := r.Cookie(cookie) if err != nil { - logger.Warn().Err(err).Send() + logger.Warn().Err(err).Msg("rails cookie missing") next.ServeHTTP(w, r) return } userID, err := ra.ParseCookie(ck.Value) if err != nil { - logger.Warn().Err(err).Send() + logger.Warn().Err(err).Msg("failed to parse rails cookie") next.ServeHTTP(w, r) return } diff --git a/serv/cmd.go b/serv/cmd.go index 12b3ce7..6fbc09f 100644 --- a/serv/cmd.go +++ b/serv/cmd.go @@ -210,6 +210,8 @@ func initConf() (*config, error) { c.Roles = append(c.Roles, configRole{Name: "anon"}) } + c.Validate() + return c, nil } diff --git a/serv/config.go b/serv/config.go index 8420c66..39e3d53 100644 --- a/serv/config.go +++ b/serv/config.go @@ -168,6 +168,32 @@ func newConfig() *viper.Viper { return vi } +func (c *config) Validate() { + rm := make(map[string]struct{}) + + for i := range c.Roles { + name := strings.ToLower(c.Roles[i].Name) + if _, ok := rm[name]; ok { + logger.Fatal().Msgf("duplicate config for role '%s'", c.Roles[i].Name) + } + rm[name] = struct{}{} + } + + tm := make(map[string]struct{}) + + for i := range c.Tables { + name := strings.ToLower(c.Tables[i].Name) + if _, ok := tm[name]; ok { + logger.Fatal().Msgf("duplicate config for table '%s'", c.Tables[i].Name) + } + tm[name] = struct{}{} + } + + if len(c.RolesQuery) == 0 { + logger.Warn().Msgf("no 'roles_query' defined.") + } +} + func (c *config) getAliasMap() map[string][]string { m := make(map[string][]string, len(c.Tables)) diff --git a/serv/core.go b/serv/core.go index 8a007ac..dc5d42b 100644 --- a/serv/core.go +++ b/serv/core.go @@ -131,16 +131,20 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) { } var role string - useRoleQuery := len(conf.RolesQuery) != 0 && isMutation(c.req.Query) + mutation := isMutation(c.req.Query) + useRoleQuery := len(conf.RolesQuery) != 0 && mutation if useRoleQuery { if role, err = c.executeRoleQuery(tx); err != nil { return nil, nil, err } + } else if v := c.Value(userRoleKey); v != nil { role = v.(string) - } else { + + } else if mutation { role = c.req.role + } ps, ok := _preparedList[gqlHash(c.req.Query, c.req.Vars, role)] @@ -151,7 +155,11 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) { var root []byte vars := varList(c, ps.args) - err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root) + if mutation { + err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root) + } else { + err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&c.req.role, &root) + } if err != nil { return nil, nil, err } diff --git a/serv/http.go b/serv/http.go index ae8ff84..737006d 100644 --- a/serv/http.go +++ b/serv/http.go @@ -37,8 +37,8 @@ type gqlReq struct { type variables map[string]json.RawMessage type gqlResp struct { - Error string `json:"error,omitempty"` - Data json.RawMessage `json:"data"` + Error string `json:"message,omitempty"` + Data json.RawMessage `json:"data,omitempty"` Extensions *extensions `json:"extensions,omitempty"` } @@ -102,7 +102,9 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) { err = ctx.handleReq(w, r) if err == errUnauthorized { - http.Error(w, "Not authorized", 401) + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(gqlResp{Error: err.Error()}) + return } if err != nil { diff --git a/serv/prepare.go b/serv/prepare.go index 8415397..2329578 100644 --- a/serv/prepare.go +++ b/serv/prepare.go @@ -31,6 +31,7 @@ func initPreparedList() { } for _, v := range _allowList.list { + err := prepareStmt(v.gql, v.vars) if err != nil { logger.Warn().Str("gql", v.gql).Err(err).Send() @@ -52,6 +53,10 @@ func prepareStmt(gql string, varBytes json.RawMessage) error { return err } + if len(stmts) != 0 && stmts[0].qc.Type == qcode.QTQuery { + c.req.Vars = nil + } + for _, s := range stmts { if len(s.sql) == 0 { continue @@ -75,9 +80,9 @@ func prepareStmt(gql string, varBytes json.RawMessage) error { var key string if s.role == nil { - key = gqlHash(gql, varBytes, "") + key = gqlHash(gql, c.req.Vars, "") } else { - key = gqlHash(gql, varBytes, s.role.Name) + key = gqlHash(gql, c.req.Vars, s.role.Name) } _preparedList[key] = &preparedItem{ diff --git a/serv/utils.go b/serv/utils.go index 9e095e8..b59dded 100644 --- a/serv/utils.go +++ b/serv/utils.go @@ -24,13 +24,26 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 { func gqlHash(b string, vars []byte, role string) string { b = strings.TrimSpace(b) h := sha1.New() + query := "query" s, e := 0, 0 space := []byte{' '} + starting := true var b0, b1 byte for { + if starting && b[e] == 'q' { + n := 0 + se := e + for e < len(b) && n < len(query) && b[e] == query[n] { + n++ + e++ + } + if n != len(query) { + io.WriteString(h, strings.ToLower(b[se:e])) + } + } if ws(b[e]) { for e < len(b) && ws(b[e]) { e++ @@ -42,6 +55,7 @@ func gqlHash(b string, vars []byte, role string) string { h.Write(space) } } else { + starting = false s = e for e < len(b) && ws(b[e]) == false { e++ diff --git a/serv/utils_test.go b/serv/utils_test.go index 17d91b7..b8babeb 100644 --- a/serv/utils_test.go +++ b/serv/utils_test.go @@ -5,7 +5,7 @@ import ( "testing" ) -func TestRelaxHash1(t *testing.T) { +func TestGQLHash1(t *testing.T) { var v1 = ` products( limit: 30, @@ -24,15 +24,15 @@ func TestRelaxHash1(t *testing.T) { price } ` - h1 := gqlHash(v1, nil) - h2 := gqlHash(v2, nil) + h1 := gqlHash(v1, nil, "") + h2 := gqlHash(v2, nil, "") if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should") } } -func TestRelaxHash2(t *testing.T) { +func TestGQLHash2(t *testing.T) { var v1 = ` { products( @@ -53,15 +53,15 @@ func TestRelaxHash2(t *testing.T) { var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } ` - h1 := gqlHash(v1, nil) - h2 := gqlHash(v2, nil) + h1 := gqlHash(v1, nil, "") + h2 := gqlHash(v2, nil, "") if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should") } } -func TestRelaxHash3(t *testing.T) { +func TestGQLHash3(t *testing.T) { var v1 = `users { id email @@ -86,15 +86,44 @@ func TestRelaxHash3(t *testing.T) { } ` - h1 := gqlHash(v1, nil) - h2 := gqlHash(v2, nil) + h1 := gqlHash(v1, nil, "") + h2 := gqlHash(v2, nil, "") if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should") } } -func TestRelaxHashWithVars1(t *testing.T) { +func TestGQLHash4(t *testing.T) { + var v1 = ` + query { + products( + limit: 30 + order_by: { price: desc } + distinct: [price] + where: { id: { and: { greater_or_equals: 20, lt: 28 } } } + ) { + id + name + price + user { + id + email + } + } + }` + + var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } ` + + h1 := gqlHash(v1, nil, "") + h2 := gqlHash(v2, nil, "") + + if strings.Compare(h1, h2) != 0 { + t.Fatal("Hashes don't match they should") + } +} + +func TestGQLHashWithVars1(t *testing.T) { var q1 = ` products( limit: 30, @@ -136,15 +165,15 @@ func TestRelaxHashWithVars1(t *testing.T) { "user": 123 }` - h1 := gqlHash(q1, []byte(v1)) - h2 := gqlHash(q2, []byte(v2)) + h1 := gqlHash(q1, []byte(v1), "user") + h2 := gqlHash(q2, []byte(v2), "user") if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should") } } -func TestRelaxHashWithVars2(t *testing.T) { +func TestGQLHashWithVars2(t *testing.T) { var q1 = ` products( limit: 30, @@ -193,8 +222,8 @@ func TestRelaxHashWithVars2(t *testing.T) { "user": 123 }` - h1 := gqlHash(q1, []byte(v1)) - h2 := gqlHash(q2, []byte(v2)) + h1 := gqlHash(q1, []byte(v1), "user") + h2 := gqlHash(q2, []byte(v2), "user") if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should")