diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index d926f2b1fc..be23357a33 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -28,7 +28,9 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, // TODO: Deprecate defaultTable var defaultTable *ast.TableName var tables []*ast.TableName + var ctes []*ast.TableName + typeMapCte := map[string]map[string]map[string]*catalog.Column{} typeMap := map[string]map[string]map[string]*catalog.Column{} indexTable := func(table catalog.Table) error { tables = append(tables, table.Rel) @@ -67,9 +69,35 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, continue } // If the table name doesn't exist, first check if it's a CTE - if _, qcerr := qc.GetTable(fqn); qcerr != nil { + cte, qcerr := qc.GetTable(fqn) + if qcerr != nil { return nil, err } + // duplicated logic from indexTable() + schema := fqn.Schema + if schema == "" { + schema = c.DefaultSchema + } + if _, exists := typeMapCte[schema]; !exists { + typeMapCte[schema] = map[string]map[string]*catalog.Column{} + } + typeMapCte[schema][fqn.Name] = map[string]*catalog.Column{} + for _, col := range cte.Columns { + cc := &catalog.Column{ + Name: col.Name, + IsNotNull: col.NotNull, + IsUnsigned: col.Unsigned, + IsArray: col.IsArray, + ArrayDims: col.ArrayDims, + Comment: col.Comment, + Length: col.Length, + } + if col.Type != nil { + cc.Type = *col.Type + } + typeMapCte[schema][fqn.Name][col.Name] = cc + } + ctes = append(ctes, fqn) continue } err = indexTable(table) @@ -195,7 +223,61 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, panic("too many field items: " + strconv.Itoa(len(items))) } - search := tables + search := ctes + if alias != "" { + if _, ok := aliasMap[alias]; ok { + // alias maps to a real table, skip CTE search entirely + search = []*ast.TableName{} + } + } + + // resolve using ctes first + var found int + for _, table := range search { + schema := table.Schema + if schema == "" { + schema = c.DefaultSchema + } + if c, ok := typeMapCte[schema][table.Name][key]; ok { + found += 1 + if ref.name != "" { + key = ref.name + } + + defaultP := named.NewInferredParam(key, c.IsNotNull) + p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) + a = append(a, Parameter{ + Number: ref.ref.Number, + Column: &Column{ + Name: p.Name(), + OriginalName: c.Name, + DataType: dataType(&c.Type), + NotNull: p.NotNull(), + Unsigned: c.IsUnsigned, + IsArray: c.IsArray, + ArrayDims: c.ArrayDims, + Length: c.Length, + Table: table, + IsNamedParam: isNamed, + IsSqlcSlice: p.IsSqlcSlice(), + }, + }) + } + } + + if found == 1 { + continue + } + if found > 1 { + return nil, &sqlerr.Error{ + Code: "42703", + Message: fmt.Sprintf("column reference %q is ambiguous", key), + Location: node.Location, + } + } + + + search = tables if alias != "" { if original, ok := aliasMap[alias]; ok { search = []*ast.TableName{original} @@ -217,7 +299,9 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } } - var found int + + // resolve using regular tables + found = 0 for _, table := range search { schema := table.Schema if schema == "" { @@ -286,6 +370,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, schema = c.DefaultSchema } + // should we use also look in CTEs? if c, ok := typeMap[schema][table.Name][key]; ok { defaultP := named.NewInferredParam(key, c.IsNotNull) p, isNamed := params.FetchMerge(ref.ref.Number, defaultP) @@ -475,6 +560,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, schema = c.DefaultSchema } + // should we use also look in CTEs? tableMap, ok := typeMap[schema][rel] if !ok { return nil, sqlerr.RelationNotFound(rel) @@ -583,6 +669,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } } + // should we use also look in CTEs? for _, table := range search { schema := table.Schema if schema == "" { diff --git a/internal/endtoend/testdata/cte_left_join/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/cte_left_join/postgresql/pgx/go/query.sql.go index 64c2320add..4e0516e8f8 100644 --- a/internal/endtoend/testdata/cte_left_join/postgresql/pgx/go/query.sql.go +++ b/internal/endtoend/testdata/cte_left_join/postgresql/pgx/go/query.sql.go @@ -7,8 +7,6 @@ package querytest import ( "context" - - "github.com/jackc/pgx/v5/pgtype" ) const badQuery = `-- name: BadQuery :exec @@ -28,7 +26,7 @@ FROM WHERE c1.name = $1 ` -func (q *Queries) BadQuery(ctx context.Context, dollar_1 pgtype.Text) error { - _, err := q.db.Exec(ctx, badQuery, dollar_1) +func (q *Queries) BadQuery(ctx context.Context, name string) error { + _, err := q.db.Exec(ctx, badQuery, name) return err } diff --git a/internal/endtoend/testdata/cte_renamed_column/issue.md b/internal/endtoend/testdata/cte_renamed_column/issue.md new file mode 100644 index 0000000000..36544f66f6 --- /dev/null +++ b/internal/endtoend/testdata/cte_renamed_column/issue.md @@ -0,0 +1 @@ +https://github.com/sqlc-dev/sqlc/issues/4288 diff --git a/internal/endtoend/testdata/cte_renamed_column/postgresql/go/db.go b/internal/endtoend/testdata/cte_renamed_column/postgresql/go/db.go new file mode 100644 index 0000000000..0057c62319 --- /dev/null +++ b/internal/endtoend/testdata/cte_renamed_column/postgresql/go/db.go @@ -0,0 +1,32 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/cte_renamed_column/postgresql/go/models.go b/internal/endtoend/testdata/cte_renamed_column/postgresql/go/models.go new file mode 100644 index 0000000000..e6de975872 --- /dev/null +++ b/internal/endtoend/testdata/cte_renamed_column/postgresql/go/models.go @@ -0,0 +1,10 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package querytest + +type User struct { + ID int32 + Name string +} diff --git a/internal/endtoend/testdata/cte_renamed_column/postgresql/go/query.sql.go b/internal/endtoend/testdata/cte_renamed_column/postgresql/go/query.sql.go new file mode 100644 index 0000000000..23107f13e7 --- /dev/null +++ b/internal/endtoend/testdata/cte_renamed_column/postgresql/go/query.sql.go @@ -0,0 +1,29 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: query.sql + +package querytest + +import ( + "context" +) + +const getUser = `-- name: GetUser :one +WITH found AS ( + SELECT id, id AS id2 FROM users WHERE id = $1 +) +SELECT id2 + $2 AS result FROM found +` + +type GetUserParams struct { + ID int32 + Id2 int32 +} + +func (q *Queries) GetUser(ctx context.Context, arg GetUserParams) (int32, error) { + row := q.db.QueryRow(ctx, getUser, arg.ID, arg.Id2) + var result int32 + err := row.Scan(&result) + return result, err +} diff --git a/internal/endtoend/testdata/cte_renamed_column/postgresql/query.sql b/internal/endtoend/testdata/cte_renamed_column/postgresql/query.sql new file mode 100644 index 0000000000..f3972e258d --- /dev/null +++ b/internal/endtoend/testdata/cte_renamed_column/postgresql/query.sql @@ -0,0 +1,5 @@ +-- name: GetUser :one +WITH found AS ( + SELECT id, id AS id2 FROM users WHERE id = $1 +) +SELECT id2 + $2 AS result FROM found; diff --git a/internal/endtoend/testdata/cte_renamed_column/postgresql/schema.sql b/internal/endtoend/testdata/cte_renamed_column/postgresql/schema.sql new file mode 100644 index 0000000000..7777c29178 --- /dev/null +++ b/internal/endtoend/testdata/cte_renamed_column/postgresql/schema.sql @@ -0,0 +1,5 @@ +-- Schema +CREATE TABLE users ( + id INT PRIMARY KEY, + name TEXT NOT NULL +); diff --git a/internal/endtoend/testdata/cte_renamed_column/postgresql/sqlc.yaml b/internal/endtoend/testdata/cte_renamed_column/postgresql/sqlc.yaml new file mode 100644 index 0000000000..5dc63e3f91 --- /dev/null +++ b/internal/endtoend/testdata/cte_renamed_column/postgresql/sqlc.yaml @@ -0,0 +1,10 @@ +version: "2" +sql: + - engine: "postgresql" + schema: "schema.sql" + queries: "query.sql" + gen: + go: + package: "querytest" + out: "go" + sql_package: "pgx/v5"