diff options
| author | Elizabeth <elizabeth@simponic.xyz> | 2024-04-02 16:26:39 -0600 |
|---|---|---|
| committer | Elizabeth <elizabeth@simponic.xyz> | 2024-04-02 16:26:39 -0600 |
| commit | bcdcc508ef4a0ae646937c91d0994a90bde719e1 (patch) | |
| tree | e5ee5aa97d2bb898428bb278973c16cc28aaa16a /hcdns | |
| parent | 657be669482462ada3b88672ff7497b652848176 (diff) | |
| download | hatecomputers.club-bcdcc508ef4a0ae646937c91d0994a90bde719e1.tar.gz hatecomputers.club-bcdcc508ef4a0ae646937c91d0994a90bde719e1.zip | |
add integration tests for dns server
Diffstat (limited to 'hcdns')
| -rw-r--r-- | hcdns/server.go | 112 |
1 files changed, 112 insertions, 0 deletions
diff --git a/hcdns/server.go b/hcdns/server.go new file mode 100644 index 0000000..ce7894b --- /dev/null +++ b/hcdns/server.go @@ -0,0 +1,112 @@ +package hcdns + +import ( + "database/sql" + "fmt" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "github.com/miekg/dns" + "log" +) + +const MAX_RECURSION = 15 + +func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { + internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") + if err != nil { + return nil, err + } + + var answers []dns.RR + for _, record := range internalCnames { + cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) + if err != nil { + log.Println(err) + return nil, err + } + answers = append(answers, cname) + + cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) + if err != nil { + log.Println(err) + return nil, err + } + answers = append(answers, cnameRecursive...) + } + + qtypeName := dns.TypeToString[qtype] + if qtypeName == "" { + return nil, fmt.Errorf("invalid query type %d", qtype) + } + + typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName) + if err != nil { + return nil, err + } + for _, record := range typeDnsRecords { + answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) + if err != nil { + return nil, err + } + answers = append(answers, answer) + } + + return answers, nil +} + +func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { + if maxDepth == 0 { + return nil, fmt.Errorf("too much recursion") + } + + answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth) + if err != nil { + return nil, err + } + + return answers, nil +} + +type DnsHandler struct { + DnsResolvers []string + DbConn *sql.DB +} + +func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + msg := new(dns.Msg) + msg.SetReply(r) + msg.Authoritative = true + + for _, question := range r.Question { + answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) + if err != nil { + fmt.Println(err) + msg.SetRcode(r, dns.RcodeServerFailure) + w.WriteMsg(msg) + return + } + msg.Answer = append(msg.Answer, answers...) + } + + if len(msg.Answer) == 0 { + msg.SetRcode(r, dns.RcodeNameError) + } + + log.Println(msg.Answer) + w.WriteMsg(msg) +} + +func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { + handler := &DnsHandler{ + DbConn: dbConn, + } + addr := fmt.Sprintf(":%d", argv.DnsPort) + + return &dns.Server{ + Addr: addr, + Net: "udp", + Handler: handler, + UDPSize: 65535, + ReusePort: true, + } +} |
