summaryrefslogtreecommitdiff
path: root/dns/server.go
blob: 9b3e5e90c4a9f86e71abf791afad1628987023d6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
package dns

import (
	"database/sql"
	"fmt"
	"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
	"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
	"github.com/miekg/dns"
	"log"
)

const MAX_RECURSION = 10

func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
	internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
	if err != nil {
		return nil, err
	}

	var answers []dns.RR
	for _, record := range internalCnames {
		cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
		if err != nil {
			return nil, err
		}
		answers = append(answers, cname)

		cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
		if err != nil {
			return nil, err
		}
		answers = append(answers, cnameRecursive...)
	}

	qtypeName := dns.TypeToString[qtype]
	if qtypeName == "" {
		return nil, fmt.Errorf("invalid query type %d", qtype)
	}

	typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName)
	if err != nil {
		return nil, err
	}
	for _, record := range typeDnsRecords {
		answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
		if err != nil {
			return nil, err
		}
		answers = append(answers, answer)
	}

	return answers, nil
}

func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
	if maxDepth == 0 {
		return nil, fmt.Errorf("too much recursion")
	}

	answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
	if err != nil {
		return nil, err
	}

	if len(answers) > 0 {
		return answers, nil
	}

	return nil, fmt.Errorf("no records found for %s", domain)
}

type DnsHandler struct {
	DnsResolvers []string
	DbConn       *sql.DB
}

func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
	msg := new(dns.Msg)
	msg.SetReply(r)
	msg.Authoritative = true

	for _, question := range r.Question {
		answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
		if err != nil {
			fmt.Println(err)
			continue
		}
		msg.Answer = append(msg.Answer, answers...)
	}

	if len(msg.Answer) == 0 {
		msg.SetRcode(r, dns.RcodeNameError)
	}

	log.Println(msg.Answer)
	w.WriteMsg(msg)
}

func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
	handler := &DnsHandler{
		DnsResolvers: argv.DnsRecursion,
		DbConn:       dbConn,
	}
	addr := fmt.Sprintf(":%d", argv.DnsPort)

	return &dns.Server{
		Addr:      addr,
		Net:       "udp",
		Handler:   handler,
		UDPSize:   65535,
		ReusePort: true,
	}
}