summaryrefslogtreecommitdiff
path: root/api/dns/dns.go
diff options
context:
space:
mode:
Diffstat (limited to 'api/dns/dns.go')
-rw-r--r--api/dns/dns.go72
1 files changed, 47 insertions, 25 deletions
diff --git a/api/dns/dns.go b/api/dns/dns.go
index 7e9c7c7..c24fa4a 100644
--- a/api/dns/dns.go
+++ b/api/dns/dns.go
@@ -36,29 +36,48 @@ func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request
}
}
-func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+func CreateDNSRecordContinuation(externalDnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
formErrors := types.BannerMessages{
Messages: []string{},
}
- internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true"
- name := req.FormValue("name")
- if internal && !strings.HasSuffix(name, ".") {
- name += "."
+ dnsRecord := &database.DNSRecord{}
+ id := req.FormValue("id")
+ isNewRecord := id == ""
+ if !isNewRecord {
+ retrievedDnsRecord, err := database.GetDNSRecord(context.DBConn, id)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ formErrors.Messages = append(formErrors.Messages, "error getting record from id")
+ } else {
+ dnsRecord = retrievedDnsRecord
+ }
+ } else {
+ dnsRecord.UserID = context.User.ID
+ }
+
+ dnsRecord.Internal = req.FormValue("internal") == "on" || req.FormValue("internal") == "true"
+
+ dnsRecord.Name = req.FormValue("name")
+ if dnsRecord.Internal && !strings.HasSuffix(dnsRecord.Name, ".") {
+ dnsRecord.Name += "."
}
recordType := req.FormValue("type")
- recordType = strings.ToUpper(recordType)
+ dnsRecord.Type = strings.ToUpper(recordType)
+
+ dnsRecord.Content = req.FormValue("content")
- recordContent := req.FormValue("content")
ttl := req.FormValue("ttl")
ttlNum, err := strconv.Atoi(ttl)
if err != nil {
resp.WriteHeader(http.StatusBadRequest)
formErrors.Messages = append(formErrors.Messages, "invalid ttl")
}
+ dnsRecord.TTL = ttlNum
dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
if err != nil {
@@ -71,27 +90,29 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max
formErrors.Messages = append(formErrors.Messages, "max records reached")
}
- dnsRecord := &database.DNSRecord{
- UserID: context.User.ID,
- Name: name,
- Type: recordType,
- Content: recordContent,
- TTL: ttlNum,
- Internal: internal,
- }
-
- if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) {
+ if len(formErrors.Messages) == 0 && !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) {
resp.WriteHeader(http.StatusUnauthorized)
- formErrors.Messages = append(formErrors.Messages, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
+ formErrors.Messages = append(formErrors.Messages, "external 'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
}
- if len(formErrors.Messages) == 0 {
+ if isNewRecord && len(formErrors.Messages) == 0 {
if dnsRecord.Internal {
dnsRecord.ID = utils.RandomId()
} else {
- dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord)
+ dnsRecord.ID, err = externalDnsAdapter.CreateDNSRecord(dnsRecord)
+ if err != nil {
+ log.Println("error creating external dns record", err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ formErrors.Messages = append(formErrors.Messages, err.Error())
+ }
+ }
+ }
+
+ if !isNewRecord && len(formErrors.Messages) == 0 {
+ if !dnsRecord.Internal {
+ err = externalDnsAdapter.UpdateDNSRecord(dnsRecord)
if err != nil {
- log.Println(err)
+ log.Println("error updating external dns record", err)
resp.WriteHeader(http.StatusInternalServerError)
formErrors.Messages = append(formErrors.Messages, err.Error())
}
@@ -108,20 +129,21 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max
if len(formErrors.Messages) == 0 {
formSuccess := types.BannerMessages{
- Messages: []string{"record added."},
+ Messages: []string{"record saved."},
}
(*context.TemplateData)["Success"] = formSuccess
return success(context, req, resp)
}
- (*context.TemplateData)[""] = &formErrors
+ log.Println(formErrors.Messages)
+ (*context.TemplateData)["Error"] = &formErrors
(*context.TemplateData)["RecordForm"] = dnsRecord
return failure(context, req, resp)
}
}
}
-func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+func DeleteDNSRecordContinuation(externalDnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
recordId := req.FormValue("id")
@@ -138,7 +160,7 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun
}
if !record.Internal {
- err = dnsAdapter.DeleteDNSRecord(recordId)
+ err = externalDnsAdapter.DeleteDNSRecord(recordId)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)