summaryrefslogtreecommitdiff
path: root/dns/server.go
blob: 63bb067c4872cc4a4b6c481ef0625e59c319409a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
package dns

import (
	"database/sql"
	"fmt"
	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
	"github.com/miekg/dns"
	"log"
)

const MAX_RECURSION = 10

func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
	if maxDepth == 0 {
		return nil, fmt.Errorf("too much recursion")
	}

	internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
	if err != nil {
		return nil, err
	}

	answers := []dns.RR{}
	for _, record := range internalCnames {
		cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1)
		answers = append(answers, cnameRecursive...)
	}

	qtypeName := dns.TypeToString[qtype]
	if qtypeName == "" {
		return nil, fmt.Errorf("invalid query type %d", qtype)
	}

	typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName)
	if err != nil {
		return nil, err
	}
	for _, record := range typeDnsRecords {
		answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content))
		if err != nil {
			return nil, err
		}
		answers = append(answers, answer)
	}

	if len(answers) > 0 {
		// base case; we found the answer
		return answers, nil
	}

	message := new(dns.Msg)
	message.SetQuestion(dns.Fqdn(domain), qtype)
	message.RecursionDesired = true

	client := new(dns.Client)

	i := 0
	in, _, err := client.Exchange(message, dnsResolvers[i])
	for err != nil {
		i += 1
		if i == len(dnsResolvers) {
			log.Println(err)
			return nil, err
		}
		in, _, err = client.Exchange(message, dnsResolvers[i])
	}

	answers = append(answers, in.Answer...)
	return answers, nil
}

type DnsHandler struct {
	DnsResolvers []string
	DbConn       *sql.DB
}

func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
	msg := new(dns.Msg)
	msg.SetReply(r)
	msg.Authoritative = true

	for _, question := range r.Question {
		answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION)
		if err != nil {
			fmt.Println(err)
			continue
		}
		msg.Answer = append(msg.Answer, answers...)
	}

	log.Println(msg.Answer)
	w.WriteMsg(msg)
}

func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
	handler := &DnsHandler{
		DnsResolvers: argv.DnsRecursion,
		DbConn:       dbConn,
	}
	addr := fmt.Sprintf(":%d", argv.DnsPort)

	return &dns.Server{
		Addr:      addr,
		Net:       "udp",
		Handler:   handler,
		UDPSize:   65535,
		ReusePort: true,
	}
}