diff options
| author | simponic <simponic@hatecomputers.club> | 2024-03-28 12:57:35 -0400 |
|---|---|---|
| committer | simponic <simponic@hatecomputers.club> | 2024-03-28 12:57:35 -0400 |
| commit | b2fc689bdcff28bf75c0128db19ba4730d726b4f (patch) | |
| tree | 37c16d95183242516ba667aa5f441539d152c279 /api | |
| parent | 75ba836d6072235fc7a71659f8630ab3c1b210ad (diff) | |
| download | hatecomputers.club-b2fc689bdcff28bf75c0128db19ba4730d726b4f.tar.gz hatecomputers.club-b2fc689bdcff28bf75c0128db19ba4730d726b4f.zip | |
dns api (#1)
Co-authored-by: Elizabeth Hunt <elizabeth.hunt@simponic.xyz>
Reviewed-on: https://git.hatecomputers.club/hatecomputers/hatecomputers.club/pulls/1
Diffstat (limited to 'api')
| -rw-r--r-- | api/api_keys.go | 84 | ||||
| -rw-r--r-- | api/auth.go | 74 | ||||
| -rw-r--r-- | api/dns.go | 114 | ||||
| -rw-r--r-- | api/serve.go | 29 |
4 files changed, 281 insertions, 20 deletions
diff --git a/api/api_keys.go b/api/api_keys.go new file mode 100644 index 0000000..17ed6c9 --- /dev/null +++ b/api/api_keys.go @@ -0,0 +1,84 @@ +package api + +import ( + "log" + "net/http" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +const MAX_USER_API_KEYS = 5 + +func ListAPIKeysContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + apiKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + (*context.TemplateData)["APIKeys"] = apiKeys + return success(context, req, resp) + } +} + +func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + formErrors := FormError{ + Errors: []string{}, + } + + apiKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + if len(apiKeys) >= MAX_USER_API_KEYS { + formErrors.Errors = append(formErrors.Errors, "max api keys reached") + } + + _, err = database.SaveAPIKey(context.DBConn, &database.UserApiKey{ + UserID: context.User.ID, + Key: utils.RandomId(), + }) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + http.Redirect(resp, req, "/keys", http.StatusFound) + return success(context, req, resp) + } +} + +func DeleteAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + key := req.FormValue("key") + + apiKey, err := database.GetAPIKey(context.DBConn, key) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + if (apiKey == nil) || (apiKey.UserID != context.User.ID) { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + err = database.DeleteAPIKey(context.DBConn, key) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + http.Redirect(resp, req, "/keys", http.StatusFound) + return success(context, req, resp) + } +} diff --git a/api/auth.go b/api/auth.go index 4733971..dcddf5a 100644 --- a/api/auth.go +++ b/api/auth.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/base64" "encoding/json" + "fmt" "io" "log" "net/http" @@ -116,32 +117,69 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp } } +func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { + if bearerToken == "" { + return nil, nil + } + + parts := strings.Split(bearerToken, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + return nil, nil + } + + apiKey, err := database.GetAPIKey(dbConn, parts[1]) + if err != nil { + return nil, err + } + if apiKey == nil { + return nil, nil + } + + user, err := database.GetUser(dbConn, apiKey.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + +func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { + session, err := database.GetSession(dbConn, sessionId) + if err != nil { + return nil, err + } + + if session.ExpireAt.Before(time.Now()) { + session = nil + database.DeleteSession(dbConn, sessionId) + return nil, fmt.Errorf("session expired") + } + + user, err := database.GetUser(dbConn, session.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain { + authHeader := req.Header.Get("Authorization") + user, userErr := getUserFromAuthHeader(context.DBConn, authHeader) + sessionCookie, err := req.Cookie("session") - if err != nil { - resp.WriteHeader(http.StatusUnauthorized) - return failure(context, req, resp) + if err == nil { + user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value) } - session, err := database.GetSession(context.DBConn, sessionCookie.Value) - if err == nil && session.ExpireAt.Before(time.Now()) { - session = nil - database.DeleteSession(context.DBConn, sessionCookie.Value) - } - if err != nil || session == nil { + if userErr != nil || user == nil { + log.Println(userErr, user) + http.SetCookie(resp, &http.Cookie{ Name: "session", - MaxAge: 0, + MaxAge: 0, // reset session cookie in case }) - - return failure(context, req, resp) - } - - user, err := database.GetUser(context.DBConn, session.UserID) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } @@ -3,10 +3,23 @@ package api import ( "log" "net/http" + "strconv" + "strings" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" ) +const MAX_USER_RECORDS = 20 + +type FormError struct { + Errors []string +} + +func userCanFuckWithDNSRecord(user *database.User, record *database.DNSRecord) bool { + return user.ID == record.UserID && (record.Name == user.Username || strings.HasSuffix(record.Name, "."+user.Username)) +} + func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain { dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID) @@ -17,7 +30,108 @@ func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp } (*context.TemplateData)["DNSRecords"] = dnsRecords + return success(context, req, resp) + } +} + +func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + formErrors := FormError{ + Errors: []string{}, + } + + name := req.FormValue("name") + recordType := req.FormValue("type") + recordContent := req.FormValue("content") + ttl := req.FormValue("ttl") + ttlNum, err := strconv.Atoi(ttl) + if err != nil { + formErrors.Errors = append(formErrors.Errors, "invalid ttl") + } + + dnsRecord := &database.DNSRecord{ + UserID: context.User.ID, + Name: name, + Type: recordType, + Content: recordContent, + TTL: ttlNum, + } + + dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + if len(dnsRecords) >= MAX_USER_RECORDS { + formErrors.Errors = append(formErrors.Errors, "max records reached") + } + + if !userCanFuckWithDNSRecord(context.User, dnsRecord) { + formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username) + } + + if len(formErrors.Errors) == 0 { + cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord) + if err != nil { + log.Println(err) + formErrors.Errors = append(formErrors.Errors, err.Error()) + } + + dnsRecord.ID = cloudflareRecordId + } + + if len(formErrors.Errors) == 0 { + _, err := database.SaveDNSRecord(context.DBConn, dnsRecord) + if err != nil { + log.Println(err) + formErrors.Errors = append(formErrors.Errors, "error saving record") + } + } + + if len(formErrors.Errors) == 0 { + http.Redirect(resp, req, "/dns", http.StatusFound) + return success(context, req, resp) + } + + (*context.TemplateData)["DNSRecords"] = dnsRecords + (*context.TemplateData)["FormError"] = &formErrors + (*context.TemplateData)["RecordForm"] = dnsRecord + + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } +} + +func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + recordId := req.FormValue("id") + record, err := database.GetDNSRecord(context.DBConn, recordId) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + if !userCanFuckWithDNSRecord(context.User, record) { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + err = database.DeleteDNSRecord(context.DBConn, recordId) + if err != nil { + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + http.Redirect(resp, req, "/dns", http.StatusFound) return success(context, req, resp) } } diff --git a/api/serve.go b/api/serve.go index 38b65b2..d16ea99 100644 --- a/api/serve.go +++ b/api/serve.go @@ -70,7 +70,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { mux := http.NewServeMux() fileServer := http.FileServer(http.Dir(argv.StaticPath)) - mux.Handle("/static/", http.StripPrefix("/static/", fileServer)) + mux.Handle("GET /static/", http.StripPrefix("/static/", fileServer)) makeRequestContext := func() *RequestContext { return &RequestContext{ @@ -81,7 +81,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { } } - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) @@ -116,6 +116,31 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) + mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) { + requestContext := makeRequestContext() + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CreateDNSRecordContinuation, GoLoginContinuation)(IdContinuation, TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + }) + + mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) { + requestContext := makeRequestContext() + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + }) + + mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) { + requestContext := makeRequestContext() + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListAPIKeysContinuation, GoLoginContinuation)(TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + }) + + mux.HandleFunc("POST /keys", func(w http.ResponseWriter, r *http.Request) { + requestContext := makeRequestContext() + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CreateAPIKeyContinuation, GoLoginContinuation)(IdContinuation, TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + }) + + mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) { + requestContext := makeRequestContext() + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteAPIKeyContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + }) + mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() name := r.PathValue("name") |
