cmd/sniproxy: reimplement DNS server
Switch to golang.org/x/net/dns/dnsmessage, due to https://github.com/miekg/dns/issues/1427 Updates https://github.com/tailscale/tailscale/issues/1748 Signed-off-by: Denton Gentry <dgentry@tailscale.com>dgentry/sniproxy-dns
parent
5b4a35e7f1
commit
0c4fd8ef3a
|
@ -9,14 +9,13 @@ package main
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"golang.org/x/net/dns/dnsmessage"
|
||||||
"inet.af/tcpproxy"
|
"inet.af/tcpproxy"
|
||||||
"tailscale.com/client/tailscale"
|
"tailscale.com/client/tailscale"
|
||||||
"tailscale.com/net/netutil"
|
"tailscale.com/net/netutil"
|
||||||
|
@ -26,6 +25,7 @@ import (
|
||||||
var (
|
var (
|
||||||
ports = flag.String("ports", "443", "comma-separated list of ports to proxy")
|
ports = flag.String("ports", "443", "comma-separated list of ports to proxy")
|
||||||
dnsserv = flag.Bool("dns", true, "run a small DNS server to reply to any query with its own address")
|
dnsserv = flag.Bool("dns", true, "run a small DNS server to reply to any query with its own address")
|
||||||
|
tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -113,46 +113,100 @@ func (s *server) getAddresses() (ip4, ip6 netip.Addr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) serveDns() {
|
func (s *server) serveDns() {
|
||||||
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
|
buf := make([]byte, 1024)
|
||||||
switch r.Opcode {
|
|
||||||
case dns.OpcodeQuery:
|
|
||||||
m := s.dnsResponse(r)
|
|
||||||
m.SetReply(r)
|
|
||||||
w.WriteMsg(m)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
pconn, err := s.ts.ListenPacket("udp", ":53")
|
pconn, err := s.ts.ListenPacket("udp", ":53")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to start DNS listener: %s\n ", err.Error())
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
_, addr, err := pconn.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("pconn.ReadFrom failed: %v\n ", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg dnsmessage.Message
|
||||||
|
err = msg.Unpack(buf)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("dnsmessage.Message unpack failed: %v\n ", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err := s.dnsResponse(&msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("s.dnsResponse failed: %v\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = pconn.WriteTo(buf, addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("pconn.WriteTo failed: %v\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) dnsResponse(req *dnsmessage.Message) (buf []byte, err error) {
|
||||||
|
resp := dnsmessage.NewBuilder(buf,
|
||||||
|
dnsmessage.Header{
|
||||||
|
ID: req.Header.ID,
|
||||||
|
Response: true,
|
||||||
|
Authoritative: true,
|
||||||
|
})
|
||||||
|
resp.EnableCompression()
|
||||||
|
|
||||||
|
if len(req.Questions) == 0 {
|
||||||
|
buf, _ = resp.Finish()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer := &dns.Server{PacketConn: pconn}
|
q := req.Questions[0]
|
||||||
err = dnsServer.ActivateAndServe()
|
err = resp.StartQuestions()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to start DNS server: %s\n ", err.Error())
|
return
|
||||||
}
|
}
|
||||||
}
|
resp.Question(q)
|
||||||
|
|
||||||
func (s *server) dnsResponse(requestMsg *dns.Msg) *dns.Msg {
|
|
||||||
responseMsg := new(dns.Msg)
|
|
||||||
if len(requestMsg.Question) == 0 {
|
|
||||||
return responseMsg
|
|
||||||
}
|
|
||||||
|
|
||||||
q := requestMsg.Question[0]
|
|
||||||
var rr dns.RR
|
|
||||||
ip4, ip6 := s.getAddresses()
|
ip4, ip6 := s.getAddresses()
|
||||||
|
err = resp.StartAnswers()
|
||||||
switch q.Qtype {
|
if err != nil {
|
||||||
case dns.TypeAAAA:
|
return
|
||||||
rr, _ = dns.NewRR(fmt.Sprintf("%s 120 IN AAAA %s", q.Name, ip6.String()))
|
|
||||||
|
|
||||||
case dns.TypeA:
|
|
||||||
rr, _ = dns.NewRR(fmt.Sprintf("%s 120 IN A %s", q.Name, ip4.String()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
responseMsg.Answer = append(responseMsg.Answer, rr)
|
switch q.Type {
|
||||||
return responseMsg
|
case dnsmessage.TypeAAAA:
|
||||||
|
err = resp.AAAAResource(
|
||||||
|
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
||||||
|
dnsmessage.AAAAResource{AAAA: ip6.As16()},
|
||||||
|
)
|
||||||
|
|
||||||
|
case dnsmessage.TypeA:
|
||||||
|
err = resp.AResource(
|
||||||
|
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
||||||
|
dnsmessage.AResource{A: ip4.As4()},
|
||||||
|
)
|
||||||
|
case dnsmessage.TypeSOA:
|
||||||
|
err = resp.SOAResource(
|
||||||
|
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
||||||
|
dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
|
||||||
|
Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
|
||||||
|
)
|
||||||
|
case dnsmessage.TypeNS:
|
||||||
|
err = resp.NSResource(
|
||||||
|
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
||||||
|
dnsmessage.NSResource{NS: tsMBox},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err = resp.Finish()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue