Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion internal/codegen/golang/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ func (v QueryValue) YDBParamMapEntries() string {

// ydbBuilderMethodForColumnType maps a YDB column data type to a ParamsBuilder method name.
func ydbBuilderMethodForColumnType(dbType string) string {
switch strings.ToLower(dbType) {
baseType := extractBaseType(strings.ToLower(dbType))

switch baseType {
case "bool":
return "Bool"
case "uint64":
Expand Down
23 changes: 20 additions & 3 deletions internal/codegen/golang/ydb_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ import (

func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
columnType := strings.ToLower(sdk.DataType(col.Type))
notNull := col.NotNull || col.IsArray
notNull := (col.NotNull || col.IsArray) && !isNullableType(columnType)
emitPointersForNull := options.EmitPointersForNullTypes

columnType = extractBaseType(columnType)

// https://ydb.tech/docs/ru/yql/reference/types/
// ydb-go-sdk doesn't support sql.Null* yet
switch columnType {
Expand Down Expand Up @@ -49,7 +51,7 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col
}
// return "sql.NullInt16"
return "*int16"
case "int", "int32": //ydb doesn't have int type, but we need it to support untyped constants
case "int", "int32": //ydb doesn't have int type, but we need it to support untyped constants
if notNull {
return "int32"
}
Expand Down Expand Up @@ -159,7 +161,7 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col
return "*string"
}
return "*string"

case "date", "date32", "datetime", "timestamp", "tzdate", "tztimestamp", "tzdatetime":
if notNull {
return "time.Time"
Expand All @@ -185,3 +187,18 @@ func YDBType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Col
}

}

// This function extracts the base type from optional types
func extractBaseType(typeStr string) string {
if strings.HasPrefix(typeStr, "optional<") && strings.HasSuffix(typeStr, ">") {
return strings.TrimSuffix(strings.TrimPrefix(typeStr, "optional<"), ">")
}
if strings.HasSuffix(typeStr, "?") {
return strings.TrimSuffix(typeStr, "?")
}
return typeStr
}

func isNullableType(typeStr string) bool {
return strings.HasPrefix(typeStr, "optional<") && strings.HasSuffix(typeStr, ">") || strings.HasSuffix(typeStr, "?")
}
165 changes: 132 additions & 33 deletions internal/engine/ydb/convert.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ydb

import (
"fmt"
"log"
"strconv"
"strings"
Expand Down Expand Up @@ -1787,7 +1788,15 @@ func (c *cc) VisitType_name_or_bind(n *parser.Type_name_or_bindContext) interfac
}
return typeName
} else if b := n.Bind_parameter(); b != nil {
return &ast.TypeName{Name: "BIND:" + identifier(parseAnIdOrType(b.An_id_or_type()))}
param, ok := b.Accept(c).(ast.Node)
if !ok {
return todo("VisitType_name_or_bind", b)
}
return &ast.TypeName{
Names: &ast.List{
Items: []ast.Node{param},
},
}
}
return todo("VisitType_name_or_bind", n)
}
Expand All @@ -1797,6 +1806,8 @@ func (c *cc) VisitType_name(n *parser.Type_nameContext) interface{} {
return todo("VisitType_name", n)
}

questionCount := len(n.AllQUESTION())

if composite := n.Type_name_composite(); composite != nil {
typeName, ok := composite.Accept(c).(ast.Node)
if !ok {
Expand All @@ -1815,8 +1826,12 @@ func (c *cc) VisitType_name(n *parser.Type_nameContext) interface{} {
if !ok {
return todo("VisitType_name", decimal.Integer_or_bind(1))
}
name := "decimal"
if questionCount > 0 {
name = name + "?"
}
return &ast.TypeName{
Name: "decimal",
Name: name,
TypeOid: 0,
Names: &ast.List{
Items: []ast.Node{
Expand All @@ -1829,12 +1844,17 @@ func (c *cc) VisitType_name(n *parser.Type_nameContext) interface{} {
}

if simple := n.Type_name_simple(); simple != nil {
name := simple.GetText()
if questionCount > 0 {
name = name + "?"
}
return &ast.TypeName{
Name: simple.GetText(),
Name: name,
TypeOid: 0,
}
}

// todo: handle multiple ? suffixes
return todo("VisitType_name", n)
}

Expand Down Expand Up @@ -1868,19 +1888,7 @@ func (c *cc) VisitType_name_composite(n *parser.Type_name_compositeContext) inte
}

if opt := n.Type_name_optional(); opt != nil {
if typeName := opt.Type_name_or_bind(); typeName != nil {
tn, ok := typeName.Accept(c).(ast.Node)
if !ok {
return todo("VisitType_name_composite", typeName)
}
return &ast.TypeName{
Name: "Optional",
TypeOid: 0,
Names: &ast.List{
Items: []ast.Node{tn},
},
}
}
return opt.Accept(c)
}

if tuple := n.Type_name_tuple(); tuple != nil {
Expand Down Expand Up @@ -2025,6 +2033,27 @@ func (c *cc) VisitType_name_composite(n *parser.Type_name_compositeContext) inte
return todo("VisitType_name_composite", n)
}

func (c *cc) VisitType_name_optional(n *parser.Type_name_optionalContext) interface{} {
if n == nil || n.Type_name_or_bind() == nil {
return todo("VisitType_name_optional", n)
}

tn, ok := n.Type_name_or_bind().Accept(c).(ast.Node)
if !ok {
return todo("VisitType_name_optional", n.Type_name_or_bind())
}
innerTypeName, ok := tn.(*ast.TypeName)
if !ok {
return todo("VisitType_name_optional", n.Type_name_or_bind())
}
name := fmt.Sprintf("Optional<%s>", innerTypeName.Name)
return &ast.TypeName{
Name: name,
TypeOid: 0,
Names: &ast.List{},
}
}

func (c *cc) VisitSql_stmt_core(n *parser.Sql_stmt_coreContext) interface{} {
if n == nil {
return todo("VisitSql_stmt_core", n)
Expand Down Expand Up @@ -2799,13 +2828,28 @@ func (c *cc) handleInvokeSuffix(base ast.Node, invokeCtx *parser.Invoke_exprCont
}
funcName := strings.Join(nameParts, ".")

if funcName == "coalesce" {
if funcName == "coalesce" || funcName == "nvl" {
return &ast.CoalesceExpr{
Args: funcCall.Args,
Location: baseNode.Location,
}
}

if funcName == "greatest" || funcName == "max_of" {
return &ast.MinMaxExpr{
Op: ast.MinMaxOp(1),
Args: funcCall.Args,
Location: baseNode.Location,
}
}
if funcName == "least" || funcName == "min_of" {
return &ast.MinMaxExpr{
Op: ast.MinMaxOp(2),
Args: funcCall.Args,
Location: baseNode.Location,
}
}

funcCall.Func = &ast.FuncName{Name: funcName}
funcCall.Funcname.Items = append(funcCall.Funcname.Items, &ast.String{Str: funcName})

Expand All @@ -2816,15 +2860,12 @@ func (c *cc) handleInvokeSuffix(base ast.Node, invokeCtx *parser.Invoke_exprCont
}
}

stmt := &ast.RecursiveFuncCall{
Func: base,
Funcname: funcCall.Funcname,
AggStar: funcCall.AggStar,
Location: funcCall.Location,
Args: funcCall.Args,
AggDistinct: funcCall.AggDistinct,
stmt := &ast.FuncExpr{
Xpr: base,
Args: funcCall.Args,
Location: funcCall.Location,
}
stmt.Funcname.Items = append(stmt.Funcname.Items, base)

return stmt
}

Expand Down Expand Up @@ -2943,16 +2984,42 @@ func (c *cc) VisitId_expr(n *parser.Id_exprContext) interface{} {
if n == nil {
return todo("VisitId_expr", n)
}

ref := &ast.ColumnRef{
Fields: &ast.List{},
Location: c.pos(n.GetStart()),
}

if id := n.Identifier(); id != nil {
return &ast.ColumnRef{
Fields: &ast.List{
Items: []ast.Node{
NewIdentifier(id.GetText()),
},
},
Location: c.pos(id.GetStart()),
}
ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(id.GetText()))
return ref
}

if keyword := n.Keyword_compat(); keyword != nil {
ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText()))
return ref
}

if keyword := n.Keyword_alter_uncompat(); keyword != nil {
ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText()))
return ref
}

if keyword := n.Keyword_in_uncompat(); keyword != nil {
ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText()))
return ref
}

if keyword := n.Keyword_window_uncompat(); keyword != nil {
ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText()))
return ref
}

if keyword := n.Keyword_hint_uncompat(); keyword != nil {
ref.Fields.Items = append(ref.Fields.Items, NewIdentifier(keyword.GetText()))
return ref
}

return todo("VisitId_expr", n)
}

Expand All @@ -2979,12 +3046,44 @@ func (c *cc) VisitAtom_expr(n *parser.Atom_exprContext) interface{} {
return todo("VisitAtom_expr", n.Bind_parameter())
}
return expr
case n.Cast_expr() != nil:
expr, ok := n.Cast_expr().Accept(c).(ast.Node)
if !ok {
return todo("VisitAtom_expr", n.Cast_expr())
}
return expr
// TODO: check other cases
default:
return todo("VisitAtom_expr", n)
}
}

func (c *cc) VisitCast_expr(n *parser.Cast_exprContext) interface{} {
if n == nil || n.CAST() == nil || n.Expr() == nil || n.AS() == nil || n.Type_name_or_bind() == nil {
return todo("VisitCast_expr", n)
}

expr, ok := n.Expr().Accept(c).(ast.Node)
if !ok {
return todo("VisitCast_expr", n.Expr())
}

temp, ok := n.Type_name_or_bind().Accept(c).(ast.Node)
if !ok {
return todo("VisitCast_expr", n.Type_name_or_bind())
}
typeName, ok := temp.(*ast.TypeName)
if !ok {
return todo("VisitCast_expr", n.Type_name_or_bind())
}

return &ast.TypeCast{
Arg: expr,
TypeName: typeName,
Location: c.pos(n.GetStart()),
}
}

func (c *cc) VisitLiteral_value(n *parser.Literal_valueContext) interface{} {
if n == nil {
return todo("VisitLiteral_value", n)
Expand Down
Loading
Loading