summaryrefslogtreecommitdiff
path: root/dns/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'dns/server.go')
-rw-r--r--dns/server.go110
1 files changed, 110 insertions, 0 deletions
diff --git a/dns/server.go b/dns/server.go
new file mode 100644
index 0000000..63bb067
--- /dev/null
+++ b/dns/server.go
@@ -0,0 +1,110 @@
+package dns
+
+import (
+ "database/sql"
+ "fmt"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "github.com/miekg/dns"
+ "log"
+)
+
+const MAX_RECURSION = 10
+
+func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
+ if maxDepth == 0 {
+ return nil, fmt.Errorf("too much recursion")
+ }
+
+ internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
+ if err != nil {
+ return nil, err
+ }
+
+ answers := []dns.RR{}
+ for _, record := range internalCnames {
+ cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1)
+ answers = append(answers, cnameRecursive...)
+ }
+
+ qtypeName := dns.TypeToString[qtype]
+ if qtypeName == "" {
+ return nil, fmt.Errorf("invalid query type %d", qtype)
+ }
+
+ typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName)
+ if err != nil {
+ return nil, err
+ }
+ for _, record := range typeDnsRecords {
+ answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content))
+ if err != nil {
+ return nil, err
+ }
+ answers = append(answers, answer)
+ }
+
+ if len(answers) > 0 {
+ // base case; we found the answer
+ return answers, nil
+ }
+
+ message := new(dns.Msg)
+ message.SetQuestion(dns.Fqdn(domain), qtype)
+ message.RecursionDesired = true
+
+ client := new(dns.Client)
+
+ i := 0
+ in, _, err := client.Exchange(message, dnsResolvers[i])
+ for err != nil {
+ i += 1
+ if i == len(dnsResolvers) {
+ log.Println(err)
+ return nil, err
+ }
+ in, _, err = client.Exchange(message, dnsResolvers[i])
+ }
+
+ answers = append(answers, in.Answer...)
+ return answers, nil
+}
+
+type DnsHandler struct {
+ DnsResolvers []string
+ DbConn *sql.DB
+}
+
+func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
+ msg := new(dns.Msg)
+ msg.SetReply(r)
+ msg.Authoritative = true
+
+ for _, question := range r.Question {
+ answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION)
+ if err != nil {
+ fmt.Println(err)
+ continue
+ }
+ msg.Answer = append(msg.Answer, answers...)
+ }
+
+ log.Println(msg.Answer)
+ w.WriteMsg(msg)
+}
+
+func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
+ handler := &DnsHandler{
+ DnsResolvers: argv.DnsRecursion,
+ DbConn: dbConn,
+ }
+ addr := fmt.Sprintf(":%d", argv.DnsPort)
+
+ return &dns.Server{
+ Addr: addr,
+ Net: "udp",
+ Handler: handler,
+ UDPSize: 65535,
+ ReusePort: true,
+ }
+}