Skip to content

Commit 3252fbd

Browse files
committed
fix closing the connection when OpClose receive, add test to this case
1 parent d7633a5 commit 3252fbd

File tree

5 files changed

+89
-3
lines changed

5 files changed

+89
-3
lines changed

channel.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ func newChannel(id string) *Channel {
2525
select {
2626
case conn := <-c.delConn:
2727
c.mu.Lock()
28+
_ = conn.Close()
2829
delete(c.connections, conn)
2930
c.mu.Unlock()
3031
}

channel_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ func TestChannel_Emit(t *testing.T) {
8181

8282
func TestChannel_Remove(t *testing.T) {
8383
ts, wsServer := wsServer()
84-
defer ts.Close()
8584
defer func() {
8685
require.NoError(t, wsServer.Shutdown())
86+
ts.Close()
8787
}()
8888

8989
ch := wsServer.NewChannel("test-channel-add")

conn.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ func (c *Conn) Send(data interface{}) error {
9292

9393
// Close closing websocket connection.
9494
func (c *Conn) Close() error {
95+
if c.conn == nil {
96+
return nil
97+
}
98+
9599
c.done <- true
96100

97101
err := c.conn.Close()

websocket.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,9 @@ func (s *Server) Handler(w http.ResponseWriter, r *http.Request) {
207207
for {
208208
header, _ := ws.ReadHeader(conn)
209209
if err = ws.CheckHeader(header, state); err != nil {
210+
log.Printf("drop connection: CheckHeader: %v", err)
210211
s.dropConn(connection)
211-
return
212+
break
212213
}
213214

214215
cipherReader.Reset(io.LimitReader(conn, header.Length), header.Mask)
@@ -261,8 +262,9 @@ func (s *Server) Handler(w http.ResponseWriter, r *http.Request) {
261262
}
262263

263264
if err != nil || header.OpCode == ws.OpClose {
265+
log.Printf("drop connection: %v or OpClose", err)
264266
s.dropConn(connection)
265-
return
267+
break
266268
}
267269

268270
header.Masked = false

websocket_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/gobwas/ws"
1010
"github.com/gobwas/ws/wsutil"
1111
"github.com/stretchr/testify/require"
12+
"log"
1213
"math/rand"
1314
"net"
1415
"net/http"
@@ -529,6 +530,84 @@ func TestServerProcessMessage(t *testing.T) {
529530
}
530531
}
531532

533+
func TestServer_ConnectionClose(t *testing.T) {
534+
ts, wsServer := wsServer()
535+
defer func() {
536+
require.NoError(t, wsServer.Shutdown())
537+
ts.Close()
538+
}()
539+
540+
ch := wsServer.NewChannel("test-channel-add")
541+
msg := Message{
542+
Name: "test",
543+
Data: []byte("Hello World"),
544+
}
545+
messageBytes, err := json.Marshal(msg)
546+
require.NoError(t, err)
547+
548+
ticker := time.NewTicker(time.Millisecond * 1)
549+
done := make(chan bool, 1)
550+
551+
wsServer.OnConnect(func(c *Conn) {
552+
ch.Add(c)
553+
require.Equal(t, 1, ch.Count(), "channel must contain only 1 connection")
554+
log.Print("Connected")
555+
})
556+
wsServer.OnDisconnect(func(c *Conn) {
557+
log.Print("Disconnected")
558+
})
559+
wsServer.On("test", func(c *Conn, msg *Message) {
560+
log.Printf("message: %s", msg.Name)
561+
})
562+
563+
u := url.URL{Scheme: "ws", Host: strings.Replace(ts.URL, "http://", "", 1), Path: "/ws"}
564+
c, _, _, err := ws.Dial(context.Background(), u.String())
565+
require.NoError(t, err)
566+
defer func() {
567+
require.NoError(t, c.Close())
568+
}()
569+
570+
go func() {
571+
for {
572+
select {
573+
case <-ticker.C:
574+
err = ws.WriteHeader(c, ws.Header{
575+
Fin: true,
576+
OpCode: ws.OpText,
577+
Masked: true,
578+
Length: int64(len(messageBytes)),
579+
})
580+
require.NoError(t, err)
581+
582+
n, err := c.Write(messageBytes)
583+
require.NoError(t, err)
584+
require.Equal(t, len(messageBytes), n)
585+
case <-done:
586+
ticker.Stop()
587+
return
588+
}
589+
}
590+
}()
591+
592+
time.Sleep(time.Millisecond * 50)
593+
594+
done <- true
595+
596+
time.Sleep(time.Millisecond * 5)
597+
598+
err = ws.WriteHeader(c, ws.Header{
599+
Fin: true,
600+
OpCode: ws.OpClose,
601+
Masked: true,
602+
Length: 0,
603+
})
604+
require.NoError(t, err)
605+
606+
time.Sleep(time.Millisecond * 5)
607+
608+
require.Equal(t, 0, ch.Count())
609+
}
610+
532611
func wsServer() (*httptest.Server, *Server) {
533612
wsServer := Start(context.Background())
534613
r := chi.NewRouter()

0 commit comments

Comments
 (0)