diff options
Diffstat (limited to 'api/dns')
| -rw-r--r-- | api/dns/dns.go | 88 | ||||
| -rw-r--r-- | api/dns/dns_test.go | 4 |
2 files changed, 52 insertions, 40 deletions
diff --git a/api/dns/dns.go b/api/dns/dns.go index aa2f356..6357dfc 100644 --- a/api/dns/dns.go +++ b/api/dns/dns.go @@ -8,39 +8,15 @@ import ( "strconv" "strings" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/external_dns" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) -func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool { - ownedByUser := (user.ID == record.UserID) - if !ownedByUser { - return false - } - - if !record.Internal { - for _, format := range ownedInternalDomainFormats { - domain := fmt.Sprintf(format, user.Username) - - isInSubDomain := strings.HasSuffix(record.Name, "."+domain) - if domain == record.Name || isInSubDomain { - return true - } - } - return false - } - - owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name) - if err != nil { - log.Println(err) - return false - } +const MaxUserRecords = 100 - userIsOwnerOfDomain := owner == user.ID - return ownedByUser && userIsOwnerOfDomain -} +var UserOwnedInternalFmtDomains = []string{"%s", "%s.endpoints"} func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { @@ -59,8 +35,8 @@ func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { - formErrors := types.FormError{ - Errors: []string{}, + formErrors := types.BannerMessages{ + Messages: []string{}, } internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true" @@ -77,7 +53,7 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max ttlNum, err := strconv.Atoi(ttl) if err != nil { resp.WriteHeader(http.StatusBadRequest) - formErrors.Errors = append(formErrors.Errors, "invalid ttl") + formErrors.Messages = append(formErrors.Messages, "invalid ttl") } dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) @@ -88,7 +64,7 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max } if dnsRecordCount >= maxUserRecords { resp.WriteHeader(http.StatusTooManyRequests) - formErrors.Errors = append(formErrors.Errors, "max records reached") + formErrors.Messages = append(formErrors.Messages, "max records reached") } dnsRecord := &database.DNSRecord{ @@ -102,10 +78,10 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) { resp.WriteHeader(http.StatusUnauthorized) - formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") + formErrors.Messages = append(formErrors.Messages, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") } - if len(formErrors.Errors) == 0 { + if len(formErrors.Messages) == 0 { if dnsRecord.Internal { dnsRecord.ID = utils.RandomId() } else { @@ -113,24 +89,28 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) - formErrors.Errors = append(formErrors.Errors, err.Error()) + formErrors.Messages = append(formErrors.Messages, err.Error()) } } } - if len(formErrors.Errors) == 0 { + if len(formErrors.Messages) == 0 { _, err := database.SaveDNSRecord(context.DBConn, dnsRecord) if err != nil { log.Println(err) - formErrors.Errors = append(formErrors.Errors, "error saving record") + formErrors.Messages = append(formErrors.Messages, "error saving record") } } - if len(formErrors.Errors) == 0 { + if len(formErrors.Messages) == 0 { + formSuccess := types.BannerMessages{ + Messages: []string{"record added."}, + } + (*context.TemplateData)["Success"] = formSuccess return success(context, req, resp) } - (*context.TemplateData)["FormError"] = &formErrors + (*context.TemplateData)[""] = &formErrors (*context.TemplateData)["RecordForm"] = dnsRecord return failure(context, req, resp) } @@ -168,7 +148,39 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun return failure(context, req, resp) } + formSuccess := types.BannerMessages{ + Messages: []string{"record deleted."}, + } + (*context.TemplateData)["Success"] = formSuccess return success(context, req, resp) } } } + +func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool { + ownedByUser := (user.ID == record.UserID) + if !ownedByUser { + return false + } + + if !record.Internal { + for _, format := range ownedInternalDomainFormats { + domain := fmt.Sprintf(format, user.Username) + + isInSubDomain := strings.HasSuffix(record.Name, "."+domain) + if domain == record.Name || isInSubDomain { + return true + } + } + return false + } + + owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name) + if err != nil { + log.Println(err) + return false + } + + userIsOwnerOfDomain := owner == user.ID + return ownedByUser && userIsOwnerOfDomain +} diff --git a/api/dns/dns_test.go b/api/dns/dns_test.go index 43dc680..30baedf 100644 --- a/api/dns/dns_test.go +++ b/api/dns/dns_test.go @@ -39,7 +39,7 @@ func setup() (*sql.DB, *types.RequestContext, func()) { Mail: "test@test.com", DisplayName: "test", } - database.FindOrSaveUser(testDb, user) + database.FindOrSaveBaseUser(testDb, user) context := &types.RequestContext{ DBConn: testDb, @@ -246,7 +246,7 @@ func TestThatUserMustOwnRecordToRemove(t *testing.T) { defer testServer.Close() nonOwnerUser := &database.User{ID: "n/a", Username: "testuser"} - _, err := database.FindOrSaveUser(db, nonOwnerUser) + _, err := database.FindOrSaveBaseUser(db, nonOwnerUser) if err != nil { t.Error(err) } |
