summaryrefslogtreecommitdiff
path: root/hcdns
diff options
context:
space:
mode:
Diffstat (limited to 'hcdns')
-rw-r--r--hcdns/server.go112
1 files changed, 112 insertions, 0 deletions
diff --git a/hcdns/server.go b/hcdns/server.go
new file mode 100644
index 0000000..ce7894b
--- /dev/null
+++ b/hcdns/server.go
@@ -0,0 +1,112 @@
+package hcdns
+
+import (
+ "database/sql"
+ "fmt"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "github.com/miekg/dns"
+ "log"
+)
+
+const MAX_RECURSION = 15
+
+func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
+ internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
+ if err != nil {
+ return nil, err
+ }
+
+ var answers []dns.RR
+ for _, record := range internalCnames {
+ cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
+ if err != nil {
+ log.Println(err)
+ return nil, err
+ }
+ answers = append(answers, cname)
+
+ cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
+ if err != nil {
+ log.Println(err)
+ return nil, err
+ }
+ answers = append(answers, cnameRecursive...)
+ }
+
+ qtypeName := dns.TypeToString[qtype]
+ if qtypeName == "" {
+ return nil, fmt.Errorf("invalid query type %d", qtype)
+ }
+
+ typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName)
+ if err != nil {
+ return nil, err
+ }
+ for _, record := range typeDnsRecords {
+ answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
+ if err != nil {
+ return nil, err
+ }
+ answers = append(answers, answer)
+ }
+
+ return answers, nil
+}
+
+func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
+ if maxDepth == 0 {
+ return nil, fmt.Errorf("too much recursion")
+ }
+
+ answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
+ if err != nil {
+ return nil, err
+ }
+
+ return answers, nil
+}
+
+type DnsHandler struct {
+ DnsResolvers []string
+ DbConn *sql.DB
+}
+
+func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
+ msg := new(dns.Msg)
+ msg.SetReply(r)
+ msg.Authoritative = true
+
+ for _, question := range r.Question {
+ answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
+ if err != nil {
+ fmt.Println(err)
+ msg.SetRcode(r, dns.RcodeServerFailure)
+ w.WriteMsg(msg)
+ return
+ }
+ msg.Answer = append(msg.Answer, answers...)
+ }
+
+ if len(msg.Answer) == 0 {
+ msg.SetRcode(r, dns.RcodeNameError)
+ }
+
+ log.Println(msg.Answer)
+ w.WriteMsg(msg)
+}
+
+func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
+ handler := &DnsHandler{
+ DbConn: dbConn,
+ }
+ addr := fmt.Sprintf(":%d", argv.DnsPort)
+
+ return &dns.Server{
+ Addr: addr,
+ Net: "udp",
+ Handler: handler,
+ UDPSize: 65535,
+ ReusePort: true,
+ }
+}