Skip to content

Commit 6675966

Browse files
serprexlance6716
andauthored
CLIENT_DEPRECATE_EOF (#1059)
* wip * make it work * oops * now that tests passed, turn on by default, test with it off * fix EOF handling for COM_LIST_FIELDS * comment warning * Update client/resp.go Co-authored-by: lance6716 <[email protected]> * combine isEOFPacket, comment on why 0xffffff --------- Co-authored-by: lance6716 <[email protected]>
1 parent 90dbffc commit 6675966

File tree

5 files changed

+87
-51
lines changed

5 files changed

+87
-51
lines changed

client/auth.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func (c *Conn) readInitialHandshake() error {
9292
pos += 2
9393

9494
// The upper 2 bytes of the Capabilities Flags
95-
c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
95+
c.capability |= uint32(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
9696
pos += 2
9797

9898
// length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0
@@ -209,10 +209,8 @@ func (c *Conn) writeAuthHandshake() error {
209209

210210
// Set default client capabilities that reflect the abilities of this library
211211
capability := mysql.CLIENT_PROTOCOL_41 | mysql.CLIENT_SECURE_CONNECTION |
212-
mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_PLUGIN_AUTH
213-
// Adjust client capability flags based on server support
214-
capability |= c.capability & mysql.CLIENT_LONG_FLAG
215-
capability |= c.capability & mysql.CLIENT_QUERY_ATTRIBUTES
212+
mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_PLUGIN_AUTH |
213+
mysql.CLIENT_LONG_FLAG | mysql.CLIENT_QUERY_ATTRIBUTES | mysql.CLIENT_DEPRECATE_EOF
216214
// Adjust client capability flags on specific client requests
217215
// Only flags that would make any sense setting and aren't handled elsewhere
218216
// in the library are supported here
@@ -275,6 +273,7 @@ func (c *Conn) writeAuthHandshake() error {
275273
data := make([]byte, length+4)
276274

277275
// capability [32 bit]
276+
c.capability &= capability
278277
data[4] = byte(capability)
279278
data[5] = byte(capability >> 8)
280279
data[6] = byte(capability >> 16)

client/client_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,18 @@ func (s *clientTestSuite) TestConn_Compress() {
101101
require.NoError(s.T(), err)
102102
}
103103

104+
func (s *clientTestSuite) TestConn_NoDeprecateEOF() {
105+
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
106+
conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error {
107+
conn.UnsetCapability(mysql.CLIENT_DEPRECATE_EOF)
108+
return nil
109+
})
110+
require.NoError(s.T(), err)
111+
112+
_, err = conn.Execute("SELECT VERSION()")
113+
require.NoError(s.T(), err)
114+
}
115+
104116
func (s *clientTestSuite) TestConn_SetCapability() {
105117
caps := []uint32{
106118
mysql.CLIENT_LONG_PASSWORD,
@@ -125,6 +137,7 @@ func (s *clientTestSuite) TestConn_SetCapability() {
125137
mysql.CLIENT_PLUGIN_AUTH,
126138
mysql.CLIENT_CONNECT_ATTRS,
127139
mysql.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA,
140+
mysql.CLIENT_DEPRECATE_EOF,
128141
}
129142

130143
for _, capI := range caps {

client/conn.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ func (c *Conn) UnsetCapability(cap uint32) {
252252

253253
// HasCapability returns true if the connection has the specific capability
254254
func (c *Conn) HasCapability(cap uint32) bool {
255-
return c.ccaps&cap > 0
255+
return c.ccaps&cap != 0
256256
}
257257

258258
// UseSSL: use default SSL

client/resp.go

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,11 @@ import (
1414
"github.com/go-mysql-org/go-mysql/utils"
1515
)
1616

17-
func (c *Conn) readUntilEOF() (err error) {
18-
var data []byte
19-
20-
for {
21-
data, err = c.ReadPacket()
22-
if err != nil {
23-
return err
24-
}
25-
26-
// EOF Packet
27-
if c.isEOFPacket(data) {
28-
return err
29-
}
30-
}
31-
}
32-
3317
func (c *Conn) isEOFPacket(data []byte) bool {
34-
return data[0] == mysql.EOF_HEADER && len(data) <= 5
18+
// 0xffffff due to https://dev.mysql.com/worklog/task/?id=7766
19+
// "Server will never send OK packet longer than 16777216 bytes thus limiting
20+
// size of OK packet to be 16777215 bytes"
21+
return data[0] == mysql.EOF_HEADER && len(data) <= 0xffffff
3522
}
3623

3724
func (c *Conn) handleOKPacket(data []byte) (*mysql.Result, error) {
@@ -336,33 +323,16 @@ func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *mysql.Re
336323
}
337324

338325
func (c *Conn) readResultColumns(result *mysql.Result) (err error) {
339-
i := 0
340326
var data []byte
341327

342-
for {
328+
for i := range result.Fields {
343329
rawPkgLen := len(result.RawPkg)
344330
result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg)
345331
if err != nil {
346332
return err
347333
}
348334
data = result.RawPkg[rawPkgLen:]
349335

350-
// EOF Packet
351-
if c.isEOFPacket(data) {
352-
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
353-
result.Warnings = binary.LittleEndian.Uint16(data[1:])
354-
// todo add strict_mode, warning will be treat as error
355-
result.Status = binary.LittleEndian.Uint16(data[3:])
356-
c.status = result.Status
357-
}
358-
359-
if i != len(result.Fields) {
360-
err = mysql.ErrMalformPacket
361-
}
362-
363-
return err
364-
}
365-
366336
if result.Fields[i] == nil {
367337
result.Fields[i] = &mysql.Field{}
368338
}
@@ -372,8 +342,30 @@ func (c *Conn) readResultColumns(result *mysql.Result) (err error) {
372342
}
373343

374344
result.FieldNames[utils.ByteSliceToString(result.Fields[i].Name)] = i
345+
}
346+
347+
if c.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
348+
// EOF Packet
349+
rawPkgLen := len(result.RawPkg)
350+
result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg)
351+
if err != nil {
352+
return err
353+
}
354+
data = result.RawPkg[rawPkgLen:]
375355

376-
i++
356+
if c.isEOFPacket(data) {
357+
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
358+
result.Warnings = binary.LittleEndian.Uint16(data[1:])
359+
// todo add strict_mode, warning will be treat as error
360+
result.Status = binary.LittleEndian.Uint16(data[3:])
361+
c.status = result.Status
362+
}
363+
return nil
364+
} else {
365+
return mysql.ErrMalformPacket
366+
}
367+
} else {
368+
return nil
377369
}
378370
}
379371

@@ -388,15 +380,21 @@ func (c *Conn) readResultRows(result *mysql.Result, isBinary bool) (err error) {
388380
}
389381
data = result.RawPkg[rawPkgLen:]
390382

391-
// EOF Packet
392383
if c.isEOFPacket(data) {
393-
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
384+
if c.capability&mysql.CLIENT_DEPRECATE_EOF != 0 {
385+
// Treat like OK
386+
affectedRows, _, n := mysql.LengthEncodedInt(data[1:])
387+
insertId, _, m := mysql.LengthEncodedInt(data[1+n:])
388+
result.Status = binary.LittleEndian.Uint16(data[1+n+m:])
389+
result.AffectedRows = affectedRows
390+
result.InsertId = insertId
391+
c.status = result.Status
392+
} else if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
394393
result.Warnings = binary.LittleEndian.Uint16(data[1:])
395394
// todo add strict_mode, warning will be treat as error
396395
result.Status = binary.LittleEndian.Uint16(data[3:])
397396
c.status = result.Status
398397
}
399-
400398
break
401399
}
402400

@@ -435,9 +433,16 @@ func (c *Conn) readResultRowsStreaming(result *mysql.Result, isBinary bool, perR
435433
return err
436434
}
437435

438-
// EOF Packet
439436
if c.isEOFPacket(data) {
440-
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
437+
if c.capability&mysql.CLIENT_DEPRECATE_EOF != 0 {
438+
// Treat like OK
439+
affectedRows, _, n := mysql.LengthEncodedInt(data[1:])
440+
insertId, _, m := mysql.LengthEncodedInt(data[1+n:])
441+
result.Status = binary.LittleEndian.Uint16(data[1+n+m:])
442+
result.AffectedRows = affectedRows
443+
result.InsertId = insertId
444+
c.status = result.Status
445+
} else if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
441446
result.Warnings = binary.LittleEndian.Uint16(data[1:])
442447
// todo add strict_mode, warning will be treat as error
443448
result.Status = binary.LittleEndian.Uint16(data[3:])

client/stmt.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,33 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
275275
}
276276

277277
if s.params > 0 {
278-
if err := s.conn.readUntilEOF(); err != nil {
279-
return nil, errors.Trace(err)
278+
for range s.params {
279+
if _, err := s.conn.ReadPacket(); err != nil {
280+
return nil, errors.Trace(err)
281+
}
282+
}
283+
if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
284+
if packet, err := s.conn.ReadPacket(); err != nil {
285+
return nil, errors.Trace(err)
286+
} else if !c.isEOFPacket(packet) {
287+
return nil, mysql.ErrMalformPacket
288+
}
280289
}
281290
}
282291

283292
if s.columns > 0 {
284-
if err := s.conn.readUntilEOF(); err != nil {
285-
return nil, errors.Trace(err)
293+
// TODO process when CLIENT_CACHE_METADATA enabled
294+
for range s.columns {
295+
if _, err := s.conn.ReadPacket(); err != nil {
296+
return nil, errors.Trace(err)
297+
}
298+
}
299+
if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
300+
if packet, err := s.conn.ReadPacket(); err != nil {
301+
return nil, errors.Trace(err)
302+
} else if !c.isEOFPacket(packet) {
303+
return nil, mysql.ErrMalformPacket
304+
}
286305
}
287306
}
288307

0 commit comments

Comments
 (0)