From 60fc4ebb599d82f5c7ddaca52f8aba74f0876381 Mon Sep 17 00:00:00 2001 From: simponic Date: Thu, 28 Mar 2024 16:58:07 -0400 Subject: internal recursive dns server (#2) Co-authored-by: Lizzy Hunt Reviewed-on: https://git.hatecomputers.club/hatecomputers/hatecomputers.club/pulls/2 --- dns/server.go | 110 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 dns/server.go (limited to 'dns/server.go') diff --git a/dns/server.go b/dns/server.go new file mode 100644 index 0000000..63bb067 --- /dev/null +++ b/dns/server.go @@ -0,0 +1,110 @@ +package dns + +import ( + "database/sql" + "fmt" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "github.com/miekg/dns" + "log" +) + +const MAX_RECURSION = 10 + +func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { + if maxDepth == 0 { + return nil, fmt.Errorf("too much recursion") + } + + internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") + if err != nil { + return nil, err + } + + answers := []dns.RR{} + for _, record := range internalCnames { + cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1) + answers = append(answers, cnameRecursive...) + } + + qtypeName := dns.TypeToString[qtype] + if qtypeName == "" { + return nil, fmt.Errorf("invalid query type %d", qtype) + } + + typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName) + if err != nil { + return nil, err + } + for _, record := range typeDnsRecords { + answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content)) + if err != nil { + return nil, err + } + answers = append(answers, answer) + } + + if len(answers) > 0 { + // base case; we found the answer + return answers, nil + } + + message := new(dns.Msg) + message.SetQuestion(dns.Fqdn(domain), qtype) + message.RecursionDesired = true + + client := new(dns.Client) + + i := 0 + in, _, err := client.Exchange(message, dnsResolvers[i]) + for err != nil { + i += 1 + if i == len(dnsResolvers) { + log.Println(err) + return nil, err + } + in, _, err = client.Exchange(message, dnsResolvers[i]) + } + + answers = append(answers, in.Answer...) + return answers, nil +} + +type DnsHandler struct { + DnsResolvers []string + DbConn *sql.DB +} + +func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + msg := new(dns.Msg) + msg.SetReply(r) + msg.Authoritative = true + + for _, question := range r.Question { + answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION) + if err != nil { + fmt.Println(err) + continue + } + msg.Answer = append(msg.Answer, answers...) + } + + log.Println(msg.Answer) + w.WriteMsg(msg) +} + +func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { + handler := &DnsHandler{ + DnsResolvers: argv.DnsRecursion, + DbConn: dbConn, + } + addr := fmt.Sprintf(":%d", argv.DnsPort) + + return &dns.Server{ + Addr: addr, + Net: "udp", + Handler: handler, + UDPSize: 65535, + ReusePort: true, + } +} -- cgit v1.2.3-70-g09d2