summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorElizabeth Hunt <elizabeth.hunt@simponic.xyz>2024-04-07 19:02:42 -0600
committerElizabeth Hunt <elizabeth.hunt@simponic.xyz>2024-04-07 19:02:42 -0600
commite2ce6804a76c771759603e3b3800a013275217a1 (patch)
treefa0303581fdfc166ffaf3c9a482434a2ffa34e3e
parent5768f07ce51271239b16b4cfda6206366002cefc (diff)
downloadhatecomputers.club-e2ce6804a76c771759603e3b3800a013275217a1.tar.gz
hatecomputers.club-e2ce6804a76c771759603e3b3800a013275217a1.zip
be authoritative, but only when there's no external queries occuring
-rw-r--r--hcdns/server.go9
-rw-r--r--hcdns/server_test.go160
2 files changed, 89 insertions, 80 deletions
diff --git a/hcdns/server.go b/hcdns/server.go
index e5a8d29..2e110e8 100644
--- a/hcdns/server.go
+++ b/hcdns/server.go
@@ -70,11 +70,12 @@ func (h *DnsHandler) recursiveResolve(domain string, qtype uint16, maxDepth int)
answers = append(answers, cname)
cnameRecursive, cnameAuth, err := h.resolveDNS(record.Content, qtype, maxDepth-1)
+ authoritative = authoritative && cnameAuth
if err != nil {
log.Println(err)
return nil, authoritative, err
}
- authoritative = authoritative && cnameAuth
+
answers = append(answers, cnameRecursive...)
}
@@ -126,14 +127,16 @@ func (h *DnsHandler) resolveDNS(domain string, qtype uint16, maxDepth int) ([]dn
continue
}
- cnameAnswers, _, err := h.resolveDNS(cname.Target, qtype, maxDepth-1)
+ cnameAnswers, cnameAuth, err := h.resolveDNS(cname.Target, qtype, maxDepth-1)
+ authoritative = authoritative && cnameAuth
if err != nil {
return nil, false, err
}
answers = append(answers, cnameAnswers...)
}
- return answers, false, nil
+ authoritative = authoritative && len(externalAnswers) == 0
+ return answers, authoritative, nil
}
func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
diff --git a/hcdns/server_test.go b/hcdns/server_test.go
index 9993bbf..f1b283f 100644
--- a/hcdns/server_test.go
+++ b/hcdns/server_test.go
@@ -58,83 +58,6 @@ func setup(arguments *args.Arguments) (*sql.DB, *dns.Server, string, func()) {
}
}
-func TestWhenExternalDomain(t *testing.T) {
- externalDb, _, externalAddr, externalCleanup := setup(nil)
- internalDb, _, internalAddr, internalCleanup := setup(&args.Arguments{
- DnsPort: randomPort(),
- DnsResolvers: []string{externalAddr},
- })
- defer internalCleanup()
- defer externalCleanup()
-
- authoritativeRecords := []database.DNSRecord{
- {
- ID: "1",
- UserID: "test",
- Name: "external.example.com.",
- Type: "CNAME",
- Content: "external.internal.example.com.",
- },
- }
- internalRecords := []database.DNSRecord{
- {
- ID: "1",
- UserID: "test",
- Name: "external.internal.example.com.",
- Type: "A",
- Content: "127.0.0.1",
- },
- {
- ID: "2",
- UserID: "test",
- Name: "test.internal.example.com.",
- Type: "CNAME",
- Content: "external.example.com.",
- },
- }
-
- for _, record := range authoritativeRecords {
- database.SaveDNSRecord(externalDb, &record)
- }
- for _, record := range internalRecords {
- database.SaveDNSRecord(internalDb, &record)
- }
-
- // ensure that if the record doesn't exist in the internal database, it will
- // go and query the external dns resolvers, then loop back to the internal
-
- qtype := dns.TypeA
- domain := dns.Fqdn("test.internal.example.com.")
- client := &dns.Client{}
- message := &dns.Msg{}
- message.SetQuestion(domain, qtype)
-
- in, _, err := client.Exchange(message, internalAddr)
-
- if err != nil {
- t.Fatal(err)
- }
-
- if len(in.Answer) != 3 {
- t.Fatalf("expected 3 answers, got %d", len(in.Answer))
- }
-
- aRecord := in.Answer[2]
- if aRecord.Header().Name != internalRecords[0].Name {
- t.Fatalf("expected %s, got %s", domain, aRecord.Header().Name)
- }
- if aRecord.Header().Rrtype != dns.TypeA {
- t.Fatalf("expected %s, got %s", dns.TypeToString[aRecord.Header().Rrtype], internalRecords[1].Type)
- }
- if aRecord.(*dns.A).A.String() != internalRecords[0].Content {
- t.Fatalf("expected %s, got %s", internalRecords[0].Content, aRecord.(*dns.A).A.String())
- }
-
- if in.Authoritative {
- t.Fatalf("expected non-authoritative response")
- }
-}
-
func TestWhenCNAMEIsResolved(t *testing.T) {
testDb, _, addr, cleanup := setup(nil)
defer cleanup()
@@ -322,3 +245,86 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
}
}
+
+func TestWhenExternalDomain(t *testing.T) {
+ externalDb, _, externalAddr, externalCleanup := setup(nil)
+ internalDb, _, internalAddr, internalCleanup := setup(&args.Arguments{
+ DnsPort: randomPort(),
+ DnsResolvers: []string{externalAddr},
+ })
+ defer internalCleanup()
+ defer externalCleanup()
+
+ authoritativeRecords := []database.DNSRecord{
+ {
+ ID: "1",
+ UserID: "test",
+ Name: "external.example.com.",
+ Type: "CNAME",
+ Content: "external.internal.example.com.",
+ },
+ {
+ ID: "2",
+ UserID: "test",
+ Name: "final.example.com.",
+ Type: "A",
+ Content: "127.0.0.1",
+ },
+ }
+ internalRecords := []database.DNSRecord{
+ {
+ ID: "1",
+ UserID: "test",
+ Name: "external.internal.example.com.",
+ Type: "CNAME",
+ Content: "final.example.com",
+ },
+ {
+ ID: "2",
+ UserID: "test",
+ Name: "test.internal.example.com.",
+ Type: "CNAME",
+ Content: "external.example.com.",
+ },
+ }
+
+ for _, record := range authoritativeRecords {
+ database.SaveDNSRecord(externalDb, &record)
+ }
+ for _, record := range internalRecords {
+ database.SaveDNSRecord(internalDb, &record)
+ }
+
+ // ensure that if the record doesn't exist in the internal database, it will
+ // go and query the external dns resolvers, then loop back to the internal
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn("test.internal.example.com.")
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, internalAddr)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) != 4 {
+ t.Fatalf("expected 4 answers, got %d", len(in.Answer))
+ }
+
+ aRecord := in.Answer[3]
+ if aRecord.Header().Name != authoritativeRecords[1].Name {
+ t.Fatalf("expected %s, got %s", domain, aRecord.Header().Name)
+ }
+ if aRecord.Header().Rrtype != dns.TypeA {
+ t.Fatalf("expected %s, got %s", dns.TypeToString[aRecord.Header().Rrtype], internalRecords[1].Type)
+ }
+ if aRecord.(*dns.A).A.String() != authoritativeRecords[1].Content {
+ t.Fatalf("expected %s, got %s", authoritativeRecords[1].Content, aRecord.(*dns.A).A.String())
+ }
+ if in.Authoritative {
+ t.Fatalf("expected non-authoritative response")
+ }
+}