From bcdcc508ef4a0ae646937c91d0994a90bde719e1 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Tue, 2 Apr 2024 16:26:39 -0600 Subject: add integration tests for dns server --- hcdns/server.go | 112 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 hcdns/server.go (limited to 'hcdns') diff --git a/hcdns/server.go b/hcdns/server.go new file mode 100644 index 0000000..ce7894b --- /dev/null +++ b/hcdns/server.go @@ -0,0 +1,112 @@ +package hcdns + +import ( + "database/sql" + "fmt" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "github.com/miekg/dns" + "log" +) + +const MAX_RECURSION = 15 + +func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { + internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") + if err != nil { + return nil, err + } + + var answers []dns.RR + for _, record := range internalCnames { + cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) + if err != nil { + log.Println(err) + return nil, err + } + answers = append(answers, cname) + + cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) + if err != nil { + log.Println(err) + return nil, err + } + answers = append(answers, cnameRecursive...) + } + + qtypeName := dns.TypeToString[qtype] + if qtypeName == "" { + return nil, fmt.Errorf("invalid query type %d", qtype) + } + + typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName) + if err != nil { + return nil, err + } + for _, record := range typeDnsRecords { + answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) + if err != nil { + return nil, err + } + answers = append(answers, answer) + } + + return answers, nil +} + +func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { + if maxDepth == 0 { + return nil, fmt.Errorf("too much recursion") + } + + answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth) + if err != nil { + return nil, err + } + + return answers, nil +} + +type DnsHandler struct { + DnsResolvers []string + DbConn *sql.DB +} + +func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + msg := new(dns.Msg) + msg.SetReply(r) + msg.Authoritative = true + + for _, question := range r.Question { + answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) + if err != nil { + fmt.Println(err) + msg.SetRcode(r, dns.RcodeServerFailure) + w.WriteMsg(msg) + return + } + msg.Answer = append(msg.Answer, answers...) + } + + if len(msg.Answer) == 0 { + msg.SetRcode(r, dns.RcodeNameError) + } + + log.Println(msg.Answer) + w.WriteMsg(msg) +} + +func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { + handler := &DnsHandler{ + DbConn: dbConn, + } + addr := fmt.Sprintf(":%d", argv.DnsPort) + + return &dns.Server{ + Addr: addr, + Net: "udp", + Handler: handler, + UDPSize: 65535, + ReusePort: true, + } +} -- cgit v1.2.3-70-g09d2 From da6b6011fc8a73af7d0feb32f116e6b10de11b44 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Wed, 3 Apr 2024 15:33:02 -0600 Subject: refactor dns server test a bit --- hcdns/server_test.go | 254 +++++++++++++++++++++++++++++++++++++++++++++++++++ test/dns_test.go | 253 -------------------------------------------------- 2 files changed, 254 insertions(+), 253 deletions(-) create mode 100644 hcdns/server_test.go delete mode 100644 test/dns_test.go (limited to 'hcdns') diff --git a/hcdns/server_test.go b/hcdns/server_test.go new file mode 100644 index 0000000..177def4 --- /dev/null +++ b/hcdns/server_test.go @@ -0,0 +1,254 @@ +package hcdns_test + +import ( + "database/sql" + "fmt" + "math/rand" + "os" + "sync" + "testing" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" + "github.com/miekg/dns" +) + +func randomPort() int { + return rand.Intn(3000) + 5192 +} + +func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { + randomDb := utils.RandomId() + dnsPort := randomPort() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + testUser := &database.User{ + ID: "test", + } + database.FindOrSaveUser(testDb, testUser) + + waitLock := &sync.Mutex{} + server := hcdns.MakeServer(&args.Arguments{ + DnsPort: dnsPort, + }, testDb) + server.NotifyStartedFunc = func() { + waitLock.Unlock() + } + waitLock.Lock() + + go func() { + server.ListenAndServe() + }() + waitLock.Lock() + + address := fmt.Sprintf("127.0.0.1:%d", dnsPort) + return testDb, server, &address, waitLock, func() { + server.Shutdown() + + testDb.Close() + os.Remove(randomDb) + } +} + +func TestWhenCNAMEIsResolved(t *testing.T) { + t.Log("TestWhenCNAMEIsResolved") + + testDb, _, addr, lock, cleanup := setup() + defer cleanup() + defer lock.Unlock() + + records := []*database.DNSRecord{ + { + ID: "0", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "next.internal.example.com.", + TTL: 300, + Internal: true, + }, { + ID: "1", + UserID: "test", + Name: "next.internal.example.com.", + Type: "CNAME", + Content: "res.example.com.", + TTL: 300, + Internal: true, + }, + { + ID: "2", + UserID: "test", + Name: "res.example.com.", + Type: "A", + Content: "1.2.3.2", + TTL: 300, + Internal: true, + }, + } + + for _, record := range records { + database.SaveDNSRecord(testDb, record) + } + + qtype := dns.TypeA + domain := dns.Fqdn("cname.internal.example.com.") + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, *addr) + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 3 { + t.Fatalf("expected 3 answers, got %d", len(in.Answer)) + } + + for i, record := range records { + if in.Answer[i].Header().Name != record.Name { + t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name) + } + + if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] { + t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype) + } + + if int(in.Answer[i].Header().Ttl) != record.TTL { + t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } + } + + if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" { + t.Fatalf("expected final record to be the A record with correct IP") + } +} + +func TestWhenNoRecordNxDomain(t *testing.T) { + t.Log("TestWhenNoRecordNxDomain") + + _, _, addr, lock, cleanup := setup() + defer cleanup() + defer lock.Unlock() + + qtype := dns.TypeA + domain := dns.Fqdn("nonexistant.example.com.") + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, *addr) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 0 { + t.Fatalf("expected 0 answers, got %d", len(in.Answer)) + } + + if in.Rcode != dns.RcodeNameError { + t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) + } +} + +func TestWhenUnresolvingCNAME(t *testing.T) { + t.Log("TestWhenUnresolvingCNAME") + + testDb, _, addr, lock, cleanup := setup() + defer cleanup() + defer lock.Unlock() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "nonexistant.example.com.", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, *addr) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 1 { + t.Fatalf("expected 1 answer, got %d", len(in.Answer)) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } + + if in.Answer[0].Header().Name != cname.Name { + t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) + } + + if in.Answer[0].Header().Rrtype != dns.TypeCNAME { + t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) + } + + if in.Answer[0].(*dns.CNAME).Target != cname.Content { + t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) + } + + if in.Rcode == dns.RcodeNameError { + t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) + } +} + +func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { + t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") + + testDb, _, addr, lock, cleanup := setup() + defer cleanup() + defer lock.Unlock() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "cname.internal.example.com.", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, *addr) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) > 0 { + t.Fatalf("expected 0 answers, got %d", len(in.Answer)) + } + + if in.Rcode != dns.RcodeServerFailure { + t.Fatalf("expected SERVFAIL, got %d", in.Rcode) + } +} diff --git a/test/dns_test.go b/test/dns_test.go deleted file mode 100644 index d875f3f..0000000 --- a/test/dns_test.go +++ /dev/null @@ -1,253 +0,0 @@ -package hcdns - -import ( - "database/sql" - "fmt" - "math/rand" - "os" - "sync" - "testing" - - "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" - "github.com/miekg/dns" -) - -func randomPort() int { - return rand.Intn(3000) + 1024 -} - -func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { - randomDb := utils.RandomId() - dnsPort := randomPort() - - testDb := database.MakeConn(&randomDb) - database.Migrate(testDb) - testUser := &database.User{ - ID: "test", - } - database.FindOrSaveUser(testDb, testUser) - - waitLock := &sync.Mutex{} - server := hcdns.MakeServer(&args.Arguments{ - DnsPort: dnsPort, - }, testDb) - server.NotifyStartedFunc = func() { - waitLock.Unlock() - } - waitLock.Lock() - - go func() { - server.ListenAndServe() - }() - waitLock.Lock() - - address := fmt.Sprintf("127.0.0.1:%d", dnsPort) - return testDb, server, &address, waitLock, func() { - server.Shutdown() - - testDb.Close() - os.Remove(randomDb) - } -} - -func TestWhenCNAMEIsResolved(t *testing.T) { - t.Log("TestWhenCNAMEIsResolved") - - testDb, _, addr, lock, cleanup := setup() - defer cleanup() - defer lock.Unlock() - - cname := &database.DNSRecord{ - ID: "1", - UserID: "test", - Name: "cname.internal.example.com.", - Type: "CNAME", - Content: "res.example.com.", - TTL: 300, - Internal: true, - } - a := &database.DNSRecord{ - ID: "2", - UserID: "test", - Name: "res.example.com.", - Type: "A", - Content: "127.0.0.1", - TTL: 300, - Internal: true, - } - database.SaveDNSRecord(testDb, cname) - database.SaveDNSRecord(testDb, a) - - qtype := dns.TypeA - domain := dns.Fqdn(cname.Name) - client := &dns.Client{} - message := &dns.Msg{} - message.SetQuestion(domain, qtype) - - in, _, err := client.Exchange(message, *addr) - - if err != nil { - t.Fatal(err) - } - - if len(in.Answer) != 2 { - t.Fatalf("expected 2 answers, got %d", len(in.Answer)) - } - - if in.Answer[0].Header().Name != cname.Name { - t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) - } - - if in.Answer[1].Header().Name != a.Name { - t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name) - } - - if in.Answer[0].(*dns.CNAME).Target != a.Name { - t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) - } - - if in.Answer[1].(*dns.A).A.String() != a.Content { - t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String()) - } - - if in.Answer[0].Header().Rrtype != dns.TypeCNAME { - t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) - } - - if in.Answer[1].Header().Rrtype != dns.TypeA { - t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype) - } - - if int(in.Answer[0].Header().Ttl) != cname.TTL { - t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl) - } - - if !in.Authoritative { - t.Fatalf("expected authoritative response") - } -} - -func TestWhenNoRecordNxDomain(t *testing.T) { - t.Log("TestWhenNoRecordNxDomain") - - _, _, addr, lock, cleanup := setup() - defer cleanup() - defer lock.Unlock() - - qtype := dns.TypeA - domain := dns.Fqdn("nonexistant.example.com.") - client := &dns.Client{} - message := &dns.Msg{} - message.SetQuestion(domain, qtype) - - in, _, err := client.Exchange(message, *addr) - - if err != nil { - t.Fatal(err) - } - - if len(in.Answer) != 0 { - t.Fatalf("expected 0 answers, got %d", len(in.Answer)) - } - - if in.Rcode != dns.RcodeNameError { - t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) - } -} - -func TestWhenUnresolvingCNAME(t *testing.T) { - t.Log("TestWhenUnresolvingCNAME") - - testDb, _, addr, lock, cleanup := setup() - defer cleanup() - defer lock.Unlock() - - cname := &database.DNSRecord{ - ID: "1", - UserID: "test", - Name: "cname.internal.example.com.", - Type: "CNAME", - Content: "nonexistant.example.com.", - TTL: 300, - Internal: true, - } - database.SaveDNSRecord(testDb, cname) - - qtype := dns.TypeA - domain := dns.Fqdn(cname.Name) - client := &dns.Client{} - message := &dns.Msg{} - message.SetQuestion(domain, qtype) - - in, _, err := client.Exchange(message, *addr) - - if err != nil { - t.Fatal(err) - } - - if len(in.Answer) != 1 { - t.Fatalf("expected 1 answer, got %d", len(in.Answer)) - } - - if !in.Authoritative { - t.Fatalf("expected authoritative response") - } - - if in.Answer[0].Header().Name != cname.Name { - t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) - } - - if in.Answer[0].Header().Rrtype != dns.TypeCNAME { - t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) - } - - if in.Answer[0].(*dns.CNAME).Target != cname.Content { - t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) - } - - if in.Rcode == dns.RcodeNameError { - t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) - } -} - -func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { - t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") - - testDb, _, addr, lock, cleanup := setup() - defer cleanup() - defer lock.Unlock() - - cname := &database.DNSRecord{ - ID: "1", - UserID: "test", - Name: "cname.internal.example.com.", - Type: "CNAME", - Content: "cname.internal.example.com.", - TTL: 300, - Internal: true, - } - database.SaveDNSRecord(testDb, cname) - - qtype := dns.TypeA - domain := dns.Fqdn(cname.Name) - client := &dns.Client{} - message := &dns.Msg{} - message.SetQuestion(domain, qtype) - - in, _, err := client.Exchange(message, *addr) - - if err != nil { - t.Fatal(err) - } - - if len(in.Answer) > 0 { - t.Fatalf("expected 0 answers, got %d", len(in.Answer)) - } - - if in.Rcode != dns.RcodeServerFailure { - t.Fatalf("expected SERVFAIL, got %d", in.Rcode) - } -} -- cgit v1.2.3-70-g09d2