From dbd548d428f222babb4e1d6a182b90f19b192e1f Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Mon, 24 Jun 2024 00:18:28 -0700 Subject: POST record with id to update to fix cloudflare 500 --- api/dns/dns.go | 72 ++++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 47 insertions(+), 25 deletions(-) (limited to 'api/dns/dns.go') diff --git a/api/dns/dns.go b/api/dns/dns.go index 7e9c7c7..c24fa4a 100644 --- a/api/dns/dns.go +++ b/api/dns/dns.go @@ -36,29 +36,48 @@ func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request } } -func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { +func CreateDNSRecordContinuation(externalDnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { formErrors := types.BannerMessages{ Messages: []string{}, } - internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true" - name := req.FormValue("name") - if internal && !strings.HasSuffix(name, ".") { - name += "." + dnsRecord := &database.DNSRecord{} + id := req.FormValue("id") + isNewRecord := id == "" + if !isNewRecord { + retrievedDnsRecord, err := database.GetDNSRecord(context.DBConn, id) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + formErrors.Messages = append(formErrors.Messages, "error getting record from id") + } else { + dnsRecord = retrievedDnsRecord + } + } else { + dnsRecord.UserID = context.User.ID + } + + dnsRecord.Internal = req.FormValue("internal") == "on" || req.FormValue("internal") == "true" + + dnsRecord.Name = req.FormValue("name") + if dnsRecord.Internal && !strings.HasSuffix(dnsRecord.Name, ".") { + dnsRecord.Name += "." } recordType := req.FormValue("type") - recordType = strings.ToUpper(recordType) + dnsRecord.Type = strings.ToUpper(recordType) + + dnsRecord.Content = req.FormValue("content") - recordContent := req.FormValue("content") ttl := req.FormValue("ttl") ttlNum, err := strconv.Atoi(ttl) if err != nil { resp.WriteHeader(http.StatusBadRequest) formErrors.Messages = append(formErrors.Messages, "invalid ttl") } + dnsRecord.TTL = ttlNum dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) if err != nil { @@ -71,27 +90,29 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max formErrors.Messages = append(formErrors.Messages, "max records reached") } - dnsRecord := &database.DNSRecord{ - UserID: context.User.ID, - Name: name, - Type: recordType, - Content: recordContent, - TTL: ttlNum, - Internal: internal, - } - - if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) { + if len(formErrors.Messages) == 0 && !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) { resp.WriteHeader(http.StatusUnauthorized) - formErrors.Messages = append(formErrors.Messages, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") + formErrors.Messages = append(formErrors.Messages, "external 'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") } - if len(formErrors.Messages) == 0 { + if isNewRecord && len(formErrors.Messages) == 0 { if dnsRecord.Internal { dnsRecord.ID = utils.RandomId() } else { - dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord) + dnsRecord.ID, err = externalDnsAdapter.CreateDNSRecord(dnsRecord) + if err != nil { + log.Println("error creating external dns record", err) + resp.WriteHeader(http.StatusInternalServerError) + formErrors.Messages = append(formErrors.Messages, err.Error()) + } + } + } + + if !isNewRecord && len(formErrors.Messages) == 0 { + if !dnsRecord.Internal { + err = externalDnsAdapter.UpdateDNSRecord(dnsRecord) if err != nil { - log.Println(err) + log.Println("error updating external dns record", err) resp.WriteHeader(http.StatusInternalServerError) formErrors.Messages = append(formErrors.Messages, err.Error()) } @@ -108,20 +129,21 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max if len(formErrors.Messages) == 0 { formSuccess := types.BannerMessages{ - Messages: []string{"record added."}, + Messages: []string{"record saved."}, } (*context.TemplateData)["Success"] = formSuccess return success(context, req, resp) } - (*context.TemplateData)[""] = &formErrors + log.Println(formErrors.Messages) + (*context.TemplateData)["Error"] = &formErrors (*context.TemplateData)["RecordForm"] = dnsRecord return failure(context, req, resp) } } } -func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { +func DeleteDNSRecordContinuation(externalDnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { recordId := req.FormValue("id") @@ -138,7 +160,7 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun } if !record.Internal { - err = dnsAdapter.DeleteDNSRecord(recordId) + err = externalDnsAdapter.DeleteDNSRecord(recordId) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) -- cgit v1.2.3-70-g09d2