summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsimponic <simponic@hatecomputers.club>2024-04-07 21:08:43 -0400
committersimponic <simponic@hatecomputers.club>2024-04-07 21:08:43 -0400
commit86c4ad160a0442713680ff1eaa85ead635b10f8f (patch)
treefa0303581fdfc166ffaf3c9a482434a2ffa34e3e
parent83cc6267fd5ce2f61200314424c5f400f65ff2ba (diff)
parente2ce6804a76c771759603e3b3800a013275217a1 (diff)
downloadhatecomputers.club-86c4ad160a0442713680ff1eaa85ead635b10f8f.tar.gz
hatecomputers.club-86c4ad160a0442713680ff1eaa85ead635b10f8f.zip
Merge pull request 'reimplement-recursive-resolver' (#6) from reimplement-recursive-resolver into main
Reviewed-on: https://git.hatecomputers.club/hatecomputers/hatecomputers.club/pulls/6
-rw-r--r--Dockerfile2
-rw-r--r--args/args.go12
-rw-r--r--hcdns/server.go123
-rw-r--r--hcdns/server_test.go130
4 files changed, 208 insertions, 59 deletions
diff --git a/Dockerfile b/Dockerfile
index 591423f..82f411a 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers
EXPOSE 8080
-CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"]
+CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-resolvers", "1.1.1.1,1.0.0.1"]
diff --git a/args/args.go b/args/args.go
index f71e8e3..8465fc8 100644
--- a/args/args.go
+++ b/args/args.go
@@ -22,8 +22,9 @@ type Arguments struct {
OauthConfig *oauth2.Config
OauthUserInfoURI string
- Dns bool
- DnsPort int
+ DnsResolvers []string
+ Dns bool
+ DnsPort int
CloudflareToken string
CloudflareZone string
@@ -36,6 +37,7 @@ func GetArgs() (*Arguments, error) {
databasePath := flag.String("database-path", "./hatecomputers.db", "Path to the SQLite database")
templatePath := flag.String("template-path", "./templates", "Path to the template directory")
staticPath := flag.String("static-path", "./static", "Path to the static directory")
+ dnsResolvers := flag.String("dns-resolvers", "1.1.1.1,1.0.0.1", "Comma-separated list of DNS resolvers")
scheduler := flag.Bool("scheduler", false, "Run scheduled jobs via cron")
migrate := flag.Bool("migrate", false, "Run the migrations")
@@ -101,8 +103,10 @@ func GetArgs() (*Arguments, error) {
Server: *server,
Migrate: *migrate,
Scheduler: *scheduler,
- Dns: *dns,
- DnsPort: *dnsPort,
+
+ Dns: *dns,
+ DnsPort: *dnsPort,
+ DnsResolvers: strings.Split(*dnsResolvers, ","),
OauthConfig: oauthConfig,
OauthUserInfoURI: oauthUserInfoURI,
diff --git a/hcdns/server.go b/hcdns/server.go
index ce7894b..2e110e8 100644
--- a/hcdns/server.go
+++ b/hcdns/server.go
@@ -11,74 +11,142 @@ import (
const MAX_RECURSION = 15
-func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
- internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
+type DnsHandler struct {
+ DnsResolvers []string
+ DbConn *sql.DB
+}
+
+func (h *DnsHandler) resolveExternal(domain string, qtype uint16) ([]dns.RR, error) {
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(dns.Fqdn(domain), qtype)
+ message.RecursionDesired = true
+
+ if len(h.DnsResolvers) == 0 {
+ return []dns.RR{}, nil
+ }
+
+ i := 0
+ in, _, err := client.Exchange(message, h.DnsResolvers[i])
+ for err != nil && i < len(h.DnsResolvers) {
+ i++
+ in, _, err = client.Exchange(message, h.DnsResolvers[i])
+ }
+
if err != nil {
return nil, err
}
+ if len(in.Answer) == 0 {
+ return nil, nil
+ }
+
+ return in.Answer, nil
+}
+
+func resultSetFound(answers []dns.RR, domain string, qtype uint16) bool {
+ for _, answer := range answers {
+ if answer.Header().Name == domain && answer.Header().Rrtype == qtype {
+ return true
+ }
+ }
+ return false
+}
+
+func (h *DnsHandler) recursiveResolve(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) {
+ internalCnames, err := database.FindDNSRecords(h.DbConn, domain, "CNAME")
+ if err != nil {
+ return nil, true, err
+ }
+
+ authoritative := true
var answers []dns.RR
for _, record := range internalCnames {
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
if err != nil {
log.Println(err)
- return nil, err
+ return nil, authoritative, err
}
answers = append(answers, cname)
- cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
+ cnameRecursive, cnameAuth, err := h.resolveDNS(record.Content, qtype, maxDepth-1)
+ authoritative = authoritative && cnameAuth
if err != nil {
log.Println(err)
- return nil, err
+ return nil, authoritative, err
}
+
answers = append(answers, cnameRecursive...)
}
qtypeName := dns.TypeToString[qtype]
- if qtypeName == "" {
- return nil, fmt.Errorf("invalid query type %d", qtype)
- }
-
- typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName)
+ records, err := database.FindDNSRecords(h.DbConn, domain, qtypeName)
if err != nil {
- return nil, err
+ return nil, authoritative, err
}
- for _, record := range typeDnsRecords {
+
+ for _, record := range records {
answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
if err != nil {
- return nil, err
+ return nil, authoritative, err
}
answers = append(answers, answer)
}
- return answers, nil
+ return answers, authoritative, nil
}
-func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
+func (h *DnsHandler) resolveDNS(domain string, qtype uint16, maxDepth int) ([]dns.RR, bool, error) {
+ log.Println("resolving", domain, dns.TypeToString[qtype], maxDepth)
if maxDepth == 0 {
- return nil, fmt.Errorf("too much recursion")
+ return nil, false, fmt.Errorf("too much recursion")
}
- answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
+ answers, authoritative, err := h.recursiveResolve(domain, qtype, maxDepth)
if err != nil {
- return nil, err
+ return nil, false, err
}
- return answers, nil
-}
+ if len(answers) > 0 { // base case - we got the answer
+ return answers, authoritative, nil
+ }
-type DnsHandler struct {
- DnsResolvers []string
- DbConn *sql.DB
+ externalAnswers, err := h.resolveExternal(domain, qtype)
+ if err != nil {
+ return nil, false, err
+ }
+
+ answers = append(answers, externalAnswers...)
+ if resultSetFound(externalAnswers, domain, qtype) {
+ return answers, false, nil
+ }
+
+ for _, answer := range externalAnswers {
+ cname, ok := answer.(*dns.CNAME)
+ if !ok {
+ continue
+ }
+
+ cnameAnswers, cnameAuth, err := h.resolveDNS(cname.Target, qtype, maxDepth-1)
+ authoritative = authoritative && cnameAuth
+ if err != nil {
+ return nil, false, err
+ }
+ answers = append(answers, cnameAnswers...)
+ }
+
+ authoritative = authoritative && len(externalAnswers) == 0
+ return answers, authoritative, nil
}
func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
- msg := new(dns.Msg)
+ msg := &dns.Msg{}
msg.SetReply(r)
- msg.Authoritative = true
+ msg.Authoritative = false
for _, question := range r.Question {
- answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
+ answers, authoritative, err := h.resolveDNS(question.Name, question.Qtype, MAX_RECURSION)
+ msg.Authoritative = authoritative
if err != nil {
fmt.Println(err)
msg.SetRcode(r, dns.RcodeServerFailure)
@@ -98,7 +166,8 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
handler := &DnsHandler{
- DbConn: dbConn,
+ DbConn: dbConn,
+ DnsResolvers: argv.DnsResolvers,
}
addr := fmt.Sprintf(":%d", argv.DnsPort)
diff --git a/hcdns/server_test.go b/hcdns/server_test.go
index 177def4..f1b283f 100644
--- a/hcdns/server_test.go
+++ b/hcdns/server_test.go
@@ -19,9 +19,8 @@ func randomPort() int {
return rand.Intn(3000) + 5192
}
-func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
+func setup(arguments *args.Arguments) (*sql.DB, *dns.Server, string, func()) {
randomDb := utils.RandomId()
- dnsPort := randomPort()
testDb := database.MakeConn(&randomDb)
database.Migrate(testDb)
@@ -30,10 +29,15 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
}
database.FindOrSaveUser(testDb, testUser)
+ dnsArguments := arguments
+ if dnsArguments == nil {
+ dnsArguments = &args.Arguments{
+ DnsPort: randomPort(),
+ }
+ }
+
waitLock := &sync.Mutex{}
- server := hcdns.MakeServer(&args.Arguments{
- DnsPort: dnsPort,
- }, testDb)
+ server := hcdns.MakeServer(dnsArguments, testDb)
server.NotifyStartedFunc = func() {
waitLock.Unlock()
}
@@ -44,8 +48,9 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
}()
waitLock.Lock()
- address := fmt.Sprintf("127.0.0.1:%d", dnsPort)
- return testDb, server, &address, waitLock, func() {
+ address := fmt.Sprintf("127.0.0.1:%d", dnsArguments.DnsPort)
+ return testDb, server, address, func() {
+ waitLock.Unlock()
server.Shutdown()
testDb.Close()
@@ -54,11 +59,8 @@ func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
}
func TestWhenCNAMEIsResolved(t *testing.T) {
- t.Log("TestWhenCNAMEIsResolved")
-
- testDb, _, addr, lock, cleanup := setup()
+ testDb, _, addr, cleanup := setup(nil)
defer cleanup()
- defer lock.Unlock()
records := []*database.DNSRecord{
{
@@ -99,7 +101,7 @@ func TestWhenCNAMEIsResolved(t *testing.T) {
message := &dns.Msg{}
message.SetQuestion(domain, qtype)
- in, _, err := client.Exchange(message, *addr)
+ in, _, err := client.Exchange(message, addr)
if err != nil {
t.Fatal(err)
}
@@ -132,11 +134,8 @@ func TestWhenCNAMEIsResolved(t *testing.T) {
}
func TestWhenNoRecordNxDomain(t *testing.T) {
- t.Log("TestWhenNoRecordNxDomain")
-
- _, _, addr, lock, cleanup := setup()
+ _, _, addr, cleanup := setup(nil)
defer cleanup()
- defer lock.Unlock()
qtype := dns.TypeA
domain := dns.Fqdn("nonexistant.example.com.")
@@ -144,7 +143,7 @@ func TestWhenNoRecordNxDomain(t *testing.T) {
message := &dns.Msg{}
message.SetQuestion(domain, qtype)
- in, _, err := client.Exchange(message, *addr)
+ in, _, err := client.Exchange(message, addr)
if err != nil {
t.Fatal(err)
@@ -160,11 +159,8 @@ func TestWhenNoRecordNxDomain(t *testing.T) {
}
func TestWhenUnresolvingCNAME(t *testing.T) {
- t.Log("TestWhenUnresolvingCNAME")
-
- testDb, _, addr, lock, cleanup := setup()
+ testDb, _, addr, cleanup := setup(nil)
defer cleanup()
- defer lock.Unlock()
cname := &database.DNSRecord{
ID: "1",
@@ -183,7 +179,7 @@ func TestWhenUnresolvingCNAME(t *testing.T) {
message := &dns.Msg{}
message.SetQuestion(domain, qtype)
- in, _, err := client.Exchange(message, *addr)
+ in, _, err := client.Exchange(message, addr)
if err != nil {
t.Fatal(err)
@@ -215,11 +211,8 @@ func TestWhenUnresolvingCNAME(t *testing.T) {
}
func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
- t.Log("TestWhenUnresolvingCNAMEWithMaxDepth")
-
- testDb, _, addr, lock, cleanup := setup()
+ testDb, _, addr, cleanup := setup(nil)
defer cleanup()
- defer lock.Unlock()
cname := &database.DNSRecord{
ID: "1",
@@ -238,7 +231,7 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
message := &dns.Msg{}
message.SetQuestion(domain, qtype)
- in, _, err := client.Exchange(message, *addr)
+ in, _, err := client.Exchange(message, addr)
if err != nil {
t.Fatal(err)
@@ -252,3 +245,86 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
}
}
+
+func TestWhenExternalDomain(t *testing.T) {
+ externalDb, _, externalAddr, externalCleanup := setup(nil)
+ internalDb, _, internalAddr, internalCleanup := setup(&args.Arguments{
+ DnsPort: randomPort(),
+ DnsResolvers: []string{externalAddr},
+ })
+ defer internalCleanup()
+ defer externalCleanup()
+
+ authoritativeRecords := []database.DNSRecord{
+ {
+ ID: "1",
+ UserID: "test",
+ Name: "external.example.com.",
+ Type: "CNAME",
+ Content: "external.internal.example.com.",
+ },
+ {
+ ID: "2",
+ UserID: "test",
+ Name: "final.example.com.",
+ Type: "A",
+ Content: "127.0.0.1",
+ },
+ }
+ internalRecords := []database.DNSRecord{
+ {
+ ID: "1",
+ UserID: "test",
+ Name: "external.internal.example.com.",
+ Type: "CNAME",
+ Content: "final.example.com",
+ },
+ {
+ ID: "2",
+ UserID: "test",
+ Name: "test.internal.example.com.",
+ Type: "CNAME",
+ Content: "external.example.com.",
+ },
+ }
+
+ for _, record := range authoritativeRecords {
+ database.SaveDNSRecord(externalDb, &record)
+ }
+ for _, record := range internalRecords {
+ database.SaveDNSRecord(internalDb, &record)
+ }
+
+ // ensure that if the record doesn't exist in the internal database, it will
+ // go and query the external dns resolvers, then loop back to the internal
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn("test.internal.example.com.")
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, internalAddr)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) != 4 {
+ t.Fatalf("expected 4 answers, got %d", len(in.Answer))
+ }
+
+ aRecord := in.Answer[3]
+ if aRecord.Header().Name != authoritativeRecords[1].Name {
+ t.Fatalf("expected %s, got %s", domain, aRecord.Header().Name)
+ }
+ if aRecord.Header().Rrtype != dns.TypeA {
+ t.Fatalf("expected %s, got %s", dns.TypeToString[aRecord.Header().Rrtype], internalRecords[1].Type)
+ }
+ if aRecord.(*dns.A).A.String() != authoritativeRecords[1].Content {
+ t.Fatalf("expected %s, got %s", authoritativeRecords[1].Content, aRecord.(*dns.A).A.String())
+ }
+ if in.Authoritative {
+ t.Fatalf("expected non-authoritative response")
+ }
+}