diff options
Diffstat (limited to 'hcdns/server.go')
| -rw-r--r-- | hcdns/server.go | 120 |
1 files changed, 93 insertions, 27 deletions
diff --git a/hcdns/server.go b/hcdns/server.go index ce7894b..e5a8d29 100644 --- a/hcdns/server.go +++ b/hcdns/server.go @@ -11,74 +11,139 @@ import ( const MAX_RECURSION = 15 -func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { - internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") +type DnsHandler struct { + DnsResolvers []string + DbConn *sql.DB +} + +func (h *DnsHandler) resolveExternal(domain string, qtype uint16) ([]dns.RR, error) { + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(dns.Fqdn(domain), qtype) + message.RecursionDesired = true + + if len(h.DnsResolvers) == 0 { + return []dns.RR{}, nil + } + + i := 0 + in, _, err := client.Exchange(message, h.DnsResolvers[i]) + for err != nil && i < len(h.DnsResolvers) { + i++ + in, _, err = client.Exchange(message, h.DnsResolvers[i]) + } + if err != nil { return nil, err } + if len(in.Answer) == 0 { + return nil, nil + } + + return in.Answer, nil +} + +func resultSetFound(answers []dns.RR, domain string, qtype uint16) bool { + for _, answer := range answers { + if answer.Header().Name == domain && answer.Header().Rrtype == qtype { + return true + } + } + return false +} + +func (h *DnsHandler) recursiveResolve(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) { + internalCnames, err := database.FindDNSRecords(h.DbConn, domain, "CNAME") + if err != nil { + return nil, true, err + } + + authoritative := true var answers []dns.RR for _, record := range internalCnames { cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) if err != nil { log.Println(err) - return nil, err + return nil, authoritative, err } answers = append(answers, cname) - cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) + cnameRecursive, cnameAuth, err := h.resolveDNS(record.Content, qtype, maxDepth-1) if err != nil { log.Println(err) - return nil, err + return nil, authoritative, err } + authoritative = authoritative && cnameAuth answers = append(answers, cnameRecursive...) } qtypeName := dns.TypeToString[qtype] - if qtypeName == "" { - return nil, fmt.Errorf("invalid query type %d", qtype) - } - - typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName) + records, err := database.FindDNSRecords(h.DbConn, domain, qtypeName) if err != nil { - return nil, err + return nil, authoritative, err } - for _, record := range typeDnsRecords { + + for _, record := range records { answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) if err != nil { - return nil, err + return nil, authoritative, err } answers = append(answers, answer) } - return answers, nil + return answers, authoritative, nil } -func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { +func (h *DnsHandler) resolveDNS(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) { + log.Println("resolving", domain, dns.TypeToString[qtype], maxDepth) if maxDepth == 0 { - return nil, fmt.Errorf("too much recursion") + return nil, false, fmt.Errorf("too much recursion") } - answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth) + answers, authoritative, err := h.recursiveResolve(domain, qtype, maxDepth) if err != nil { - return nil, err + return nil, false, err } - return answers, nil -} + if len(answers) > 0 { // base case - we got the answer + return answers, authoritative, nil + } -type DnsHandler struct { - DnsResolvers []string - DbConn *sql.DB + externalAnswers, err := h.resolveExternal(domain, qtype) + if err != nil { + return nil, false, err + } + + answers = append(answers, externalAnswers...) + if resultSetFound(externalAnswers, domain, qtype) { + return answers, false, nil + } + + for _, answer := range externalAnswers { + cname, ok := answer.(*dns.CNAME) + if !ok { + continue + } + + cnameAnswers, _, err := h.resolveDNS(cname.Target, qtype, maxDepth-1) + if err != nil { + return nil, false, err + } + answers = append(answers, cnameAnswers...) + } + + return answers, false, nil } func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - msg := new(dns.Msg) + msg := &dns.Msg{} msg.SetReply(r) - msg.Authoritative = true + msg.Authoritative = false for _, question := range r.Question { - answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) + answers, authoritative, err := h.resolveDNS(question.Name, question.Qtype, MAX_RECURSION) + msg.Authoritative = authoritative if err != nil { fmt.Println(err) msg.SetRcode(r, dns.RcodeServerFailure) @@ -98,7 +163,8 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { handler := &DnsHandler{ - DbConn: dbConn, + DbConn: dbConn, + DnsResolvers: argv.DnsResolvers, } addr := fmt.Sprintf(":%d", argv.DnsPort) |
