Compare commits
1 Commits
main
...
crawshaw/p
Author | SHA1 | Date |
---|---|---|
![]() |
5f256f114f |
|
@ -0,0 +1,42 @@
|
|||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package pidlisten implements a TCP listener that only
|
||||
// accepts connections from the current process.
|
||||
package pidlisten
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
type listener struct {
|
||||
ln net.Listener
|
||||
}
|
||||
|
||||
func (pln *listener) Accept() (net.Conn, error) {
|
||||
for {
|
||||
conn, err := pln.ln.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ok, err := checkPIDLocal(conn)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("pidlisten: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (pln *listener) Close() error {
|
||||
return pln.ln.Close()
|
||||
}
|
||||
|
||||
func (pln *listener) Addr() net.Addr {
|
||||
return pln.ln.Addr()
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package pidlisten
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"go4.org/mem"
|
||||
"io/fs"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"tailscale.com/util/dirwalk"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
)
|
||||
|
||||
// NewPIDListener wraps a net.Listener so that it only accepts connections from the current process.
|
||||
func NewPIDListener(ln net.Listener) net.Listener {
|
||||
return &listener{ln: ln}
|
||||
}
|
||||
|
||||
var errFoundSocket = errors.New("found socket")
|
||||
|
||||
func checkPIDLocal(conn net.Conn) (bool, error) {
|
||||
remoteAddr := conn.RemoteAddr()
|
||||
var remoteIP net.IP
|
||||
switch remoteAddr.Network() {
|
||||
case "tcp":
|
||||
remoteIP = remoteAddr.(*net.TCPAddr).IP
|
||||
case "udp":
|
||||
remoteIP = remoteAddr.(*net.UDPAddr).IP
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
if !remoteIP.IsLoopback() {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// You can look up a net.Conn in both directions.
|
||||
// There are different inodes for remote->local and local->remote.
|
||||
// We want to look up the starting side of the net.Conn and check
|
||||
// that its inode belongs to the current PID.
|
||||
s, err := netlink.SocketGet(conn.RemoteAddr(), conn.LocalAddr())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
want := fmt.Sprintf("socket:[%d]", s.INode)
|
||||
dir := fmt.Sprintf("/proc/%d/fd", os.Getpid())
|
||||
err = dirwalk.WalkShallow(mem.S(dir), func(name mem.RO, de fs.DirEntry) error {
|
||||
n, err := os.Readlink(filepath.Join(dir, name.StringCopy()))
|
||||
if err == nil && want == n {
|
||||
return errFoundSocket
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err == errFoundSocket {
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !linux
|
||||
// +build !linux
|
||||
|
||||
package pidlisten
|
||||
|
||||
import "net"
|
||||
|
||||
func checkPIDLocal(conn net.Conn) (bool, error) {
|
||||
panic("not implemented")
|
||||
}
|
|
@ -0,0 +1,122 @@
|
|||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package pidlisten
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var flagDial = flag.String("dial", "", "if set, dials the given addr and reads until close")
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
flag.Parse()
|
||||
if *flagDial != "" {
|
||||
conn, err := net.DialTimeout("tcp", *flagDial, 5*time.Second)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
b, err := io.ReadAll(conn)
|
||||
fmt.Fprintf(os.Stderr, "%s", b)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestPIDLocal(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer clientConn.Close()
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ok, err := checkPIDLocal(conn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !ok {
|
||||
t.Errorf("checkPIDLocal=false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func testExternalProcess(t *testing.T, ln net.Listener) string {
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
panic(err)
|
||||
}
|
||||
fmt.Fprintf(c, "hello\n")
|
||||
c.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
out, err := exec.Command(exe, "-dial="+ln.Addr().String()).CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
func TestExternalDialWorks(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
out := testExternalProcess(t, ln)
|
||||
if out != "hello\n" {
|
||||
t.Errorf("out=%q, want hello", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPIDExternal(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
ln = NewPIDListener(ln)
|
||||
out := testExternalProcess(t, ln)
|
||||
|
||||
if len(out) != 0 {
|
||||
t.Errorf("unexpected socket output: %q", out)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue