diff --git a/psql/schema.go b/psql/schema.go index 389a134..b885f72 100644 --- a/psql/schema.go +++ b/psql/schema.go @@ -66,7 +66,14 @@ func NewDBSchema(info *DBInfo, aliases map[string][]string) (*DBSchema, error) { } for i, t := range info.Tables { - err := schema.updateRelationships(t, info.Columns[i]) + err := schema.firstDegreeRels(t, info.Columns[i]) + if err != nil { + return nil, err + } + } + + for i, t := range info.Tables { + err := schema.secondDegreeRels(t, info.Columns[i]) if err != nil { return nil, err } @@ -131,8 +138,7 @@ func (s *DBSchema) addTable( return nil } -func (s *DBSchema) updateRelationships(t DBTable, cols []DBColumn) error { - jcols := make([]DBColumn, 0, len(cols)) +func (s *DBSchema) firstDegreeRels(t DBTable, cols []DBColumn) error { ct := t.Key cti, ok := s.t[ct] if !ok { @@ -230,6 +236,51 @@ func (s *DBSchema) updateRelationships(t DBTable, cols []DBColumn) error { if err := s.SetRel(ft, ct, rel2); err != nil { return err } + } + + return nil +} + +func (s *DBSchema) secondDegreeRels(t DBTable, cols []DBColumn) error { + jcols := make([]DBColumn, 0, len(cols)) + ct := t.Key + cti, ok := s.t[ct] + if !ok { + return fmt.Errorf("invalid foreign key table '%s'", ct) + } + + for i := range cols { + c := cols[i] + + if len(c.FKeyTable) == 0 { + continue + } + + // Foreign key column name + ft := strings.ToLower(c.FKeyTable) + + ti, ok := s.t[ft] + if !ok { + return fmt.Errorf("invalid foreign key table '%s'", ft) + } + + // This is an embedded relationship like when a json/jsonb column + // is exposed as a table + if c.Name == c.FKeyTable && len(c.FKeyColID) == 0 { + continue + } + + if len(c.FKeyColID) == 0 { + continue + } + + // Foreign key column id + fcid := c.FKeyColID[0] + + if _, ok := ti.ColIDMap[fcid]; !ok { + return fmt.Errorf("invalid foreign key column id '%d' for table '%s'", + fcid, ti.Name) + } jcols = append(jcols, c) } @@ -322,6 +373,9 @@ func (s *DBSchema) GetTable(table string) (*DBTableInfo, error) { } func (s *DBSchema) SetRel(child, parent string, rel *DBRel) error { + sp := strings.ToLower(flect.Singularize(parent)) + pp := strings.ToLower(flect.Pluralize(parent)) + sc := strings.ToLower(flect.Singularize(child)) pc := strings.ToLower(flect.Pluralize(child)) @@ -333,9 +387,6 @@ func (s *DBSchema) SetRel(child, parent string, rel *DBRel) error { s.rm[pc] = make(map[string]*DBRel) } - sp := strings.ToLower(flect.Singularize(parent)) - pp := strings.ToLower(flect.Pluralize(parent)) - if _, ok := s.rm[sc][sp]; !ok { s.rm[sc][sp] = rel } diff --git a/psql/strings.go b/psql/strings.go index b9c4292..6213a44 100644 --- a/psql/strings.go +++ b/psql/strings.go @@ -19,6 +19,10 @@ func (rt RelType) String() string { } func (re *DBRel) String() string { + if re.Type == RelOneToManyThrough { + return fmt.Sprintf("'%s.%s' --(Through: %s)--> '%s.%s'", + re.Left.Table, re.Left.Col, re.Through, re.Right.Table, re.Right.Col) + } return fmt.Sprintf("'%s.%s' --(%s)--> '%s.%s'", re.Left.Table, re.Left.Col, re.Type, re.Right.Table, re.Right.Col) } diff --git a/psql/test_schema.go b/psql/test_schema.go index 235e8ca..32da488 100644 --- a/psql/test_schema.go +++ b/psql/test_schema.go @@ -92,7 +92,14 @@ func getTestSchema() *DBSchema { } for i, t := range tables { - err := schema.updateRelationships(t, columns[i]) + err := schema.firstDegreeRels(t, columns[i]) + if err != nil { + log.Fatal(err) + } + } + + for i, t := range tables { + err := schema.secondDegreeRels(t, columns[i]) if err != nil { log.Fatal(err) }