diff options
| author | simponic <simponic@hatecomputers.club> | 2024-04-06 15:43:18 -0400 |
|---|---|---|
| committer | simponic <simponic@hatecomputers.club> | 2024-04-06 15:43:18 -0400 |
| commit | 83cc6267fd5ce2f61200314424c5f400f65ff2ba (patch) | |
| tree | eafb35310236a15572cbb6e16ff8d6f181bfe240 /hcdns/server_test.go | |
| parent | 569d2788ebfb90774faf361f62bfe7968e091465 (diff) | |
| parent | cad8e2c4ed5e3bab61ff243f8677f8a46eaeafb0 (diff) | |
| download | hatecomputers.club-83cc6267fd5ce2f61200314424c5f400f65ff2ba.tar.gz hatecomputers.club-83cc6267fd5ce2f61200314424c5f400f65ff2ba.zip | |
Merge pull request 'testing | dont be recursive for external domains | finalize oauth' (#5) from dont-be-authoritative into main
Reviewed-on: https://git.hatecomputers.club/hatecomputers/hatecomputers.club/pulls/5
Diffstat (limited to 'hcdns/server_test.go')
| -rw-r--r-- | hcdns/server_test.go | 254 |
1 files changed, 254 insertions, 0 deletions
diff --git a/hcdns/server_test.go b/hcdns/server_test.go new file mode 100644 index 0000000..177def4 --- /dev/null +++ b/hcdns/server_test.go @@ -0,0 +1,254 @@ +package hcdns_test + +import ( + "database/sql" + "fmt" + "math/rand" + "os" + "sync" + "testing" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" + "github.com/miekg/dns" +) + +func randomPort() int { + return rand.Intn(3000) + 5192 +} + +func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { + randomDb := utils.RandomId() + dnsPort := randomPort() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + testUser := &database.User{ + ID: "test", + } + database.FindOrSaveUser(testDb, testUser) + + waitLock := &sync.Mutex{} + server := hcdns.MakeServer(&args.Arguments{ + DnsPort: dnsPort, + }, testDb) + server.NotifyStartedFunc = func() { + waitLock.Unlock() + } + waitLock.Lock() + + go func() { + server.ListenAndServe() + }() + waitLock.Lock() + + address := fmt.Sprintf("127.0.0.1:%d", dnsPort) + return testDb, server, &address, waitLock, func() { + server.Shutdown() + + testDb.Close() + os.Remove(randomDb) + } +} + +func TestWhenCNAMEIsResolved(t *testing.T) { + t.Log("TestWhenCNAMEIsResolved") + + testDb, _, addr, lock, cleanup := setup() + defer cleanup() + defer lock.Unlock() + + records := []*database.DNSRecord{ + { + ID: "0", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "next.internal.example.com.", + TTL: 300, + Internal: true, + }, { + ID: "1", + UserID: "test", + Name: "next.internal.example.com.", + Type: "CNAME", + Content: "res.example.com.", + TTL: 300, + Internal: true, + }, + { + ID: "2", + UserID: "test", + Name: "res.example.com.", + Type: "A", + Content: "1.2.3.2", + TTL: 300, + Internal: true, + }, + } + + for _, record := range records { + database.SaveDNSRecord(testDb, record) + } + + qtype := dns.TypeA + domain := dns.Fqdn("cname.internal.example.com.") + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, *addr) + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 3 { + t.Fatalf("expected 3 answers, got %d", len(in.Answer)) + } + + for i, record := range records { + if in.Answer[i].Header().Name != record.Name { + t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name) + } + + if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] { + t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype) + } + + if int(in.Answer[i].Header().Ttl) != record.TTL { + t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } + } + + if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" { + t.Fatalf("expected final record to be the A record with correct IP") + } +} + +func TestWhenNoRecordNxDomain(t *testing.T) { + t.Log("TestWhenNoRecordNxDomain") + + _, _, addr, lock, cleanup := setup() + defer cleanup() + defer lock.Unlock() + + qtype := dns.TypeA + domain := dns.Fqdn("nonexistant.example.com.") + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, *addr) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 0 { + t.Fatalf("expected 0 answers, got %d", len(in.Answer)) + } + + if in.Rcode != dns.RcodeNameError { + t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) + } +} + +func TestWhenUnresolvingCNAME(t *testing.T) { + t.Log("TestWhenUnresolvingCNAME") + + testDb, _, addr, lock, cleanup := setup() + defer cleanup() + defer lock.Unlock() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "nonexistant.example.com.", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, *addr) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 1 { + t.Fatalf("expected 1 answer, got %d", len(in.Answer)) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } + + if in.Answer[0].Header().Name != cname.Name { + t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) + } + + if in.Answer[0].Header().Rrtype != dns.TypeCNAME { + t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) + } + + if in.Answer[0].(*dns.CNAME).Target != cname.Content { + t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) + } + + if in.Rcode == dns.RcodeNameError { + t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) + } +} + +func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { + t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") + + testDb, _, addr, lock, cleanup := setup() + defer cleanup() + defer lock.Unlock() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "cname.internal.example.com.", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, *addr) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) > 0 { + t.Fatalf("expected 0 answers, got %d", len(in.Answer)) + } + + if in.Rcode != dns.RcodeServerFailure { + t.Fatalf("expected SERVFAIL, got %d", in.Rcode) + } +} |
