From 0c4fd8ef3aac958d4f8e80bebdba689b926d54c4 Mon Sep 17 00:00:00 2001 From: Denton Gentry Date: Sun, 5 Mar 2023 09:21:40 -0800 Subject: [PATCH] cmd/sniproxy: reimplement DNS server Switch to golang.org/x/net/dns/dnsmessage, due to https://github.com/miekg/dns/issues/1427 Updates https://github.com/tailscale/tailscale/issues/1748 Signed-off-by: Denton Gentry --- cmd/sniproxy/snipproxy.go | 120 +++++++++++++++++++++++++++----------- 1 file changed, 87 insertions(+), 33 deletions(-) diff --git a/cmd/sniproxy/snipproxy.go b/cmd/sniproxy/snipproxy.go index 34a2a7fbb..9672551cd 100644 --- a/cmd/sniproxy/snipproxy.go +++ b/cmd/sniproxy/snipproxy.go @@ -9,14 +9,13 @@ package main import ( "context" "flag" - "fmt" "log" "net" "net/netip" "strings" "time" - "github.com/miekg/dns" + "golang.org/x/net/dns/dnsmessage" "inet.af/tcpproxy" "tailscale.com/client/tailscale" "tailscale.com/net/netutil" @@ -26,6 +25,7 @@ import ( var ( ports = flag.String("ports", "443", "comma-separated list of ports to proxy") dnsserv = flag.Bool("dns", true, "run a small DNS server to reply to any query with its own address") + tsMBox = dnsmessage.MustNewName("support.tailscale.com.") ) func main() { @@ -113,46 +113,100 @@ func (s *server) getAddresses() (ip4, ip6 netip.Addr) { } func (s *server) serveDns() { - dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { - switch r.Opcode { - case dns.OpcodeQuery: - m := s.dnsResponse(r) - m.SetReply(r) - w.WriteMsg(m) - } - }) - + buf := make([]byte, 1024) pconn, err := s.ts.ListenPacket("udp", ":53") if err != nil { - log.Printf("Failed to start DNS listener: %s\n ", err.Error()) + log.Fatal(err) + } + + for { + _, addr, err := pconn.ReadFrom(buf) + if err != nil { + log.Printf("pconn.ReadFrom failed: %v\n ", err) + continue + } + + var msg dnsmessage.Message + err = msg.Unpack(buf) + if err != nil { + log.Printf("dnsmessage.Message unpack failed: %v\n ", err) + continue + } + + buf, err := s.dnsResponse(&msg) + if err != nil { + log.Printf("s.dnsResponse failed: %v\n", err) + continue + } + + _, err = pconn.WriteTo(buf, addr) + if err != nil { + log.Printf("pconn.WriteTo failed: %v\n", err) + continue + } + } +} + +func (s *server) dnsResponse(req *dnsmessage.Message) (buf []byte, err error) { + resp := dnsmessage.NewBuilder(buf, + dnsmessage.Header{ + ID: req.Header.ID, + Response: true, + Authoritative: true, + }) + resp.EnableCompression() + + if len(req.Questions) == 0 { + buf, _ = resp.Finish() return } - dnsServer := &dns.Server{PacketConn: pconn} - err = dnsServer.ActivateAndServe() + q := req.Questions[0] + err = resp.StartQuestions() if err != nil { - log.Printf("Failed to start DNS server: %s\n ", err.Error()) + return } -} + resp.Question(q) -func (s *server) dnsResponse(requestMsg *dns.Msg) *dns.Msg { - responseMsg := new(dns.Msg) - if len(requestMsg.Question) == 0 { - return responseMsg - } - - q := requestMsg.Question[0] - var rr dns.RR ip4, ip6 := s.getAddresses() - - switch q.Qtype { - case dns.TypeAAAA: - rr, _ = dns.NewRR(fmt.Sprintf("%s 120 IN AAAA %s", q.Name, ip6.String())) - - case dns.TypeA: - rr, _ = dns.NewRR(fmt.Sprintf("%s 120 IN A %s", q.Name, ip4.String())) + err = resp.StartAnswers() + if err != nil { + return } - responseMsg.Answer = append(responseMsg.Answer, rr) - return responseMsg + switch q.Type { + case dnsmessage.TypeAAAA: + err = resp.AAAAResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AAAAResource{AAAA: ip6.As16()}, + ) + + case dnsmessage.TypeA: + err = resp.AResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AResource{A: ip4.As4()}, + ) + case dnsmessage.TypeSOA: + err = resp.SOAResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, + Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, + ) + case dnsmessage.TypeNS: + err = resp.NSResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.NSResource{NS: tsMBox}, + ) + } + + if err != nil { + return + } + + buf, err = resp.Finish() + if err != nil { + return + } + + return }