summaryrefslogtreecommitdiff
path: root/api/dns
diff options
context:
space:
mode:
Diffstat (limited to 'api/dns')
-rw-r--r--api/dns/dns.go88
-rw-r--r--api/dns/dns_test.go4
2 files changed, 52 insertions, 40 deletions
diff --git a/api/dns/dns.go b/api/dns/dns.go
index aa2f356..6357dfc 100644
--- a/api/dns/dns.go
+++ b/api/dns/dns.go
@@ -8,39 +8,15 @@ import (
"strconv"
"strings"
- "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/external_dns"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
)
-func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool {
- ownedByUser := (user.ID == record.UserID)
- if !ownedByUser {
- return false
- }
-
- if !record.Internal {
- for _, format := range ownedInternalDomainFormats {
- domain := fmt.Sprintf(format, user.Username)
-
- isInSubDomain := strings.HasSuffix(record.Name, "."+domain)
- if domain == record.Name || isInSubDomain {
- return true
- }
- }
- return false
- }
-
- owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name)
- if err != nil {
- log.Println(err)
- return false
- }
+const MaxUserRecords = 100
- userIsOwnerOfDomain := owner == user.ID
- return ownedByUser && userIsOwnerOfDomain
-}
+var UserOwnedInternalFmtDomains = []string{"%s", "%s.endpoints"}
func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
@@ -59,8 +35,8 @@ func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request
func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
- formErrors := types.FormError{
- Errors: []string{},
+ formErrors := types.BannerMessages{
+ Messages: []string{},
}
internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true"
@@ -77,7 +53,7 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max
ttlNum, err := strconv.Atoi(ttl)
if err != nil {
resp.WriteHeader(http.StatusBadRequest)
- formErrors.Errors = append(formErrors.Errors, "invalid ttl")
+ formErrors.Messages = append(formErrors.Messages, "invalid ttl")
}
dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
@@ -88,7 +64,7 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max
}
if dnsRecordCount >= maxUserRecords {
resp.WriteHeader(http.StatusTooManyRequests)
- formErrors.Errors = append(formErrors.Errors, "max records reached")
+ formErrors.Messages = append(formErrors.Messages, "max records reached")
}
dnsRecord := &database.DNSRecord{
@@ -102,10 +78,10 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max
if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) {
resp.WriteHeader(http.StatusUnauthorized)
- formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
+ formErrors.Messages = append(formErrors.Messages, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
}
- if len(formErrors.Errors) == 0 {
+ if len(formErrors.Messages) == 0 {
if dnsRecord.Internal {
dnsRecord.ID = utils.RandomId()
} else {
@@ -113,24 +89,28 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
- formErrors.Errors = append(formErrors.Errors, err.Error())
+ formErrors.Messages = append(formErrors.Messages, err.Error())
}
}
}
- if len(formErrors.Errors) == 0 {
+ if len(formErrors.Messages) == 0 {
_, err := database.SaveDNSRecord(context.DBConn, dnsRecord)
if err != nil {
log.Println(err)
- formErrors.Errors = append(formErrors.Errors, "error saving record")
+ formErrors.Messages = append(formErrors.Messages, "error saving record")
}
}
- if len(formErrors.Errors) == 0 {
+ if len(formErrors.Messages) == 0 {
+ formSuccess := types.BannerMessages{
+ Messages: []string{"record added."},
+ }
+ (*context.TemplateData)["Success"] = formSuccess
return success(context, req, resp)
}
- (*context.TemplateData)["FormError"] = &formErrors
+ (*context.TemplateData)[""] = &formErrors
(*context.TemplateData)["RecordForm"] = dnsRecord
return failure(context, req, resp)
}
@@ -168,7 +148,39 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun
return failure(context, req, resp)
}
+ formSuccess := types.BannerMessages{
+ Messages: []string{"record deleted."},
+ }
+ (*context.TemplateData)["Success"] = formSuccess
return success(context, req, resp)
}
}
}
+
+func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool {
+ ownedByUser := (user.ID == record.UserID)
+ if !ownedByUser {
+ return false
+ }
+
+ if !record.Internal {
+ for _, format := range ownedInternalDomainFormats {
+ domain := fmt.Sprintf(format, user.Username)
+
+ isInSubDomain := strings.HasSuffix(record.Name, "."+domain)
+ if domain == record.Name || isInSubDomain {
+ return true
+ }
+ }
+ return false
+ }
+
+ owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name)
+ if err != nil {
+ log.Println(err)
+ return false
+ }
+
+ userIsOwnerOfDomain := owner == user.ID
+ return ownedByUser && userIsOwnerOfDomain
+}
diff --git a/api/dns/dns_test.go b/api/dns/dns_test.go
index 43dc680..30baedf 100644
--- a/api/dns/dns_test.go
+++ b/api/dns/dns_test.go
@@ -39,7 +39,7 @@ func setup() (*sql.DB, *types.RequestContext, func()) {
Mail: "test@test.com",
DisplayName: "test",
}
- database.FindOrSaveUser(testDb, user)
+ database.FindOrSaveBaseUser(testDb, user)
context := &types.RequestContext{
DBConn: testDb,
@@ -246,7 +246,7 @@ func TestThatUserMustOwnRecordToRemove(t *testing.T) {
defer testServer.Close()
nonOwnerUser := &database.User{ID: "n/a", Username: "testuser"}
- _, err := database.FindOrSaveUser(db, nonOwnerUser)
+ _, err := database.FindOrSaveBaseUser(db, nonOwnerUser)
if err != nil {
t.Error(err)
}