From bcdcc508ef4a0ae646937c91d0994a90bde719e1 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Tue, 2 Apr 2024 16:26:39 -0600 Subject: add integration tests for dns server --- dns/server.go | 113 -------------------------- hcdns/server.go | 112 +++++++++++++++++++++++++ main.go | 4 +- test/dns_test.go | 244 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 358 insertions(+), 115 deletions(-) delete mode 100644 dns/server.go create mode 100644 hcdns/server.go create mode 100644 test/dns_test.go diff --git a/dns/server.go b/dns/server.go deleted file mode 100644 index 9b3e5e9..0000000 --- a/dns/server.go +++ /dev/null @@ -1,113 +0,0 @@ -package dns - -import ( - "database/sql" - "fmt" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "github.com/miekg/dns" - "log" -) - -const MAX_RECURSION = 10 - -func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { - internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") - if err != nil { - return nil, err - } - - var answers []dns.RR - for _, record := range internalCnames { - cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) - if err != nil { - return nil, err - } - answers = append(answers, cname) - - cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) - if err != nil { - return nil, err - } - answers = append(answers, cnameRecursive...) - } - - qtypeName := dns.TypeToString[qtype] - if qtypeName == "" { - return nil, fmt.Errorf("invalid query type %d", qtype) - } - - typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName) - if err != nil { - return nil, err - } - for _, record := range typeDnsRecords { - answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) - if err != nil { - return nil, err - } - answers = append(answers, answer) - } - - return answers, nil -} - -func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { - if maxDepth == 0 { - return nil, fmt.Errorf("too much recursion") - } - - answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth) - if err != nil { - return nil, err - } - - if len(answers) > 0 { - return answers, nil - } - - return nil, fmt.Errorf("no records found for %s", domain) -} - -type DnsHandler struct { - DnsResolvers []string - DbConn *sql.DB -} - -func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - msg := new(dns.Msg) - msg.SetReply(r) - msg.Authoritative = true - - for _, question := range r.Question { - answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) - if err != nil { - fmt.Println(err) - continue - } - msg.Answer = append(msg.Answer, answers...) - } - - if len(msg.Answer) == 0 { - msg.SetRcode(r, dns.RcodeNameError) - } - - log.Println(msg.Answer) - w.WriteMsg(msg) -} - -func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { - handler := &DnsHandler{ - DnsResolvers: argv.DnsRecursion, - DbConn: dbConn, - } - addr := fmt.Sprintf(":%d", argv.DnsPort) - - return &dns.Server{ - Addr: addr, - Net: "udp", - Handler: handler, - UDPSize: 65535, - ReusePort: true, - } -} diff --git a/hcdns/server.go b/hcdns/server.go new file mode 100644 index 0000000..ce7894b --- /dev/null +++ b/hcdns/server.go @@ -0,0 +1,112 @@ +package hcdns + +import ( + "database/sql" + "fmt" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "github.com/miekg/dns" + "log" +) + +const MAX_RECURSION = 15 + +func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { + internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") + if err != nil { + return nil, err + } + + var answers []dns.RR + for _, record := range internalCnames { + cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) + if err != nil { + log.Println(err) + return nil, err + } + answers = append(answers, cname) + + cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) + if err != nil { + log.Println(err) + return nil, err + } + answers = append(answers, cnameRecursive...) + } + + qtypeName := dns.TypeToString[qtype] + if qtypeName == "" { + return nil, fmt.Errorf("invalid query type %d", qtype) + } + + typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName) + if err != nil { + return nil, err + } + for _, record := range typeDnsRecords { + answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) + if err != nil { + return nil, err + } + answers = append(answers, answer) + } + + return answers, nil +} + +func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { + if maxDepth == 0 { + return nil, fmt.Errorf("too much recursion") + } + + answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth) + if err != nil { + return nil, err + } + + return answers, nil +} + +type DnsHandler struct { + DnsResolvers []string + DbConn *sql.DB +} + +func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + msg := new(dns.Msg) + msg.SetReply(r) + msg.Authoritative = true + + for _, question := range r.Question { + answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) + if err != nil { + fmt.Println(err) + msg.SetRcode(r, dns.RcodeServerFailure) + w.WriteMsg(msg) + return + } + msg.Answer = append(msg.Answer, answers...) + } + + if len(msg.Answer) == 0 { + msg.SetRcode(r, dns.RcodeNameError) + } + + log.Println(msg.Answer) + w.WriteMsg(msg) +} + +func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { + handler := &DnsHandler{ + DbConn: dbConn, + } + addr := fmt.Sprintf(":%d", argv.DnsPort) + + return &dns.Server{ + Addr: addr, + Net: "udp", + Handler: handler, + UDPSize: 65535, + ReusePort: true, + } +} diff --git a/main.go b/main.go index 2991821..e0f3e55 100644 --- a/main.go +++ b/main.go @@ -6,7 +6,7 @@ import ( "git.hatecomputers.club/hatecomputers/hatecomputers.club/api" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/dns" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" "git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler" "github.com/joho/godotenv" ) @@ -52,7 +52,7 @@ func main() { } if argv.Dns { - server := dns.MakeServer(argv, dbConn) + server := hcdns.MakeServer(argv, dbConn) log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort) go func() { err = server.ListenAndServe() diff --git a/test/dns_test.go b/test/dns_test.go new file mode 100644 index 0000000..ce6deb5 --- /dev/null +++ b/test/dns_test.go @@ -0,0 +1,244 @@ +package hcdns + +import ( + "database/sql" + "os" + "sync" + "testing" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" + "github.com/miekg/dns" +) + +const ( + testDBPath = "test.db" + address = "127.0.0.1:8353" + dnsPort = 8353 +) + +func setup(dbPath string) (*sql.DB, *dns.Server, *sync.WaitGroup) { + testDb := database.MakeConn(&dbPath) + database.Migrate(testDb) + testUser := &database.User{ + ID: "test", + } + database.FindOrSaveUser(testDb, testUser) + + server := hcdns.MakeServer(&args.Arguments{ + DnsPort: dnsPort, + }, testDb) + + waitGroup := sync.WaitGroup{} + waitGroup.Add(1) + go func() { + server.ListenAndServe() + waitGroup.Done() + }() + + return testDb, server, &waitGroup +} + +func destroy(conn *sql.DB, path string) { + conn.Close() + os.Remove(path) +} + +func TestWhenCNAMEIsResolved(t *testing.T) { + t.Log("TestWhenCNAMEIsResolved") + + testDb, server, _ := setup(testDBPath) + defer destroy(testDb, testDBPath) + defer server.Shutdown() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "res.example.com.", + TTL: 300, + Internal: true, + } + a := &database.DNSRecord{ + ID: "2", + UserID: "test", + Name: "res.example.com.", + Type: "A", + Content: "127.0.0.1", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + database.SaveDNSRecord(testDb, a) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := new(dns.Client) + message := new(dns.Msg) + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, address) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 2 { + t.Fatalf("expected 2 answers, got %d", len(in.Answer)) + } + + if in.Answer[0].Header().Name != cname.Name { + t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) + } + + if in.Answer[1].Header().Name != a.Name { + t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name) + } + + if in.Answer[0].(*dns.CNAME).Target != a.Name { + t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) + } + + if in.Answer[1].(*dns.A).A.String() != a.Content { + t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String()) + } + + if in.Answer[0].Header().Rrtype != dns.TypeCNAME { + t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) + } + + if in.Answer[1].Header().Rrtype != dns.TypeA { + t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype) + } + + if int(in.Answer[0].Header().Ttl) != cname.TTL { + t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } +} + +func TestWhenNoRecordNxDomain(t *testing.T) { + t.Log("TestWhenNoRecordNxDomain") + + testDb, server, _ := setup(testDBPath) + defer destroy(testDb, testDBPath) + defer server.Shutdown() + + qtype := dns.TypeA + domain := dns.Fqdn("nonexistant.example.com.") + client := new(dns.Client) + message := new(dns.Msg) + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, address) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 0 { + t.Fatalf("expected 0 answers, got %d", len(in.Answer)) + } + + if in.Rcode != dns.RcodeNameError { + t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) + } +} + +func TestWhenUnresolvingCNAME(t *testing.T) { + t.Log("TestWhenUnresolvingCNAME") + + testDb, server, _ := setup(testDBPath) + defer destroy(testDb, testDBPath) + defer server.Shutdown() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "nonexistant.example.com.", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := new(dns.Client) + message := new(dns.Msg) + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, address) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 1 { + t.Fatalf("expected 1 answer, got %d", len(in.Answer)) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } + + if in.Answer[0].Header().Name != cname.Name { + t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) + } + + if in.Answer[0].Header().Rrtype != dns.TypeCNAME { + t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) + } + + if in.Answer[0].(*dns.CNAME).Target != cname.Content { + t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) + } + + if in.Rcode == dns.RcodeNameError { + t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) + } +} + +func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { + t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") + + testDb, server, _ := setup(testDBPath) + defer destroy(testDb, testDBPath) + defer server.Shutdown() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "cname.internal.example.com.", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := new(dns.Client) + message := new(dns.Msg) + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, address) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) > 0 { + t.Fatalf("expected 0 answers, got %d", len(in.Answer)) + } + if in.Rcode != dns.RcodeServerFailure { + t.Fatalf("expected SERVFAIL, got %d", in.Rcode) + } +} -- cgit v1.2.3-70-g09d2