diff options
Diffstat (limited to 'hcdns')
| -rw-r--r-- | hcdns/server.go | 112 | ||||
| -rw-r--r-- | hcdns/server_test.go | 254 |
2 files changed, 366 insertions, 0 deletions
diff --git a/hcdns/server.go b/hcdns/server.go new file mode 100644 index 0000000..ce7894b --- /dev/null +++ b/hcdns/server.go @@ -0,0 +1,112 @@ +package hcdns + +import ( + "database/sql" + "fmt" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "github.com/miekg/dns" + "log" +) + +const MAX_RECURSION = 15 + +func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { + internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") + if err != nil { + return nil, err + } + + var answers []dns.RR + for _, record := range internalCnames { + cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) + if err != nil { + log.Println(err) + return nil, err + } + answers = append(answers, cname) + + cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) + if err != nil { + log.Println(err) + return nil, err + } + answers = append(answers, cnameRecursive...) + } + + qtypeName := dns.TypeToString[qtype] + if qtypeName == "" { + return nil, fmt.Errorf("invalid query type %d", qtype) + } + + typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName) + if err != nil { + return nil, err + } + for _, record := range typeDnsRecords { + answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) + if err != nil { + return nil, err + } + answers = append(answers, answer) + } + + return answers, nil +} + +func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { + if maxDepth == 0 { + return nil, fmt.Errorf("too much recursion") + } + + answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth) + if err != nil { + return nil, err + } + + return answers, nil +} + +type DnsHandler struct { + DnsResolvers []string + DbConn *sql.DB +} + +func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + msg := new(dns.Msg) + msg.SetReply(r) + msg.Authoritative = true + + for _, question := range r.Question { + answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) + if err != nil { + fmt.Println(err) + msg.SetRcode(r, dns.RcodeServerFailure) + w.WriteMsg(msg) + return + } + msg.Answer = append(msg.Answer, answers...) + } + + if len(msg.Answer) == 0 { + msg.SetRcode(r, dns.RcodeNameError) + } + + log.Println(msg.Answer) + w.WriteMsg(msg) +} + +func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { + handler := &DnsHandler{ + DbConn: dbConn, + } + addr := fmt.Sprintf(":%d", argv.DnsPort) + + return &dns.Server{ + Addr: addr, + Net: "udp", + Handler: handler, + UDPSize: 65535, + ReusePort: true, + } +} diff --git a/hcdns/server_test.go b/hcdns/server_test.go new file mode 100644 index 0000000..177def4 --- /dev/null +++ b/hcdns/server_test.go @@ -0,0 +1,254 @@ +package hcdns_test + +import ( + "database/sql" + "fmt" + "math/rand" + "os" + "sync" + "testing" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" + "github.com/miekg/dns" +) + +func randomPort() int { + return rand.Intn(3000) + 5192 +} + +func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { + randomDb := utils.RandomId() + dnsPort := randomPort() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + testUser := &database.User{ + ID: "test", + } + database.FindOrSaveUser(testDb, testUser) + + waitLock := &sync.Mutex{} + server := hcdns.MakeServer(&args.Arguments{ + DnsPort: dnsPort, + }, testDb) + server.NotifyStartedFunc = func() { + waitLock.Unlock() + } + waitLock.Lock() + + go func() { + server.ListenAndServe() + }() + waitLock.Lock() + + address := fmt.Sprintf("127.0.0.1:%d", dnsPort) + return testDb, server, &address, waitLock, func() { + server.Shutdown() + + testDb.Close() + os.Remove(randomDb) + } +} + +func TestWhenCNAMEIsResolved(t *testing.T) { + t.Log("TestWhenCNAMEIsResolved") + + testDb, _, addr, lock, cleanup := setup() + defer cleanup() + defer lock.Unlock() + + records := []*database.DNSRecord{ + { + ID: "0", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "next.internal.example.com.", + TTL: 300, + Internal: true, + }, { + ID: "1", + UserID: "test", + Name: "next.internal.example.com.", + Type: "CNAME", + Content: "res.example.com.", + TTL: 300, + Internal: true, + }, + { + ID: "2", + UserID: "test", + Name: "res.example.com.", + Type: "A", + Content: "1.2.3.2", + TTL: 300, + Internal: true, + }, + } + + for _, record := range records { + database.SaveDNSRecord(testDb, record) + } + + qtype := dns.TypeA + domain := dns.Fqdn("cname.internal.example.com.") + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, *addr) + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 3 { + t.Fatalf("expected 3 answers, got %d", len(in.Answer)) + } + + for i, record := range records { + if in.Answer[i].Header().Name != record.Name { + t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name) + } + + if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] { + t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype) + } + + if int(in.Answer[i].Header().Ttl) != record.TTL { + t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } + } + + if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" { + t.Fatalf("expected final record to be the A record with correct IP") + } +} + +func TestWhenNoRecordNxDomain(t *testing.T) { + t.Log("TestWhenNoRecordNxDomain") + + _, _, addr, lock, cleanup := setup() + defer cleanup() + defer lock.Unlock() + + qtype := dns.TypeA + domain := dns.Fqdn("nonexistant.example.com.") + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, *addr) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 0 { + t.Fatalf("expected 0 answers, got %d", len(in.Answer)) + } + + if in.Rcode != dns.RcodeNameError { + t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) + } +} + +func TestWhenUnresolvingCNAME(t *testing.T) { + t.Log("TestWhenUnresolvingCNAME") + + testDb, _, addr, lock, cleanup := setup() + defer cleanup() + defer lock.Unlock() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "nonexistant.example.com.", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, *addr) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 1 { + t.Fatalf("expected 1 answer, got %d", len(in.Answer)) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } + + if in.Answer[0].Header().Name != cname.Name { + t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) + } + + if in.Answer[0].Header().Rrtype != dns.TypeCNAME { + t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) + } + + if in.Answer[0].(*dns.CNAME).Target != cname.Content { + t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) + } + + if in.Rcode == dns.RcodeNameError { + t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) + } +} + +func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { + t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") + + testDb, _, addr, lock, cleanup := setup() + defer cleanup() + defer lock.Unlock() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "cname.internal.example.com.", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, *addr) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) > 0 { + t.Fatalf("expected 0 answers, got %d", len(in.Answer)) + } + + if in.Rcode != dns.RcodeServerFailure { + t.Fatalf("expected SERVFAIL, got %d", in.Rcode) + } +} |
