summaryrefslogtreecommitdiff
path: root/hcdns/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'hcdns/server.go')
-rw-r--r--hcdns/server.go120
1 files changed, 93 insertions, 27 deletions
diff --git a/hcdns/server.go b/hcdns/server.go
index ce7894b..e5a8d29 100644
--- a/hcdns/server.go
+++ b/hcdns/server.go
@@ -11,74 +11,139 @@ import (
const MAX_RECURSION = 15
-func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
- internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
+type DnsHandler struct {
+ DnsResolvers []string
+ DbConn *sql.DB
+}
+
+func (h *DnsHandler) resolveExternal(domain string, qtype uint16) ([]dns.RR, error) {
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(dns.Fqdn(domain), qtype)
+ message.RecursionDesired = true
+
+ if len(h.DnsResolvers) == 0 {
+ return []dns.RR{}, nil
+ }
+
+ i := 0
+ in, _, err := client.Exchange(message, h.DnsResolvers[i])
+ for err != nil && i < len(h.DnsResolvers) {
+ i++
+ in, _, err = client.Exchange(message, h.DnsResolvers[i])
+ }
+
if err != nil {
return nil, err
}
+ if len(in.Answer) == 0 {
+ return nil, nil
+ }
+
+ return in.Answer, nil
+}
+
+func resultSetFound(answers []dns.RR, domain string, qtype uint16) bool {
+ for _, answer := range answers {
+ if answer.Header().Name == domain && answer.Header().Rrtype == qtype {
+ return true
+ }
+ }
+ return false
+}
+
+func (h *DnsHandler) recursiveResolve(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) {
+ internalCnames, err := database.FindDNSRecords(h.DbConn, domain, "CNAME")
+ if err != nil {
+ return nil, true, err
+ }
+
+ authoritative := true
var answers []dns.RR
for _, record := range internalCnames {
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
if err != nil {
log.Println(err)
- return nil, err
+ return nil, authoritative, err
}
answers = append(answers, cname)
- cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
+ cnameRecursive, cnameAuth, err := h.resolveDNS(record.Content, qtype, maxDepth-1)
if err != nil {
log.Println(err)
- return nil, err
+ return nil, authoritative, err
}
+ authoritative = authoritative && cnameAuth
answers = append(answers, cnameRecursive...)
}
qtypeName := dns.TypeToString[qtype]
- if qtypeName == "" {
- return nil, fmt.Errorf("invalid query type %d", qtype)
- }
-
- typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName)
+ records, err := database.FindDNSRecords(h.DbConn, domain, qtypeName)
if err != nil {
- return nil, err
+ return nil, authoritative, err
}
- for _, record := range typeDnsRecords {
+
+ for _, record := range records {
answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
if err != nil {
- return nil, err
+ return nil, authoritative, err
}
answers = append(answers, answer)
}
- return answers, nil
+ return answers, authoritative, nil
}
-func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
+func (h *DnsHandler) resolveDNS(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) {
+ log.Println("resolving", domain, dns.TypeToString[qtype], maxDepth)
if maxDepth == 0 {
- return nil, fmt.Errorf("too much recursion")
+ return nil, false, fmt.Errorf("too much recursion")
}
- answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
+ answers, authoritative, err := h.recursiveResolve(domain, qtype, maxDepth)
if err != nil {
- return nil, err
+ return nil, false, err
}
- return answers, nil
-}
+ if len(answers) > 0 { // base case - we got the answer
+ return answers, authoritative, nil
+ }
-type DnsHandler struct {
- DnsResolvers []string
- DbConn *sql.DB
+ externalAnswers, err := h.resolveExternal(domain, qtype)
+ if err != nil {
+ return nil, false, err
+ }
+
+ answers = append(answers, externalAnswers...)
+ if resultSetFound(externalAnswers, domain, qtype) {
+ return answers, false, nil
+ }
+
+ for _, answer := range externalAnswers {
+ cname, ok := answer.(*dns.CNAME)
+ if !ok {
+ continue
+ }
+
+ cnameAnswers, _, err := h.resolveDNS(cname.Target, qtype, maxDepth-1)
+ if err != nil {
+ return nil, false, err
+ }
+ answers = append(answers, cnameAnswers...)
+ }
+
+ return answers, false, nil
}
func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
- msg := new(dns.Msg)
+ msg := &dns.Msg{}
msg.SetReply(r)
- msg.Authoritative = true
+ msg.Authoritative = false
for _, question := range r.Question {
- answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
+ answers, authoritative, err := h.resolveDNS(question.Name, question.Qtype, MAX_RECURSION)
+ msg.Authoritative = authoritative
if err != nil {
fmt.Println(err)
msg.SetRcode(r, dns.RcodeServerFailure)
@@ -98,7 +163,8 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
handler := &DnsHandler{
- DbConn: dbConn,
+ DbConn: dbConn,
+ DnsResolvers: argv.DnsResolvers,
}
addr := fmt.Sprintf(":%d", argv.DnsPort)