Skip to content

Commit 324cbb3

Browse files
authored
[1.9] fix PING on compressed connections (#1723)
Add missing mc.syncSequence() Fix #1718
1 parent dfd973a commit 324cbb3

File tree

2 files changed

+19
-22
lines changed

2 files changed

+19
-22
lines changed

compress.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,11 @@ func (c *compIO) readCompressedPacket() error {
113113
// Server may return error packet (e.g. 1153 Got a packet bigger than 'max_allowed_packet' bytes)
114114
// before receiving all packets from client. In this case, seqnr is younger than expected.
115115
// NOTE: Both of mariadbclient and mysqlclient do not check seqnr. Only server checks it.
116-
if debug && compressionSequence != c.mc.sequence {
116+
if debug && compressionSequence != c.mc.compressSequence {
117117
fmt.Printf("WARN: unexpected cmpress seq nr: expected %v, got %v",
118-
c.mc.sequence, compressionSequence)
118+
c.mc.compressSequence, compressionSequence)
119119
}
120-
c.mc.sequence = compressionSequence + 1
121-
c.mc.compressSequence = c.mc.sequence
120+
c.mc.compressSequence = compressionSequence + 1
122121

123122
comprData, err := c.mc.readNext(comprLength)
124123
if err != nil {
@@ -200,7 +199,7 @@ func (c *compIO) writeCompressedPacket(data []byte, uncompressedLen int) (int, e
200199
comprLength := len(data) - 7
201200
if debug {
202201
fmt.Printf(
203-
"writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v",
202+
"writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v\n",
204203
comprLength, uncompressedLen, mc.compressSequence)
205204
}
206205

packets.go

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"fmt"
1818
"io"
1919
"math"
20+
"os"
2021
"strconv"
2122
"time"
2223
)
@@ -62,26 +63,20 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
6263
pktLen := getUint24(data[:3])
6364
seq := data[3]
6465

65-
if mc.compress {
66+
// check packet sync [8 bit]
67+
if seq != mc.sequence {
68+
mc.log(fmt.Sprintf("[warn] unexpected sequence nr: expected %v, got %v", mc.sequence, seq))
6669
// MySQL and MariaDB doesn't check packet nr in compressed packet.
67-
if debug && seq != mc.compressSequence {
68-
fmt.Printf("[debug] mismatched compression sequence nr: expected: %v, got %v",
69-
mc.compressSequence, seq)
70-
}
71-
mc.compressSequence = seq + 1
72-
} else {
73-
// check packet sync [8 bit]
74-
if seq != mc.sequence {
75-
mc.log(fmt.Sprintf("[warn] unexpected seq nr: expected %v, got %v", mc.sequence, seq))
70+
if !mc.compress {
7671
// For large packets, we stop reading as soon as sync error.
7772
if len(prevData) > 0 {
7873
mc.close()
7974
return nil, ErrPktSyncMul
8075
}
8176
invalidSequence = true
8277
}
83-
mc.sequence++
8478
}
79+
mc.sequence = seq + 1
8580

8681
// packets with length 0 terminate a previous packet which is a
8782
// multiple of (2^24)-1 bytes long
@@ -146,7 +141,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
146141

147142
// Write packet
148143
if debug {
149-
fmt.Printf("writePacket: size=%v seq=%v", size, mc.sequence)
144+
fmt.Fprintf(os.Stderr, "writePacket: size=%v seq=%v\n", size, mc.sequence)
150145
}
151146

152147
n, err := writeFunc(data[:4+size])
@@ -445,7 +440,9 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
445440
data[4] = command
446441

447442
// Send CMD packet
448-
return mc.writePacket(data)
443+
err = mc.writePacket(data)
444+
mc.syncSequence()
445+
return err
449446
}
450447

451448
func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
@@ -486,7 +483,9 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
486483
binary.LittleEndian.PutUint32(data[5:], arg)
487484

488485
// Send CMD packet
489-
return mc.writePacket(data)
486+
err = mc.writePacket(data)
487+
mc.syncSequence()
488+
return err
490489
}
491490

492491
/******************************************************************************
@@ -956,7 +955,6 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
956955
pktLen = dataOffset + argLen
957956
}
958957

959-
stmt.mc.resetSequence()
960958
// Add command byte [1 byte]
961959
data[4] = comStmtSendLongData
962960

@@ -968,15 +966,15 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
968966

969967
// Send CMD packet
970968
err := stmt.mc.writePacket(data[:4+pktLen])
969+
// Every COM_LONG_DATA packet reset Packet Sequence
970+
stmt.mc.resetSequence()
971971
if err == nil {
972972
data = data[pktLen-dataOffset:]
973973
continue
974974
}
975975
return err
976976
}
977977

978-
// Reset Packet Sequence
979-
stmt.mc.resetSequence()
980978
return nil
981979
}
982980

0 commit comments

Comments
 (0)