From 5768f07ce51271239b16b4cfda6206366002cefc Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Sun, 7 Apr 2024 17:04:43 -0600 Subject: init --- hcdns/server_test.go | 124 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 97 insertions(+), 27 deletions(-) (limited to 'hcdns/server_test.go') diff --git a/hcdns/server_test.go b/hcdns/server_test.go index 177def4..9993bbf 100644 --- a/hcdns/server_test.go +++ b/hcdns/server_test.go @@ -19,9 +19,8 @@ func randomPort() int { return rand.Intn(3000) + 5192 } -func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { +func setup(arguments *args.Arguments) (*sql.DB, *dns.Server, string, func()) { randomDb := utils.RandomId() - dnsPort := randomPort() testDb := database.MakeConn(&randomDb) database.Migrate(testDb) @@ -30,10 +29,15 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { } database.FindOrSaveUser(testDb, testUser) + dnsArguments := arguments + if dnsArguments == nil { + dnsArguments = &args.Arguments{ + DnsPort: randomPort(), + } + } + waitLock := &sync.Mutex{} - server := hcdns.MakeServer(&args.Arguments{ - DnsPort: dnsPort, - }, testDb) + server := hcdns.MakeServer(dnsArguments, testDb) server.NotifyStartedFunc = func() { waitLock.Unlock() } @@ -44,8 +48,9 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { }() waitLock.Lock() - address := fmt.Sprintf("127.0.0.1:%d", dnsPort) - return testDb, server, &address, waitLock, func() { + address := fmt.Sprintf("127.0.0.1:%d", dnsArguments.DnsPort) + return testDb, server, address, func() { + waitLock.Unlock() server.Shutdown() testDb.Close() @@ -53,12 +58,86 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { } } -func TestWhenCNAMEIsResolved(t *testing.T) { - t.Log("TestWhenCNAMEIsResolved") +func TestWhenExternalDomain(t *testing.T) { + externalDb, _, externalAddr, externalCleanup := setup(nil) + internalDb, _, internalAddr, internalCleanup := setup(&args.Arguments{ + DnsPort: randomPort(), + DnsResolvers: []string{externalAddr}, + }) + defer internalCleanup() + defer externalCleanup() + + authoritativeRecords := []database.DNSRecord{ + { + ID: "1", + UserID: "test", + Name: "external.example.com.", + Type: "CNAME", + Content: "external.internal.example.com.", + }, + } + internalRecords := []database.DNSRecord{ + { + ID: "1", + UserID: "test", + Name: "external.internal.example.com.", + Type: "A", + Content: "127.0.0.1", + }, + { + ID: "2", + UserID: "test", + Name: "test.internal.example.com.", + Type: "CNAME", + Content: "external.example.com.", + }, + } + + for _, record := range authoritativeRecords { + database.SaveDNSRecord(externalDb, &record) + } + for _, record := range internalRecords { + database.SaveDNSRecord(internalDb, &record) + } - testDb, _, addr, lock, cleanup := setup() + // ensure that if the record doesn't exist in the internal database, it will + // go and query the external dns resolvers, then loop back to the internal + + qtype := dns.TypeA + domain := dns.Fqdn("test.internal.example.com.") + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, internalAddr) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 3 { + t.Fatalf("expected 3 answers, got %d", len(in.Answer)) + } + + aRecord := in.Answer[2] + if aRecord.Header().Name != internalRecords[0].Name { + t.Fatalf("expected %s, got %s", domain, aRecord.Header().Name) + } + if aRecord.Header().Rrtype != dns.TypeA { + t.Fatalf("expected %s, got %s", dns.TypeToString[aRecord.Header().Rrtype], internalRecords[1].Type) + } + if aRecord.(*dns.A).A.String() != internalRecords[0].Content { + t.Fatalf("expected %s, got %s", internalRecords[0].Content, aRecord.(*dns.A).A.String()) + } + + if in.Authoritative { + t.Fatalf("expected non-authoritative response") + } +} + +func TestWhenCNAMEIsResolved(t *testing.T) { + testDb, _, addr, cleanup := setup(nil) defer cleanup() - defer lock.Unlock() records := []*database.DNSRecord{ { @@ -99,7 +178,7 @@ func TestWhenCNAMEIsResolved(t *testing.T) { message := &dns.Msg{} message.SetQuestion(domain, qtype) - in, _, err := client.Exchange(message, *addr) + in, _, err := client.Exchange(message, addr) if err != nil { t.Fatal(err) } @@ -132,11 +211,8 @@ func TestWhenCNAMEIsResolved(t *testing.T) { } func TestWhenNoRecordNxDomain(t *testing.T) { - t.Log("TestWhenNoRecordNxDomain") - - _, _, addr, lock, cleanup := setup() + _, _, addr, cleanup := setup(nil) defer cleanup() - defer lock.Unlock() qtype := dns.TypeA domain := dns.Fqdn("nonexistant.example.com.") @@ -144,7 +220,7 @@ func TestWhenNoRecordNxDomain(t *testing.T) { message := &dns.Msg{} message.SetQuestion(domain, qtype) - in, _, err := client.Exchange(message, *addr) + in, _, err := client.Exchange(message, addr) if err != nil { t.Fatal(err) @@ -160,11 +236,8 @@ func TestWhenNoRecordNxDomain(t *testing.T) { } func TestWhenUnresolvingCNAME(t *testing.T) { - t.Log("TestWhenUnresolvingCNAME") - - testDb, _, addr, lock, cleanup := setup() + testDb, _, addr, cleanup := setup(nil) defer cleanup() - defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", @@ -183,7 +256,7 @@ func TestWhenUnresolvingCNAME(t *testing.T) { message := &dns.Msg{} message.SetQuestion(domain, qtype) - in, _, err := client.Exchange(message, *addr) + in, _, err := client.Exchange(message, addr) if err != nil { t.Fatal(err) @@ -215,11 +288,8 @@ func TestWhenUnresolvingCNAME(t *testing.T) { } func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { - t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") - - testDb, _, addr, lock, cleanup := setup() + testDb, _, addr, cleanup := setup(nil) defer cleanup() - defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", @@ -238,7 +308,7 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { message := &dns.Msg{} message.SetQuestion(domain, qtype) - in, _, err := client.Exchange(message, *addr) + in, _, err := client.Exchange(message, addr) if err != nil { t.Fatal(err) -- cgit v1.2.3-70-g09d2