tstest/natlab: correctly handle dual-stacked PacketConns.
Adds a test with multiple networks, one of which is v4-only. Signed-off-by: David Anderson <danderson@tailscale.com>reviewable/pr519/r2
parent
771eb05bcb
commit
1d4f9852a7
|
@ -184,27 +184,33 @@ type Machine struct {
|
||||||
interfaces []*Interface
|
interfaces []*Interface
|
||||||
routes []routeEntry // sorted by longest prefix to shortest
|
routes []routeEntry // sorted by longest prefix to shortest
|
||||||
|
|
||||||
conns map[netaddr.IPPort]*conn
|
conns4 map[netaddr.IPPort]*conn
|
||||||
|
conns6 map[netaddr.IPPort]*conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Machine) deliverIncomingPacket(p []byte, dst, src netaddr.IPPort) {
|
func (m *Machine) deliverIncomingPacket(p []byte, dst, src netaddr.IPPort) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
// TODO(danderson): check behavior of dual stack sockets
|
conns := m.conns4
|
||||||
c, ok := m.conns[dst]
|
if dst.IP.Is6() {
|
||||||
if !ok {
|
conns = m.conns6
|
||||||
dst = netaddr.IPPort{IP: unspecOf(dst.IP), Port: dst.Port}
|
|
||||||
c, ok = m.conns[dst]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
possibleDsts := []netaddr.IPPort{
|
||||||
select {
|
dst,
|
||||||
case c.in <- incomingPacket{src: src, p: p}:
|
netaddr.IPPort{IP: v6unspec, Port: dst.Port},
|
||||||
default:
|
netaddr.IPPort{IP: v4unspec, Port: dst.Port},
|
||||||
// Queue overflow. Just drop it.
|
}
|
||||||
|
for _, dst := range possibleDsts {
|
||||||
|
c, ok := conns[dst]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case c.in <- incomingPacket{src: src, p: p}:
|
||||||
|
default:
|
||||||
|
// Queue overflow. Just drop it.
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -284,7 +290,12 @@ func (m *Machine) writePacket(p []byte, dst, src netaddr.IPPort) (n int, err err
|
||||||
case src.IP == v4unspec:
|
case src.IP == v4unspec:
|
||||||
src.IP = iface.V4()
|
src.IP = iface.V4()
|
||||||
case src.IP == v6unspec:
|
case src.IP == v6unspec:
|
||||||
src.IP = iface.V6()
|
// v6unspec in Go means "any src, but match address families"
|
||||||
|
if dst.IP.Is6() {
|
||||||
|
src.IP = iface.V6()
|
||||||
|
} else if dst.IP.Is4() {
|
||||||
|
src.IP = iface.V4()
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
if !iface.Contains(src.IP) {
|
if !iface.Contains(src.IP) {
|
||||||
return 0, fmt.Errorf("can't send to %v with src %v on interface %v", dst.IP, src.IP, iface)
|
return 0, fmt.Errorf("can't send to %v with src %v on interface %v", dst.IP, src.IP, iface)
|
||||||
|
@ -321,59 +332,86 @@ func (m *Machine) hasv6() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Machine) registerConn(c *conn) error {
|
func (m *Machine) registerConn4(c *conn) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
if _, ok := m.conns[c.ipp]; ok {
|
if c.ipp.IP.Is6() && c.ipp.IP != v6unspec {
|
||||||
|
return fmt.Errorf("registerConn4 got IPv6 %s", c.ipp)
|
||||||
|
}
|
||||||
|
if _, ok := m.conns4[c.ipp]; ok {
|
||||||
return fmt.Errorf("duplicate conn listening on %v", c.ipp)
|
return fmt.Errorf("duplicate conn listening on %v", c.ipp)
|
||||||
}
|
}
|
||||||
if m.conns == nil {
|
if m.conns4 == nil {
|
||||||
m.conns = map[netaddr.IPPort]*conn{}
|
m.conns4 = map[netaddr.IPPort]*conn{}
|
||||||
}
|
}
|
||||||
m.conns[c.ipp] = c
|
m.conns4[c.ipp] = c
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Machine) unregisterConn(c *conn) {
|
func (m *Machine) unregisterConn4(c *conn) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
delete(m.conns, c.ipp)
|
delete(m.conns4, c.ipp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Machine) registerConn6(c *conn) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if c.ipp.IP.Is4() {
|
||||||
|
return fmt.Errorf("registerConn6 got IPv4 %s", c.ipp)
|
||||||
|
}
|
||||||
|
if _, ok := m.conns6[c.ipp]; ok {
|
||||||
|
return fmt.Errorf("duplicate conn listening on %v", c.ipp)
|
||||||
|
}
|
||||||
|
if m.conns6 == nil {
|
||||||
|
m.conns6 = map[netaddr.IPPort]*conn{}
|
||||||
|
}
|
||||||
|
m.conns6[c.ipp] = c
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Machine) unregisterConn6(c *conn) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
delete(m.conns6, c.ipp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Machine) AddNetwork(n *Network) {}
|
func (m *Machine) AddNetwork(n *Network) {}
|
||||||
|
|
||||||
func (m *Machine) ListenPacket(network, address string) (net.PacketConn, error) {
|
func (m *Machine) ListenPacket(network, address string) (net.PacketConn, error) {
|
||||||
// if udp4, udp6, etc... look at address IP vs unspec
|
// if udp4, udp6, etc... look at address IP vs unspec
|
||||||
var fam uint8
|
var (
|
||||||
|
fam uint8
|
||||||
|
ip netaddr.IP
|
||||||
|
)
|
||||||
switch network {
|
switch network {
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported network type %q", network)
|
return nil, fmt.Errorf("unsupported network type %q", network)
|
||||||
case "udp":
|
case "udp":
|
||||||
|
fam = 0
|
||||||
|
ip = v6unspec
|
||||||
case "udp4":
|
case "udp4":
|
||||||
fam = 4
|
fam = 4
|
||||||
|
ip = v4unspec
|
||||||
case "udp6":
|
case "udp6":
|
||||||
fam = 6
|
fam = 6
|
||||||
|
ip = v6unspec
|
||||||
}
|
}
|
||||||
|
|
||||||
host, portStr, err := net.SplitHostPort(address)
|
host, portStr, err := net.SplitHostPort(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if host == "" {
|
if host != "" {
|
||||||
if m.hasv6() {
|
ip, err = netaddr.ParseIP(host)
|
||||||
host = "::"
|
if err != nil {
|
||||||
} else {
|
return nil, err
|
||||||
host = "0.0.0.0"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
port, err := strconv.ParseUint(portStr, 10, 16)
|
port, err := strconv.ParseUint(portStr, 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ip, err := netaddr.ParseIP(host)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ipp := netaddr.IPPort{IP: ip, Port: uint16(port)}
|
ipp := netaddr.IPPort{IP: ip, Port: uint16(port)}
|
||||||
|
|
||||||
c := &conn{
|
c := &conn{
|
||||||
|
@ -382,8 +420,22 @@ func (m *Machine) ListenPacket(network, address string) (net.PacketConn, error)
|
||||||
ipp: ipp,
|
ipp: ipp,
|
||||||
in: make(chan incomingPacket, 100), // arbitrary
|
in: make(chan incomingPacket, 100), // arbitrary
|
||||||
}
|
}
|
||||||
if err := m.registerConn(c); err != nil {
|
switch c.fam {
|
||||||
return nil, err
|
case 0:
|
||||||
|
if err := m.registerConn4(c); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := m.registerConn6(c); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case 4:
|
||||||
|
if err := m.registerConn4(c); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case 6:
|
||||||
|
if err := m.registerConn6(c); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
@ -437,7 +489,15 @@ func (c *conn) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
c.closed = true
|
c.closed = true
|
||||||
c.m.unregisterConn(c)
|
switch c.fam {
|
||||||
|
case 0:
|
||||||
|
c.m.unregisterConn4(c)
|
||||||
|
c.m.unregisterConn6(c)
|
||||||
|
case 4:
|
||||||
|
c.m.unregisterConn4(c)
|
||||||
|
case 6:
|
||||||
|
c.m.unregisterConn6(c)
|
||||||
|
}
|
||||||
c.breakActiveReadsLocked()
|
c.breakActiveReadsLocked()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,45 +77,68 @@ func TestSendPacket(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLAN(t *testing.T) {
|
func TestMultiNetwork(t *testing.T) {
|
||||||
// TODO: very duplicate-ey with the previous test, but important
|
|
||||||
// right now to test explicit construction of Networks.
|
|
||||||
lan := Network{
|
lan := Network{
|
||||||
Name: "lan1",
|
Name: "lan",
|
||||||
Prefix4: mustPrefix("192.168.0.0/24"),
|
Prefix4: mustPrefix("192.168.0.0/24"),
|
||||||
}
|
}
|
||||||
|
internet := NewInternet()
|
||||||
|
|
||||||
foo := NewMachine("foo")
|
client := NewMachine("client")
|
||||||
bar := NewMachine("bar")
|
nat := NewMachine("nat")
|
||||||
ifFoo := foo.Attach("eth0", &lan)
|
server := NewMachine("server")
|
||||||
ifBar := bar.Attach("eth0", &lan)
|
|
||||||
|
|
||||||
fooPC, err := foo.ListenPacket("udp4", ":123")
|
ifClient := client.Attach("eth0", &lan)
|
||||||
|
ifNATWAN := nat.Attach("ethwan", internet)
|
||||||
|
ifNATLAN := nat.Attach("ethlan", &lan)
|
||||||
|
ifServer := server.Attach("eth0", internet)
|
||||||
|
|
||||||
|
clientPC, err := client.ListenPacket("udp", ":123")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
barPC, err := bar.ListenPacket("udp4", ":456")
|
natPC, err := nat.ListenPacket("udp", ":456")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
serverPC, err := server.ListenPacket("udp", ":789")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
const msg = "message"
|
clientAddr := netaddr.IPPort{IP: ifClient.V4(), Port: 123}
|
||||||
barAddr := netaddr.IPPort{IP: ifBar.V4(), Port: 456}
|
natLANAddr := netaddr.IPPort{IP: ifNATLAN.V4(), Port: 456}
|
||||||
if _, err := fooPC.WriteTo([]byte(msg), barAddr.UDPAddr()); err != nil {
|
natWANAddr := netaddr.IPPort{IP: ifNATWAN.V4(), Port: 456}
|
||||||
|
serverAddr := netaddr.IPPort{IP: ifServer.V4(), Port: 789}
|
||||||
|
|
||||||
|
const msg1, msg2 = "hello", "world"
|
||||||
|
if _, err := natPC.WriteTo([]byte(msg1), clientAddr.UDPAddr()); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := natPC.WriteTo([]byte(msg2), serverAddr.UDPAddr()); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := make([]byte, 1500)
|
buf := make([]byte, 1500)
|
||||||
n, addr, err := barPC.ReadFrom(buf)
|
n, addr, err := clientPC.ReadFrom(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
buf = buf[:n]
|
if string(buf[:n]) != msg1 {
|
||||||
if string(buf) != msg {
|
t.Errorf("read %q; want %q", buf[:n], msg1)
|
||||||
t.Errorf("read %q; want %q", buf, msg)
|
|
||||||
}
|
}
|
||||||
fooAddr := netaddr.IPPort{IP: ifFoo.V4(), Port: 123}
|
if addr.String() != natLANAddr.String() {
|
||||||
if addr.String() != fooAddr.String() {
|
t.Errorf("addr = %q; want %q", addr, natLANAddr)
|
||||||
t.Errorf("addr = %q; want %q", addr, fooAddr)
|
}
|
||||||
|
|
||||||
|
n, addr, err = serverPC.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if string(buf[:n]) != msg2 {
|
||||||
|
t.Errorf("read %q; want %q", buf[:n], msg2)
|
||||||
|
}
|
||||||
|
if addr.String() != natWANAddr.String() {
|
||||||
|
t.Errorf("addr = %q; want %q", addr, natLANAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue