diff options
| author | simponic <simponic@hatecomputers.club> | 2024-04-06 15:43:18 -0400 |
|---|---|---|
| committer | simponic <simponic@hatecomputers.club> | 2024-04-06 15:43:18 -0400 |
| commit | 83cc6267fd5ce2f61200314424c5f400f65ff2ba (patch) | |
| tree | eafb35310236a15572cbb6e16ff8d6f181bfe240 /api/dns/dns.go | |
| parent | 569d2788ebfb90774faf361f62bfe7968e091465 (diff) | |
| parent | cad8e2c4ed5e3bab61ff243f8677f8a46eaeafb0 (diff) | |
| download | hatecomputers.club-83cc6267fd5ce2f61200314424c5f400f65ff2ba.tar.gz hatecomputers.club-83cc6267fd5ce2f61200314424c5f400f65ff2ba.zip | |
Merge pull request 'testing | dont be recursive for external domains | finalize oauth' (#5) from dont-be-authoritative into main
Reviewed-on: https://git.hatecomputers.club/hatecomputers/hatecomputers.club/pulls/5
Diffstat (limited to 'api/dns/dns.go')
| -rw-r--r-- | api/dns/dns.go | 174 |
1 files changed, 174 insertions, 0 deletions
diff --git a/api/dns/dns.go b/api/dns/dns.go new file mode 100644 index 0000000..aa2f356 --- /dev/null +++ b/api/dns/dns.go @@ -0,0 +1,174 @@ +package dns + +import ( + "database/sql" + "fmt" + "log" + "net/http" + "strconv" + "strings" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool { + ownedByUser := (user.ID == record.UserID) + if !ownedByUser { + return false + } + + if !record.Internal { + for _, format := range ownedInternalDomainFormats { + domain := fmt.Sprintf(format, user.Username) + + isInSubDomain := strings.HasSuffix(record.Name, "."+domain) + if domain == record.Name || isInSubDomain { + return true + } + } + return false + } + + owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name) + if err != nil { + log.Println(err) + return false + } + + userIsOwnerOfDomain := owner == user.ID + return ownedByUser && userIsOwnerOfDomain +} + +func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + (*context.TemplateData)["DNSRecords"] = dnsRecords + return success(context, req, resp) + } +} + +func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + formErrors := types.FormError{ + Errors: []string{}, + } + + internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true" + name := req.FormValue("name") + if internal && !strings.HasSuffix(name, ".") { + name += "." + } + + recordType := req.FormValue("type") + recordType = strings.ToUpper(recordType) + + recordContent := req.FormValue("content") + ttl := req.FormValue("ttl") + ttlNum, err := strconv.Atoi(ttl) + if err != nil { + resp.WriteHeader(http.StatusBadRequest) + formErrors.Errors = append(formErrors.Errors, "invalid ttl") + } + + dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + if dnsRecordCount >= maxUserRecords { + resp.WriteHeader(http.StatusTooManyRequests) + formErrors.Errors = append(formErrors.Errors, "max records reached") + } + + dnsRecord := &database.DNSRecord{ + UserID: context.User.ID, + Name: name, + Type: recordType, + Content: recordContent, + TTL: ttlNum, + Internal: internal, + } + + if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) { + resp.WriteHeader(http.StatusUnauthorized) + formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") + } + + if len(formErrors.Errors) == 0 { + if dnsRecord.Internal { + dnsRecord.ID = utils.RandomId() + } else { + dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + formErrors.Errors = append(formErrors.Errors, err.Error()) + } + } + } + + if len(formErrors.Errors) == 0 { + _, err := database.SaveDNSRecord(context.DBConn, dnsRecord) + if err != nil { + log.Println(err) + formErrors.Errors = append(formErrors.Errors, "error saving record") + } + } + + if len(formErrors.Errors) == 0 { + return success(context, req, resp) + } + + (*context.TemplateData)["FormError"] = &formErrors + (*context.TemplateData)["RecordForm"] = dnsRecord + return failure(context, req, resp) + } + } +} + +func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + recordId := req.FormValue("id") + record, err := database.GetDNSRecord(context.DBConn, recordId) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + if !(record.UserID == context.User.ID) { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + if !record.Internal { + err = dnsAdapter.DeleteDNSRecord(recordId) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + } + + err = database.DeleteDNSRecord(context.DBConn, recordId) + if err != nil { + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + return success(context, req, resp) + } + } +} |
