summaryrefslogtreecommitdiff
path: root/hcdns/server_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'hcdns/server_test.go')
-rw-r--r--hcdns/server_test.go124
1 files changed, 97 insertions, 27 deletions
diff --git a/hcdns/server_test.go b/hcdns/server_test.go
index 177def4..9993bbf 100644
--- a/hcdns/server_test.go
+++ b/hcdns/server_test.go
@@ -19,9 +19,8 @@ func randomPort() int {
return rand.Intn(3000) + 5192
}
-func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
+func setup(arguments *args.Arguments) (*sql.DB, *dns.Server, string, func()) {
randomDb := utils.RandomId()
- dnsPort := randomPort()
testDb := database.MakeConn(&randomDb)
database.Migrate(testDb)
@@ -30,10 +29,15 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
}
database.FindOrSaveUser(testDb, testUser)
+ dnsArguments := arguments
+ if dnsArguments == nil {
+ dnsArguments = &args.Arguments{
+ DnsPort: randomPort(),
+ }
+ }
+
waitLock := &sync.Mutex{}
- server := hcdns.MakeServer(&args.Arguments{
- DnsPort: dnsPort,
- }, testDb)
+ server := hcdns.MakeServer(dnsArguments, testDb)
server.NotifyStartedFunc = func() {
waitLock.Unlock()
}
@@ -44,8 +48,9 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
}()
waitLock.Lock()
- address := fmt.Sprintf("127.0.0.1:%d", dnsPort)
- return testDb, server, &address, waitLock, func() {
+ address := fmt.Sprintf("127.0.0.1:%d", dnsArguments.DnsPort)
+ return testDb, server, address, func() {
+ waitLock.Unlock()
server.Shutdown()
testDb.Close()
@@ -53,12 +58,86 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
}
}
-func TestWhenCNAMEIsResolved(t *testing.T) {
- t.Log("TestWhenCNAMEIsResolved")
+func TestWhenExternalDomain(t *testing.T) {
+ externalDb, _, externalAddr, externalCleanup := setup(nil)
+ internalDb, _, internalAddr, internalCleanup := setup(&args.Arguments{
+ DnsPort: randomPort(),
+ DnsResolvers: []string{externalAddr},
+ })
+ defer internalCleanup()
+ defer externalCleanup()
+
+ authoritativeRecords := []database.DNSRecord{
+ {
+ ID: "1",
+ UserID: "test",
+ Name: "external.example.com.",
+ Type: "CNAME",
+ Content: "external.internal.example.com.",
+ },
+ }
+ internalRecords := []database.DNSRecord{
+ {
+ ID: "1",
+ UserID: "test",
+ Name: "external.internal.example.com.",
+ Type: "A",
+ Content: "127.0.0.1",
+ },
+ {
+ ID: "2",
+ UserID: "test",
+ Name: "test.internal.example.com.",
+ Type: "CNAME",
+ Content: "external.example.com.",
+ },
+ }
+
+ for _, record := range authoritativeRecords {
+ database.SaveDNSRecord(externalDb, &record)
+ }
+ for _, record := range internalRecords {
+ database.SaveDNSRecord(internalDb, &record)
+ }
- testDb, _, addr, lock, cleanup := setup()
+ // ensure that if the record doesn't exist in the internal database, it will
+ // go and query the external dns resolvers, then loop back to the internal
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn("test.internal.example.com.")
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, internalAddr)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) != 3 {
+ t.Fatalf("expected 3 answers, got %d", len(in.Answer))
+ }
+
+ aRecord := in.Answer[2]
+ if aRecord.Header().Name != internalRecords[0].Name {
+ t.Fatalf("expected %s, got %s", domain, aRecord.Header().Name)
+ }
+ if aRecord.Header().Rrtype != dns.TypeA {
+ t.Fatalf("expected %s, got %s", dns.TypeToString[aRecord.Header().Rrtype], internalRecords[1].Type)
+ }
+ if aRecord.(*dns.A).A.String() != internalRecords[0].Content {
+ t.Fatalf("expected %s, got %s", internalRecords[0].Content, aRecord.(*dns.A).A.String())
+ }
+
+ if in.Authoritative {
+ t.Fatalf("expected non-authoritative response")
+ }
+}
+
+func TestWhenCNAMEIsResolved(t *testing.T) {
+ testDb, _, addr, cleanup := setup(nil)
defer cleanup()
- defer lock.Unlock()
records := []*database.DNSRecord{
{
@@ -99,7 +178,7 @@ func TestWhenCNAMEIsResolved(t *testing.T) {
message := &dns.Msg{}
message.SetQuestion(domain, qtype)
- in, _, err := client.Exchange(message, *addr)
+ in, _, err := client.Exchange(message, addr)
if err != nil {
t.Fatal(err)
}
@@ -132,11 +211,8 @@ func TestWhenCNAMEIsResolved(t *testing.T) {
}
func TestWhenNoRecordNxDomain(t *testing.T) {
- t.Log("TestWhenNoRecordNxDomain")
-
- _, _, addr, lock, cleanup := setup()
+ _, _, addr, cleanup := setup(nil)
defer cleanup()
- defer lock.Unlock()
qtype := dns.TypeA
domain := dns.Fqdn("nonexistant.example.com.")
@@ -144,7 +220,7 @@ func TestWhenNoRecordNxDomain(t *testing.T) {
message := &dns.Msg{}
message.SetQuestion(domain, qtype)
- in, _, err := client.Exchange(message, *addr)
+ in, _, err := client.Exchange(message, addr)
if err != nil {
t.Fatal(err)
@@ -160,11 +236,8 @@ func TestWhenNoRecordNxDomain(t *testing.T) {
}
func TestWhenUnresolvingCNAME(t *testing.T) {
- t.Log("TestWhenUnresolvingCNAME")
-
- testDb, _, addr, lock, cleanup := setup()
+ testDb, _, addr, cleanup := setup(nil)
defer cleanup()
- defer lock.Unlock()
cname := &database.DNSRecord{
ID: "1",
@@ -183,7 +256,7 @@ func TestWhenUnresolvingCNAME(t *testing.T) {
message := &dns.Msg{}
message.SetQuestion(domain, qtype)
- in, _, err := client.Exchange(message, *addr)
+ in, _, err := client.Exchange(message, addr)
if err != nil {
t.Fatal(err)
@@ -215,11 +288,8 @@ func TestWhenUnresolvingCNAME(t *testing.T) {
}
func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
- t.Log("TestWhenUnresolvingCNAMEWithMaxDepth")
-
- testDb, _, addr, lock, cleanup := setup()
+ testDb, _, addr, cleanup := setup(nil)
defer cleanup()
- defer lock.Unlock()
cname := &database.DNSRecord{
ID: "1",
@@ -238,7 +308,7 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
message := &dns.Msg{}
message.SetQuestion(domain, qtype)
- in, _, err := client.Exchange(message, *addr)
+ in, _, err := client.Exchange(message, addr)
if err != nil {
t.Fatal(err)