summaryrefslogtreecommitdiff
path: root/hcdns/server_test.go
diff options
context:
space:
mode:
authorElizabeth <elizabeth@simponic.xyz>2024-04-03 15:33:02 -0600
committerElizabeth <elizabeth@simponic.xyz>2024-04-03 15:33:02 -0600
commitda6b6011fc8a73af7d0feb32f116e6b10de11b44 (patch)
treef7d6fefdc11352cd98aba5106aad8504aa87d3ba /hcdns/server_test.go
parentcc33a90bfd455f36169b01b0cca064cd35e2524f (diff)
downloadhatecomputers.club-da6b6011fc8a73af7d0feb32f116e6b10de11b44.tar.gz
hatecomputers.club-da6b6011fc8a73af7d0feb32f116e6b10de11b44.zip
refactor dns server test a bit
Diffstat (limited to 'hcdns/server_test.go')
-rw-r--r--hcdns/server_test.go254
1 files changed, 254 insertions, 0 deletions
diff --git a/hcdns/server_test.go b/hcdns/server_test.go
new file mode 100644
index 0000000..177def4
--- /dev/null
+++ b/hcdns/server_test.go
@@ -0,0 +1,254 @@
+package hcdns_test
+
+import (
+ "database/sql"
+ "fmt"
+ "math/rand"
+ "os"
+ "sync"
+ "testing"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
+ "github.com/miekg/dns"
+)
+
+func randomPort() int {
+ return rand.Intn(3000) + 5192
+}
+
+func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
+ randomDb := utils.RandomId()
+ dnsPort := randomPort()
+
+ testDb := database.MakeConn(&randomDb)
+ database.Migrate(testDb)
+ testUser := &database.User{
+ ID: "test",
+ }
+ database.FindOrSaveUser(testDb, testUser)
+
+ waitLock := &sync.Mutex{}
+ server := hcdns.MakeServer(&args.Arguments{
+ DnsPort: dnsPort,
+ }, testDb)
+ server.NotifyStartedFunc = func() {
+ waitLock.Unlock()
+ }
+ waitLock.Lock()
+
+ go func() {
+ server.ListenAndServe()
+ }()
+ waitLock.Lock()
+
+ address := fmt.Sprintf("127.0.0.1:%d", dnsPort)
+ return testDb, server, &address, waitLock, func() {
+ server.Shutdown()
+
+ testDb.Close()
+ os.Remove(randomDb)
+ }
+}
+
+func TestWhenCNAMEIsResolved(t *testing.T) {
+ t.Log("TestWhenCNAMEIsResolved")
+
+ testDb, _, addr, lock, cleanup := setup()
+ defer cleanup()
+ defer lock.Unlock()
+
+ records := []*database.DNSRecord{
+ {
+ ID: "0",
+ UserID: "test",
+ Name: "cname.internal.example.com.",
+ Type: "CNAME",
+ Content: "next.internal.example.com.",
+ TTL: 300,
+ Internal: true,
+ }, {
+ ID: "1",
+ UserID: "test",
+ Name: "next.internal.example.com.",
+ Type: "CNAME",
+ Content: "res.example.com.",
+ TTL: 300,
+ Internal: true,
+ },
+ {
+ ID: "2",
+ UserID: "test",
+ Name: "res.example.com.",
+ Type: "A",
+ Content: "1.2.3.2",
+ TTL: 300,
+ Internal: true,
+ },
+ }
+
+ for _, record := range records {
+ database.SaveDNSRecord(testDb, record)
+ }
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn("cname.internal.example.com.")
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, *addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) != 3 {
+ t.Fatalf("expected 3 answers, got %d", len(in.Answer))
+ }
+
+ for i, record := range records {
+ if in.Answer[i].Header().Name != record.Name {
+ t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name)
+ }
+
+ if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] {
+ t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype)
+ }
+
+ if int(in.Answer[i].Header().Ttl) != record.TTL {
+ t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl)
+ }
+
+ if !in.Authoritative {
+ t.Fatalf("expected authoritative response")
+ }
+ }
+
+ if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" {
+ t.Fatalf("expected final record to be the A record with correct IP")
+ }
+}
+
+func TestWhenNoRecordNxDomain(t *testing.T) {
+ t.Log("TestWhenNoRecordNxDomain")
+
+ _, _, addr, lock, cleanup := setup()
+ defer cleanup()
+ defer lock.Unlock()
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn("nonexistant.example.com.")
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, *addr)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) != 0 {
+ t.Fatalf("expected 0 answers, got %d", len(in.Answer))
+ }
+
+ if in.Rcode != dns.RcodeNameError {
+ t.Fatalf("expected NXDOMAIN, got %d", in.Rcode)
+ }
+}
+
+func TestWhenUnresolvingCNAME(t *testing.T) {
+ t.Log("TestWhenUnresolvingCNAME")
+
+ testDb, _, addr, lock, cleanup := setup()
+ defer cleanup()
+ defer lock.Unlock()
+
+ cname := &database.DNSRecord{
+ ID: "1",
+ UserID: "test",
+ Name: "cname.internal.example.com.",
+ Type: "CNAME",
+ Content: "nonexistant.example.com.",
+ TTL: 300,
+ Internal: true,
+ }
+ database.SaveDNSRecord(testDb, cname)
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn(cname.Name)
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, *addr)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) != 1 {
+ t.Fatalf("expected 1 answer, got %d", len(in.Answer))
+ }
+
+ if !in.Authoritative {
+ t.Fatalf("expected authoritative response")
+ }
+
+ if in.Answer[0].Header().Name != cname.Name {
+ t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name)
+ }
+
+ if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
+ t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
+ }
+
+ if in.Answer[0].(*dns.CNAME).Target != cname.Content {
+ t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target)
+ }
+
+ if in.Rcode == dns.RcodeNameError {
+ t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode)
+ }
+}
+
+func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
+ t.Log("TestWhenUnresolvingCNAMEWithMaxDepth")
+
+ testDb, _, addr, lock, cleanup := setup()
+ defer cleanup()
+ defer lock.Unlock()
+
+ cname := &database.DNSRecord{
+ ID: "1",
+ UserID: "test",
+ Name: "cname.internal.example.com.",
+ Type: "CNAME",
+ Content: "cname.internal.example.com.",
+ TTL: 300,
+ Internal: true,
+ }
+ database.SaveDNSRecord(testDb, cname)
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn(cname.Name)
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, *addr)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) > 0 {
+ t.Fatalf("expected 0 answers, got %d", len(in.Answer))
+ }
+
+ if in.Rcode != dns.RcodeServerFailure {
+ t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
+ }
+}