summaryrefslogtreecommitdiff
path: root/dns
diff options
context:
space:
mode:
Diffstat (limited to 'dns')
-rw-r--r--dns/server.go53
1 files changed, 25 insertions, 28 deletions
diff --git a/dns/server.go b/dns/server.go
index f5365e8..9b3e5e9 100644
--- a/dns/server.go
+++ b/dns/server.go
@@ -11,17 +11,13 @@ import (
const MAX_RECURSION = 10
-func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
- if maxDepth == 0 {
- return nil, fmt.Errorf("too much recursion")
- }
-
+func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
if err != nil {
return nil, err
}
- answers := []dns.RR{}
+ var answers []dns.RR
for _, record := range internalCnames {
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
if err != nil {
@@ -29,7 +25,10 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp
}
answers = append(answers, cname)
- cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1)
+ cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
+ if err != nil {
+ return nil, err
+ }
answers = append(answers, cnameRecursive...)
}
@@ -43,37 +42,31 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp
return nil, err
}
for _, record := range typeDnsRecords {
- answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content))
+ answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
if err != nil {
return nil, err
}
answers = append(answers, answer)
}
- if len(answers) > 0 {
- // base case; we found the answer
- return answers, nil
- }
+ return answers, nil
+}
- message := new(dns.Msg)
- message.SetQuestion(dns.Fqdn(domain), qtype)
- message.RecursionDesired = true
+func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
+ if maxDepth == 0 {
+ return nil, fmt.Errorf("too much recursion")
+ }
- client := new(dns.Client)
+ answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
+ if err != nil {
+ return nil, err
+ }
- i := 0
- in, _, err := client.Exchange(message, dnsResolvers[i])
- for err != nil {
- i += 1
- if i == len(dnsResolvers) {
- log.Println(err)
- return nil, err
- }
- in, _, err = client.Exchange(message, dnsResolvers[i])
+ if len(answers) > 0 {
+ return answers, nil
}
- answers = append(answers, in.Answer...)
- return answers, nil
+ return nil, fmt.Errorf("no records found for %s", domain)
}
type DnsHandler struct {
@@ -87,7 +80,7 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
msg.Authoritative = true
for _, question := range r.Question {
- answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION)
+ answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
if err != nil {
fmt.Println(err)
continue
@@ -95,6 +88,10 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
msg.Answer = append(msg.Answer, answers...)
}
+ if len(msg.Answer) == 0 {
+ msg.SetRcode(r, dns.RcodeNameError)
+ }
+
log.Println(msg.Answer)
w.WriteMsg(msg)
}