summaryrefslogtreecommitdiff
path: root/api
diff options
context:
space:
mode:
authorsimponic <simponic@hatecomputers.club>2024-03-28 12:57:35 -0400
committersimponic <simponic@hatecomputers.club>2024-03-28 12:57:35 -0400
commitb2fc689bdcff28bf75c0128db19ba4730d726b4f (patch)
tree37c16d95183242516ba667aa5f441539d152c279 /api
parent75ba836d6072235fc7a71659f8630ab3c1b210ad (diff)
downloadhatecomputers.club-b2fc689bdcff28bf75c0128db19ba4730d726b4f.tar.gz
hatecomputers.club-b2fc689bdcff28bf75c0128db19ba4730d726b4f.zip
dns api (#1)
Co-authored-by: Elizabeth Hunt <elizabeth.hunt@simponic.xyz> Reviewed-on: https://git.hatecomputers.club/hatecomputers/hatecomputers.club/pulls/1
Diffstat (limited to 'api')
-rw-r--r--api/api_keys.go84
-rw-r--r--api/auth.go74
-rw-r--r--api/dns.go114
-rw-r--r--api/serve.go29
4 files changed, 281 insertions, 20 deletions
diff --git a/api/api_keys.go b/api/api_keys.go
new file mode 100644
index 0000000..17ed6c9
--- /dev/null
+++ b/api/api_keys.go
@@ -0,0 +1,84 @@
+package api
+
+import (
+ "log"
+ "net/http"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
+)
+
+const MAX_USER_API_KEYS = 5
+
+func ListAPIKeysContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
+ return func(success Continuation, failure Continuation) ContinuationChain {
+ apiKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+
+ (*context.TemplateData)["APIKeys"] = apiKeys
+ return success(context, req, resp)
+ }
+}
+
+func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
+ return func(success Continuation, failure Continuation) ContinuationChain {
+ formErrors := FormError{
+ Errors: []string{},
+ }
+
+ apiKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+
+ if len(apiKeys) >= MAX_USER_API_KEYS {
+ formErrors.Errors = append(formErrors.Errors, "max api keys reached")
+ }
+
+ _, err = database.SaveAPIKey(context.DBConn, &database.UserApiKey{
+ UserID: context.User.ID,
+ Key: utils.RandomId(),
+ })
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+
+ http.Redirect(resp, req, "/keys", http.StatusFound)
+ return success(context, req, resp)
+ }
+}
+
+func DeleteAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
+ return func(success Continuation, failure Continuation) ContinuationChain {
+ key := req.FormValue("key")
+
+ apiKey, err := database.GetAPIKey(context.DBConn, key)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+ if (apiKey == nil) || (apiKey.UserID != context.User.ID) {
+ resp.WriteHeader(http.StatusUnauthorized)
+ return failure(context, req, resp)
+ }
+
+ err = database.DeleteAPIKey(context.DBConn, key)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+
+ http.Redirect(resp, req, "/keys", http.StatusFound)
+ return success(context, req, resp)
+ }
+}
diff --git a/api/auth.go b/api/auth.go
index 4733971..dcddf5a 100644
--- a/api/auth.go
+++ b/api/auth.go
@@ -5,6 +5,7 @@ import (
"database/sql"
"encoding/base64"
"encoding/json"
+ "fmt"
"io"
"log"
"net/http"
@@ -116,32 +117,69 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
}
}
+func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) {
+ if bearerToken == "" {
+ return nil, nil
+ }
+
+ parts := strings.Split(bearerToken, " ")
+ if len(parts) != 2 || parts[0] != "Bearer" {
+ return nil, nil
+ }
+
+ apiKey, err := database.GetAPIKey(dbConn, parts[1])
+ if err != nil {
+ return nil, err
+ }
+ if apiKey == nil {
+ return nil, nil
+ }
+
+ user, err := database.GetUser(dbConn, apiKey.UserID)
+ if err != nil {
+ return nil, err
+ }
+
+ return user, nil
+}
+
+func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) {
+ session, err := database.GetSession(dbConn, sessionId)
+ if err != nil {
+ return nil, err
+ }
+
+ if session.ExpireAt.Before(time.Now()) {
+ session = nil
+ database.DeleteSession(dbConn, sessionId)
+ return nil, fmt.Errorf("session expired")
+ }
+
+ user, err := database.GetUser(dbConn, session.UserID)
+ if err != nil {
+ return nil, err
+ }
+
+ return user, nil
+}
+
func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
+ authHeader := req.Header.Get("Authorization")
+ user, userErr := getUserFromAuthHeader(context.DBConn, authHeader)
+
sessionCookie, err := req.Cookie("session")
- if err != nil {
- resp.WriteHeader(http.StatusUnauthorized)
- return failure(context, req, resp)
+ if err == nil {
+ user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value)
}
- session, err := database.GetSession(context.DBConn, sessionCookie.Value)
- if err == nil && session.ExpireAt.Before(time.Now()) {
- session = nil
- database.DeleteSession(context.DBConn, sessionCookie.Value)
- }
- if err != nil || session == nil {
+ if userErr != nil || user == nil {
+ log.Println(userErr, user)
+
http.SetCookie(resp, &http.Cookie{
Name: "session",
- MaxAge: 0,
+ MaxAge: 0, // reset session cookie in case
})
-
- return failure(context, req, resp)
- }
-
- user, err := database.GetUser(context.DBConn, session.UserID)
- if err != nil {
- log.Println(err)
- resp.WriteHeader(http.StatusUnauthorized)
return failure(context, req, resp)
}
diff --git a/api/dns.go b/api/dns.go
index 3105f91..5123acc 100644
--- a/api/dns.go
+++ b/api/dns.go
@@ -3,10 +3,23 @@ package api
import (
"log"
"net/http"
+ "strconv"
+ "strings"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
)
+const MAX_USER_RECORDS = 20
+
+type FormError struct {
+ Errors []string
+}
+
+func userCanFuckWithDNSRecord(user *database.User, record *database.DNSRecord) bool {
+ return user.ID == record.UserID && (record.Name == user.Username || strings.HasSuffix(record.Name, "."+user.Username))
+}
+
func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID)
@@ -17,7 +30,108 @@ func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp
}
(*context.TemplateData)["DNSRecords"] = dnsRecords
+ return success(context, req, resp)
+ }
+}
+
+func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
+ return func(success Continuation, failure Continuation) ContinuationChain {
+ formErrors := FormError{
+ Errors: []string{},
+ }
+
+ name := req.FormValue("name")
+ recordType := req.FormValue("type")
+ recordContent := req.FormValue("content")
+ ttl := req.FormValue("ttl")
+ ttlNum, err := strconv.Atoi(ttl)
+ if err != nil {
+ formErrors.Errors = append(formErrors.Errors, "invalid ttl")
+ }
+
+ dnsRecord := &database.DNSRecord{
+ UserID: context.User.ID,
+ Name: name,
+ Type: recordType,
+ Content: recordContent,
+ TTL: ttlNum,
+ }
+
+ dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+ if len(dnsRecords) >= MAX_USER_RECORDS {
+ formErrors.Errors = append(formErrors.Errors, "max records reached")
+ }
+
+ if !userCanFuckWithDNSRecord(context.User, dnsRecord) {
+ formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username)
+ }
+
+ if len(formErrors.Errors) == 0 {
+ cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord)
+ if err != nil {
+ log.Println(err)
+ formErrors.Errors = append(formErrors.Errors, err.Error())
+ }
+
+ dnsRecord.ID = cloudflareRecordId
+ }
+
+ if len(formErrors.Errors) == 0 {
+ _, err := database.SaveDNSRecord(context.DBConn, dnsRecord)
+ if err != nil {
+ log.Println(err)
+ formErrors.Errors = append(formErrors.Errors, "error saving record")
+ }
+ }
+
+ if len(formErrors.Errors) == 0 {
+ http.Redirect(resp, req, "/dns", http.StatusFound)
+ return success(context, req, resp)
+ }
+
+ (*context.TemplateData)["DNSRecords"] = dnsRecords
+ (*context.TemplateData)["FormError"] = &formErrors
+ (*context.TemplateData)["RecordForm"] = dnsRecord
+
+ resp.WriteHeader(http.StatusBadRequest)
+ return failure(context, req, resp)
+ }
+}
+
+func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
+ return func(success Continuation, failure Continuation) ContinuationChain {
+ recordId := req.FormValue("id")
+ record, err := database.GetDNSRecord(context.DBConn, recordId)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+
+ if !userCanFuckWithDNSRecord(context.User, record) {
+ resp.WriteHeader(http.StatusUnauthorized)
+ return failure(context, req, resp)
+ }
+
+ err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+
+ err = database.DeleteDNSRecord(context.DBConn, recordId)
+ if err != nil {
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+ http.Redirect(resp, req, "/dns", http.StatusFound)
return success(context, req, resp)
}
}
diff --git a/api/serve.go b/api/serve.go
index 38b65b2..d16ea99 100644
--- a/api/serve.go
+++ b/api/serve.go
@@ -70,7 +70,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
mux := http.NewServeMux()
fileServer := http.FileServer(http.Dir(argv.StaticPath))
- mux.Handle("/static/", http.StripPrefix("/static/", fileServer))
+ mux.Handle("GET /static/", http.StripPrefix("/static/", fileServer))
makeRequestContext := func() *RequestContext {
return &RequestContext{
@@ -81,7 +81,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
}
}
- mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+ mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
@@ -116,6 +116,31 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
+ mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) {
+ requestContext := makeRequestContext()
+ LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CreateDNSRecordContinuation, GoLoginContinuation)(IdContinuation, TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ })
+
+ mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) {
+ requestContext := makeRequestContext()
+ LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ })
+
+ mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) {
+ requestContext := makeRequestContext()
+ LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListAPIKeysContinuation, GoLoginContinuation)(TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ })
+
+ mux.HandleFunc("POST /keys", func(w http.ResponseWriter, r *http.Request) {
+ requestContext := makeRequestContext()
+ LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CreateAPIKeyContinuation, GoLoginContinuation)(IdContinuation, TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ })
+
+ mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) {
+ requestContext := makeRequestContext()
+ LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteAPIKeyContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ })
+
mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
name := r.PathValue("name")