From dbd548d428f222babb4e1d6a182b90f19b192e1f Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Mon, 24 Jun 2024 00:18:28 -0700 Subject: POST record with id to update to fix cloudflare 500 --- api/dns/dns.go | 72 ++++++++++++++++++++++++++++++---------------- api/dns/dns_test.go | 74 ++++++++++++++++++++++++++++++++++++++++++++++++ api/template/template.go | 2 +- 3 files changed, 122 insertions(+), 26 deletions(-) (limited to 'api') diff --git a/api/dns/dns.go b/api/dns/dns.go index 7e9c7c7..c24fa4a 100644 --- a/api/dns/dns.go +++ b/api/dns/dns.go @@ -36,29 +36,48 @@ func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request } } -func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { +func CreateDNSRecordContinuation(externalDnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { formErrors := types.BannerMessages{ Messages: []string{}, } - internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true" - name := req.FormValue("name") - if internal && !strings.HasSuffix(name, ".") { - name += "." + dnsRecord := &database.DNSRecord{} + id := req.FormValue("id") + isNewRecord := id == "" + if !isNewRecord { + retrievedDnsRecord, err := database.GetDNSRecord(context.DBConn, id) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + formErrors.Messages = append(formErrors.Messages, "error getting record from id") + } else { + dnsRecord = retrievedDnsRecord + } + } else { + dnsRecord.UserID = context.User.ID + } + + dnsRecord.Internal = req.FormValue("internal") == "on" || req.FormValue("internal") == "true" + + dnsRecord.Name = req.FormValue("name") + if dnsRecord.Internal && !strings.HasSuffix(dnsRecord.Name, ".") { + dnsRecord.Name += "." } recordType := req.FormValue("type") - recordType = strings.ToUpper(recordType) + dnsRecord.Type = strings.ToUpper(recordType) + + dnsRecord.Content = req.FormValue("content") - recordContent := req.FormValue("content") ttl := req.FormValue("ttl") ttlNum, err := strconv.Atoi(ttl) if err != nil { resp.WriteHeader(http.StatusBadRequest) formErrors.Messages = append(formErrors.Messages, "invalid ttl") } + dnsRecord.TTL = ttlNum dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) if err != nil { @@ -71,27 +90,29 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max formErrors.Messages = append(formErrors.Messages, "max records reached") } - dnsRecord := &database.DNSRecord{ - UserID: context.User.ID, - Name: name, - Type: recordType, - Content: recordContent, - TTL: ttlNum, - Internal: internal, - } - - if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) { + if len(formErrors.Messages) == 0 && !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) { resp.WriteHeader(http.StatusUnauthorized) - formErrors.Messages = append(formErrors.Messages, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") + formErrors.Messages = append(formErrors.Messages, "external 'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") } - if len(formErrors.Messages) == 0 { + if isNewRecord && len(formErrors.Messages) == 0 { if dnsRecord.Internal { dnsRecord.ID = utils.RandomId() } else { - dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord) + dnsRecord.ID, err = externalDnsAdapter.CreateDNSRecord(dnsRecord) + if err != nil { + log.Println("error creating external dns record", err) + resp.WriteHeader(http.StatusInternalServerError) + formErrors.Messages = append(formErrors.Messages, err.Error()) + } + } + } + + if !isNewRecord && len(formErrors.Messages) == 0 { + if !dnsRecord.Internal { + err = externalDnsAdapter.UpdateDNSRecord(dnsRecord) if err != nil { - log.Println(err) + log.Println("error updating external dns record", err) resp.WriteHeader(http.StatusInternalServerError) formErrors.Messages = append(formErrors.Messages, err.Error()) } @@ -108,20 +129,21 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max if len(formErrors.Messages) == 0 { formSuccess := types.BannerMessages{ - Messages: []string{"record added."}, + Messages: []string{"record saved."}, } (*context.TemplateData)["Success"] = formSuccess return success(context, req, resp) } - (*context.TemplateData)[""] = &formErrors + log.Println(formErrors.Messages) + (*context.TemplateData)["Error"] = &formErrors (*context.TemplateData)["RecordForm"] = dnsRecord return failure(context, req, resp) } } } -func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { +func DeleteDNSRecordContinuation(externalDnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { recordId := req.FormValue("id") @@ -138,7 +160,7 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun } if !record.Internal { - err = dnsAdapter.DeleteDNSRecord(recordId) + err = externalDnsAdapter.DeleteDNSRecord(recordId) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) diff --git a/api/dns/dns_test.go b/api/dns/dns_test.go index 30baedf..c4c581b 100644 --- a/api/dns/dns_test.go +++ b/api/dns/dns_test.go @@ -57,6 +57,7 @@ func setup() (*sql.DB, *types.RequestContext, func()) { type SignallingExternalDnsAdapter struct { AddChannel chan *database.DNSRecord RmChannel chan string + UpdateChan chan *database.DNSRecord } func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) { @@ -72,6 +73,12 @@ func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error { return nil } +func (adapter *SignallingExternalDnsAdapter) UpdateDNSRecord(record *database.DNSRecord) error { + go func() { adapter.UpdateChan <- record }() + + return nil +} + func TestThatOwnerCanPutRecordInDomain(t *testing.T) { db, context, cleanup := setup() defer cleanup() @@ -172,6 +179,73 @@ func TestThatUserCanAddToPublicEndpoints(t *testing.T) { } } +func TestThatUserCanUpdateExistingRecord(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + updateChannel := make(chan *database.DNSRecord) + signallingDnsAdapter := &SignallingExternalDnsAdapter{ + UpdateChan: updateChannel, + } + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + responseRecorder := httptest.NewRecorder() + nonexistantRecord := httptest.NewRequest("POST", testServer.URL, nil) + + id := "1" + name := "test." + context.User.Username + nonexistantRecord.Form = map[string][]string{ + "id": {id}, + "internal": {"off"}, + "name": {name}, + "type": {"CNAME"}, + "ttl": {"43000"}, + "content": {"new.domain."}, + } + + testServer.Config.Handler.ServeHTTP(responseRecorder, nonexistantRecord) + if responseRecorder.Code != http.StatusInternalServerError { + t.Errorf("expected internal server error return, got %d", responseRecorder.Code) + } + + record := &database.DNSRecord{ + ID: id, + Internal: false, + Name: name, + Type: "CNAME", + Content: "test.domain.", + TTL: 43000, + UserID: context.User.ID, + } + _, err := database.SaveDNSRecord(db, record) + if err != nil { + t.Error(err) + } + + existantRecord := nonexistantRecord + existantRecordRecorder := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(existantRecordRecorder, existantRecord) + if existantRecordRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", existantRecordRecorder.Code) + } + select { + case req := <-updateChannel: + newRecord, err := database.GetDNSRecord(db, req.ID) + if err != nil { + t.Error(err) + } + if newRecord.Content != "new.domain." { + t.Errorf("expected updated record, got %s", newRecord.Content) + } + case <-time.After(100 * time.Millisecond): + t.Errorf("expected updated record channel") + } +} + func TestThatExternalDnsSaves(t *testing.T) { db, context, cleanup := setup() defer cleanup() diff --git a/api/template/template.go b/api/template/template.go index 2875649..ad6a573 100644 --- a/api/template/template.go +++ b/api/template/template.go @@ -46,7 +46,7 @@ func renderTemplate(context *types.RequestContext, templateName string, showBase func TemplateContinuation(path string, showBase bool) types.Continuation { return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { - html, err := renderTemplate(context, path, true) + html, err := renderTemplate(context, path, showBase) if errors.Is(err, os.ErrNotExist) { resp.WriteHeader(404) html, err = renderTemplate(context, "404.html", true) -- cgit v1.2.3-70-g09d2