summaryrefslogtreecommitdiff
path: root/hcdns/server.go
diff options
context:
space:
mode:
authorsimponic <simponic@hatecomputers.club>2024-04-06 15:43:18 -0400
committersimponic <simponic@hatecomputers.club>2024-04-06 15:43:18 -0400
commit83cc6267fd5ce2f61200314424c5f400f65ff2ba (patch)
treeeafb35310236a15572cbb6e16ff8d6f181bfe240 /hcdns/server.go
parent569d2788ebfb90774faf361f62bfe7968e091465 (diff)
parentcad8e2c4ed5e3bab61ff243f8677f8a46eaeafb0 (diff)
downloadhatecomputers.club-83cc6267fd5ce2f61200314424c5f400f65ff2ba.tar.gz
hatecomputers.club-83cc6267fd5ce2f61200314424c5f400f65ff2ba.zip
Merge pull request 'testing | dont be recursive for external domains | finalize oauth' (#5) from dont-be-authoritative into main
Reviewed-on: https://git.hatecomputers.club/hatecomputers/hatecomputers.club/pulls/5
Diffstat (limited to 'hcdns/server.go')
-rw-r--r--hcdns/server.go112
1 files changed, 112 insertions, 0 deletions
diff --git a/hcdns/server.go b/hcdns/server.go
new file mode 100644
index 0000000..ce7894b
--- /dev/null
+++ b/hcdns/server.go
@@ -0,0 +1,112 @@
+package hcdns
+
+import (
+ "database/sql"
+ "fmt"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "github.com/miekg/dns"
+ "log"
+)
+
+const MAX_RECURSION = 15
+
+func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
+ internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
+ if err != nil {
+ return nil, err
+ }
+
+ var answers []dns.RR
+ for _, record := range internalCnames {
+ cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
+ if err != nil {
+ log.Println(err)
+ return nil, err
+ }
+ answers = append(answers, cname)
+
+ cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
+ if err != nil {
+ log.Println(err)
+ return nil, err
+ }
+ answers = append(answers, cnameRecursive...)
+ }
+
+ qtypeName := dns.TypeToString[qtype]
+ if qtypeName == "" {
+ return nil, fmt.Errorf("invalid query type %d", qtype)
+ }
+
+ typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName)
+ if err != nil {
+ return nil, err
+ }
+ for _, record := range typeDnsRecords {
+ answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
+ if err != nil {
+ return nil, err
+ }
+ answers = append(answers, answer)
+ }
+
+ return answers, nil
+}
+
+func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
+ if maxDepth == 0 {
+ return nil, fmt.Errorf("too much recursion")
+ }
+
+ answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
+ if err != nil {
+ return nil, err
+ }
+
+ return answers, nil
+}
+
+type DnsHandler struct {
+ DnsResolvers []string
+ DbConn *sql.DB
+}
+
+func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
+ msg := new(dns.Msg)
+ msg.SetReply(r)
+ msg.Authoritative = true
+
+ for _, question := range r.Question {
+ answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
+ if err != nil {
+ fmt.Println(err)
+ msg.SetRcode(r, dns.RcodeServerFailure)
+ w.WriteMsg(msg)
+ return
+ }
+ msg.Answer = append(msg.Answer, answers...)
+ }
+
+ if len(msg.Answer) == 0 {
+ msg.SetRcode(r, dns.RcodeNameError)
+ }
+
+ log.Println(msg.Answer)
+ w.WriteMsg(msg)
+}
+
+func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
+ handler := &DnsHandler{
+ DbConn: dbConn,
+ }
+ addr := fmt.Sprintf(":%d", argv.DnsPort)
+
+ return &dns.Server{
+ Addr: addr,
+ Net: "udp",
+ Handler: handler,
+ UDPSize: 65535,
+ ReusePort: true,
+ }
+}