From 89d435640bcc6f44274dca4a40422a9e8757ad7d Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Sun, 31 Mar 2019 11:18:33 -0400 Subject: [PATCH] Add aggregrate functions to GQL queries --- README.md | 28 ++++++++ psql/psql.go | 163 ++++++++++++++++++++++++++++++++-------------- psql/psql_test.go | 20 ++++++ qcode/qcode.go | 6 +- 4 files changed, 167 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index 17bdbe6..0b926bc 100644 --- a/README.md +++ b/README.md @@ -160,6 +160,34 @@ contains | column: { contains: [1, 2, 4] } | Is this array/json column a subset contained_in | column: { contains: "{'a':1, 'b':2}" } | Is this array/json column a subset of these value is_null | column: { is_null: true } | Is column value null or not +#### Aggregation + +You will often find the need to fetch aggregated values from the database such as `count`, `max`, `min`, etc. This is simple to go with GraphQL just prefix the aggregation name to the field name that you want to aggregrate. The below query will group products by name and find the minimum price for each group. Notice the `min_price` field we're adding `min_` to price. + +```gql +query { + products { + name + min_price + } +} +``` + +Name | Explained | +--- | --- | +avg | Average value +count | Count the values +max | Maximum value +min | Minimum value +stddev | [Standard Deviation](https://en.wikipedia.org/wiki/Standard_deviation) +stddev_pop | Population Standard Deviation +stddev_samp | Sample Standard Deviation +variance | [Variance](https://en.wikipedia.org/wiki/Variance) +var_pop | Population Standard Variance +var_samp | Sample Standard variance + +All kinds of quries are possible with GraphQL below is an example that uses a lot of the features available to web devs using GraphQL to get the exact data they need. Comments are also valid within queries. + ```javascript query { products( diff --git a/psql/psql.go b/psql/psql.go index 93adb53..f5f9fd3 100644 --- a/psql/psql.go +++ b/psql/psql.go @@ -109,8 +109,6 @@ type selectBlock struct { func (v *selectBlock) render(w io.Writer, schema *DBSchema, childCols []*qcode.Column, childIDs []int) error { - isNotRoot := (v.parent != nil) - hasFilters := (v.sel.Where != nil) hasOrder := len(v.sel.OrderBy) != 0 // SELECT @@ -156,46 +154,11 @@ func (v *selectBlock) render(w io.Writer, } // END-SELECT - // FROM - io.WriteString(w, " FROM (SELECT ") - - // Local column names - v.renderLocalColumns(w, append(v.sel.Cols, childCols...)) - - fmt.Fprintf(w, ` FROM "%s"`, v.sel.Table) - - if isNotRoot || hasFilters { - if isNotRoot { - v.renderJoinTable(w, schema, childIDs) - } - - io.WriteString(w, ` WHERE (`) - - if isNotRoot { - v.renderRelationship(w, schema) - } - - if hasFilters { - err := v.renderWhere(w) - if err != nil { - return err - } - } - - io.WriteString(w, `)`) + // FROM (SELECT .... ) + err = v.renderBaseSelect(w, schema, childCols, childIDs) + if err != nil { + return err } - - if len(v.sel.Paging.Limit) != 0 { - fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, v.sel.Paging.Limit) - } else { - io.WriteString(w, ` LIMIT ('20') :: integer`) - } - - if len(v.sel.Paging.Offset) != 0 { - fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, v.sel.Paging.Offset) - } - - fmt.Fprintf(w, `) AS "%s_%d"`, v.sel.Table, v.sel.ID) // END-FROM return nil @@ -297,18 +260,94 @@ func (v *selectBlock) renderJoinedColumns(w io.Writer, childIDs []int) error { return nil } -func (v *selectBlock) renderLocalColumns(w io.Writer, columns []*qcode.Column) { - for i, col := range columns { - if len(col.Table) != 0 { - fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name) - } else { - fmt.Fprintf(w, `"%s"."%s"`, v.sel.Table, col.Name) +func (v *selectBlock) renderBaseSelect(w io.Writer, schema *DBSchema, childCols []*qcode.Column, childIDs []int) error { + var groupBy []int + + isNotRoot := (v.parent != nil) + hasFilters := (v.sel.Where != nil) + + io.WriteString(w, " FROM (SELECT ") + + for i, col := range v.sel.Cols { + cn := col.Name + fn := "" + + if _, ok := v.schema.ColMap[TCKey{v.sel.Table, cn}]; !ok { + pl := funcPrefixLen(cn) + if pl == 0 { + continue + } + fn = cn[0 : pl-1] + cn = cn[pl:] } - if i < len(columns)-1 { + if len(fn) != 0 { + fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, v.sel.Table, cn, col.Name) + } else { + groupBy = append(groupBy, i) + fmt.Fprintf(w, `"%s"."%s"`, v.sel.Table, cn) + } + + if i < len(v.sel.Cols)-1 || len(childCols) != 0 { io.WriteString(w, ", ") } } + + for i, col := range childCols { + fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name) + + if i < len(childCols)-1 { + io.WriteString(w, ", ") + } + } + + fmt.Fprintf(w, ` FROM "%s"`, v.sel.Table) + + if isNotRoot || hasFilters { + if isNotRoot { + v.renderJoinTable(w, schema, childIDs) + } + + io.WriteString(w, ` WHERE (`) + + if isNotRoot { + v.renderRelationship(w, schema) + } + + if hasFilters { + err := v.renderWhere(w) + if err != nil { + return err + } + } + + io.WriteString(w, `)`) + } + + if len(groupBy) != 0 { + fmt.Fprintf(w, ` GROUP BY `) + + for i, id := range groupBy { + fmt.Fprintf(w, `"%s"."%s"`, v.sel.Table, v.sel.Cols[id].Name) + + if i < len(groupBy)-1 { + io.WriteString(w, ", ") + } + } + } + + if len(v.sel.Paging.Limit) != 0 { + fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, v.sel.Paging.Limit) + } else { + io.WriteString(w, ` LIMIT ('20') :: integer`) + } + + if len(v.sel.Paging.Offset) != 0 { + fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, v.sel.Paging.Offset) + } + + fmt.Fprintf(w, `) AS "%s_%d"`, v.sel.Table, v.sel.ID) + return nil } func (v *selectBlock) renderOrderByColumns(w io.Writer) { @@ -549,3 +588,31 @@ func renderVal(w io.Writer, ex *qcode.Exp, vars Variables) { } io.WriteString(w, `)`) } + +func funcPrefixLen(fn string) int { + switch { + case strings.HasPrefix(fn, "avg_"): + return 4 + case strings.HasPrefix(fn, "count_"): + return 6 + case strings.HasPrefix(fn, "max_"): + return 4 + case strings.HasPrefix(fn, "min_"): + return 4 + case strings.HasPrefix(fn, "sum_"): + return 4 + case strings.HasPrefix(fn, "stddev_"): + return 7 + case strings.HasPrefix(fn, "stddev_pop_"): + return 11 + case strings.HasPrefix(fn, "stddev_samp_"): + return 12 + case strings.HasPrefix(fn, "variance_"): + return 9 + case strings.HasPrefix(fn, "var_pop_"): + return 8 + case strings.HasPrefix(fn, "var_samp_"): + return 9 + } + return 0 +} diff --git a/psql/psql_test.go b/psql/psql_test.go index 3a4784f..1438532 100644 --- a/psql/psql_test.go +++ b/psql/psql_test.go @@ -317,6 +317,26 @@ func TestCompileGQLManyToManyReverse(t *testing.T) { } } +func TestCompileGQLAggFunction(t *testing.T) { + gql := `query { + products { + name + count_price + } + }` + + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "products" FROM (SELECT "products"."name", count("products"."price") AS count_price FROM "products" GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` + + resSQL, err := compileGQLToPSQL(gql) + if err != nil { + t.Fatal(err) + } + + if resSQL != sql { + t.Fatal(errNotExpected) + } +} + func BenchmarkCompileGQLToSQL(b *testing.B) { gql := `query { products( diff --git a/qcode/qcode.go b/qcode/qcode.go index 2285976..da97633 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -284,18 +284,20 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) { for i := range field.Children { f := field.Children[i] + fn := strings.ToLower(f.Name) - if com.bl != nil && com.bl.MatchString(f.Name) { + if com.bl != nil && com.bl.MatchString(fn) { continue } if f.Children == nil { - col := &Column{Name: f.Name} + col := &Column{Name: fn} if len(f.Alias) != 0 { col.FieldName = f.Alias } else { col.FieldName = f.Name } + s.Cols = append(s.Cols, col) } else { st.Push(f)