summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--hcdns/server.go (renamed from dns/server.go)19
-rw-r--r--main.go4
-rw-r--r--test/dns_test.go244
3 files changed, 255 insertions, 12 deletions
diff --git a/dns/server.go b/hcdns/server.go
index 9b3e5e9..ce7894b 100644
--- a/dns/server.go
+++ b/hcdns/server.go
@@ -1,4 +1,4 @@
-package dns
+package hcdns
import (
"database/sql"
@@ -9,7 +9,7 @@ import (
"log"
)
-const MAX_RECURSION = 10
+const MAX_RECURSION = 15
func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
@@ -21,12 +21,14 @@ func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth
for _, record := range internalCnames {
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
if err != nil {
+ log.Println(err)
return nil, err
}
answers = append(answers, cname)
cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
if err != nil {
+ log.Println(err)
return nil, err
}
answers = append(answers, cnameRecursive...)
@@ -62,11 +64,7 @@ func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dn
return nil, err
}
- if len(answers) > 0 {
- return answers, nil
- }
-
- return nil, fmt.Errorf("no records found for %s", domain)
+ return answers, nil
}
type DnsHandler struct {
@@ -83,7 +81,9 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
if err != nil {
fmt.Println(err)
- continue
+ msg.SetRcode(r, dns.RcodeServerFailure)
+ w.WriteMsg(msg)
+ return
}
msg.Answer = append(msg.Answer, answers...)
}
@@ -98,8 +98,7 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
handler := &DnsHandler{
- DnsResolvers: argv.DnsRecursion,
- DbConn: dbConn,
+ DbConn: dbConn,
}
addr := fmt.Sprintf(":%d", argv.DnsPort)
diff --git a/main.go b/main.go
index 2991821..e0f3e55 100644
--- a/main.go
+++ b/main.go
@@ -6,7 +6,7 @@ import (
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
- "git.hatecomputers.club/hatecomputers/hatecomputers.club/dns"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler"
"github.com/joho/godotenv"
)
@@ -52,7 +52,7 @@ func main() {
}
if argv.Dns {
- server := dns.MakeServer(argv, dbConn)
+ server := hcdns.MakeServer(argv, dbConn)
log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort)
go func() {
err = server.ListenAndServe()
diff --git a/test/dns_test.go b/test/dns_test.go
new file mode 100644
index 0000000..ce6deb5
--- /dev/null
+++ b/test/dns_test.go
@@ -0,0 +1,244 @@
+package hcdns
+
+import (
+ "database/sql"
+ "os"
+ "sync"
+ "testing"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
+ "github.com/miekg/dns"
+)
+
+const (
+ testDBPath = "test.db"
+ address = "127.0.0.1:8353"
+ dnsPort = 8353
+)
+
+func setup(dbPath string) (*sql.DB, *dns.Server, *sync.WaitGroup) {
+ testDb := database.MakeConn(&dbPath)
+ database.Migrate(testDb)
+ testUser := &database.User{
+ ID: "test",
+ }
+ database.FindOrSaveUser(testDb, testUser)
+
+ server := hcdns.MakeServer(&args.Arguments{
+ DnsPort: dnsPort,
+ }, testDb)
+
+ waitGroup := sync.WaitGroup{}
+ waitGroup.Add(1)
+ go func() {
+ server.ListenAndServe()
+ waitGroup.Done()
+ }()
+
+ return testDb, server, &waitGroup
+}
+
+func destroy(conn *sql.DB, path string) {
+ conn.Close()
+ os.Remove(path)
+}
+
+func TestWhenCNAMEIsResolved(t *testing.T) {
+ t.Log("TestWhenCNAMEIsResolved")
+
+ testDb, server, _ := setup(testDBPath)
+ defer destroy(testDb, testDBPath)
+ defer server.Shutdown()
+
+ cname := &database.DNSRecord{
+ ID: "1",
+ UserID: "test",
+ Name: "cname.internal.example.com.",
+ Type: "CNAME",
+ Content: "res.example.com.",
+ TTL: 300,
+ Internal: true,
+ }
+ a := &database.DNSRecord{
+ ID: "2",
+ UserID: "test",
+ Name: "res.example.com.",
+ Type: "A",
+ Content: "127.0.0.1",
+ TTL: 300,
+ Internal: true,
+ }
+ database.SaveDNSRecord(testDb, cname)
+ database.SaveDNSRecord(testDb, a)
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn(cname.Name)
+ client := new(dns.Client)
+ message := new(dns.Msg)
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, address)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) != 2 {
+ t.Fatalf("expected 2 answers, got %d", len(in.Answer))
+ }
+
+ if in.Answer[0].Header().Name != cname.Name {
+ t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name)
+ }
+
+ if in.Answer[1].Header().Name != a.Name {
+ t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name)
+ }
+
+ if in.Answer[0].(*dns.CNAME).Target != a.Name {
+ t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target)
+ }
+
+ if in.Answer[1].(*dns.A).A.String() != a.Content {
+ t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String())
+ }
+
+ if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
+ t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
+ }
+
+ if in.Answer[1].Header().Rrtype != dns.TypeA {
+ t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype)
+ }
+
+ if int(in.Answer[0].Header().Ttl) != cname.TTL {
+ t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl)
+ }
+
+ if !in.Authoritative {
+ t.Fatalf("expected authoritative response")
+ }
+}
+
+func TestWhenNoRecordNxDomain(t *testing.T) {
+ t.Log("TestWhenNoRecordNxDomain")
+
+ testDb, server, _ := setup(testDBPath)
+ defer destroy(testDb, testDBPath)
+ defer server.Shutdown()
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn("nonexistant.example.com.")
+ client := new(dns.Client)
+ message := new(dns.Msg)
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, address)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) != 0 {
+ t.Fatalf("expected 0 answers, got %d", len(in.Answer))
+ }
+
+ if in.Rcode != dns.RcodeNameError {
+ t.Fatalf("expected NXDOMAIN, got %d", in.Rcode)
+ }
+}
+
+func TestWhenUnresolvingCNAME(t *testing.T) {
+ t.Log("TestWhenUnresolvingCNAME")
+
+ testDb, server, _ := setup(testDBPath)
+ defer destroy(testDb, testDBPath)
+ defer server.Shutdown()
+
+ cname := &database.DNSRecord{
+ ID: "1",
+ UserID: "test",
+ Name: "cname.internal.example.com.",
+ Type: "CNAME",
+ Content: "nonexistant.example.com.",
+ TTL: 300,
+ Internal: true,
+ }
+ database.SaveDNSRecord(testDb, cname)
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn(cname.Name)
+ client := new(dns.Client)
+ message := new(dns.Msg)
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, address)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) != 1 {
+ t.Fatalf("expected 1 answer, got %d", len(in.Answer))
+ }
+
+ if !in.Authoritative {
+ t.Fatalf("expected authoritative response")
+ }
+
+ if in.Answer[0].Header().Name != cname.Name {
+ t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name)
+ }
+
+ if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
+ t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
+ }
+
+ if in.Answer[0].(*dns.CNAME).Target != cname.Content {
+ t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target)
+ }
+
+ if in.Rcode == dns.RcodeNameError {
+ t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode)
+ }
+}
+
+func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
+ t.Log("TestWhenUnresolvingCNAMEWithMaxDepth")
+
+ testDb, server, _ := setup(testDBPath)
+ defer destroy(testDb, testDBPath)
+ defer server.Shutdown()
+
+ cname := &database.DNSRecord{
+ ID: "1",
+ UserID: "test",
+ Name: "cname.internal.example.com.",
+ Type: "CNAME",
+ Content: "cname.internal.example.com.",
+ TTL: 300,
+ Internal: true,
+ }
+ database.SaveDNSRecord(testDb, cname)
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn(cname.Name)
+ client := new(dns.Client)
+ message := new(dns.Msg)
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, address)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) > 0 {
+ t.Fatalf("expected 0 answers, got %d", len(in.Answer))
+ }
+ if in.Rcode != dns.RcodeServerFailure {
+ t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
+ }
+}