@@ -15,6 +15,7 @@ import (
1515 "database/sql"
1616 "database/sql/driver"
1717 "fmt"
18+ "math"
1819 "reflect"
1920 "testing"
2021 "time"
3536 _ driver.StmtQueryContext = & mysqlStmt {}
3637)
3738
39+ // Ensure that all the driver interfaces are implemented
40+ var (
41+ // _ driver.RowsColumnTypeLength = &binaryRows{}
42+ // _ driver.RowsColumnTypeLength = &textRows{}
43+ _ driver.RowsColumnTypeDatabaseTypeName = & binaryRows {}
44+ _ driver.RowsColumnTypeDatabaseTypeName = & textRows {}
45+ _ driver.RowsColumnTypeNullable = & binaryRows {}
46+ _ driver.RowsColumnTypeNullable = & textRows {}
47+ _ driver.RowsColumnTypePrecisionScale = & binaryRows {}
48+ _ driver.RowsColumnTypePrecisionScale = & textRows {}
49+ _ driver.RowsColumnTypeScanType = & binaryRows {}
50+ _ driver.RowsColumnTypeScanType = & textRows {}
51+ _ driver.RowsNextResultSet = & binaryRows {}
52+ _ driver.RowsNextResultSet = & textRows {}
53+ )
54+
3855func TestMultiResultSet (t * testing.T ) {
3956 type result struct {
4057 values [][]int
@@ -558,3 +575,206 @@ func TestContextBeginReadOnly(t *testing.T) {
558575 }
559576 })
560577}
578+
579+ func TestRowsColumnTypes (t * testing.T ) {
580+ niNULL := sql.NullInt64 {Int64 : 0 , Valid : false }
581+ ni0 := sql.NullInt64 {Int64 : 0 , Valid : true }
582+ ni1 := sql.NullInt64 {Int64 : 1 , Valid : true }
583+ ni42 := sql.NullInt64 {Int64 : 42 , Valid : true }
584+ nfNULL := sql.NullFloat64 {Float64 : 0.0 , Valid : false }
585+ nf0 := sql.NullFloat64 {Float64 : 0.0 , Valid : true }
586+ nf1337 := sql.NullFloat64 {Float64 : 13.37 , Valid : true }
587+ nt0 := NullTime {Time : time .Date (2006 , 01 , 02 , 15 , 04 , 05 , 0 , time .UTC ), Valid : true }
588+ nt1 := NullTime {Time : time .Date (2006 , 01 , 02 , 15 , 04 , 05 , 100000000 , time .UTC ), Valid : true }
589+ nt2 := NullTime {Time : time .Date (2006 , 01 , 02 , 15 , 04 , 05 , 110000000 , time .UTC ), Valid : true }
590+ nt6 := NullTime {Time : time .Date (2006 , 01 , 02 , 15 , 04 , 05 , 111111000 , time .UTC ), Valid : true }
591+ rbNULL := sql .RawBytes (nil )
592+ rb0 := sql .RawBytes ("0" )
593+ rb42 := sql .RawBytes ("42" )
594+ rbTest := sql .RawBytes ("Test" )
595+
596+ var columns = []struct {
597+ name string
598+ fieldType string // type used when creating table schema
599+ databaseTypeName string // actual type used by MySQL
600+ scanType reflect.Type
601+ nullable bool
602+ precision int64 // 0 if not ok
603+ scale int64
604+ valuesIn [3 ]string
605+ valuesOut [3 ]interface {}
606+ }{
607+ {"boolnull" , "BOOL" , "TINYINT" , scanTypeNullInt , true , 0 , 0 , [3 ]string {"NULL" , "true" , "0" }, [3 ]interface {}{niNULL , ni1 , ni0 }},
608+ {"bool" , "BOOL NOT NULL" , "TINYINT" , scanTypeInt8 , false , 0 , 0 , [3 ]string {"1" , "0" , "FALSE" }, [3 ]interface {}{int8 (1 ), int8 (0 ), int8 (0 )}},
609+ {"intnull" , "INTEGER" , "INT" , scanTypeNullInt , true , 0 , 0 , [3 ]string {"0" , "NULL" , "42" }, [3 ]interface {}{ni0 , niNULL , ni42 }},
610+ {"smallint" , "SMALLINT NOT NULL" , "SMALLINT" , scanTypeInt16 , false , 0 , 0 , [3 ]string {"0" , "-32768" , "32767" }, [3 ]interface {}{int16 (0 ), int16 (- 32768 ), int16 (32767 )}},
611+ {"smallintnull" , "SMALLINT" , "SMALLINT" , scanTypeNullInt , true , 0 , 0 , [3 ]string {"0" , "NULL" , "42" }, [3 ]interface {}{ni0 , niNULL , ni42 }},
612+ {"int3null" , "INT(3)" , "INT" , scanTypeNullInt , true , 0 , 0 , [3 ]string {"0" , "NULL" , "42" }, [3 ]interface {}{ni0 , niNULL , ni42 }},
613+ {"int7" , "INT(7) NOT NULL" , "INT" , scanTypeInt32 , false , 0 , 0 , [3 ]string {"0" , "-1337" , "42" }, [3 ]interface {}{int32 (0 ), int32 (- 1337 ), int32 (42 )}},
614+ {"bigint" , "BIGINT NOT NULL" , "BIGINT" , scanTypeInt64 , false , 0 , 0 , [3 ]string {"0" , "65535" , "-42" }, [3 ]interface {}{int64 (0 ), int64 (65535 ), int64 (- 42 )}},
615+ {"bigintnull" , "BIGINT" , "BIGINT" , scanTypeNullInt , true , 0 , 0 , [3 ]string {"NULL" , "1" , "42" }, [3 ]interface {}{niNULL , ni1 , ni42 }},
616+ {"tinyuint" , "TINYINT UNSIGNED NOT NULL" , "TINYINT" , scanTypeUint8 , false , 0 , 0 , [3 ]string {"0" , "255" , "42" }, [3 ]interface {}{uint8 (0 ), uint8 (255 ), uint8 (42 )}},
617+ {"smalluint" , "SMALLINT UNSIGNED NOT NULL" , "SMALLINT" , scanTypeUint16 , false , 0 , 0 , [3 ]string {"0" , "65535" , "42" }, [3 ]interface {}{uint16 (0 ), uint16 (65535 ), uint16 (42 )}},
618+ {"biguint" , "BIGINT UNSIGNED NOT NULL" , "BIGINT" , scanTypeUint64 , false , 0 , 0 , [3 ]string {"0" , "65535" , "42" }, [3 ]interface {}{uint64 (0 ), uint64 (65535 ), uint64 (42 )}},
619+ {"uint13" , "INT(13) UNSIGNED NOT NULL" , "INT" , scanTypeUint32 , false , 0 , 0 , [3 ]string {"0" , "1337" , "42" }, [3 ]interface {}{uint32 (0 ), uint32 (1337 ), uint32 (42 )}},
620+ {"float" , "FLOAT NOT NULL" , "FLOAT" , scanTypeFloat32 , false , math .MaxInt64 , math .MaxInt64 , [3 ]string {"0" , "42" , "13.37" }, [3 ]interface {}{float32 (0 ), float32 (42 ), float32 (13.37 )}},
621+ {"floatnull" , "FLOAT" , "FLOAT" , scanTypeNullFloat , true , math .MaxInt64 , math .MaxInt64 , [3 ]string {"0" , "NULL" , "13.37" }, [3 ]interface {}{nf0 , nfNULL , nf1337 }},
622+ {"float74null" , "FLOAT(7,4)" , "FLOAT" , scanTypeNullFloat , true , math .MaxInt64 , 4 , [3 ]string {"0" , "NULL" , "13.37" }, [3 ]interface {}{nf0 , nfNULL , nf1337 }},
623+ {"double" , "DOUBLE NOT NULL" , "DOUBLE" , scanTypeFloat64 , false , math .MaxInt64 , math .MaxInt64 , [3 ]string {"0" , "42" , "13.37" }, [3 ]interface {}{float64 (0 ), float64 (42 ), float64 (13.37 )}},
624+ {"doublenull" , "DOUBLE" , "DOUBLE" , scanTypeNullFloat , true , math .MaxInt64 , math .MaxInt64 , [3 ]string {"0" , "NULL" , "13.37" }, [3 ]interface {}{nf0 , nfNULL , nf1337 }},
625+ {"decimal1" , "DECIMAL(10,6) NOT NULL" , "DECIMAL" , scanTypeRawBytes , false , 10 , 6 , [3 ]string {"0" , "13.37" , "1234.123456" }, [3 ]interface {}{sql .RawBytes ("0.000000" ), sql .RawBytes ("13.370000" ), sql .RawBytes ("1234.123456" )}},
626+ {"decimal1null" , "DECIMAL(10,6)" , "DECIMAL" , scanTypeRawBytes , true , 10 , 6 , [3 ]string {"0" , "NULL" , "1234.123456" }, [3 ]interface {}{sql .RawBytes ("0.000000" ), rbNULL , sql .RawBytes ("1234.123456" )}},
627+ {"decimal2" , "DECIMAL(8,4) NOT NULL" , "DECIMAL" , scanTypeRawBytes , false , 8 , 4 , [3 ]string {"0" , "13.37" , "1234.123456" }, [3 ]interface {}{sql .RawBytes ("0.0000" ), sql .RawBytes ("13.3700" ), sql .RawBytes ("1234.1235" )}},
628+ {"decimal2null" , "DECIMAL(8,4)" , "DECIMAL" , scanTypeRawBytes , true , 8 , 4 , [3 ]string {"0" , "NULL" , "1234.123456" }, [3 ]interface {}{sql .RawBytes ("0.0000" ), rbNULL , sql .RawBytes ("1234.1235" )}},
629+ {"decimal3" , "DECIMAL(5,0) NOT NULL" , "DECIMAL" , scanTypeRawBytes , false , 5 , 0 , [3 ]string {"0" , "13.37" , "-12345.123456" }, [3 ]interface {}{rb0 , sql .RawBytes ("13" ), sql .RawBytes ("-12345" )}},
630+ {"decimal3null" , "DECIMAL(5,0)" , "DECIMAL" , scanTypeRawBytes , true , 5 , 0 , [3 ]string {"0" , "NULL" , "-12345.123456" }, [3 ]interface {}{rb0 , rbNULL , sql .RawBytes ("-12345" )}},
631+ {"char25null" , "CHAR(25)" , "CHAR" , scanTypeRawBytes , true , 0 , 0 , [3 ]string {"0" , "NULL" , "'Test'" }, [3 ]interface {}{rb0 , rbNULL , rbTest }},
632+ {"varchar42" , "VARCHAR(42) NOT NULL" , "VARCHAR" , scanTypeRawBytes , false , 0 , 0 , [3 ]string {"0" , "'Test'" , "42" }, [3 ]interface {}{rb0 , rbTest , rb42 }},
633+ {"textnull" , "TEXT" , "BLOB" , scanTypeRawBytes , true , 0 , 0 , [3 ]string {"0" , "NULL" , "'Test'" }, [3 ]interface {}{rb0 , rbNULL , rbTest }},
634+ {"longtext" , "LONGTEXT NOT NULL" , "BLOB" , scanTypeRawBytes , false , 0 , 0 , [3 ]string {"0" , "'Test'" , "42" }, [3 ]interface {}{rb0 , rbTest , rb42 }},
635+ {"datetime" , "DATETIME" , "DATETIME" , scanTypeNullTime , true , 0 , 0 , [3 ]string {"'2006-01-02 15:04:05'" , "'2006-01-02 15:04:05.1'" , "'2006-01-02 15:04:05.111111'" }, [3 ]interface {}{nt0 , nt0 , nt0 }},
636+ {"datetime2" , "DATETIME(2)" , "DATETIME" , scanTypeNullTime , true , 2 , 2 , [3 ]string {"'2006-01-02 15:04:05'" , "'2006-01-02 15:04:05.1'" , "'2006-01-02 15:04:05.111111'" }, [3 ]interface {}{nt0 , nt1 , nt2 }},
637+ {"datetime6" , "DATETIME(6)" , "DATETIME" , scanTypeNullTime , true , 6 , 6 , [3 ]string {"'2006-01-02 15:04:05'" , "'2006-01-02 15:04:05.1'" , "'2006-01-02 15:04:05.111111'" }, [3 ]interface {}{nt0 , nt1 , nt6 }},
638+ }
639+
640+ schema := ""
641+ values1 := ""
642+ values2 := ""
643+ values3 := ""
644+ for _ , column := range columns {
645+ schema += fmt .Sprintf ("`%s` %s, " , column .name , column .fieldType )
646+ values1 += column .valuesIn [0 ] + ", "
647+ values2 += column .valuesIn [1 ] + ", "
648+ values3 += column .valuesIn [2 ] + ", "
649+ }
650+ schema = schema [:len (schema )- 2 ]
651+ values1 = values1 [:len (values1 )- 2 ]
652+ values2 = values2 [:len (values2 )- 2 ]
653+ values3 = values3 [:len (values3 )- 2 ]
654+
655+ dsns := []string {
656+ dsn + "&parseTime=true" ,
657+ dsn + "&parseTime=false" ,
658+ }
659+ for _ , testdsn := range dsns {
660+ runTests (t , testdsn , func (dbt * DBTest ) {
661+ dbt .mustExec ("CREATE TABLE test (" + schema + ")" )
662+ dbt .mustExec ("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")" )
663+
664+ rows , err := dbt .db .Query ("SELECT * FROM test" )
665+ if err != nil {
666+ t .Fatalf ("Query: %v" , err )
667+ }
668+
669+ tt , err := rows .ColumnTypes ()
670+ if err != nil {
671+ t .Fatalf ("ColumnTypes: %v" , err )
672+ }
673+
674+ if len (tt ) != len (columns ) {
675+ t .Fatalf ("unexpected number of columns: expected %d, got %d" , len (columns ), len (tt ))
676+ }
677+
678+ types := make ([]reflect.Type , len (tt ))
679+ for i , tp := range tt {
680+ column := columns [i ]
681+
682+ // Name
683+ name := tp .Name ()
684+ if name != column .name {
685+ t .Errorf ("column name mismatch %s != %s" , name , column .name )
686+ continue
687+ }
688+
689+ // DatabaseTypeName
690+ databaseTypeName := tp .DatabaseTypeName ()
691+ if databaseTypeName != column .databaseTypeName {
692+ t .Errorf ("databasetypename name mismatch for column %q: %s != %s" , name , databaseTypeName , column .databaseTypeName )
693+ continue
694+ }
695+
696+ // ScanType
697+ scanType := tp .ScanType ()
698+ if scanType != column .scanType {
699+ if scanType == nil {
700+ t .Errorf ("scantype is null for column %q" , name )
701+ } else {
702+ t .Errorf ("scantype mismatch for column %q: %s != %s" , name , scanType .Name (), column .scanType .Name ())
703+ }
704+ continue
705+ }
706+ types [i ] = scanType
707+
708+ // Nullable
709+ nullable , ok := tp .Nullable ()
710+ if ! ok {
711+ t .Errorf ("nullable not ok %q" , name )
712+ continue
713+ }
714+ if nullable != column .nullable {
715+ t .Errorf ("nullable mismatch for column %q: %t != %t" , name , nullable , column .nullable )
716+ }
717+
718+ // Length
719+ // length, ok := tp.Length()
720+ // if length != column.length {
721+ // if !ok {
722+ // t.Errorf("length not ok for column %q", name)
723+ // } else {
724+ // t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length)
725+ // }
726+ // continue
727+ // }
728+
729+ // Precision and Scale
730+ precision , scale , ok := tp .DecimalSize ()
731+ if precision != column .precision {
732+ if ! ok {
733+ t .Errorf ("precision not ok for column %q" , name )
734+ } else {
735+ t .Errorf ("precision mismatch for column %q: %d != %d" , name , precision , column .precision )
736+ }
737+ continue
738+ }
739+ if scale != column .scale {
740+ if ! ok {
741+ t .Errorf ("scale not ok for column %q" , name )
742+ } else {
743+ t .Errorf ("scale mismatch for column %q: %d != %d" , name , scale , column .scale )
744+ }
745+ continue
746+ }
747+ }
748+
749+ values := make ([]interface {}, len (tt ))
750+ for i := range values {
751+ values [i ] = reflect .New (types [i ]).Interface ()
752+ }
753+ i := 0
754+ for rows .Next () {
755+ err = rows .Scan (values ... )
756+ if err != nil {
757+ t .Fatalf ("failed to scan values in %v" , err )
758+ }
759+ for j := range values {
760+ value := reflect .ValueOf (values [j ]).Elem ().Interface ()
761+ if ! reflect .DeepEqual (value , columns [j ].valuesOut [i ]) {
762+ if columns [j ].scanType == scanTypeRawBytes {
763+ t .Errorf ("row %d, column %d: %v != %v" , i , j , string (value .(sql.RawBytes )), string (columns [j ].valuesOut [i ].(sql.RawBytes )))
764+ } else {
765+ t .Errorf ("row %d, column %d: %v != %v" , i , j , value , columns [j ].valuesOut [i ])
766+ }
767+ }
768+ }
769+ i ++
770+ }
771+ if i != 3 {
772+ t .Errorf ("expected 3 rows, got %d" , i )
773+ }
774+
775+ if err := rows .Close (); err != nil {
776+ t .Errorf ("error closing rows: %s" , err )
777+ }
778+ })
779+ }
780+ }
0 commit comments