diff options
Diffstat (limited to 'api/dns/dns.go')
| -rw-r--r-- | api/dns/dns.go | 22 |
1 files changed, 9 insertions, 13 deletions
diff --git a/api/dns/dns.go b/api/dns/dns.go index 4805146..aa2f356 100644 --- a/api/dns/dns.go +++ b/api/dns/dns.go @@ -14,10 +14,6 @@ import ( "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) -const MAX_USER_RECORDS = 65 - -var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} - func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool { ownedByUser := (user.ID == record.UserID) if !ownedByUser { @@ -60,14 +56,14 @@ func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request } } -func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { +func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { formErrors := types.FormError{ Errors: []string{}, } - internal := req.FormValue("internal") == "on" + internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true" name := req.FormValue("name") if internal && !strings.HasSuffix(name, ".") { name += "." @@ -80,6 +76,7 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun ttl := req.FormValue("ttl") ttlNum, err := strconv.Atoi(ttl) if err != nil { + resp.WriteHeader(http.StatusBadRequest) formErrors.Errors = append(formErrors.Errors, "invalid ttl") } @@ -89,7 +86,8 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } - if dnsRecordCount >= MAX_USER_RECORDS { + if dnsRecordCount >= maxUserRecords { + resp.WriteHeader(http.StatusTooManyRequests) formErrors.Errors = append(formErrors.Errors, "max records reached") } @@ -102,7 +100,8 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun Internal: internal, } - if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, USER_OWNED_INTERNAL_FMT_DOMAINS) { + if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) { + resp.WriteHeader(http.StatusUnauthorized) formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") } @@ -113,6 +112,7 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord) if err != nil { log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) formErrors.Errors = append(formErrors.Errors, err.Error()) } } @@ -127,14 +127,11 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun } if len(formErrors.Errors) == 0 { - http.Redirect(resp, req, "/dns", http.StatusFound) return success(context, req, resp) } (*context.TemplateData)["FormError"] = &formErrors (*context.TemplateData)["RecordForm"] = dnsRecord - - resp.WriteHeader(http.StatusBadRequest) return failure(context, req, resp) } } @@ -151,7 +148,7 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun return failure(context, req, resp) } - if !userCanFuckWithDNSRecord(context.DBConn, context.User, record, USER_OWNED_INTERNAL_FMT_DOMAINS) { + if !(record.UserID == context.User.ID) { resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } @@ -171,7 +168,6 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun return failure(context, req, resp) } - http.Redirect(resp, req, "/dns", http.StatusFound) return success(context, req, resp) } } |
