net/pidlisten: new package that restricts dials to the current process
To be used in the C library wrapping tsnet to provide LocalAPI access. This commit contains a linux implementation. More operating systems to follow. Signed-off-by: David Crawshaw <crawshaw@tailscale.com>crawshaw/pidlisten
parent
e484e1c0fc
commit
5f256f114f
|
@ -0,0 +1,42 @@
|
||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
// Package pidlisten implements a TCP listener that only
|
||||||
|
// accepts connections from the current process.
|
||||||
|
package pidlisten
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type listener struct {
|
||||||
|
ln net.Listener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pln *listener) Accept() (net.Conn, error) {
|
||||||
|
for {
|
||||||
|
conn, err := pln.ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ok, err := checkPIDLocal(conn)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("pidlisten: %w", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
conn.Close()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pln *listener) Close() error {
|
||||||
|
return pln.ln.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pln *listener) Addr() net.Addr {
|
||||||
|
return pln.ln.Addr()
|
||||||
|
}
|
|
@ -0,0 +1,63 @@
|
||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
package pidlisten
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"go4.org/mem"
|
||||||
|
"io/fs"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"tailscale.com/util/dirwalk"
|
||||||
|
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewPIDListener wraps a net.Listener so that it only accepts connections from the current process.
|
||||||
|
func NewPIDListener(ln net.Listener) net.Listener {
|
||||||
|
return &listener{ln: ln}
|
||||||
|
}
|
||||||
|
|
||||||
|
var errFoundSocket = errors.New("found socket")
|
||||||
|
|
||||||
|
func checkPIDLocal(conn net.Conn) (bool, error) {
|
||||||
|
remoteAddr := conn.RemoteAddr()
|
||||||
|
var remoteIP net.IP
|
||||||
|
switch remoteAddr.Network() {
|
||||||
|
case "tcp":
|
||||||
|
remoteIP = remoteAddr.(*net.TCPAddr).IP
|
||||||
|
case "udp":
|
||||||
|
remoteIP = remoteAddr.(*net.UDPAddr).IP
|
||||||
|
default:
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if !remoteIP.IsLoopback() {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// You can look up a net.Conn in both directions.
|
||||||
|
// There are different inodes for remote->local and local->remote.
|
||||||
|
// We want to look up the starting side of the net.Conn and check
|
||||||
|
// that its inode belongs to the current PID.
|
||||||
|
s, err := netlink.SocketGet(conn.RemoteAddr(), conn.LocalAddr())
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
want := fmt.Sprintf("socket:[%d]", s.INode)
|
||||||
|
dir := fmt.Sprintf("/proc/%d/fd", os.Getpid())
|
||||||
|
err = dirwalk.WalkShallow(mem.S(dir), func(name mem.RO, de fs.DirEntry) error {
|
||||||
|
n, err := os.Readlink(filepath.Join(dir, name.StringCopy()))
|
||||||
|
if err == nil && want == n {
|
||||||
|
return errFoundSocket
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err == errFoundSocket {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
//go:build !linux
|
||||||
|
// +build !linux
|
||||||
|
|
||||||
|
package pidlisten
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
func checkPIDLocal(conn net.Conn) (bool, error) {
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
|
@ -0,0 +1,122 @@
|
||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
//go:build linux
|
||||||
|
// +build linux
|
||||||
|
|
||||||
|
package pidlisten
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var flagDial = flag.String("dial", "", "if set, dials the given addr and reads until close")
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
flag.Parse()
|
||||||
|
if *flagDial != "" {
|
||||||
|
conn, err := net.DialTimeout("tcp", *flagDial, 5*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
b, err := io.ReadAll(conn)
|
||||||
|
fmt.Fprintf(os.Stderr, "%s", b)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
os.Exit(m.Run())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPIDLocal(t *testing.T) {
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
ok, err := checkPIDLocal(conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("checkPIDLocal=false, want true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testExternalProcess(t *testing.T, ln net.Listener) string {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
c, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(c, "hello\n")
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
exe, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := exec.Command(exe, "-dial="+ln.Addr().String()).CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return string(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExternalDialWorks(t *testing.T) {
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
out := testExternalProcess(t, ln)
|
||||||
|
if out != "hello\n" {
|
||||||
|
t.Errorf("out=%q, want hello", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPIDExternal(t *testing.T) {
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
ln = NewPIDListener(ln)
|
||||||
|
out := testExternalProcess(t, ln)
|
||||||
|
|
||||||
|
if len(out) != 0 {
|
||||||
|
t.Errorf("unexpected socket output: %q", out)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue