tailscale/wgengine/wgcfg/parser.go

198 lines
4.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package wgcfg
import (
"bufio"
"encoding/hex"
"fmt"
"io"
"net"
"strconv"
"strings"
"inet.af/netaddr"
)
type ParseError struct {
why string
offender string
}
func (e *ParseError) Error() string {
return fmt.Sprintf("%s: %s", e.why, e.offender)
}
func validateEndpoints(s string) error {
vals := strings.Split(s, ",")
for _, val := range vals {
_, _, err := parseEndpoint(val)
if err != nil {
return err
}
}
return nil
}
func parseEndpoint(s string) (host string, port uint16, err error) {
i := strings.LastIndexByte(s, ':')
if i < 0 {
return "", 0, &ParseError{"Missing port from endpoint", s}
}
host, portStr := s[:i], s[i+1:]
if len(host) < 1 {
return "", 0, &ParseError{"Invalid endpoint host", host}
}
uport, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return "", 0, err
}
hostColon := strings.IndexByte(host, ':')
if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 {
err := &ParseError{"Brackets must contain an IPv6 address", host}
if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 {
maybeV6 := net.ParseIP(host[1 : len(host)-1])
if maybeV6 == nil || len(maybeV6) != net.IPv6len {
return "", 0, err
}
} else {
return "", 0, err
}
host = host[1 : len(host)-1]
}
return host, uint16(uport), nil
}
func parseKeyHex(s string) (*Key, error) {
k, err := hex.DecodeString(s)
if err != nil {
return nil, &ParseError{"Invalid key: " + err.Error(), s}
}
if len(k) != KeySize {
return nil, &ParseError{"Keys must decode to exactly 32 bytes", s}
}
var key Key
copy(key[:], k)
return &key, nil
}
// FromUAPI generates a Config from r.
// r should be generated by calling device.IpcGetOperation;
// it is not compatible with other uapi streams.
func FromUAPI(r io.Reader) (*Config, error) {
cfg := new(Config)
var peer *Peer // current peer being operated on
deviceConfig := true
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
continue
}
parts := strings.Split(line, "=")
if len(parts) != 2 {
return nil, fmt.Errorf("failed to parse line %q, found %d =-separated parts, want 2", line, len(parts))
}
key := parts[0]
value := parts[1]
if key == "public_key" {
if deviceConfig {
deviceConfig = false
}
// Load/create the peer we are now configuring.
var err error
peer, err = cfg.handlePublicKeyLine(value)
if err != nil {
return nil, err
}
continue
}
var err error
if deviceConfig {
err = cfg.handleDeviceLine(key, value)
} else {
err = cfg.handlePeerLine(peer, key, value)
}
if err != nil {
return nil, err
}
}
if err := scanner.Err(); err != nil {
return nil, err
}
return cfg, nil
}
func (cfg *Config) handleDeviceLine(key, value string) error {
switch key {
case "private_key":
k, err := parseKeyHex(value)
if err != nil {
return err
}
// wireguard-go guarantees not to send zero value; private keys are already clamped.
cfg.PrivateKey = PrivateKey(*k)
case "listen_port":
port, err := strconv.ParseUint(value, 10, 16)
if err != nil {
return fmt.Errorf("failed to parse listen_port: %w", err)
}
cfg.ListenPort = uint16(port)
case "fwmark":
// ignore
default:
return fmt.Errorf("unexpected IpcGetOperation key: %v", key)
}
return nil
}
func (cfg *Config) handlePublicKeyLine(value string) (*Peer, error) {
k, err := parseKeyHex(value)
if err != nil {
return nil, err
}
cfg.Peers = append(cfg.Peers, Peer{})
peer := &cfg.Peers[len(cfg.Peers)-1]
peer.PublicKey = *k
return peer, nil
}
func (cfg *Config) handlePeerLine(peer *Peer, key, value string) error {
switch key {
case "endpoint":
err := validateEndpoints(value)
if err != nil {
return err
}
peer.Endpoints = value
case "persistent_keepalive_interval":
n, err := strconv.ParseUint(value, 10, 16)
if err != nil {
return err
}
peer.PersistentKeepalive = uint16(n)
case "allowed_ip":
ipp, err := netaddr.ParseIPPrefix(value)
if err != nil {
return err
}
peer.AllowedIPs = append(peer.AllowedIPs, ipp)
case "protocol_version":
if value != "1" {
return fmt.Errorf("invalid protocol version: %v", value)
}
case "preshared_key", "last_handshake_time_sec", "last_handshake_time_nsec", "tx_bytes", "rx_bytes":
// ignore
default:
return fmt.Errorf("unexpected IpcGetOperation key: %v", key)
}
return nil
}