summaryrefslogtreecommitdiff
path: root/hcdns
diff options
context:
space:
mode:
Diffstat (limited to 'hcdns')
-rw-r--r--hcdns/server.go112
-rw-r--r--hcdns/server_test.go254
2 files changed, 366 insertions, 0 deletions
diff --git a/hcdns/server.go b/hcdns/server.go
new file mode 100644
index 0000000..ce7894b
--- /dev/null
+++ b/hcdns/server.go
@@ -0,0 +1,112 @@
+package hcdns
+
+import (
+ "database/sql"
+ "fmt"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "github.com/miekg/dns"
+ "log"
+)
+
+const MAX_RECURSION = 15
+
+func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
+ internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
+ if err != nil {
+ return nil, err
+ }
+
+ var answers []dns.RR
+ for _, record := range internalCnames {
+ cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
+ if err != nil {
+ log.Println(err)
+ return nil, err
+ }
+ answers = append(answers, cname)
+
+ cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
+ if err != nil {
+ log.Println(err)
+ return nil, err
+ }
+ answers = append(answers, cnameRecursive...)
+ }
+
+ qtypeName := dns.TypeToString[qtype]
+ if qtypeName == "" {
+ return nil, fmt.Errorf("invalid query type %d", qtype)
+ }
+
+ typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName)
+ if err != nil {
+ return nil, err
+ }
+ for _, record := range typeDnsRecords {
+ answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
+ if err != nil {
+ return nil, err
+ }
+ answers = append(answers, answer)
+ }
+
+ return answers, nil
+}
+
+func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
+ if maxDepth == 0 {
+ return nil, fmt.Errorf("too much recursion")
+ }
+
+ answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
+ if err != nil {
+ return nil, err
+ }
+
+ return answers, nil
+}
+
+type DnsHandler struct {
+ DnsResolvers []string
+ DbConn *sql.DB
+}
+
+func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
+ msg := new(dns.Msg)
+ msg.SetReply(r)
+ msg.Authoritative = true
+
+ for _, question := range r.Question {
+ answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
+ if err != nil {
+ fmt.Println(err)
+ msg.SetRcode(r, dns.RcodeServerFailure)
+ w.WriteMsg(msg)
+ return
+ }
+ msg.Answer = append(msg.Answer, answers...)
+ }
+
+ if len(msg.Answer) == 0 {
+ msg.SetRcode(r, dns.RcodeNameError)
+ }
+
+ log.Println(msg.Answer)
+ w.WriteMsg(msg)
+}
+
+func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
+ handler := &DnsHandler{
+ DbConn: dbConn,
+ }
+ addr := fmt.Sprintf(":%d", argv.DnsPort)
+
+ return &dns.Server{
+ Addr: addr,
+ Net: "udp",
+ Handler: handler,
+ UDPSize: 65535,
+ ReusePort: true,
+ }
+}
diff --git a/hcdns/server_test.go b/hcdns/server_test.go
new file mode 100644
index 0000000..177def4
--- /dev/null
+++ b/hcdns/server_test.go
@@ -0,0 +1,254 @@
+package hcdns_test
+
+import (
+ "database/sql"
+ "fmt"
+ "math/rand"
+ "os"
+ "sync"
+ "testing"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
+ "github.com/miekg/dns"
+)
+
+func randomPort() int {
+ return rand.Intn(3000) + 5192
+}
+
+func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
+ randomDb := utils.RandomId()
+ dnsPort := randomPort()
+
+ testDb := database.MakeConn(&randomDb)
+ database.Migrate(testDb)
+ testUser := &database.User{
+ ID: "test",
+ }
+ database.FindOrSaveUser(testDb, testUser)
+
+ waitLock := &sync.Mutex{}
+ server := hcdns.MakeServer(&args.Arguments{
+ DnsPort: dnsPort,
+ }, testDb)
+ server.NotifyStartedFunc = func() {
+ waitLock.Unlock()
+ }
+ waitLock.Lock()
+
+ go func() {
+ server.ListenAndServe()
+ }()
+ waitLock.Lock()
+
+ address := fmt.Sprintf("127.0.0.1:%d", dnsPort)
+ return testDb, server, &address, waitLock, func() {
+ server.Shutdown()
+
+ testDb.Close()
+ os.Remove(randomDb)
+ }
+}
+
+func TestWhenCNAMEIsResolved(t *testing.T) {
+ t.Log("TestWhenCNAMEIsResolved")
+
+ testDb, _, addr, lock, cleanup := setup()
+ defer cleanup()
+ defer lock.Unlock()
+
+ records := []*database.DNSRecord{
+ {
+ ID: "0",
+ UserID: "test",
+ Name: "cname.internal.example.com.",
+ Type: "CNAME",
+ Content: "next.internal.example.com.",
+ TTL: 300,
+ Internal: true,
+ }, {
+ ID: "1",
+ UserID: "test",
+ Name: "next.internal.example.com.",
+ Type: "CNAME",
+ Content: "res.example.com.",
+ TTL: 300,
+ Internal: true,
+ },
+ {
+ ID: "2",
+ UserID: "test",
+ Name: "res.example.com.",
+ Type: "A",
+ Content: "1.2.3.2",
+ TTL: 300,
+ Internal: true,
+ },
+ }
+
+ for _, record := range records {
+ database.SaveDNSRecord(testDb, record)
+ }
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn("cname.internal.example.com.")
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, *addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) != 3 {
+ t.Fatalf("expected 3 answers, got %d", len(in.Answer))
+ }
+
+ for i, record := range records {
+ if in.Answer[i].Header().Name != record.Name {
+ t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name)
+ }
+
+ if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] {
+ t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype)
+ }
+
+ if int(in.Answer[i].Header().Ttl) != record.TTL {
+ t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl)
+ }
+
+ if !in.Authoritative {
+ t.Fatalf("expected authoritative response")
+ }
+ }
+
+ if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" {
+ t.Fatalf("expected final record to be the A record with correct IP")
+ }
+}
+
+func TestWhenNoRecordNxDomain(t *testing.T) {
+ t.Log("TestWhenNoRecordNxDomain")
+
+ _, _, addr, lock, cleanup := setup()
+ defer cleanup()
+ defer lock.Unlock()
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn("nonexistant.example.com.")
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, *addr)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) != 0 {
+ t.Fatalf("expected 0 answers, got %d", len(in.Answer))
+ }
+
+ if in.Rcode != dns.RcodeNameError {
+ t.Fatalf("expected NXDOMAIN, got %d", in.Rcode)
+ }
+}
+
+func TestWhenUnresolvingCNAME(t *testing.T) {
+ t.Log("TestWhenUnresolvingCNAME")
+
+ testDb, _, addr, lock, cleanup := setup()
+ defer cleanup()
+ defer lock.Unlock()
+
+ cname := &database.DNSRecord{
+ ID: "1",
+ UserID: "test",
+ Name: "cname.internal.example.com.",
+ Type: "CNAME",
+ Content: "nonexistant.example.com.",
+ TTL: 300,
+ Internal: true,
+ }
+ database.SaveDNSRecord(testDb, cname)
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn(cname.Name)
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, *addr)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) != 1 {
+ t.Fatalf("expected 1 answer, got %d", len(in.Answer))
+ }
+
+ if !in.Authoritative {
+ t.Fatalf("expected authoritative response")
+ }
+
+ if in.Answer[0].Header().Name != cname.Name {
+ t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name)
+ }
+
+ if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
+ t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
+ }
+
+ if in.Answer[0].(*dns.CNAME).Target != cname.Content {
+ t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target)
+ }
+
+ if in.Rcode == dns.RcodeNameError {
+ t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode)
+ }
+}
+
+func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
+ t.Log("TestWhenUnresolvingCNAMEWithMaxDepth")
+
+ testDb, _, addr, lock, cleanup := setup()
+ defer cleanup()
+ defer lock.Unlock()
+
+ cname := &database.DNSRecord{
+ ID: "1",
+ UserID: "test",
+ Name: "cname.internal.example.com.",
+ Type: "CNAME",
+ Content: "cname.internal.example.com.",
+ TTL: 300,
+ Internal: true,
+ }
+ database.SaveDNSRecord(testDb, cname)
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn(cname.Name)
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, *addr)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) > 0 {
+ t.Fatalf("expected 0 answers, got %d", len(in.Answer))
+ }
+
+ if in.Rcode != dns.RcodeServerFailure {
+ t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
+ }
+}