diff --git a/config/dev.yml b/config/dev.yml index 2fe6357..56b4764 100644 --- a/config/dev.yml +++ b/config/dev.yml @@ -79,7 +79,7 @@ database: # Define defaults to for the field key and values below defaults: - #filter: ["{ user_id: { eq: $user_id } }"] + # filter: ["{ user_id: { eq: $user_id } }"] # Field and table names that you wish to block blacklist: diff --git a/jsn/json_test.go b/jsn/json_test.go index ba545e6..5a1fbec 100644 --- a/jsn/json_test.go +++ b/jsn/json_test.go @@ -262,6 +262,24 @@ func TestStrip(t *testing.T) { } } +func TestValidateTrue(t *testing.T) { + json := []byte(` [{"id":1,"embed":{"id":8}},{"id":2},{"id":3},{"id":4},{"id":5},{"id":6},{"id":7},{"id":8},{"id":9},{"id":10},{"id":11},{"id":12},{"id":13}]`) + + err := Validate(string(json)) + if err != nil { + t.Error(err) + } +} + +func TestValidateFalse(t *testing.T) { + json := []byte(` [{ "hello": 123"}]`) + + err := Validate(string(json)) + if err == nil { + t.Error("JSON validation failed to detect invalid json") + } +} + func TestReplace(t *testing.T) { var buf bytes.Buffer diff --git a/jsn/validate.go b/jsn/validate.go new file mode 100644 index 0000000..731b6ba --- /dev/null +++ b/jsn/validate.go @@ -0,0 +1,386 @@ +package jsn + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "unsafe" +) + +// Validate validates JSON s. +func Validate(s string) error { + s = skipWS(s) + + tail, err := validateValue(s) + if err != nil { + return fmt.Errorf("cannot parse JSON: %s; unparsed tail: %q", err, startEndString(tail)) + } + tail = skipWS(tail) + if len(tail) > 0 { + return fmt.Errorf("unexpected tail: %q", startEndString(tail)) + } + return nil +} + +// ValidateBytes validates JSON b. +func ValidateBytes(b []byte) error { + return Validate(b2s(b)) +} + +func validateValue(s string) (string, error) { + if len(s) == 0 { + return s, fmt.Errorf("cannot parse empty string") + } + + if s[0] == '{' { + tail, err := validateObject(s[1:]) + if err != nil { + return tail, fmt.Errorf("cannot parse object: %s", err) + } + return tail, nil + } + if s[0] == '[' { + tail, err := validateArray(s[1:]) + if err != nil { + return tail, fmt.Errorf("cannot parse array: %s", err) + } + return tail, nil + } + if s[0] == '"' { + sv, tail, err := validateString(s[1:]) + if err != nil { + return tail, fmt.Errorf("cannot parse string: %s", err) + } + // Scan the string for control chars. + for i := 0; i < len(sv); i++ { + if sv[i] < 0x20 { + return tail, fmt.Errorf("string cannot contain control char 0x%02X", sv[i]) + } + } + return tail, nil + } + if s[0] == 't' { + if len(s) < len("true") || s[:len("true")] != "true" { + return s, fmt.Errorf("unexpected value found: %q", s) + } + return s[len("true"):], nil + } + if s[0] == 'f' { + if len(s) < len("false") || s[:len("false")] != "false" { + return s, fmt.Errorf("unexpected value found: %q", s) + } + return s[len("false"):], nil + } + if s[0] == 'n' { + if len(s) < len("null") || s[:len("null")] != "null" { + return s, fmt.Errorf("unexpected value found: %q", s) + } + return s[len("null"):], nil + } + + tail, err := validateNumber(s) + if err != nil { + return tail, fmt.Errorf("cannot parse number: %s", err) + } + return tail, nil +} + +func validateArray(s string) (string, error) { + s = skipWS(s) + if len(s) == 0 { + return s, fmt.Errorf("missing ']'") + } + if s[0] == ']' { + return s[1:], nil + } + + for { + var err error + + s = skipWS(s) + s, err = validateValue(s) + if err != nil { + return s, fmt.Errorf("cannot parse array value: %s", err) + } + + s = skipWS(s) + if len(s) == 0 { + return s, fmt.Errorf("unexpected end of array") + } + if s[0] == ',' { + s = s[1:] + continue + } + if s[0] == ']' { + s = s[1:] + return s, nil + } + return s, fmt.Errorf("missing ',' after array value") + } +} + +func validateObject(s string) (string, error) { + s = skipWS(s) + if len(s) == 0 { + return s, fmt.Errorf("missing '}'") + } + if s[0] == '}' { + return s[1:], nil + } + + for { + var err error + + // Parse key. + s = skipWS(s) + if len(s) == 0 || s[0] != '"' { + return s, fmt.Errorf(`cannot find opening '"" for object key`) + } + + var key string + key, s, err = validateKey(s[1:]) + if err != nil { + return s, fmt.Errorf("cannot parse object key: %s", err) + } + // Scan the key for control chars. + for i := 0; i < len(key); i++ { + if key[i] < 0x20 { + return s, fmt.Errorf("object key cannot contain control char 0x%02X", key[i]) + } + } + s = skipWS(s) + if len(s) == 0 || s[0] != ':' { + return s, fmt.Errorf("missing ':' after object key") + } + s = s[1:] + + // Parse value + s = skipWS(s) + s, err = validateValue(s) + if err != nil { + return s, fmt.Errorf("cannot parse object value: %s", err) + } + s = skipWS(s) + if len(s) == 0 { + return s, fmt.Errorf("unexpected end of object") + } + if s[0] == ',' { + s = s[1:] + continue + } + if s[0] == '}' { + return s[1:], nil + } + return s, fmt.Errorf("missing ',' after object value") + } +} + +// validateKey is similar to validateString, but is optimized +// for typical object keys, which are quite small and have no escape sequences. +func validateKey(s string) (string, string, error) { + for i := 0; i < len(s); i++ { + if s[i] == '"' { + // Fast path - the key doesn't contain escape sequences. + return s[:i], s[i+1:], nil + } + if s[i] == '\\' { + // Slow path - the key contains escape sequences. + return validateString(s) + } + } + return "", s, fmt.Errorf(`missing closing '"'`) +} + +func validateString(s string) (string, string, error) { + // Try fast path - a string without escape sequences. + if n := strings.IndexByte(s, '"'); n >= 0 && strings.IndexByte(s[:n], '\\') < 0 { + return s[:n], s[n+1:], nil + } + + // Slow path - escape sequences are present. + rs, tail, err := parseRawString(s) + if err != nil { + return rs, tail, err + } + for { + n := strings.IndexByte(rs, '\\') + if n < 0 { + return rs, tail, nil + } + n++ + if n >= len(rs) { + return rs, tail, fmt.Errorf("BUG: parseRawString returned invalid string with trailing backslash: %q", rs) + } + ch := rs[n] + rs = rs[n+1:] + switch ch { + case '"', '\\', '/', 'b', 'f', 'n', 'r', 't': + // Valid escape sequences - see http://json.org/ + break + case 'u': + if len(rs) < 4 { + return rs, tail, fmt.Errorf(`too short escape sequence: \u%s`, rs) + } + xs := rs[:4] + _, err := strconv.ParseUint(xs, 16, 16) + if err != nil { + return rs, tail, fmt.Errorf(`invalid escape sequence \u%s: %s`, xs, err) + } + rs = rs[4:] + default: + return rs, tail, fmt.Errorf(`unknown escape sequence \%c`, ch) + } + } +} + +func validateNumber(s string) (string, error) { + if len(s) == 0 { + return s, fmt.Errorf("zero-length number") + } + if s[0] == '-' { + s = s[1:] + if len(s) == 0 { + return s, fmt.Errorf("missing number after minus") + } + } + i := 0 + for i < len(s) { + if s[i] < '0' || s[i] > '9' { + break + } + i++ + } + if i <= 0 { + return s, fmt.Errorf("expecting 0..9 digit, got %c", s[0]) + } + if s[0] == '0' && i != 1 { + return s, fmt.Errorf("unexpected number starting from 0") + } + if i >= len(s) { + return "", nil + } + if s[i] == '.' { + // Validate fractional part + s = s[i+1:] + if len(s) == 0 { + return s, fmt.Errorf("missing fractional part") + } + i = 0 + for i < len(s) { + if s[i] < '0' || s[i] > '9' { + break + } + i++ + } + if i == 0 { + return s, fmt.Errorf("expecting 0..9 digit in fractional part, got %c", s[0]) + } + if i >= len(s) { + return "", nil + } + } + if s[i] == 'e' || s[i] == 'E' { + // Validate exponent part + s = s[i+1:] + if len(s) == 0 { + return s, fmt.Errorf("missing exponent part") + } + if s[0] == '-' || s[0] == '+' { + s = s[1:] + if len(s) == 0 { + return s, fmt.Errorf("missing exponent part") + } + } + i = 0 + for i < len(s) { + if s[i] < '0' || s[i] > '9' { + break + } + i++ + } + if i == 0 { + return s, fmt.Errorf("expecting 0..9 digit in exponent part, got %c", s[0]) + } + if i >= len(s) { + return "", nil + } + } + return s[i:], nil +} + +func skipWS(s string) string { + if len(s) == 0 || s[0] > 0x20 { + // Fast path. + return s + } + return skipWSSlow(s) +} + +func skipWSSlow(s string) string { + if len(s) == 0 || s[0] != 0x20 && s[0] != 0x0A && s[0] != 0x09 && s[0] != 0x0D { + return s + } + for i := 1; i < len(s); i++ { + if s[i] != 0x20 && s[i] != 0x0A && s[i] != 0x09 && s[i] != 0x0D { + return s[i:] + } + } + return "" +} + +func b2s(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +func s2b(s string) []byte { + strh := (*reflect.StringHeader)(unsafe.Pointer(&s)) + var sh reflect.SliceHeader + sh.Data = strh.Data + sh.Len = strh.Len + sh.Cap = strh.Len + return *(*[]byte)(unsafe.Pointer(&sh)) +} + +const maxStartEndStringLen = 80 + +func startEndString(s string) string { + if len(s) <= maxStartEndStringLen { + return s + } + start := s[:40] + end := s[len(s)-40:] + return start + "..." + end +} + +func parseRawString(s string) (string, string, error) { + n := strings.IndexByte(s, '"') + if n < 0 { + return s, "", fmt.Errorf(`missing closing '"'`) + } + if n == 0 || s[n-1] != '\\' { + // Fast path. No escaped ". + return s[:n], s[n+1:], nil + } + + // Slow path - possible escaped " found. + ss := s + for { + i := n - 1 + for i > 0 && s[i-1] == '\\' { + i-- + } + if uint(n-i)%2 == 0 { + return ss[:len(ss)-len(s)+n], s[n+1:], nil + } + s = s[n+1:] + + n = strings.IndexByte(s, '"') + if n < 0 { + return ss, "", fmt.Errorf(`missing closing '"'`) + } + if n == 0 || s[n-1] != '\\' { + return ss[:len(ss)-len(s)+n], s[n+1:], nil + } + } +} diff --git a/psql/psql.go b/psql/psql.go index 27e8a23..de111a1 100644 --- a/psql/psql.go +++ b/psql/psql.go @@ -132,6 +132,7 @@ func (c *Compiler) getTable(sel *qcode.Select) (*DBTableInfo, error) { if tn, ok := c.tmap[sel.Table]; ok { return c.schema.GetTable(tn) } + return c.schema.GetTable(sel.Table) } diff --git a/rails-app/Gemfile b/rails-app/Gemfile index 966b6f6..f823a14 100644 --- a/rails-app/Gemfile +++ b/rails-app/Gemfile @@ -4,7 +4,7 @@ git_source(:github) { |repo| "https://github.com/#{repo}.git" } ruby '2.5.5' # Bundle edge Rails instead: gem 'rails', github: 'rails/rails' -gem 'rails', '~> 5.2.2', '>= 5.2.2.1' +gem 'rails', '~> 6.0.0.rc1' # Use postgresql as the database for Active Record gem 'pg', '>= 0.18', '< 2.0' # Use Puma as the app server diff --git a/rails-app/Gemfile.lock b/rails-app/Gemfile.lock index d924251..f739603 100644 --- a/rails-app/Gemfile.lock +++ b/rails-app/Gemfile.lock @@ -14,52 +14,65 @@ GIT GEM remote: https://rubygems.org/ specs: - actioncable (5.2.2.1) - actionpack (= 5.2.2.1) + actioncable (6.0.0.rc1) + actionpack (= 6.0.0.rc1) nio4r (~> 2.0) websocket-driver (>= 0.6.1) - actionmailer (5.2.2.1) - actionpack (= 5.2.2.1) - actionview (= 5.2.2.1) - activejob (= 5.2.2.1) + actionmailbox (6.0.0.rc1) + actionpack (= 6.0.0.rc1) + activejob (= 6.0.0.rc1) + activerecord (= 6.0.0.rc1) + activestorage (= 6.0.0.rc1) + activesupport (= 6.0.0.rc1) + mail (>= 2.7.1) + actionmailer (6.0.0.rc1) + actionpack (= 6.0.0.rc1) + actionview (= 6.0.0.rc1) + activejob (= 6.0.0.rc1) mail (~> 2.5, >= 2.5.4) rails-dom-testing (~> 2.0) - actionpack (5.2.2.1) - actionview (= 5.2.2.1) - activesupport (= 5.2.2.1) + actionpack (6.0.0.rc1) + actionview (= 6.0.0.rc1) + activesupport (= 6.0.0.rc1) rack (~> 2.0) rack-test (>= 0.6.3) rails-dom-testing (~> 2.0) rails-html-sanitizer (~> 1.0, >= 1.0.2) - actionview (5.2.2.1) - activesupport (= 5.2.2.1) + actiontext (6.0.0.rc1) + actionpack (= 6.0.0.rc1) + activerecord (= 6.0.0.rc1) + activestorage (= 6.0.0.rc1) + activesupport (= 6.0.0.rc1) + nokogiri (>= 1.8.5) + actionview (6.0.0.rc1) + activesupport (= 6.0.0.rc1) builder (~> 3.1) erubi (~> 1.4) rails-dom-testing (~> 2.0) rails-html-sanitizer (~> 1.0, >= 1.0.3) - activejob (5.2.2.1) - activesupport (= 5.2.2.1) + activejob (6.0.0.rc1) + activesupport (= 6.0.0.rc1) globalid (>= 0.3.6) - activemodel (5.2.2.1) - activesupport (= 5.2.2.1) - activerecord (5.2.2.1) - activemodel (= 5.2.2.1) - activesupport (= 5.2.2.1) - arel (>= 9.0) - activestorage (5.2.2.1) - actionpack (= 5.2.2.1) - activerecord (= 5.2.2.1) + activemodel (6.0.0.rc1) + activesupport (= 6.0.0.rc1) + activerecord (6.0.0.rc1) + activemodel (= 6.0.0.rc1) + activesupport (= 6.0.0.rc1) + activestorage (6.0.0.rc1) + actionpack (= 6.0.0.rc1) + activejob (= 6.0.0.rc1) + activerecord (= 6.0.0.rc1) marcel (~> 0.3.1) - activesupport (5.2.2.1) + activesupport (6.0.0.rc1) concurrent-ruby (~> 1.0, >= 1.0.2) i18n (>= 0.7, < 2) minitest (~> 5.1) tzinfo (~> 1.1) + zeitwerk (~> 2.1, >= 2.1.4) addressable (2.6.0) public_suffix (>= 2.0.2, < 4.0) archive-zip (0.12.0) io-like (~> 0.3.0) - arel (9.0.0) bcrypt (3.1.12) bindex (0.5.0) bootsnap (1.4.1) @@ -125,7 +138,7 @@ GEM msgpack (1.2.9) multi_json (1.13.1) nio4r (2.3.1) - nokogiri (1.10.1) + nokogiri (1.10.3) mini_portile2 (~> 2.4.0) orm_adapter (0.5.0) pastel (0.7.2) @@ -134,33 +147,35 @@ GEM pg (1.1.4) public_suffix (3.0.3) puma (3.12.1) - rack (2.0.6) + rack (2.0.7) rack-test (1.1.0) rack (>= 1.0, < 3) - rails (5.2.2.1) - actioncable (= 5.2.2.1) - actionmailer (= 5.2.2.1) - actionpack (= 5.2.2.1) - actionview (= 5.2.2.1) - activejob (= 5.2.2.1) - activemodel (= 5.2.2.1) - activerecord (= 5.2.2.1) - activestorage (= 5.2.2.1) - activesupport (= 5.2.2.1) + rails (6.0.0.rc1) + actioncable (= 6.0.0.rc1) + actionmailbox (= 6.0.0.rc1) + actionmailer (= 6.0.0.rc1) + actionpack (= 6.0.0.rc1) + actiontext (= 6.0.0.rc1) + actionview (= 6.0.0.rc1) + activejob (= 6.0.0.rc1) + activemodel (= 6.0.0.rc1) + activerecord (= 6.0.0.rc1) + activestorage (= 6.0.0.rc1) + activesupport (= 6.0.0.rc1) bundler (>= 1.3.0) - railties (= 5.2.2.1) + railties (= 6.0.0.rc1) sprockets-rails (>= 2.0.0) rails-dom-testing (2.0.3) activesupport (>= 4.2.0) nokogiri (>= 1.6) rails-html-sanitizer (1.0.4) loofah (~> 2.2, >= 2.2.2) - railties (5.2.2.1) - actionpack (= 5.2.2.1) - activesupport (= 5.2.2.1) + railties (6.0.0.rc1) + actionpack (= 6.0.0.rc1) + activesupport (= 6.0.0.rc1) method_source rake (>= 0.8.7) - thor (>= 0.19.0, < 2.0) + thor (>= 0.20.3, < 2.0) rake (12.3.2) rb-fsevent (0.10.3) rb-inotify (0.10.0) @@ -251,6 +266,7 @@ GEM websocket-extensions (0.1.3) xpath (3.2.0) nokogiri (~> 1.8) + zeitwerk (2.1.6) PLATFORMS ruby @@ -267,7 +283,7 @@ DEPENDENCIES listen (>= 3.0.5, < 3.2) pg (>= 0.18, < 2.0) puma (~> 3.11) - rails (~> 5.2.2, >= 5.2.2.1) + rails (~> 6.0.0.rc1) redis-rails sass-rails (~> 5.0) selenium-webdriver diff --git a/serv/config.go b/serv/config.go index 6d00cbc..c425c57 100644 --- a/serv/config.go +++ b/serv/config.go @@ -1,5 +1,9 @@ package serv +import ( + "github.com/gobuffalo/flect" +) + type config struct { AppName string `mapstructure:"app_name"` Env string @@ -88,7 +92,7 @@ func (c *config) getAliasMap() map[string]string { if len(t.Table) == 0 { continue } - m[t.Name] = t.Table + m[flect.Pluralize(t.Name)] = t.Table } return m } @@ -102,11 +106,12 @@ func (c *config) getFilterMap() map[string][]string { if len(t.Filter) == 0 { continue } + name := flect.Pluralize(t.Name) if t.Filter[0] == "none" { - m[t.Name] = []string{} + m[name] = []string{} } else { - m[t.Name] = t.Filter + m[name] = t.Filter } } diff --git a/serv/core.go b/serv/core.go index 381bddb..118acea 100644 --- a/serv/core.go +++ b/serv/core.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "net/http" "os" @@ -75,6 +76,10 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { return errors.New("something wrong no remote ids found in db response") } + if err != nil { + return err + } + var ob bytes.Buffer err = jsn.Replace(&ob, data, from, to) @@ -192,14 +197,14 @@ func (c *coreContext) resolveRemotes( return nil, nil } - go func(n int) { + go func(n int, s *qcode.Select) { defer wg.Done() st := time.Now() b, err := r.Fn(req, id) if err != nil { - cerr = err + cerr = fmt.Errorf("%s: %s", s.Table, err) return } @@ -216,7 +221,7 @@ func (c *coreContext) resolveRemotes( if len(s.Cols) != 0 { err = jsn.Filter(&ob, b, colsToList(s.Cols)) if err != nil { - cerr = err + cerr = fmt.Errorf("%s: %s", s.Table, err) return } @@ -225,7 +230,7 @@ func (c *coreContext) resolveRemotes( } to[n] = jsn.Field{[]byte(s.FieldName), ob.Bytes()} - }(i) + }(i, s) } wg.Wait() diff --git a/serv/reso.go b/serv/reso.go index 76da0d2..a3ac702 100644 --- a/serv/reso.go +++ b/serv/reso.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/cespare/xxhash/v2" + "github.com/dosco/super-graph/jsn" "github.com/dosco/super-graph/psql" ) @@ -112,6 +113,10 @@ func buildFn(r configRemote) func(*http.Request, []byte) ([]byte, error) { return nil, err } + if err := jsn.ValidateBytes(b); err != nil { + return nil, err + } + return b, nil }