diff options
Diffstat (limited to 'test')
| -rw-r--r-- | test/dns_test.go | 244 |
1 files changed, 244 insertions, 0 deletions
diff --git a/test/dns_test.go b/test/dns_test.go new file mode 100644 index 0000000..ce6deb5 --- /dev/null +++ b/test/dns_test.go @@ -0,0 +1,244 @@ +package hcdns + +import ( + "database/sql" + "os" + "sync" + "testing" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" + "github.com/miekg/dns" +) + +const ( + testDBPath = "test.db" + address = "127.0.0.1:8353" + dnsPort = 8353 +) + +func setup(dbPath string) (*sql.DB, *dns.Server, *sync.WaitGroup) { + testDb := database.MakeConn(&dbPath) + database.Migrate(testDb) + testUser := &database.User{ + ID: "test", + } + database.FindOrSaveUser(testDb, testUser) + + server := hcdns.MakeServer(&args.Arguments{ + DnsPort: dnsPort, + }, testDb) + + waitGroup := sync.WaitGroup{} + waitGroup.Add(1) + go func() { + server.ListenAndServe() + waitGroup.Done() + }() + + return testDb, server, &waitGroup +} + +func destroy(conn *sql.DB, path string) { + conn.Close() + os.Remove(path) +} + +func TestWhenCNAMEIsResolved(t *testing.T) { + t.Log("TestWhenCNAMEIsResolved") + + testDb, server, _ := setup(testDBPath) + defer destroy(testDb, testDBPath) + defer server.Shutdown() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "res.example.com.", + TTL: 300, + Internal: true, + } + a := &database.DNSRecord{ + ID: "2", + UserID: "test", + Name: "res.example.com.", + Type: "A", + Content: "127.0.0.1", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + database.SaveDNSRecord(testDb, a) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := new(dns.Client) + message := new(dns.Msg) + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, address) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 2 { + t.Fatalf("expected 2 answers, got %d", len(in.Answer)) + } + + if in.Answer[0].Header().Name != cname.Name { + t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) + } + + if in.Answer[1].Header().Name != a.Name { + t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name) + } + + if in.Answer[0].(*dns.CNAME).Target != a.Name { + t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) + } + + if in.Answer[1].(*dns.A).A.String() != a.Content { + t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String()) + } + + if in.Answer[0].Header().Rrtype != dns.TypeCNAME { + t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) + } + + if in.Answer[1].Header().Rrtype != dns.TypeA { + t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype) + } + + if int(in.Answer[0].Header().Ttl) != cname.TTL { + t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } +} + +func TestWhenNoRecordNxDomain(t *testing.T) { + t.Log("TestWhenNoRecordNxDomain") + + testDb, server, _ := setup(testDBPath) + defer destroy(testDb, testDBPath) + defer server.Shutdown() + + qtype := dns.TypeA + domain := dns.Fqdn("nonexistant.example.com.") + client := new(dns.Client) + message := new(dns.Msg) + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, address) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 0 { + t.Fatalf("expected 0 answers, got %d", len(in.Answer)) + } + + if in.Rcode != dns.RcodeNameError { + t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) + } +} + +func TestWhenUnresolvingCNAME(t *testing.T) { + t.Log("TestWhenUnresolvingCNAME") + + testDb, server, _ := setup(testDBPath) + defer destroy(testDb, testDBPath) + defer server.Shutdown() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "nonexistant.example.com.", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := new(dns.Client) + message := new(dns.Msg) + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, address) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 1 { + t.Fatalf("expected 1 answer, got %d", len(in.Answer)) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } + + if in.Answer[0].Header().Name != cname.Name { + t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) + } + + if in.Answer[0].Header().Rrtype != dns.TypeCNAME { + t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) + } + + if in.Answer[0].(*dns.CNAME).Target != cname.Content { + t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) + } + + if in.Rcode == dns.RcodeNameError { + t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) + } +} + +func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { + t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") + + testDb, server, _ := setup(testDBPath) + defer destroy(testDb, testDBPath) + defer server.Shutdown() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "cname.internal.example.com.", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := new(dns.Client) + message := new(dns.Msg) + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, address) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) > 0 { + t.Fatalf("expected 0 answers, got %d", len(in.Answer)) + } + if in.Rcode != dns.RcodeServerFailure { + t.Fatalf("expected SERVFAIL, got %d", in.Rcode) + } +} |
