@@ -3,6 +3,7 @@ package mysql
33import (
44 "context"
55 "database/sql"
6+ "errors"
67 "fmt"
78 "reflect"
89
@@ -12,7 +13,7 @@ import (
1213
1314// Connection is a connection to a MySQL database.
1415type Connection struct {
15- conn * sql.DB
16+ conn * sql.Conn
1617 isClosed bool
1718 config * Config
1819 driver.DefaultFilterBuilder
@@ -43,7 +44,16 @@ func (m *Connection) Ping() error {
4344 if m .isClosed {
4445 return driver .ErrConnectionIsClosed
4546 }
46- return m .conn .Ping ()
47+
48+ err := m .conn .PingContext (context .Background ())
49+ if err != nil {
50+ if errors .Is (err , sql .ErrConnDone ) {
51+ m .isClosed = true
52+ return driver .ErrConnectionIsClosed
53+ }
54+ return err
55+ }
56+ return nil
4757}
4858
4959// GetDetails returns the details of the database.
@@ -58,15 +68,22 @@ func (m *Connection) GetDetails(ctx context.Context) (driver.DatabaseDetail, err
5868
5969 tables , err := m .conn .QueryContext (ctx , "SHOW TABLES" )
6070 if err != nil {
71+ if errors .Is (err , sql .ErrConnDone ) {
72+ m .isClosed = true
73+ return driver.DatabaseDetail {}, driver .ErrConnectionIsClosed
74+ }
6175 return driver.DatabaseDetail {}, err
6276 }
6377 defer tables .Close ()
6478
6579 for tables .Next () {
6680 var tableName string
6781 err = tables .Scan (& tableName )
68-
6982 if err != nil {
83+ if errors .Is (err , sql .ErrConnDone ) {
84+ m .isClosed = true
85+ return driver.DatabaseDetail {}, driver .ErrConnectionIsClosed
86+ }
7087 return driver.DatabaseDetail {}, err
7188 }
7289
@@ -78,17 +95,25 @@ func (m *Connection) GetDetails(ctx context.Context) (driver.DatabaseDetail, err
7895 }
7996
8097 for i , table := range databaseInfo .DataCollections {
81- columns , err := m .conn .QueryContext (ctx , "SHOW COLUMNS FROM ? " , table .Name )
98+ columns , err := m .conn .QueryContext (ctx , fmt . Sprintf ( "SHOW COLUMNS FROM %s " , table .Name ) )
8299 if err != nil {
100+ if errors .Is (err , sql .ErrConnDone ) {
101+ m .isClosed = true
102+ return driver.DatabaseDetail {}, driver .ErrConnectionIsClosed
103+ }
83104 return driver.DatabaseDetail {}, err
84105 }
85106
86107 for columns .Next () {
87108 var columnName , columnType string
88- var columnNullable bool
109+ var columnNullable string
89110 var null any
90111 err = columns .Scan (& columnName , & columnType , & columnNullable , & null , & null , & null )
91112 if err != nil {
113+ if errors .Is (err , sql .ErrConnDone ) {
114+ m .isClosed = true
115+ return driver.DatabaseDetail {}, driver .ErrConnectionIsClosed
116+ }
92117 return driver.DatabaseDetail {}, err
93118 }
94119
@@ -97,18 +122,26 @@ func (m *Connection) GetDetails(ctx context.Context) (driver.DatabaseDetail, err
97122 return driver.DatabaseDetail {}, err
98123 }
99124
100- databaseInfo .DataCollections [i ].DataMap .Set (columnName , t , columnNullable )
125+ databaseInfo .DataCollections [i ].DataMap .Set (columnName , t , columnNullable == "YES" )
101126 }
102127
103- rows , err := m .conn .QueryContext (ctx , "SELECT COUNT(*) FROM ? " + m . BuildFilterSQL ( table .Name ), table .Name )
128+ rows , err := m .conn .QueryContext (ctx , fmt . Sprintf ( "SELECT COUNT(*) FROM %s %s" , table .Name , m . BuildFilterSQL ( table .Name )) )
104129 if err != nil {
130+ if errors .Is (err , sql .ErrConnDone ) {
131+ m .isClosed = true
132+ return driver.DatabaseDetail {}, driver .ErrConnectionIsClosed
133+ }
105134 return driver.DatabaseDetail {}, err
106135 }
107136
108137 for rows .Next () {
109138 var count int
110139 err = rows .Scan (& count )
111140 if err != nil {
141+ if errors .Is (err , sql .ErrConnDone ) {
142+ m .isClosed = true
143+ return driver.DatabaseDetail {}, driver .ErrConnectionIsClosed
144+ }
112145 return driver.DatabaseDetail {}, err
113146 }
114147 databaseInfo .DataCollections [i ].DataSetCount = count
@@ -124,7 +157,7 @@ func (m *Connection) Read(ctx context.Context, dataCollection string, startOffse
124157 if m .isClosed {
125158 return nil , driver .ErrConnectionIsClosed
126159 }
127- fmt .Println ("Reading from" , startOffset , "to" , endOffset )
160+ // fmt.Println("Reading from", startOffset, "to", endOffset)
128161
129162 batch := data .NewDataBatch ()
130163
@@ -140,6 +173,10 @@ func (m *Connection) Read(ctx context.Context, dataCollection string, startOffse
140173 fmt .Sprint (endOffset - startOffset ),
141174 )
142175 if err != nil {
176+ if errors .Is (err , sql .ErrConnDone ) {
177+ m .isClosed = true
178+ return nil , driver .ErrConnectionIsClosed
179+ }
143180 return nil , err
144181 }
145182
@@ -156,6 +193,10 @@ func (m *Connection) Read(ctx context.Context, dataCollection string, startOffse
156193 }
157194 err = rows .Scan (row ... )
158195 if err != nil {
196+ if errors .Is (err , sql .ErrConnDone ) {
197+ m .isClosed = true
198+ return nil , driver .ErrConnectionIsClosed
199+ }
159200 return nil , err
160201 }
161202
0 commit comments