|
9 | 9 | "github.com/gobwas/ws" |
10 | 10 | "github.com/gobwas/ws/wsutil" |
11 | 11 | "github.com/stretchr/testify/require" |
| 12 | + "log" |
12 | 13 | "math/rand" |
13 | 14 | "net" |
14 | 15 | "net/http" |
@@ -529,6 +530,84 @@ func TestServerProcessMessage(t *testing.T) { |
529 | 530 | } |
530 | 531 | } |
531 | 532 |
|
| 533 | +func TestServer_ConnectionClose(t *testing.T) { |
| 534 | + ts, wsServer := wsServer() |
| 535 | + defer func() { |
| 536 | + require.NoError(t, wsServer.Shutdown()) |
| 537 | + ts.Close() |
| 538 | + }() |
| 539 | + |
| 540 | + ch := wsServer.NewChannel("test-channel-add") |
| 541 | + msg := Message{ |
| 542 | + Name: "test", |
| 543 | + Data: []byte("Hello World"), |
| 544 | + } |
| 545 | + messageBytes, err := json.Marshal(msg) |
| 546 | + require.NoError(t, err) |
| 547 | + |
| 548 | + ticker := time.NewTicker(time.Millisecond * 1) |
| 549 | + done := make(chan bool, 1) |
| 550 | + |
| 551 | + wsServer.OnConnect(func(c *Conn) { |
| 552 | + ch.Add(c) |
| 553 | + require.Equal(t, 1, ch.Count(), "channel must contain only 1 connection") |
| 554 | + log.Print("Connected") |
| 555 | + }) |
| 556 | + wsServer.OnDisconnect(func(c *Conn) { |
| 557 | + log.Print("Disconnected") |
| 558 | + }) |
| 559 | + wsServer.On("test", func(c *Conn, msg *Message) { |
| 560 | + log.Printf("message: %s", msg.Name) |
| 561 | + }) |
| 562 | + |
| 563 | + u := url.URL{Scheme: "ws", Host: strings.Replace(ts.URL, "http://", "", 1), Path: "/ws"} |
| 564 | + c, _, _, err := ws.Dial(context.Background(), u.String()) |
| 565 | + require.NoError(t, err) |
| 566 | + defer func() { |
| 567 | + require.NoError(t, c.Close()) |
| 568 | + }() |
| 569 | + |
| 570 | + go func() { |
| 571 | + for { |
| 572 | + select { |
| 573 | + case <-ticker.C: |
| 574 | + err = ws.WriteHeader(c, ws.Header{ |
| 575 | + Fin: true, |
| 576 | + OpCode: ws.OpText, |
| 577 | + Masked: true, |
| 578 | + Length: int64(len(messageBytes)), |
| 579 | + }) |
| 580 | + require.NoError(t, err) |
| 581 | + |
| 582 | + n, err := c.Write(messageBytes) |
| 583 | + require.NoError(t, err) |
| 584 | + require.Equal(t, len(messageBytes), n) |
| 585 | + case <-done: |
| 586 | + ticker.Stop() |
| 587 | + return |
| 588 | + } |
| 589 | + } |
| 590 | + }() |
| 591 | + |
| 592 | + time.Sleep(time.Millisecond * 50) |
| 593 | + |
| 594 | + done <- true |
| 595 | + |
| 596 | + time.Sleep(time.Millisecond * 5) |
| 597 | + |
| 598 | + err = ws.WriteHeader(c, ws.Header{ |
| 599 | + Fin: true, |
| 600 | + OpCode: ws.OpClose, |
| 601 | + Masked: true, |
| 602 | + Length: 0, |
| 603 | + }) |
| 604 | + require.NoError(t, err) |
| 605 | + |
| 606 | + time.Sleep(time.Millisecond * 5) |
| 607 | + |
| 608 | + require.Equal(t, 0, ch.Count()) |
| 609 | +} |
| 610 | + |
532 | 611 | func wsServer() (*httptest.Server, *Server) { |
533 | 612 | wsServer := Start(context.Background()) |
534 | 613 | r := chi.NewRouter() |
|
0 commit comments