tstest/natlab: refactor, expose a Packet type.
HandlePacket and Inject now receive/take Packets. This is a handy container for the packet, and the attached Trace method can be used to print traces from custom packet handlers that integrate nicely with natlab's internal traces. Signed-off-by: David Anderson <danderson@tailscale.com>reviewable/pr546/r1
parent
5eedbcedd1
commit
b3d65ba943
|
@ -43,7 +43,7 @@ func (f *Firewall) timeNow() time.Time {
|
|||
return time.Now()
|
||||
}
|
||||
|
||||
func (f *Firewall) HandlePacket(p []byte, inIf *Interface, dst, src netaddr.IPPort) PacketVerdict {
|
||||
func (f *Firewall) HandlePacket(p *Packet, inIf *Interface) PacketVerdict {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if f.seen == nil {
|
||||
|
@ -52,25 +52,25 @@ func (f *Firewall) HandlePacket(p []byte, inIf *Interface, dst, src netaddr.IPPo
|
|||
|
||||
if inIf == f.TrustedInterface {
|
||||
sess := session{
|
||||
src: src,
|
||||
dst: dst,
|
||||
src: p.Src,
|
||||
dst: p.Dst,
|
||||
}
|
||||
f.seen[sess] = f.timeNow().Add(f.SessionTimeout)
|
||||
trace(p, "mach=%s iface=%s src=%s dst=%s firewall out ok", inIf.Machine().Name, inIf.name, src, dst)
|
||||
p.Trace("firewall out ok")
|
||||
return Continue
|
||||
} else {
|
||||
// reverse src and dst because the session table is from the
|
||||
// POV of outbound packets.
|
||||
sess := session{
|
||||
src: dst,
|
||||
dst: src,
|
||||
src: p.Dst,
|
||||
dst: p.Src,
|
||||
}
|
||||
now := f.timeNow()
|
||||
if now.After(f.seen[sess]) {
|
||||
trace(p, "mach=%s iface=%s src=%s dst=%s firewall drop", inIf.Machine().Name, inIf.name, src, dst)
|
||||
p.Trace("firewall drop")
|
||||
return Drop
|
||||
}
|
||||
trace(p, "mach=%s iface=%s src=%s dst=%s firewall in ok", inIf.Machine().Name, inIf.name, src, dst)
|
||||
p.Trace("firewall in ok")
|
||||
return Continue
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,21 +30,49 @@ import (
|
|||
|
||||
var traceOn, _ = strconv.ParseBool(os.Getenv("NATLAB_TRACE"))
|
||||
|
||||
func trace(p []byte, msg string, args ...interface{}) {
|
||||
// Packet represents a UDP packet flowing through the virtual network.
|
||||
type Packet struct {
|
||||
Src, Dst netaddr.IPPort
|
||||
Payload []byte
|
||||
|
||||
// Prefix set by various internal methods of natlab, to locate
|
||||
// where in the network a trace occured.
|
||||
locator string
|
||||
}
|
||||
|
||||
// Clone returns a copy of p that shares nothing with p.
|
||||
func (p *Packet) Clone() *Packet {
|
||||
return &Packet{
|
||||
Src: p.Src,
|
||||
Dst: p.Dst,
|
||||
Payload: append([]byte(nil), p.Payload...),
|
||||
locator: p.locator,
|
||||
}
|
||||
}
|
||||
|
||||
// short returns a short identifier for a packet payload,
|
||||
// suitable for printing trace information.
|
||||
func (p *Packet) short() string {
|
||||
s := sha256.Sum256(p.Payload)
|
||||
payload := base64.RawStdEncoding.EncodeToString(s[:])[:2]
|
||||
|
||||
s = sha256.Sum256([]byte(p.Src.String() + "_" + p.Dst.String()))
|
||||
tuple := base64.RawStdEncoding.EncodeToString(s[:])[:2]
|
||||
|
||||
return fmt.Sprintf("%s/%s", payload, tuple)
|
||||
}
|
||||
|
||||
func (p *Packet) Trace(msg string, args ...interface{}) {
|
||||
if !traceOn {
|
||||
return
|
||||
}
|
||||
id := packetShort(p)
|
||||
as := []interface{}{id}
|
||||
as = append(as, args...)
|
||||
fmt.Fprintf(os.Stderr, "[%s] "+msg+"\n", as...)
|
||||
allArgs := []interface{}{p.short(), p.locator, p.Src, p.Dst}
|
||||
allArgs = append(allArgs, args...)
|
||||
fmt.Fprintf(os.Stderr, "[%s]%s src=%s dst=%s "+msg+"\n", allArgs...)
|
||||
}
|
||||
|
||||
// packetShort returns a short identifier for a packet payload,
|
||||
// suitable for pritning trace information.
|
||||
func packetShort(p []byte) string {
|
||||
s := sha256.Sum256(p)
|
||||
return base64.RawStdEncoding.EncodeToString(s[:])[:4]
|
||||
func (p *Packet) setLocator(msg string, args ...interface{}) {
|
||||
p.locator = fmt.Sprintf(" "+msg, args...)
|
||||
}
|
||||
|
||||
func mustPrefix(s string) netaddr.IPPrefix {
|
||||
|
@ -79,6 +107,9 @@ type Network struct {
|
|||
func (n *Network) SetDefaultGateway(gwIf *Interface) {
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
if gwIf.net != n {
|
||||
panic(fmt.Sprintf("can't set if=%s as net=%s's default gw, if not connected to net", gwIf.name, gwIf.net.Name))
|
||||
}
|
||||
n.defaultGW = gwIf
|
||||
}
|
||||
|
||||
|
@ -139,24 +170,25 @@ func addOne(a *[16]byte, index int) {
|
|||
}
|
||||
}
|
||||
|
||||
func (n *Network) write(p []byte, dst, src netaddr.IPPort) (num int, err error) {
|
||||
func (n *Network) write(p *Packet) (num int, err error) {
|
||||
p.setLocator("net=%s", n.Name)
|
||||
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
iface, ok := n.machine[dst.IP]
|
||||
iface, ok := n.machine[p.Dst.IP]
|
||||
if !ok {
|
||||
if n.defaultGW == nil {
|
||||
trace(p, "net=%s dropped, no route to %v", n.Name, dst.IP)
|
||||
return len(p), nil
|
||||
p.Trace("no route to %v", p.Dst.IP)
|
||||
return len(p.Payload), nil
|
||||
}
|
||||
iface = n.defaultGW
|
||||
}
|
||||
|
||||
// Pretend it went across the network. Make a copy so nobody
|
||||
// can later mess with caller's memory.
|
||||
trace(p, "net=%s src=%v dst=%v -> mach=%s iface=%s", n.Name, src, dst, iface.machine.Name, iface.name)
|
||||
pcopy := append([]byte(nil), p...)
|
||||
go iface.machine.deliverIncomingPacket(pcopy, iface, dst, src)
|
||||
return len(p), nil
|
||||
p.Trace("-> mach=%s if=%s", iface.machine.Name, iface.name)
|
||||
go iface.machine.deliverIncomingPacket(p, iface)
|
||||
return len(p.Payload), nil
|
||||
}
|
||||
|
||||
type Interface struct {
|
||||
|
@ -235,7 +267,7 @@ func (v PacketVerdict) String() string {
|
|||
}
|
||||
|
||||
// A PacketHandler is a function that can process packets.
|
||||
type PacketHandler func(p []byte, inIf *Interface, dst, src netaddr.IPPort) PacketVerdict
|
||||
type PacketHandler func(p *Packet, inIf *Interface) PacketVerdict
|
||||
|
||||
// A Machine is a representation of an operating system's network
|
||||
// stack. It has a network routing table and can have multiple
|
||||
|
@ -250,8 +282,9 @@ type Machine struct {
|
|||
// every packet this Machine receives. Returns a verdict for how
|
||||
// the packet should continue to be handled (or not).
|
||||
//
|
||||
// This can be used to implement things like stateful firewalls
|
||||
// and NAT boxes.
|
||||
// The packet provided to HandlePacket can safely be mutated and
|
||||
// Inject()ed if desired. This can be used to implement things
|
||||
// like stateful firewalls and NAT boxes.
|
||||
HandlePacket PacketHandler
|
||||
|
||||
mu sync.Mutex
|
||||
|
@ -264,18 +297,22 @@ type Machine struct {
|
|||
|
||||
// Inject transmits p from src to dst, without the need for a local socket.
|
||||
// It's useful for implementing e.g. NAT boxes that need to mangle IPs.
|
||||
func (m *Machine) Inject(p []byte, dst, src netaddr.IPPort) error {
|
||||
trace(p, "mach=%s src=%s dst=%s packet injected", m.Name, src, dst)
|
||||
_, err := m.writePacket(p, dst, src)
|
||||
func (m *Machine) Inject(p *Packet) error {
|
||||
p = p.Clone()
|
||||
p.setLocator("mach=%s", m.Name)
|
||||
p.Trace("Machine.Inject")
|
||||
_, err := m.writePacket(p)
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *Machine) deliverIncomingPacket(p []byte, iface *Interface, dst, src netaddr.IPPort) {
|
||||
func (m *Machine) deliverIncomingPacket(p *Packet, iface *Interface) {
|
||||
p.setLocator("mach=%s if=%s", m.Name, iface.name)
|
||||
// TODO: can't hold lock while handling packet. This is safe as
|
||||
// long as you set HandlePacket before traffic starts flowing.
|
||||
if m.HandlePacket != nil {
|
||||
verdict := m.HandlePacket(p, iface, dst, src)
|
||||
trace(p, "mach=%s src=%v dst=%v packethandler verdict=%s", m.Name, src, dst, verdict)
|
||||
p.Trace("Machine.HandlePacket")
|
||||
verdict := m.HandlePacket(p.Clone(), iface)
|
||||
p.Trace("Machine.HandlePacket verdict=%s", verdict)
|
||||
if verdict == Drop {
|
||||
// Custom packet handler ate the packet, we're done.
|
||||
return
|
||||
|
@ -286,13 +323,13 @@ func (m *Machine) deliverIncomingPacket(p []byte, iface *Interface, dst, src net
|
|||
defer m.mu.Unlock()
|
||||
|
||||
conns := m.conns4
|
||||
if dst.IP.Is6() {
|
||||
if p.Dst.IP.Is6() {
|
||||
conns = m.conns6
|
||||
}
|
||||
possibleDsts := []netaddr.IPPort{
|
||||
dst,
|
||||
netaddr.IPPort{IP: v6unspec, Port: dst.Port},
|
||||
netaddr.IPPort{IP: v4unspec, Port: dst.Port},
|
||||
p.Dst,
|
||||
netaddr.IPPort{IP: v6unspec, Port: p.Dst.Port},
|
||||
netaddr.IPPort{IP: v4unspec, Port: p.Dst.Port},
|
||||
}
|
||||
for _, dest := range possibleDsts {
|
||||
c, ok := conns[dest]
|
||||
|
@ -300,15 +337,15 @@ func (m *Machine) deliverIncomingPacket(p []byte, iface *Interface, dst, src net
|
|||
continue
|
||||
}
|
||||
select {
|
||||
case c.in <- incomingPacket{src: src, p: p}:
|
||||
trace(p, "mach=%s src=%v dst=%v queued to conn", m.Name, src, dst)
|
||||
case c.in <- p:
|
||||
p.Trace("queued to conn")
|
||||
default:
|
||||
trace(p, "mach=%s src=%v dst=%v dropped, queue overflow", m.Name, src, dst)
|
||||
p.Trace("dropped, queue overflow")
|
||||
// Queue overflow. Just drop it.
|
||||
}
|
||||
return
|
||||
}
|
||||
trace(p, "mach=%s src=%v dst=%v dropped, no listening conn", m.Name, src, dst)
|
||||
p.Trace("dropped, no listening conn")
|
||||
}
|
||||
|
||||
func unspecOf(ip netaddr.IP) netaddr.IP {
|
||||
|
@ -378,38 +415,43 @@ var (
|
|||
v6unspec = netaddr.IPv6Unspecified()
|
||||
)
|
||||
|
||||
func (m *Machine) writePacket(p []byte, dst, src netaddr.IPPort) (n int, err error) {
|
||||
iface, err := m.interfaceForIP(dst.IP)
|
||||
func (m *Machine) writePacket(p *Packet) (n int, err error) {
|
||||
p.setLocator("mach=%s", m.Name)
|
||||
|
||||
iface, err := m.interfaceForIP(p.Dst.IP)
|
||||
if err != nil {
|
||||
trace(p, "%v", err)
|
||||
p.Trace("%v", err)
|
||||
return 0, err
|
||||
}
|
||||
origSrcIP := src.IP
|
||||
origSrcIP := p.Src.IP
|
||||
switch {
|
||||
case src.IP == v4unspec:
|
||||
src.IP = iface.V4()
|
||||
case src.IP == v6unspec:
|
||||
case p.Src.IP == v4unspec:
|
||||
p.Trace("assigning srcIP=%s", iface.V4())
|
||||
p.Src.IP = iface.V4()
|
||||
case p.Src.IP == v6unspec:
|
||||
// v6unspec in Go means "any src, but match address families"
|
||||
if dst.IP.Is6() {
|
||||
src.IP = iface.V6()
|
||||
} else if dst.IP.Is4() {
|
||||
src.IP = iface.V4()
|
||||
if p.Dst.IP.Is6() {
|
||||
p.Trace("assigning srcIP=%s", iface.V6())
|
||||
p.Src.IP = iface.V6()
|
||||
} else if p.Dst.IP.Is4() {
|
||||
p.Trace("assigning srcIP=%s", iface.V4())
|
||||
p.Src.IP = iface.V4()
|
||||
}
|
||||
default:
|
||||
if !iface.Contains(src.IP) {
|
||||
err := fmt.Errorf("can't send to %v with src %v on interface %v", dst.IP, src.IP, iface)
|
||||
trace(p, "%v", err)
|
||||
if !iface.Contains(p.Src.IP) {
|
||||
err := fmt.Errorf("can't send to %v with src %v on interface %v", p.Dst.IP, p.Src.IP, iface)
|
||||
p.Trace("%v", err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if src.IP.IsZero() {
|
||||
if p.Src.IP.IsZero() {
|
||||
err := fmt.Errorf("no matching address for address family for %v", origSrcIP)
|
||||
trace(p, "%v", err)
|
||||
p.Trace("%v", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
trace(p, "mach=%s src=%s dst=%s -> net=%s", m.Name, src, dst, iface.net.Name)
|
||||
return iface.net.write(p, dst, src)
|
||||
p.Trace("-> net=%s if=%s", iface.net.Name, iface)
|
||||
return iface.net.write(p)
|
||||
}
|
||||
|
||||
func (m *Machine) interfaceForIP(ip netaddr.IP) (*Interface, error) {
|
||||
|
@ -552,7 +594,7 @@ func (m *Machine) ListenPacket(ctx context.Context, network, address string) (ne
|
|||
m: m,
|
||||
fam: fam,
|
||||
ipp: ipp,
|
||||
in: make(chan incomingPacket, 100), // arbitrary
|
||||
in: make(chan *Packet, 100), // arbitrary
|
||||
}
|
||||
switch c.fam {
|
||||
case 0:
|
||||
|
@ -585,12 +627,7 @@ type conn struct {
|
|||
closed bool
|
||||
readDeadline time.Time
|
||||
activeReads map[*activeRead]bool
|
||||
in chan incomingPacket
|
||||
}
|
||||
|
||||
type incomingPacket struct {
|
||||
p []byte
|
||||
src netaddr.IPPort
|
||||
in chan *Packet
|
||||
}
|
||||
|
||||
type activeRead struct {
|
||||
|
@ -669,9 +706,9 @@ func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|||
|
||||
select {
|
||||
case pkt := <-c.in:
|
||||
n = copy(p, pkt.p)
|
||||
trace(pkt.p, "mach=%s src=%s PacketConn.ReadFrom", c.m.Name, pkt.src)
|
||||
return n, pkt.src.UDPAddr(), nil
|
||||
n = copy(p, pkt.Payload)
|
||||
pkt.Trace("PacketConn.ReadFrom")
|
||||
return n, pkt.Src.UDPAddr(), nil
|
||||
case <-ctx.Done():
|
||||
return 0, nil, context.DeadlineExceeded
|
||||
}
|
||||
|
@ -682,7 +719,14 @@ func (c *conn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|||
if err != nil {
|
||||
return 0, fmt.Errorf("bogus addr %T %q", addr, addr.String())
|
||||
}
|
||||
return c.m.writePacket(p, ipp, c.ipp)
|
||||
pkt := &Packet{
|
||||
Src: c.ipp,
|
||||
Dst: ipp,
|
||||
Payload: append([]byte(nil), p...),
|
||||
}
|
||||
pkt.setLocator("mach=%s", c.m.Name)
|
||||
pkt.Trace("PacketConn.WriteTo")
|
||||
return c.m.writePacket(pkt)
|
||||
}
|
||||
|
||||
func (c *conn) SetDeadline(t time.Time) error {
|
||||
|
|
|
@ -175,15 +175,17 @@ func TestPacketHandler(t *testing.T) {
|
|||
// port remappings or any other things that NATs usually to. But
|
||||
// it works as a demonstrator for a single client behind the NAT,
|
||||
// where the NAT box itself doesn't also make PacketConns.
|
||||
nat.HandlePacket = func(p []byte, iface *Interface, dst, src netaddr.IPPort) PacketVerdict {
|
||||
nat.HandlePacket = func(p *Packet, iface *Interface) PacketVerdict {
|
||||
switch {
|
||||
case dst.IP.Is6():
|
||||
case p.Dst.IP.Is6():
|
||||
return Continue // no NAT for ipv6
|
||||
case iface == ifNATLAN && src.IP == ifClient.V4():
|
||||
nat.Inject(p, dst, netaddr.IPPort{IP: ifNATWAN.V4(), Port: src.Port})
|
||||
case iface == ifNATLAN && p.Src.IP == ifClient.V4():
|
||||
p.Src.IP = ifNATWAN.V4()
|
||||
nat.Inject(p)
|
||||
return Drop
|
||||
case iface == ifNATWAN && dst.IP == ifNATWAN.V4():
|
||||
nat.Inject(p, netaddr.IPPort{IP: ifClient.V4(), Port: dst.Port}, src)
|
||||
case iface == ifNATWAN && p.Dst.IP == ifNATWAN.V4():
|
||||
p.Dst.IP = ifClient.V4()
|
||||
nat.Inject(p)
|
||||
return Drop
|
||||
default:
|
||||
return Continue
|
||||
|
@ -257,7 +259,12 @@ func TestFirewall(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
clock.Advance(time.Second)
|
||||
got := f.HandlePacket(nil, test.iface, test.dst, test.src)
|
||||
p := &Packet{
|
||||
Src: test.src,
|
||||
Dst: test.dst,
|
||||
Payload: []byte{},
|
||||
}
|
||||
got := f.HandlePacket(p, test.iface)
|
||||
if got != test.want {
|
||||
t.Errorf("iface=%s src=%s dst=%s got %v, want %v", test.iface.name, test.src, test.dst, got, test.want)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue