summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Dockerfile2
-rw-r--r--args/args.go7
-rw-r--r--dns/server.go53
3 files changed, 28 insertions, 34 deletions
diff --git a/Dockerfile b/Dockerfile
index a46f6c4..591423f 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers
EXPOSE 8080
-CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-recursion", "1.1.1.1:53,1.0.0.1:53"]
+CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"]
diff --git a/args/args.go b/args/args.go
index 40dd1af..f71e8e3 100644
--- a/args/args.go
+++ b/args/args.go
@@ -22,9 +22,8 @@ type Arguments struct {
OauthConfig *oauth2.Config
OauthUserInfoURI string
- Dns bool
- DnsRecursion []string
- DnsPort int
+ Dns bool
+ DnsPort int
CloudflareToken string
CloudflareZone string
@@ -45,7 +44,6 @@ func GetArgs() (*Arguments, error) {
server := flag.Bool("server", false, "Run the server")
dns := flag.Bool("dns", false, "Run DNS resolver")
- dnsRecursion := flag.String("dns-recursion", "1.1.1.1:53,1.0.0.1:53", "Comma separated list of DNS resolvers")
dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver")
flag.Parse()
@@ -104,7 +102,6 @@ func GetArgs() (*Arguments, error) {
Migrate: *migrate,
Scheduler: *scheduler,
Dns: *dns,
- DnsRecursion: strings.Split(*dnsRecursion, ","),
DnsPort: *dnsPort,
OauthConfig: oauthConfig,
diff --git a/dns/server.go b/dns/server.go
index f5365e8..9b3e5e9 100644
--- a/dns/server.go
+++ b/dns/server.go
@@ -11,17 +11,13 @@ import (
const MAX_RECURSION = 10
-func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
- if maxDepth == 0 {
- return nil, fmt.Errorf("too much recursion")
- }
-
+func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
if err != nil {
return nil, err
}
- answers := []dns.RR{}
+ var answers []dns.RR
for _, record := range internalCnames {
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
if err != nil {
@@ -29,7 +25,10 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp
}
answers = append(answers, cname)
- cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1)
+ cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
+ if err != nil {
+ return nil, err
+ }
answers = append(answers, cnameRecursive...)
}
@@ -43,37 +42,31 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp
return nil, err
}
for _, record := range typeDnsRecords {
- answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content))
+ answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
if err != nil {
return nil, err
}
answers = append(answers, answer)
}
- if len(answers) > 0 {
- // base case; we found the answer
- return answers, nil
- }
+ return answers, nil
+}
- message := new(dns.Msg)
- message.SetQuestion(dns.Fqdn(domain), qtype)
- message.RecursionDesired = true
+func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
+ if maxDepth == 0 {
+ return nil, fmt.Errorf("too much recursion")
+ }
- client := new(dns.Client)
+ answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
+ if err != nil {
+ return nil, err
+ }
- i := 0
- in, _, err := client.Exchange(message, dnsResolvers[i])
- for err != nil {
- i += 1
- if i == len(dnsResolvers) {
- log.Println(err)
- return nil, err
- }
- in, _, err = client.Exchange(message, dnsResolvers[i])
+ if len(answers) > 0 {
+ return answers, nil
}
- answers = append(answers, in.Answer...)
- return answers, nil
+ return nil, fmt.Errorf("no records found for %s", domain)
}
type DnsHandler struct {
@@ -87,7 +80,7 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
msg.Authoritative = true
for _, question := range r.Question {
- answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION)
+ answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
if err != nil {
fmt.Println(err)
continue
@@ -95,6 +88,10 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
msg.Answer = append(msg.Answer, answers...)
}
+ if len(msg.Answer) == 0 {
+ msg.SetRcode(r, dns.RcodeNameError)
+ }
+
log.Println(msg.Answer)
w.WriteMsg(msg)
}