tailcfg: add DiscoKey, unify some code, add some tests
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>reviewable/pr480/r1
parent
d9054da86a
commit
88c305c8af
|
@ -13,7 +13,9 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tailscale/wireguard-go/wgcfg"
|
"github.com/tailscale/wireguard-go/wgcfg"
|
||||||
|
"go4.org/mem"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/types/opt"
|
"tailscale.com/types/opt"
|
||||||
"tailscale.com/types/structs"
|
"tailscale.com/types/structs"
|
||||||
)
|
)
|
||||||
|
@ -38,6 +40,10 @@ type MachineKey [32]byte
|
||||||
// NodeKey is the curve25519 public key for a node.
|
// NodeKey is the curve25519 public key for a node.
|
||||||
type NodeKey [32]byte
|
type NodeKey [32]byte
|
||||||
|
|
||||||
|
// DiscoKey is the curve25519 public key for path discovery key.
|
||||||
|
// It's never written to disk or reused between network start-ups.
|
||||||
|
type DiscoKey [32]byte
|
||||||
|
|
||||||
type Group struct {
|
type Group struct {
|
||||||
ID GroupID
|
ID GroupID
|
||||||
Name string
|
Name string
|
||||||
|
@ -127,6 +133,7 @@ type Node struct {
|
||||||
Key NodeKey
|
Key NodeKey
|
||||||
KeyExpiry time.Time
|
KeyExpiry time.Time
|
||||||
Machine MachineKey
|
Machine MachineKey
|
||||||
|
DiscoKey DiscoKey
|
||||||
Addresses []wgcfg.CIDR // IP addresses of this Node directly
|
Addresses []wgcfg.CIDR // IP addresses of this Node directly
|
||||||
AllowedIPs []wgcfg.CIDR // range of IP addresses to route to this node
|
AllowedIPs []wgcfg.CIDR // range of IP addresses to route to this node
|
||||||
Endpoints []string `json:",omitempty"` // IP+port (public via STUN, and local LANs)
|
Endpoints []string `json:",omitempty"` // IP+port (public via STUN, and local LANs)
|
||||||
|
@ -519,59 +526,43 @@ type Debug struct {
|
||||||
LogHeapURL string `json:",omitempty"`
|
LogHeapURL string `json:",omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k MachineKey) String() string { return fmt.Sprintf("mkey:%x", k[:]) }
|
func (k MachineKey) String() string { return fmt.Sprintf("mkey:%x", k[:]) }
|
||||||
|
func (k MachineKey) MarshalText() ([]byte, error) { return keyMarshalText("mkey:", k), nil }
|
||||||
|
func (k *MachineKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "mkey:", text) }
|
||||||
|
|
||||||
func (k MachineKey) MarshalText() ([]byte, error) {
|
func keyMarshalText(prefix string, k [32]byte) []byte {
|
||||||
buf := new(bytes.Buffer)
|
buf := bytes.NewBuffer(make([]byte, 0, len(prefix)+64))
|
||||||
fmt.Fprintf(buf, "mkey:%x", k[:])
|
fmt.Fprintf(buf, "%s%x", prefix, k[:])
|
||||||
return buf.Bytes(), nil
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *MachineKey) UnmarshalText(text []byte) error {
|
func keyUnmarshalText(dst []byte, prefix string, text []byte) error {
|
||||||
s := string(text)
|
if len(text) < len(prefix) || string(text[:len(prefix)]) != prefix {
|
||||||
if !strings.HasPrefix(s, "mkey:") {
|
return fmt.Errorf("UnmarshalText: missing %q prefix", prefix)
|
||||||
return errors.New(`MachineKey.UnmarshalText: missing prefix`)
|
|
||||||
}
|
}
|
||||||
s = strings.TrimPrefix(s, `mkey:`)
|
pub, err := key.NewPublicFromHexMem(mem.B(text[len(prefix):]))
|
||||||
key, err := wgcfg.ParseHexKey(s)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("MachineKey.UnmarhsalText: %v", err)
|
return fmt.Errorf("UnmarshalText: after %q: %v", prefix, err)
|
||||||
}
|
}
|
||||||
copy(k[:], key[:])
|
copy(dst[:], pub[:])
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k NodeKey) String() string { return fmt.Sprintf("nodekey:%x", k[:]) }
|
func (k NodeKey) ShortString() string { return (key.Public(k)).ShortString() }
|
||||||
|
|
||||||
func (k NodeKey) ShortString() string {
|
func (k NodeKey) String() string { return fmt.Sprintf("nodekey:%x", k[:]) }
|
||||||
pk := wgcfg.Key(k)
|
func (k NodeKey) MarshalText() ([]byte, error) { return keyMarshalText("nodekey:", k), nil }
|
||||||
return pk.ShortString()
|
func (k *NodeKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "nodekey:", text) }
|
||||||
}
|
|
||||||
|
|
||||||
func (k NodeKey) MarshalText() ([]byte, error) {
|
// IsZero reports whether k is the zero value.
|
||||||
buf := new(bytes.Buffer)
|
func (k NodeKey) IsZero() bool { return k == NodeKey{} }
|
||||||
fmt.Fprintf(buf, "nodekey:%x", k[:])
|
|
||||||
return buf.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *NodeKey) UnmarshalText(text []byte) error {
|
func (k DiscoKey) String() string { return fmt.Sprintf("discokey:%x", k[:]) }
|
||||||
s := string(text)
|
func (k DiscoKey) MarshalText() ([]byte, error) { return keyMarshalText("discokey:", k), nil }
|
||||||
if !strings.HasPrefix(s, "nodekey:") {
|
func (k *DiscoKey) UnmarshalText(text []byte) error { return keyUnmarshalText(k[:], "discokey:", text) }
|
||||||
return errors.New(`Nodekey.UnmarshalText: missing prefix`)
|
|
||||||
}
|
|
||||||
s = strings.TrimPrefix(s, "nodekey:")
|
|
||||||
key, err := wgcfg.ParseHexKey(s)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("tailcfg.Ukey.UnmarhsalText: %v", err)
|
|
||||||
}
|
|
||||||
copy(k[:], key[:])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsZero reports whether k is the NodeKey zero value.
|
// IsZero reports whether k is the zero value.
|
||||||
func (k NodeKey) IsZero() bool {
|
func (k DiscoKey) IsZero() bool { return k == DiscoKey{} }
|
||||||
return k == NodeKey{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (id ID) String() string { return fmt.Sprintf("id:%x", int64(id)) }
|
func (id ID) String() string { return fmt.Sprintf("id:%x", int64(id)) }
|
||||||
func (id UserID) String() string { return fmt.Sprintf("userid:%x", int64(id)) }
|
func (id UserID) String() string { return fmt.Sprintf("userid:%x", int64(id)) }
|
||||||
|
@ -593,6 +584,7 @@ func (n *Node) Equal(n2 *Node) bool {
|
||||||
n.Key == n2.Key &&
|
n.Key == n2.Key &&
|
||||||
n.KeyExpiry.Equal(n2.KeyExpiry) &&
|
n.KeyExpiry.Equal(n2.KeyExpiry) &&
|
||||||
n.Machine == n2.Machine &&
|
n.Machine == n2.Machine &&
|
||||||
|
n.DiscoKey == n2.DiscoKey &&
|
||||||
reflect.DeepEqual(n.Addresses, n2.Addresses) &&
|
reflect.DeepEqual(n.Addresses, n2.Addresses) &&
|
||||||
reflect.DeepEqual(n.AllowedIPs, n2.AllowedIPs) &&
|
reflect.DeepEqual(n.AllowedIPs, n2.AllowedIPs) &&
|
||||||
reflect.DeepEqual(n.Endpoints, n2.Endpoints) &&
|
reflect.DeepEqual(n.Endpoints, n2.Endpoints) &&
|
||||||
|
|
|
@ -5,7 +5,9 @@
|
||||||
package tailcfg
|
package tailcfg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -176,7 +178,7 @@ func TestHostinfoEqual(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNodeEqual(t *testing.T) {
|
func TestNodeEqual(t *testing.T) {
|
||||||
nodeHandles := []string{"ID", "Name", "User", "Key", "KeyExpiry", "Machine", "Addresses", "AllowedIPs", "Endpoints", "DERP", "Hostinfo", "Created", "LastSeen", "KeepAlive", "MachineAuthorized"}
|
nodeHandles := []string{"ID", "Name", "User", "Key", "KeyExpiry", "Machine", "DiscoKey", "Addresses", "AllowedIPs", "Endpoints", "DERP", "Hostinfo", "Created", "LastSeen", "KeepAlive", "MachineAuthorized"}
|
||||||
if have := fieldsOf(reflect.TypeOf(Node{})); !reflect.DeepEqual(have, nodeHandles) {
|
if have := fieldsOf(reflect.TypeOf(Node{})); !reflect.DeepEqual(have, nodeHandles) {
|
||||||
t.Errorf("Node.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
|
t.Errorf("Node.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
|
||||||
have, nodeHandles)
|
have, nodeHandles)
|
||||||
|
@ -336,3 +338,51 @@ func TestNetInfoFields(t *testing.T) {
|
||||||
have, handled)
|
have, handled)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMachineKeyMarshal(t *testing.T) {
|
||||||
|
var k1, k2 MachineKey
|
||||||
|
for i := range k1 {
|
||||||
|
k1[i] = byte(i)
|
||||||
|
}
|
||||||
|
testKey(t, "mkey:", k1, &k2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNodeKeyMarshal(t *testing.T) {
|
||||||
|
var k1, k2 NodeKey
|
||||||
|
for i := range k1 {
|
||||||
|
k1[i] = byte(i)
|
||||||
|
}
|
||||||
|
testKey(t, "nodekey:", k1, &k2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiscoKeyMarshal(t *testing.T) {
|
||||||
|
var k1, k2 DiscoKey
|
||||||
|
for i := range k1 {
|
||||||
|
k1[i] = byte(i)
|
||||||
|
}
|
||||||
|
testKey(t, "discokey:", k1, &k2)
|
||||||
|
}
|
||||||
|
|
||||||
|
type keyIn interface {
|
||||||
|
String() string
|
||||||
|
MarshalText() ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testKey(t *testing.T, prefix string, in keyIn, out encoding.TextUnmarshaler) {
|
||||||
|
got, err := in.MarshalText()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := out.UnmarshalText(got); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if s := in.String(); string(got) != s {
|
||||||
|
t.Errorf("MarshalText = %q != String %q", got, s)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(string(got), prefix) {
|
||||||
|
t.Errorf("%q didn't start with prefix %q", got, prefix)
|
||||||
|
}
|
||||||
|
if reflect.ValueOf(out).Elem().Interface() != in {
|
||||||
|
t.Errorf("mismatch after unmarshal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue