From b3d65ba9434efd20a60ce7bef60cbef73d6a3b3b Mon Sep 17 00:00:00 2001 From: David Anderson Date: Sat, 11 Jul 2020 06:31:46 +0000 Subject: [PATCH] tstest/natlab: refactor, expose a Packet type. HandlePacket and Inject now receive/take Packets. This is a handy container for the packet, and the attached Trace method can be used to print traces from custom packet handlers that integrate nicely with natlab's internal traces. Signed-off-by: David Anderson --- tstest/natlab/firewall.go | 16 ++-- tstest/natlab/natlab.go | 172 ++++++++++++++++++++++------------- tstest/natlab/natlab_test.go | 21 +++-- 3 files changed, 130 insertions(+), 79 deletions(-) diff --git a/tstest/natlab/firewall.go b/tstest/natlab/firewall.go index 9a6140e8e..1030e0e27 100644 --- a/tstest/natlab/firewall.go +++ b/tstest/natlab/firewall.go @@ -43,7 +43,7 @@ func (f *Firewall) timeNow() time.Time { return time.Now() } -func (f *Firewall) HandlePacket(p []byte, inIf *Interface, dst, src netaddr.IPPort) PacketVerdict { +func (f *Firewall) HandlePacket(p *Packet, inIf *Interface) PacketVerdict { f.mu.Lock() defer f.mu.Unlock() if f.seen == nil { @@ -52,25 +52,25 @@ func (f *Firewall) HandlePacket(p []byte, inIf *Interface, dst, src netaddr.IPPo if inIf == f.TrustedInterface { sess := session{ - src: src, - dst: dst, + src: p.Src, + dst: p.Dst, } f.seen[sess] = f.timeNow().Add(f.SessionTimeout) - trace(p, "mach=%s iface=%s src=%s dst=%s firewall out ok", inIf.Machine().Name, inIf.name, src, dst) + p.Trace("firewall out ok") return Continue } else { // reverse src and dst because the session table is from the // POV of outbound packets. sess := session{ - src: dst, - dst: src, + src: p.Dst, + dst: p.Src, } now := f.timeNow() if now.After(f.seen[sess]) { - trace(p, "mach=%s iface=%s src=%s dst=%s firewall drop", inIf.Machine().Name, inIf.name, src, dst) + p.Trace("firewall drop") return Drop } - trace(p, "mach=%s iface=%s src=%s dst=%s firewall in ok", inIf.Machine().Name, inIf.name, src, dst) + p.Trace("firewall in ok") return Continue } } diff --git a/tstest/natlab/natlab.go b/tstest/natlab/natlab.go index 514829e8b..b4ea783ed 100644 --- a/tstest/natlab/natlab.go +++ b/tstest/natlab/natlab.go @@ -30,21 +30,49 @@ import ( var traceOn, _ = strconv.ParseBool(os.Getenv("NATLAB_TRACE")) -func trace(p []byte, msg string, args ...interface{}) { +// Packet represents a UDP packet flowing through the virtual network. +type Packet struct { + Src, Dst netaddr.IPPort + Payload []byte + + // Prefix set by various internal methods of natlab, to locate + // where in the network a trace occured. + locator string +} + +// Clone returns a copy of p that shares nothing with p. +func (p *Packet) Clone() *Packet { + return &Packet{ + Src: p.Src, + Dst: p.Dst, + Payload: append([]byte(nil), p.Payload...), + locator: p.locator, + } +} + +// short returns a short identifier for a packet payload, +// suitable for printing trace information. +func (p *Packet) short() string { + s := sha256.Sum256(p.Payload) + payload := base64.RawStdEncoding.EncodeToString(s[:])[:2] + + s = sha256.Sum256([]byte(p.Src.String() + "_" + p.Dst.String())) + tuple := base64.RawStdEncoding.EncodeToString(s[:])[:2] + + return fmt.Sprintf("%s/%s", payload, tuple) +} + +func (p *Packet) Trace(msg string, args ...interface{}) { if !traceOn { return } - id := packetShort(p) - as := []interface{}{id} - as = append(as, args...) - fmt.Fprintf(os.Stderr, "[%s] "+msg+"\n", as...) + allArgs := []interface{}{p.short(), p.locator, p.Src, p.Dst} + allArgs = append(allArgs, args...) + fmt.Fprintf(os.Stderr, "[%s]%s src=%s dst=%s "+msg+"\n", allArgs...) } -// packetShort returns a short identifier for a packet payload, -// suitable for pritning trace information. -func packetShort(p []byte) string { - s := sha256.Sum256(p) - return base64.RawStdEncoding.EncodeToString(s[:])[:4] +func (p *Packet) setLocator(msg string, args ...interface{}) { + p.locator = fmt.Sprintf(" "+msg, args...) } func mustPrefix(s string) netaddr.IPPrefix { @@ -79,6 +107,9 @@ type Network struct { func (n *Network) SetDefaultGateway(gwIf *Interface) { n.mu.Lock() defer n.mu.Unlock() + if gwIf.net != n { + panic(fmt.Sprintf("can't set if=%s as net=%s's default gw, if not connected to net", gwIf.name, gwIf.net.Name)) + } n.defaultGW = gwIf } @@ -139,24 +170,25 @@ func addOne(a *[16]byte, index int) { } } -func (n *Network) write(p []byte, dst, src netaddr.IPPort) (num int, err error) { +func (n *Network) write(p *Packet) (num int, err error) { + p.setLocator("net=%s", n.Name) + n.mu.Lock() defer n.mu.Unlock() - iface, ok := n.machine[dst.IP] + iface, ok := n.machine[p.Dst.IP] if !ok { if n.defaultGW == nil { - trace(p, "net=%s dropped, no route to %v", n.Name, dst.IP) - return len(p), nil + p.Trace("no route to %v", p.Dst.IP) + return len(p.Payload), nil } iface = n.defaultGW } // Pretend it went across the network. Make a copy so nobody // can later mess with caller's memory. - trace(p, "net=%s src=%v dst=%v -> mach=%s iface=%s", n.Name, src, dst, iface.machine.Name, iface.name) - pcopy := append([]byte(nil), p...) - go iface.machine.deliverIncomingPacket(pcopy, iface, dst, src) - return len(p), nil + p.Trace("-> mach=%s if=%s", iface.machine.Name, iface.name) + go iface.machine.deliverIncomingPacket(p, iface) + return len(p.Payload), nil } type Interface struct { @@ -235,7 +267,7 @@ func (v PacketVerdict) String() string { } // A PacketHandler is a function that can process packets. -type PacketHandler func(p []byte, inIf *Interface, dst, src netaddr.IPPort) PacketVerdict +type PacketHandler func(p *Packet, inIf *Interface) PacketVerdict // A Machine is a representation of an operating system's network // stack. It has a network routing table and can have multiple @@ -250,8 +282,9 @@ type Machine struct { // every packet this Machine receives. Returns a verdict for how // the packet should continue to be handled (or not). // - // This can be used to implement things like stateful firewalls - // and NAT boxes. + // The packet provided to HandlePacket can safely be mutated and + // Inject()ed if desired. This can be used to implement things + // like stateful firewalls and NAT boxes. HandlePacket PacketHandler mu sync.Mutex @@ -264,18 +297,22 @@ type Machine struct { // Inject transmits p from src to dst, without the need for a local socket. // It's useful for implementing e.g. NAT boxes that need to mangle IPs. -func (m *Machine) Inject(p []byte, dst, src netaddr.IPPort) error { - trace(p, "mach=%s src=%s dst=%s packet injected", m.Name, src, dst) - _, err := m.writePacket(p, dst, src) +func (m *Machine) Inject(p *Packet) error { + p = p.Clone() + p.setLocator("mach=%s", m.Name) + p.Trace("Machine.Inject") + _, err := m.writePacket(p) return err } -func (m *Machine) deliverIncomingPacket(p []byte, iface *Interface, dst, src netaddr.IPPort) { +func (m *Machine) deliverIncomingPacket(p *Packet, iface *Interface) { + p.setLocator("mach=%s if=%s", m.Name, iface.name) // TODO: can't hold lock while handling packet. This is safe as // long as you set HandlePacket before traffic starts flowing. if m.HandlePacket != nil { - verdict := m.HandlePacket(p, iface, dst, src) - trace(p, "mach=%s src=%v dst=%v packethandler verdict=%s", m.Name, src, dst, verdict) + p.Trace("Machine.HandlePacket") + verdict := m.HandlePacket(p.Clone(), iface) + p.Trace("Machine.HandlePacket verdict=%s", verdict) if verdict == Drop { // Custom packet handler ate the packet, we're done. return @@ -286,13 +323,13 @@ func (m *Machine) deliverIncomingPacket(p []byte, iface *Interface, dst, src net defer m.mu.Unlock() conns := m.conns4 - if dst.IP.Is6() { + if p.Dst.IP.Is6() { conns = m.conns6 } possibleDsts := []netaddr.IPPort{ - dst, - netaddr.IPPort{IP: v6unspec, Port: dst.Port}, - netaddr.IPPort{IP: v4unspec, Port: dst.Port}, + p.Dst, + netaddr.IPPort{IP: v6unspec, Port: p.Dst.Port}, + netaddr.IPPort{IP: v4unspec, Port: p.Dst.Port}, } for _, dest := range possibleDsts { c, ok := conns[dest] @@ -300,15 +337,15 @@ func (m *Machine) deliverIncomingPacket(p []byte, iface *Interface, dst, src net continue } select { - case c.in <- incomingPacket{src: src, p: p}: - trace(p, "mach=%s src=%v dst=%v queued to conn", m.Name, src, dst) + case c.in <- p: + p.Trace("queued to conn") default: - trace(p, "mach=%s src=%v dst=%v dropped, queue overflow", m.Name, src, dst) + p.Trace("dropped, queue overflow") // Queue overflow. Just drop it. } return } - trace(p, "mach=%s src=%v dst=%v dropped, no listening conn", m.Name, src, dst) + p.Trace("dropped, no listening conn") } func unspecOf(ip netaddr.IP) netaddr.IP { @@ -378,38 +415,43 @@ var ( v6unspec = netaddr.IPv6Unspecified() ) -func (m *Machine) writePacket(p []byte, dst, src netaddr.IPPort) (n int, err error) { - iface, err := m.interfaceForIP(dst.IP) +func (m *Machine) writePacket(p *Packet) (n int, err error) { + p.setLocator("mach=%s", m.Name) + + iface, err := m.interfaceForIP(p.Dst.IP) if err != nil { - trace(p, "%v", err) + p.Trace("%v", err) return 0, err } - origSrcIP := src.IP + origSrcIP := p.Src.IP switch { - case src.IP == v4unspec: - src.IP = iface.V4() - case src.IP == v6unspec: + case p.Src.IP == v4unspec: + p.Trace("assigning srcIP=%s", iface.V4()) + p.Src.IP = iface.V4() + case p.Src.IP == v6unspec: // v6unspec in Go means "any src, but match address families" - if dst.IP.Is6() { - src.IP = iface.V6() - } else if dst.IP.Is4() { - src.IP = iface.V4() + if p.Dst.IP.Is6() { + p.Trace("assigning srcIP=%s", iface.V6()) + p.Src.IP = iface.V6() + } else if p.Dst.IP.Is4() { + p.Trace("assigning srcIP=%s", iface.V4()) + p.Src.IP = iface.V4() } default: - if !iface.Contains(src.IP) { - err := fmt.Errorf("can't send to %v with src %v on interface %v", dst.IP, src.IP, iface) - trace(p, "%v", err) + if !iface.Contains(p.Src.IP) { + err := fmt.Errorf("can't send to %v with src %v on interface %v", p.Dst.IP, p.Src.IP, iface) + p.Trace("%v", err) return 0, err } } - if src.IP.IsZero() { + if p.Src.IP.IsZero() { err := fmt.Errorf("no matching address for address family for %v", origSrcIP) - trace(p, "%v", err) + p.Trace("%v", err) return 0, err } - trace(p, "mach=%s src=%s dst=%s -> net=%s", m.Name, src, dst, iface.net.Name) - return iface.net.write(p, dst, src) + p.Trace("-> net=%s if=%s", iface.net.Name, iface) + return iface.net.write(p) } func (m *Machine) interfaceForIP(ip netaddr.IP) (*Interface, error) { @@ -552,7 +594,7 @@ func (m *Machine) ListenPacket(ctx context.Context, network, address string) (ne m: m, fam: fam, ipp: ipp, - in: make(chan incomingPacket, 100), // arbitrary + in: make(chan *Packet, 100), // arbitrary } switch c.fam { case 0: @@ -585,12 +627,7 @@ type conn struct { closed bool readDeadline time.Time activeReads map[*activeRead]bool - in chan incomingPacket -} - -type incomingPacket struct { - p []byte - src netaddr.IPPort + in chan *Packet } type activeRead struct { @@ -669,9 +706,9 @@ func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { select { case pkt := <-c.in: - n = copy(p, pkt.p) - trace(pkt.p, "mach=%s src=%s PacketConn.ReadFrom", c.m.Name, pkt.src) - return n, pkt.src.UDPAddr(), nil + n = copy(p, pkt.Payload) + pkt.Trace("PacketConn.ReadFrom") + return n, pkt.Src.UDPAddr(), nil case <-ctx.Done(): return 0, nil, context.DeadlineExceeded } @@ -682,7 +719,14 @@ func (c *conn) WriteTo(p []byte, addr net.Addr) (n int, err error) { if err != nil { return 0, fmt.Errorf("bogus addr %T %q", addr, addr.String()) } - return c.m.writePacket(p, ipp, c.ipp) + pkt := &Packet{ + Src: c.ipp, + Dst: ipp, + Payload: append([]byte(nil), p...), + } + pkt.setLocator("mach=%s", c.m.Name) + pkt.Trace("PacketConn.WriteTo") + return c.m.writePacket(pkt) } func (c *conn) SetDeadline(t time.Time) error { diff --git a/tstest/natlab/natlab_test.go b/tstest/natlab/natlab_test.go index 6da0b29f1..10e335faa 100644 --- a/tstest/natlab/natlab_test.go +++ b/tstest/natlab/natlab_test.go @@ -175,15 +175,17 @@ func TestPacketHandler(t *testing.T) { // port remappings or any other things that NATs usually to. But // it works as a demonstrator for a single client behind the NAT, // where the NAT box itself doesn't also make PacketConns. - nat.HandlePacket = func(p []byte, iface *Interface, dst, src netaddr.IPPort) PacketVerdict { + nat.HandlePacket = func(p *Packet, iface *Interface) PacketVerdict { switch { - case dst.IP.Is6(): + case p.Dst.IP.Is6(): return Continue // no NAT for ipv6 - case iface == ifNATLAN && src.IP == ifClient.V4(): - nat.Inject(p, dst, netaddr.IPPort{IP: ifNATWAN.V4(), Port: src.Port}) + case iface == ifNATLAN && p.Src.IP == ifClient.V4(): + p.Src.IP = ifNATWAN.V4() + nat.Inject(p) return Drop - case iface == ifNATWAN && dst.IP == ifNATWAN.V4(): - nat.Inject(p, netaddr.IPPort{IP: ifClient.V4(), Port: dst.Port}, src) + case iface == ifNATWAN && p.Dst.IP == ifNATWAN.V4(): + p.Dst.IP = ifClient.V4() + nat.Inject(p) return Drop default: return Continue @@ -257,7 +259,12 @@ func TestFirewall(t *testing.T) { for _, test := range tests { clock.Advance(time.Second) - got := f.HandlePacket(nil, test.iface, test.dst, test.src) + p := &Packet{ + Src: test.src, + Dst: test.dst, + Payload: []byte{}, + } + got := f.HandlePacket(p, test.iface) if got != test.want { t.Errorf("iface=%s src=%s dst=%s got %v, want %v", test.iface.name, test.src, test.dst, got, test.want) }