From 60fc4ebb599d82f5c7ddaca52f8aba74f0876381 Mon Sep 17 00:00:00 2001 From: simponic Date: Thu, 28 Mar 2024 16:58:07 -0400 Subject: internal recursive dns server (#2) Co-authored-by: Lizzy Hunt Reviewed-on: https://git.hatecomputers.club/hatecomputers/hatecomputers.club/pulls/2 --- api/dns.go | 74 +++++++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 52 insertions(+), 22 deletions(-) (limited to 'api') diff --git a/api/dns.go b/api/dns.go index 5123acc..0205f5d 100644 --- a/api/dns.go +++ b/api/dns.go @@ -1,6 +1,7 @@ package api import ( + "database/sql" "log" "net/http" "strconv" @@ -8,16 +9,31 @@ import ( "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) -const MAX_USER_RECORDS = 20 +const MAX_USER_RECORDS = 65 type FormError struct { Errors []string } -func userCanFuckWithDNSRecord(user *database.User, record *database.DNSRecord) bool { - return user.ID == record.UserID && (record.Name == user.Username || strings.HasSuffix(record.Name, "."+user.Username)) +func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord) bool { + ownedByUser := (user.ID == record.UserID) + + if !record.Internal { + publicallyOwnedByUser := (record.Name == user.Username || strings.HasSuffix(record.Name, "."+user.Username)) + return ownedByUser && publicallyOwnedByUser + } + + owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name) + if err != nil { + log.Println(err) + return false + } + + userIsOwnerOfDomain := owner == user.ID + return ownedByUser && userIsOwnerOfDomain } func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { @@ -40,8 +56,15 @@ func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, res Errors: []string{}, } + internal := req.FormValue("internal") == "on" name := req.FormValue("name") + if internal && !strings.HasSuffix(name, ".") { + name += "." + } + recordType := req.FormValue("type") + recordType = strings.ToUpper(recordType) + recordContent := req.FormValue("content") ttl := req.FormValue("ttl") ttlNum, err := strconv.Atoi(ttl) @@ -50,11 +73,12 @@ func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, res } dnsRecord := &database.DNSRecord{ - UserID: context.User.ID, - Name: name, - Type: recordType, - Content: recordContent, - TTL: ttlNum, + UserID: context.User.ID, + Name: name, + Type: recordType, + Content: recordContent, + TTL: ttlNum, + Internal: internal, } dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID) @@ -67,18 +91,22 @@ func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, res formErrors.Errors = append(formErrors.Errors, "max records reached") } - if !userCanFuckWithDNSRecord(context.User, dnsRecord) { - formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username) + if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord) { + formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") } if len(formErrors.Errors) == 0 { - cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord) - if err != nil { - log.Println(err) - formErrors.Errors = append(formErrors.Errors, err.Error()) + if dnsRecord.Internal { + dnsRecord.ID = utils.RandomId() + } else { + cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord) + if err != nil { + log.Println(err) + formErrors.Errors = append(formErrors.Errors, err.Error()) + } + + dnsRecord.ID = cloudflareRecordId } - - dnsRecord.ID = cloudflareRecordId } if len(formErrors.Errors) == 0 { @@ -113,16 +141,18 @@ func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, res return failure(context, req, resp) } - if !userCanFuckWithDNSRecord(context.User, record) { + if !userCanFuckWithDNSRecord(context.DBConn, context.User, record) { resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } - err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) + if !record.Internal { + err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } } err = database.DeleteDNSRecord(context.DBConn, recordId) -- cgit v1.2.3-70-g09d2