summaryrefslogtreecommitdiff
path: root/api
diff options
context:
space:
mode:
Diffstat (limited to 'api')
-rw-r--r--api/dns/dns.go72
-rw-r--r--api/dns/dns_test.go74
-rw-r--r--api/template/template.go2
3 files changed, 122 insertions, 26 deletions
diff --git a/api/dns/dns.go b/api/dns/dns.go
index 7e9c7c7..c24fa4a 100644
--- a/api/dns/dns.go
+++ b/api/dns/dns.go
@@ -36,29 +36,48 @@ func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request
}
}
-func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+func CreateDNSRecordContinuation(externalDnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
formErrors := types.BannerMessages{
Messages: []string{},
}
- internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true"
- name := req.FormValue("name")
- if internal && !strings.HasSuffix(name, ".") {
- name += "."
+ dnsRecord := &database.DNSRecord{}
+ id := req.FormValue("id")
+ isNewRecord := id == ""
+ if !isNewRecord {
+ retrievedDnsRecord, err := database.GetDNSRecord(context.DBConn, id)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ formErrors.Messages = append(formErrors.Messages, "error getting record from id")
+ } else {
+ dnsRecord = retrievedDnsRecord
+ }
+ } else {
+ dnsRecord.UserID = context.User.ID
+ }
+
+ dnsRecord.Internal = req.FormValue("internal") == "on" || req.FormValue("internal") == "true"
+
+ dnsRecord.Name = req.FormValue("name")
+ if dnsRecord.Internal && !strings.HasSuffix(dnsRecord.Name, ".") {
+ dnsRecord.Name += "."
}
recordType := req.FormValue("type")
- recordType = strings.ToUpper(recordType)
+ dnsRecord.Type = strings.ToUpper(recordType)
+
+ dnsRecord.Content = req.FormValue("content")
- recordContent := req.FormValue("content")
ttl := req.FormValue("ttl")
ttlNum, err := strconv.Atoi(ttl)
if err != nil {
resp.WriteHeader(http.StatusBadRequest)
formErrors.Messages = append(formErrors.Messages, "invalid ttl")
}
+ dnsRecord.TTL = ttlNum
dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
if err != nil {
@@ -71,27 +90,29 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max
formErrors.Messages = append(formErrors.Messages, "max records reached")
}
- dnsRecord := &database.DNSRecord{
- UserID: context.User.ID,
- Name: name,
- Type: recordType,
- Content: recordContent,
- TTL: ttlNum,
- Internal: internal,
- }
-
- if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) {
+ if len(formErrors.Messages) == 0 && !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) {
resp.WriteHeader(http.StatusUnauthorized)
- formErrors.Messages = append(formErrors.Messages, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
+ formErrors.Messages = append(formErrors.Messages, "external 'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
}
- if len(formErrors.Messages) == 0 {
+ if isNewRecord && len(formErrors.Messages) == 0 {
if dnsRecord.Internal {
dnsRecord.ID = utils.RandomId()
} else {
- dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord)
+ dnsRecord.ID, err = externalDnsAdapter.CreateDNSRecord(dnsRecord)
+ if err != nil {
+ log.Println("error creating external dns record", err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ formErrors.Messages = append(formErrors.Messages, err.Error())
+ }
+ }
+ }
+
+ if !isNewRecord && len(formErrors.Messages) == 0 {
+ if !dnsRecord.Internal {
+ err = externalDnsAdapter.UpdateDNSRecord(dnsRecord)
if err != nil {
- log.Println(err)
+ log.Println("error updating external dns record", err)
resp.WriteHeader(http.StatusInternalServerError)
formErrors.Messages = append(formErrors.Messages, err.Error())
}
@@ -108,20 +129,21 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, max
if len(formErrors.Messages) == 0 {
formSuccess := types.BannerMessages{
- Messages: []string{"record added."},
+ Messages: []string{"record saved."},
}
(*context.TemplateData)["Success"] = formSuccess
return success(context, req, resp)
}
- (*context.TemplateData)[""] = &formErrors
+ log.Println(formErrors.Messages)
+ (*context.TemplateData)["Error"] = &formErrors
(*context.TemplateData)["RecordForm"] = dnsRecord
return failure(context, req, resp)
}
}
}
-func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+func DeleteDNSRecordContinuation(externalDnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
recordId := req.FormValue("id")
@@ -138,7 +160,7 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun
}
if !record.Internal {
- err = dnsAdapter.DeleteDNSRecord(recordId)
+ err = externalDnsAdapter.DeleteDNSRecord(recordId)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
diff --git a/api/dns/dns_test.go b/api/dns/dns_test.go
index 30baedf..c4c581b 100644
--- a/api/dns/dns_test.go
+++ b/api/dns/dns_test.go
@@ -57,6 +57,7 @@ func setup() (*sql.DB, *types.RequestContext, func()) {
type SignallingExternalDnsAdapter struct {
AddChannel chan *database.DNSRecord
RmChannel chan string
+ UpdateChan chan *database.DNSRecord
}
func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
@@ -72,6 +73,12 @@ func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error {
return nil
}
+func (adapter *SignallingExternalDnsAdapter) UpdateDNSRecord(record *database.DNSRecord) error {
+ go func() { adapter.UpdateChan <- record }()
+
+ return nil
+}
+
func TestThatOwnerCanPutRecordInDomain(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
@@ -172,6 +179,73 @@ func TestThatUserCanAddToPublicEndpoints(t *testing.T) {
}
}
+func TestThatUserCanUpdateExistingRecord(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ updateChannel := make(chan *database.DNSRecord)
+ signallingDnsAdapter := &SignallingExternalDnsAdapter{
+ UpdateChan: updateChannel,
+ }
+
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer testServer.Close()
+
+ responseRecorder := httptest.NewRecorder()
+ nonexistantRecord := httptest.NewRequest("POST", testServer.URL, nil)
+
+ id := "1"
+ name := "test." + context.User.Username
+ nonexistantRecord.Form = map[string][]string{
+ "id": {id},
+ "internal": {"off"},
+ "name": {name},
+ "type": {"CNAME"},
+ "ttl": {"43000"},
+ "content": {"new.domain."},
+ }
+
+ testServer.Config.Handler.ServeHTTP(responseRecorder, nonexistantRecord)
+ if responseRecorder.Code != http.StatusInternalServerError {
+ t.Errorf("expected internal server error return, got %d", responseRecorder.Code)
+ }
+
+ record := &database.DNSRecord{
+ ID: id,
+ Internal: false,
+ Name: name,
+ Type: "CNAME",
+ Content: "test.domain.",
+ TTL: 43000,
+ UserID: context.User.ID,
+ }
+ _, err := database.SaveDNSRecord(db, record)
+ if err != nil {
+ t.Error(err)
+ }
+
+ existantRecord := nonexistantRecord
+ existantRecordRecorder := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(existantRecordRecorder, existantRecord)
+ if existantRecordRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", existantRecordRecorder.Code)
+ }
+ select {
+ case req := <-updateChannel:
+ newRecord, err := database.GetDNSRecord(db, req.ID)
+ if err != nil {
+ t.Error(err)
+ }
+ if newRecord.Content != "new.domain." {
+ t.Errorf("expected updated record, got %s", newRecord.Content)
+ }
+ case <-time.After(100 * time.Millisecond):
+ t.Errorf("expected updated record channel")
+ }
+}
+
func TestThatExternalDnsSaves(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
diff --git a/api/template/template.go b/api/template/template.go
index 2875649..ad6a573 100644
--- a/api/template/template.go
+++ b/api/template/template.go
@@ -46,7 +46,7 @@ func renderTemplate(context *types.RequestContext, templateName string, showBase
func TemplateContinuation(path string, showBase bool) types.Continuation {
return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
- html, err := renderTemplate(context, path, true)
+ html, err := renderTemplate(context, path, showBase)
if errors.Is(err, os.ErrNotExist) {
resp.WriteHeader(404)
html, err = renderTemplate(context, "404.html", true)