net/pidlisten: new package that restricts dials to the current process

To be used in the C library wrapping tsnet to provide LocalAPI access.

This commit contains a linux implementation.
More operating systems to follow.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
crawshaw/pidlisten
David Crawshaw 2023-02-26 17:32:36 -05:00
parent e484e1c0fc
commit 5f256f114f
4 changed files with 240 additions and 0 deletions

View File

@ -0,0 +1,42 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package pidlisten implements a TCP listener that only
// accepts connections from the current process.
package pidlisten
import (
"fmt"
"net"
)
type listener struct {
ln net.Listener
}
func (pln *listener) Accept() (net.Conn, error) {
for {
conn, err := pln.ln.Accept()
if err != nil {
return nil, err
}
ok, err := checkPIDLocal(conn)
if err != nil {
conn.Close()
return nil, fmt.Errorf("pidlisten: %w", err)
}
if !ok {
conn.Close()
continue
}
return conn, nil
}
}
func (pln *listener) Close() error {
return pln.ln.Close()
}
func (pln *listener) Addr() net.Addr {
return pln.ln.Addr()
}

View File

@ -0,0 +1,63 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package pidlisten
import (
"errors"
"fmt"
"go4.org/mem"
"io/fs"
"net"
"os"
"path/filepath"
"tailscale.com/util/dirwalk"
"github.com/vishvananda/netlink"
)
// NewPIDListener wraps a net.Listener so that it only accepts connections from the current process.
func NewPIDListener(ln net.Listener) net.Listener {
return &listener{ln: ln}
}
var errFoundSocket = errors.New("found socket")
func checkPIDLocal(conn net.Conn) (bool, error) {
remoteAddr := conn.RemoteAddr()
var remoteIP net.IP
switch remoteAddr.Network() {
case "tcp":
remoteIP = remoteAddr.(*net.TCPAddr).IP
case "udp":
remoteIP = remoteAddr.(*net.UDPAddr).IP
default:
return false, nil
}
if !remoteIP.IsLoopback() {
return false, nil
}
// You can look up a net.Conn in both directions.
// There are different inodes for remote->local and local->remote.
// We want to look up the starting side of the net.Conn and check
// that its inode belongs to the current PID.
s, err := netlink.SocketGet(conn.RemoteAddr(), conn.LocalAddr())
if err != nil {
return false, err
}
want := fmt.Sprintf("socket:[%d]", s.INode)
dir := fmt.Sprintf("/proc/%d/fd", os.Getpid())
err = dirwalk.WalkShallow(mem.S(dir), func(name mem.RO, de fs.DirEntry) error {
n, err := os.Readlink(filepath.Join(dir, name.StringCopy()))
if err == nil && want == n {
return errFoundSocket
}
return nil
})
if err == errFoundSocket {
return true, nil
}
return false, err
}

View File

@ -0,0 +1,13 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !linux
// +build !linux
package pidlisten
import "net"
func checkPIDLocal(conn net.Conn) (bool, error) {
panic("not implemented")
}

View File

@ -0,0 +1,122 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux
// +build linux
package pidlisten
import (
"errors"
"flag"
"fmt"
"io"
"net"
"os"
"os/exec"
"testing"
"time"
)
var flagDial = flag.String("dial", "", "if set, dials the given addr and reads until close")
func TestMain(m *testing.M) {
flag.Parse()
if *flagDial != "" {
conn, err := net.DialTimeout("tcp", *flagDial, 5*time.Second)
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
conn.SetDeadline(time.Now().Add(5 * time.Second))
b, err := io.ReadAll(conn)
fmt.Fprintf(os.Stderr, "%s", b)
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
os.Exit(0)
}
os.Exit(m.Run())
}
func TestPIDLocal(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
clientConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
defer clientConn.Close()
conn, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
ok, err := checkPIDLocal(conn)
if err != nil {
t.Fatal(err)
}
if !ok {
t.Errorf("checkPIDLocal=false, want true")
}
}
func testExternalProcess(t *testing.T, ln net.Listener) string {
go func() {
for {
c, err := ln.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
panic(err)
}
fmt.Fprintf(c, "hello\n")
c.Close()
}
}()
exe, err := os.Executable()
if err != nil {
t.Fatal(err)
}
out, err := exec.Command(exe, "-dial="+ln.Addr().String()).CombinedOutput()
if err != nil {
t.Fatal(err)
}
return string(out)
}
func TestExternalDialWorks(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
out := testExternalProcess(t, ln)
if out != "hello\n" {
t.Errorf("out=%q, want hello", out)
}
}
func TestPIDExternal(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
ln = NewPIDListener(ln)
out := testExternalProcess(t, ln)
if len(out) != 0 {
t.Errorf("unexpected socket output: %q", out)
}
}