Skip to content

Commit 7329bb9

Browse files
fix(traffic-tracking): enable connection wrapping across all inbound handlers
1 parent 9a74dbd commit 7329bb9

8 files changed

Lines changed: 168 additions & 7 deletions

File tree

app/connectiontracker/tracker.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ import (
99
"sync/atomic"
1010
"time"
1111

12+
B "github.com/sagernet/sing/common/buf"
13+
M "github.com/sagernet/sing/common/metadata"
14+
N "github.com/sagernet/sing/common/network"
1215
"github.com/xtls/xray-core/transport/internet/stat"
1316
)
1417

@@ -314,3 +317,35 @@ func (c *TrackedConn) Write(b []byte) (int, error) {
314317
func WrapConn(conn stat.Connection, entry *ConnEntry) stat.Connection {
315318
return &TrackedConn{Connection: conn, entry: entry}
316319
}
320+
321+
// TrackedPacketConn wraps an N.PacketConn (UDP) and records per-connection
322+
// traffic counters into the associated ConnEntry.
323+
type TrackedPacketConn struct {
324+
N.PacketConn
325+
entry *ConnEntry
326+
}
327+
328+
func (c *TrackedPacketConn) ReadPacket(buffer *B.Buffer) (M.Socksaddr, error) {
329+
addr, err := c.PacketConn.ReadPacket(buffer)
330+
if err == nil && buffer.Len() > 0 {
331+
atomic.AddInt64(&c.entry.uplink, int64(buffer.Len()))
332+
atomic.StoreInt64(&c.entry.lastActivity, time.Now().UnixNano())
333+
}
334+
return addr, err
335+
}
336+
337+
func (c *TrackedPacketConn) WritePacket(buffer *B.Buffer, destination M.Socksaddr) error {
338+
err := c.PacketConn.WritePacket(buffer, destination)
339+
if err == nil && buffer.Len() > 0 {
340+
atomic.AddInt64(&c.entry.downlink, int64(buffer.Len()))
341+
atomic.StoreInt64(&c.entry.lastActivity, time.Now().UnixNano())
342+
}
343+
return err
344+
}
345+
346+
// WrapPacketConn wraps a UDP PacketConn so that every ReadPacket and WritePacket
347+
// updates the traffic counters in entry. Call after RegisterWithMeta to enable
348+
// byte-level tracking for UDP connections.
349+
func WrapPacketConn(conn N.PacketConn, entry *ConnEntry) N.PacketConn {
350+
return &TrackedPacketConn{PacketConn: conn, entry: entry}
351+
}

app/connectiontracker/tracker_test.go

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77
"testing"
88
"time"
99

10+
B "github.com/sagernet/sing/common/buf"
11+
M "github.com/sagernet/sing/common/metadata"
1012
"github.com/xtls/xray-core/app/connectiontracker"
1113
)
1214

@@ -370,3 +372,120 @@ func TestWrapConnUpdatesLastActivity(t *testing.T) {
370372
t.Errorf("LastActivity not updated: before=%v after=%v", before, after)
371373
}
372374
}
375+
376+
// fakePacketConn is a minimal N.PacketConn for WrapPacketConn tests.
377+
type fakePacketConn struct {
378+
readPacketData *B.Buffer
379+
readPacketErr error
380+
writePacketErr error
381+
}
382+
383+
func (f *fakePacketConn) ReadPacket(buffer *B.Buffer) (M.Socksaddr, error) {
384+
if f.readPacketErr != nil {
385+
return M.Socksaddr{}, f.readPacketErr
386+
}
387+
if f.readPacketData != nil {
388+
buffer.Write(f.readPacketData.Bytes())
389+
}
390+
return M.Socksaddr{}, nil
391+
}
392+
393+
func (f *fakePacketConn) WritePacket(buffer *B.Buffer, _ M.Socksaddr) error {
394+
return f.writePacketErr
395+
}
396+
397+
func (f *fakePacketConn) Close() error {
398+
return nil
399+
}
400+
401+
func (f *fakePacketConn) LocalAddr() net.Addr {
402+
return nil
403+
}
404+
405+
func (f *fakePacketConn) SetDeadline(_ time.Time) error {
406+
return nil
407+
}
408+
409+
func (f *fakePacketConn) SetReadDeadline(_ time.Time) error {
410+
return nil
411+
}
412+
413+
func (f *fakePacketConn) SetWriteDeadline(_ time.Time) error {
414+
return nil
415+
}
416+
417+
func TestWrapPacketConnCountsUplinkOnReadPacket(t *testing.T) {
418+
tracker := connectiontracker.New()
419+
_, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "")
420+
421+
data := B.New()
422+
data.Write([]byte("hello world"))
423+
424+
fpc := &fakePacketConn{readPacketData: data}
425+
wrapped := connectiontracker.WrapPacketConn(fpc, entry)
426+
427+
buf := B.New()
428+
defer buf.Release()
429+
if _, err := wrapped.ReadPacket(buf); err != nil {
430+
t.Fatal(err)
431+
}
432+
433+
conns := tracker.ListConnections()
434+
if len(conns) != 1 {
435+
t.Fatalf("expected 1 connection")
436+
}
437+
if conns[0].Uplink != 11 {
438+
t.Errorf("Uplink: got %d, want 11", conns[0].Uplink)
439+
}
440+
if conns[0].Downlink != 0 {
441+
t.Errorf("Downlink should be 0, got %d", conns[0].Downlink)
442+
}
443+
}
444+
445+
func TestWrapPacketConnCountsDownlinkOnWritePacket(t *testing.T) {
446+
tracker := connectiontracker.New()
447+
_, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "")
448+
449+
fpc := &fakePacketConn{}
450+
wrapped := connectiontracker.WrapPacketConn(fpc, entry)
451+
452+
buf := B.New()
453+
buf.Write([]byte("goodbye world"))
454+
if err := wrapped.WritePacket(buf, M.Socksaddr{}); err != nil {
455+
t.Fatal(err)
456+
}
457+
458+
conns := tracker.ListConnections()
459+
if len(conns) != 1 {
460+
t.Fatalf("expected 1 connection")
461+
}
462+
if conns[0].Downlink != 13 {
463+
t.Errorf("Downlink: got %d, want 13", conns[0].Downlink)
464+
}
465+
if conns[0].Uplink != 0 {
466+
t.Errorf("Uplink should be 0, got %d", conns[0].Uplink)
467+
}
468+
}
469+
470+
func TestWrapPacketConnUpdatesLastActivity(t *testing.T) {
471+
tracker := connectiontracker.New()
472+
_, entry := tracker.RegisterWithMeta("user@example.com", func() {}, "", "")
473+
474+
before := tracker.ListConnections()[0].LastActivity
475+
476+
time.Sleep(time.Millisecond)
477+
478+
data := B.New()
479+
data.Write([]byte("x"))
480+
481+
fpc := &fakePacketConn{readPacketData: data}
482+
wrapped := connectiontracker.WrapPacketConn(fpc, entry)
483+
buf := B.New()
484+
defer buf.Release()
485+
wrapped.ReadPacket(buf) //nolint:errcheck
486+
487+
after := tracker.ListConnections()[0].LastActivity
488+
if !after.After(before) {
489+
t.Errorf("LastActivity not updated: before=%v after=%v", before, after)
490+
}
491+
}

proxy/hysteria/server.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,9 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
104104
ctx, connCancel := context.WithCancel(ctx)
105105
defer connCancel()
106106
if email := strings.ToLower(useremail); email != "" {
107-
connID, _ := s.connTracker.RegisterWithMeta(email, connCancel, inbound.Tag, "hysteria")
107+
connID, connEntry := s.connTracker.RegisterWithMeta(email, connCancel, inbound.Tag, "hysteria")
108108
defer s.connTracker.Unregister(email, connID)
109+
conn = connectiontracker.WrapConn(conn, connEntry)
109110
}
110111

111112
if _, ok := iConn.(*hysteria.InterUdpConn); ok {

proxy/shadowsocks/server.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,9 @@ func (s *Server) handleConnection(ctx context.Context, conn stat.Connection, dis
241241
ctx, cancel := context.WithCancel(ctx)
242242
timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
243243
if email := strings.ToLower(request.User.Email); email != "" {
244-
connID, _ := s.connTracker.RegisterWithMeta(email, cancel, inbound.Tag, "shadowsocks")
244+
connID, connEntry := s.connTracker.RegisterWithMeta(email, cancel, inbound.Tag, "shadowsocks")
245245
defer s.connTracker.Unregister(email, connID)
246+
conn = connectiontracker.WrapConn(conn, connEntry)
246247
}
247248

248249
ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)

proxy/shadowsocks_2022/inbound_multi.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,9 @@ func (i *MultiUserInbound) NewConnection(ctx context.Context, conn net.Conn, met
245245
ctx, connCancel := context.WithCancel(ctx)
246246
defer connCancel()
247247
if email := strings.ToLower(user.Email); email != "" {
248-
connID, _ := i.connTracker.RegisterWithMeta(email, connCancel, inbound.Tag, "shadowsocks-2022")
248+
connID, connEntry := i.connTracker.RegisterWithMeta(email, connCancel, inbound.Tag, "shadowsocks-2022")
249249
defer i.connTracker.Unregister(email, connID)
250+
conn = connectiontracker.WrapConn(conn, connEntry)
250251
}
251252

252253
dispatcher := session.DispatcherFromContext(ctx)
@@ -278,8 +279,9 @@ func (i *MultiUserInbound) NewPacketConnection(ctx context.Context, conn N.Packe
278279
ctx, connCancel := context.WithCancel(ctx)
279280
defer connCancel()
280281
if email := strings.ToLower(user.Email); email != "" {
281-
connID, _ := i.connTracker.RegisterWithMeta(email, connCancel, inbound.Tag, "shadowsocks-2022")
282+
connID, connEntry := i.connTracker.RegisterWithMeta(email, connCancel, inbound.Tag, "shadowsocks-2022")
282283
defer i.connTracker.Unregister(email, connID)
284+
conn = connectiontracker.WrapPacketConn(conn, connEntry)
283285
}
284286

285287
dispatcher := session.DispatcherFromContext(ctx)

proxy/trojan/server.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,9 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
235235
ctx, connCancel := context.WithCancel(ctx)
236236
defer connCancel()
237237
if email := strings.ToLower(user.Email); email != "" {
238-
connID, _ := s.connTracker.RegisterWithMeta(email, connCancel, inbound.Tag, "trojan")
238+
connID, connEntry := s.connTracker.RegisterWithMeta(email, connCancel, inbound.Tag, "trojan")
239239
defer s.connTracker.Unregister(email, connID)
240+
conn = connectiontracker.WrapConn(conn, connEntry)
240241
}
241242

242243
if destination.Network == net.Network_UDP { // handle udp request

proxy/vless/inbound/inbound.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,8 +540,9 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
540540
if ib := session.InboundFromContext(ctx); ib != nil {
541541
inboundTag = ib.Tag
542542
}
543-
connID, _ := h.connTracker.RegisterWithMeta(email, connCancel, inboundTag, "vless")
543+
connID, connEntry := h.connTracker.RegisterWithMeta(email, connCancel, inboundTag, "vless")
544544
defer h.connTracker.Unregister(email, connID)
545+
connection = connectiontracker.WrapConn(connection, connEntry)
545546
}
546547

547548
inbound := session.InboundFromContext(ctx)

proxy/vmess/inbound/inbound.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,9 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s
281281
ctx, cancel := context.WithCancel(ctx)
282282
timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
283283
if email := strings.ToLower(request.User.Email); email != "" {
284-
connID, _ := h.connTracker.RegisterWithMeta(email, cancel, inbound.Tag, "vmess")
284+
connID, connEntry := h.connTracker.RegisterWithMeta(email, cancel, inbound.Tag, "vmess")
285285
defer h.connTracker.Unregister(email, connID)
286+
connection = connectiontracker.WrapConn(connection, connEntry)
286287
}
287288

288289
ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)

0 commit comments

Comments
 (0)