From f38e8719c2a8537fe9b64ed8ceca45858a58e498 Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Wed, 3 Apr 2024 17:53:50 -0600 Subject: make it compile --- api/api_keys.go | 87 ------------ api/auth.go | 287 --------------------------------------- api/auth/auth.go | 288 ++++++++++++++++++++++++++++++++++++++++ api/auth/auth_test.go | 36 +++++ api/auth_test.go | 37 ------ api/dns.go | 177 ------------------------ api/dns/dns.go | 178 +++++++++++++++++++++++++ api/dns/dns_test.go | 63 +++++++++ api/dns_test.go | 56 -------- api/guestbook.go | 88 ------------ api/guestbook/guestbook.go | 85 ++++++++++++ api/guestbook/guestbook_test.go | 136 +++++++++++++++++++ api/guestbook_test.go | 129 ------------------ api/hcaptcha.go | 69 ---------- api/hcaptcha/hcaptcha.go | 75 +++++++++++ api/keys/keys.go | 88 ++++++++++++ api/serve.go | 76 +++++------ api/template.go | 74 ----------- api/template/template.go | 76 +++++++++++ api/types/types.go | 28 ++++ 20 files changed, 1085 insertions(+), 1048 deletions(-) delete mode 100644 api/api_keys.go delete mode 100644 api/auth.go create mode 100644 api/auth/auth.go create mode 100644 api/auth/auth_test.go delete mode 100644 api/auth_test.go delete mode 100644 api/dns.go create mode 100644 api/dns/dns.go create mode 100644 api/dns/dns_test.go delete mode 100644 api/dns_test.go delete mode 100644 api/guestbook.go create mode 100644 api/guestbook/guestbook.go create mode 100644 api/guestbook/guestbook_test.go delete mode 100644 api/guestbook_test.go delete mode 100644 api/hcaptcha.go create mode 100644 api/hcaptcha/hcaptcha.go create mode 100644 api/keys/keys.go delete mode 100644 api/template.go create mode 100644 api/template/template.go create mode 100644 api/types/types.go (limited to 'api') diff --git a/api/api_keys.go b/api/api_keys.go deleted file mode 100644 index d636044..0000000 --- a/api/api_keys.go +++ /dev/null @@ -1,87 +0,0 @@ -package api - -import ( - "log" - "net/http" - - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" -) - -const MAX_USER_API_KEYS = 5 - -func ListAPIKeysContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - apiKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - (*context.TemplateData)["APIKeys"] = apiKeys - return success(context, req, resp) - } -} - -func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - formErrors := FormError{ - Errors: []string{}, - } - - numKeys, err := database.CountUserAPIKeys(context.DBConn, context.User.ID) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - if numKeys >= MAX_USER_API_KEYS { - formErrors.Errors = append(formErrors.Errors, "max api keys reached") - } - - if len(formErrors.Errors) > 0 { - (*context.TemplateData)["FormError"] = formErrors - return failure(context, req, resp) - } - - _, err = database.SaveAPIKey(context.DBConn, &database.UserApiKey{ - UserID: context.User.ID, - Key: utils.RandomId(), - }) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - return success(context, req, resp) - } -} - -func DeleteAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - key := req.FormValue("key") - - apiKey, err := database.GetAPIKey(context.DBConn, key) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - if (apiKey == nil) || (apiKey.UserID != context.User.ID) { - resp.WriteHeader(http.StatusUnauthorized) - return failure(context, req, resp) - } - - err = database.DeleteAPIKey(context.DBConn, key) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - http.Redirect(resp, req, "/keys", http.StatusFound) - return success(context, req, resp) - } -} diff --git a/api/auth.go b/api/auth.go deleted file mode 100644 index 0e4c1ed..0000000 --- a/api/auth.go +++ /dev/null @@ -1,287 +0,0 @@ -package api - -import ( - "crypto/sha256" - "database/sql" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "log" - "net/http" - "strings" - "time" - - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" - "golang.org/x/oauth2" -) - -func StartSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - verifier := utils.RandomId() + utils.RandomId() - - sha2 := sha256.New() - io.WriteString(sha2, verifier) - codeChallenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil)) - - state := utils.RandomId() - url := context.Args.OauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge", codeChallenge)) - - http.SetCookie(resp, &http.Cookie{ - Name: "verifier", - Value: verifier, - Path: "/", - Secure: true, - SameSite: http.SameSiteLaxMode, - MaxAge: 60, - }) - http.SetCookie(resp, &http.Cookie{ - Name: "state", - Value: state, - Path: "/", - Secure: true, - SameSite: http.SameSiteLaxMode, - MaxAge: 60, - }) - - http.Redirect(resp, req, url, http.StatusFound) - return success(context, req, resp) - } -} - -func InterceptOauthCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - state := req.URL.Query().Get("state") - code := req.URL.Query().Get("code") - - if code == "" || state == "" { - resp.WriteHeader(http.StatusBadRequest) - return failure(context, req, resp) - } - - if !verifyState(req, "state", state) { - resp.WriteHeader(http.StatusBadRequest) - return failure(context, req, resp) - } - verifierCookie, err := req.Cookie("verifier") - if err != nil { - resp.WriteHeader(http.StatusBadRequest) - return failure(context, req, resp) - } - - reqContext := req.Context() - token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value)) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - client := context.Args.OauthConfig.Client(reqContext, token) - user, err := getOauthUser(context.DBConn, client, context.Args.OauthUserInfoURI) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - - return failure(context, req, resp) - } - - session, err := database.MakeUserSessionFor(context.DBConn, user) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - http.SetCookie(resp, &http.Cookie{ - Name: "session", - Value: session.ID, - Path: "/", - SameSite: http.SameSiteLaxMode, - Secure: true, - }) - - redirect := "/" - redirectCookie, err := req.Cookie("redirect") - if err == nil && redirectCookie.Value != "" { - redirect = redirectCookie.Value - http.SetCookie(resp, &http.Cookie{ - Name: "redirect", - MaxAge: 0, - }) - } - - http.Redirect(resp, req, redirect, http.StatusFound) - return success(context, req, resp) - } -} - -func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { - if bearerToken == "" { - return nil, nil - } - - parts := strings.Split(bearerToken, " ") - if len(parts) != 2 || parts[0] != "Bearer" { - return nil, nil - } - - apiKey, err := database.GetAPIKey(dbConn, parts[1]) - if err != nil { - return nil, err - } - if apiKey == nil { - return nil, nil - } - - user, err := database.GetUser(dbConn, apiKey.UserID) - if err != nil { - return nil, err - } - - return user, nil -} - -func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { - session, err := database.GetSession(dbConn, sessionId) - if err != nil { - return nil, err - } - - if session.ExpireAt.Before(time.Now()) { - session = nil - database.DeleteSession(dbConn, sessionId) - return nil, fmt.Errorf("session expired") - } - - user, err := database.GetUser(dbConn, session.UserID) - if err != nil { - return nil, err - } - - return user, nil -} - -func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - authHeader := req.Header.Get("Authorization") - user, userErr := getUserFromAuthHeader(context.DBConn, authHeader) - - sessionCookie, err := req.Cookie("session") - if err == nil && sessionCookie.Value != "" { - user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value) - } - - if userErr != nil || user == nil { - log.Println(userErr, user) - - http.SetCookie(resp, &http.Cookie{ - Name: "session", - MaxAge: 0, // reset session cookie in case - }) - - context.User = nil - return failure(context, req, resp) - } - - context.User = user - return success(context, req, resp) - } -} - -func GoLoginContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - http.SetCookie(resp, &http.Cookie{ - Name: "redirect", - Value: req.URL.Path, - Path: "/", - Secure: true, - SameSite: http.SameSiteLaxMode, - }) - - http.Redirect(resp, req, "/login", http.StatusFound) - return failure(context, req, resp) - } -} - -func RefreshSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - sessionCookie, err := req.Cookie("session") - if err != nil { - resp.WriteHeader(http.StatusUnauthorized) - return failure(context, req, resp) - } - - _, err = database.RefreshSession(context.DBConn, sessionCookie.Value) - if err != nil { - resp.WriteHeader(http.StatusUnauthorized) - return failure(context, req, resp) - } - - return success(context, req, resp) - } -} - -func LogoutContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - sessionCookie, err := req.Cookie("session") - if err == nil && sessionCookie.Value != "" { - _ = database.DeleteSession(context.DBConn, sessionCookie.Value) - } - - http.Redirect(resp, req, "/", http.StatusFound) - http.SetCookie(resp, &http.Cookie{ - Name: "session", - MaxAge: 0, - }) - return success(context, req, resp) - } -} - -func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) { - userResponse, err := client.Get(uri) - if err != nil { - return nil, err - } - - userStruct, err := createUserFromResponse(userResponse) - if err != nil { - return nil, err - } - - user, err := database.FindOrSaveUser(dbConn, userStruct) - if err != nil { - return nil, err - } - - return user, nil -} - -func createUserFromResponse(response *http.Response) (*database.User, error) { - user := &database.User{ - CreatedAt: time.Now(), - } - - err := json.NewDecoder(response.Body).Decode(user) - defer response.Body.Close() - - if err != nil { - log.Println(err) - return nil, err - } - - user.Username = strings.ToLower(user.Username) - user.Username = strings.Split(user.Username, "@")[0] - - return user, nil -} - -func verifyState(req *http.Request, stateCookieName string, expectedState string) bool { - cookie, err := req.Cookie(stateCookieName) - if err != nil || cookie.Value != expectedState { - return false - } - - return true -} diff --git a/api/auth/auth.go b/api/auth/auth.go new file mode 100644 index 0000000..dc348b2 --- /dev/null +++ b/api/auth/auth.go @@ -0,0 +1,288 @@ +package auth + +import ( + "crypto/sha256" + "database/sql" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "strings" + "time" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" + "golang.org/x/oauth2" +) + +func StartSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + verifier := utils.RandomId() + utils.RandomId() + + sha2 := sha256.New() + io.WriteString(sha2, verifier) + codeChallenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil)) + + state := utils.RandomId() + url := context.Args.OauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge", codeChallenge)) + + http.SetCookie(resp, &http.Cookie{ + Name: "verifier", + Value: verifier, + Path: "/", + Secure: true, + SameSite: http.SameSiteLaxMode, + MaxAge: 60, + }) + http.SetCookie(resp, &http.Cookie{ + Name: "state", + Value: state, + Path: "/", + Secure: true, + SameSite: http.SameSiteLaxMode, + MaxAge: 60, + }) + + http.Redirect(resp, req, url, http.StatusFound) + return success(context, req, resp) + } +} + +func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + state := req.URL.Query().Get("state") + code := req.URL.Query().Get("code") + + if code == "" || state == "" { + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + + if !verifyState(req, "state", state) { + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + verifierCookie, err := req.Cookie("verifier") + if err != nil { + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + + reqContext := req.Context() + token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value)) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + client := context.Args.OauthConfig.Client(reqContext, token) + user, err := getOauthUser(context.DBConn, client, context.Args.OauthUserInfoURI) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + + return failure(context, req, resp) + } + + session, err := database.MakeUserSessionFor(context.DBConn, user) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + http.SetCookie(resp, &http.Cookie{ + Name: "session", + Value: session.ID, + Path: "/", + SameSite: http.SameSiteLaxMode, + Secure: true, + }) + + redirect := "/" + redirectCookie, err := req.Cookie("redirect") + if err == nil && redirectCookie.Value != "" { + redirect = redirectCookie.Value + http.SetCookie(resp, &http.Cookie{ + Name: "redirect", + MaxAge: 0, + }) + } + + http.Redirect(resp, req, redirect, http.StatusFound) + return success(context, req, resp) + } +} + +func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { + if bearerToken == "" { + return nil, nil + } + + parts := strings.Split(bearerToken, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + return nil, nil + } + + typesKey, err := database.GetAPIKey(dbConn, parts[1]) + if err != nil { + return nil, err + } + if typesKey == nil { + return nil, nil + } + + user, err := database.GetUser(dbConn, typesKey.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + +func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { + session, err := database.GetSession(dbConn, sessionId) + if err != nil { + return nil, err + } + + if session.ExpireAt.Before(time.Now()) { + session = nil + database.DeleteSession(dbConn, sessionId) + return nil, fmt.Errorf("session expired") + } + + user, err := database.GetUser(dbConn, session.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + +func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + authHeader := req.Header.Get("Authorization") + user, userErr := getUserFromAuthHeader(context.DBConn, authHeader) + + sessionCookie, err := req.Cookie("session") + if err == nil && sessionCookie.Value != "" { + user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value) + } + + if userErr != nil || user == nil { + log.Println(userErr, user) + + http.SetCookie(resp, &http.Cookie{ + Name: "session", + MaxAge: 0, // reset session cookie in case + }) + + context.User = nil + return failure(context, req, resp) + } + + context.User = user + return success(context, req, resp) + } +} + +func GoLoginContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + http.SetCookie(resp, &http.Cookie{ + Name: "redirect", + Value: req.URL.Path, + Path: "/", + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + + http.Redirect(resp, req, "/login", http.StatusFound) + return failure(context, req, resp) + } +} + +func RefreshSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + sessionCookie, err := req.Cookie("session") + if err != nil { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + _, err = database.RefreshSession(context.DBConn, sessionCookie.Value) + if err != nil { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + return success(context, req, resp) + } +} + +func LogoutContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + sessionCookie, err := req.Cookie("session") + if err == nil && sessionCookie.Value != "" { + _ = database.DeleteSession(context.DBConn, sessionCookie.Value) + } + + http.Redirect(resp, req, "/", http.StatusFound) + http.SetCookie(resp, &http.Cookie{ + Name: "session", + MaxAge: 0, + }) + return success(context, req, resp) + } +} + +func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) { + userResponse, err := client.Get(uri) + if err != nil { + return nil, err + } + + userStruct, err := createUserFromResponse(userResponse) + if err != nil { + return nil, err + } + + user, err := database.FindOrSaveUser(dbConn, userStruct) + if err != nil { + return nil, err + } + + return user, nil +} + +func createUserFromResponse(response *http.Response) (*database.User, error) { + user := &database.User{ + CreatedAt: time.Now(), + } + + err := json.NewDecoder(response.Body).Decode(user) + defer response.Body.Close() + + if err != nil { + log.Println(err) + return nil, err + } + + user.Username = strings.ToLower(user.Username) + user.Username = strings.Split(user.Username, "@")[0] + + return user, nil +} + +func verifyState(req *http.Request, stateCookieName string, expectedState string) bool { + cookie, err := req.Cookie(stateCookieName) + if err != nil || cookie.Value != expectedState { + return false + } + + return true +} diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go new file mode 100644 index 0000000..a6c2a45 --- /dev/null +++ b/api/auth/auth_test.go @@ -0,0 +1,36 @@ +package auth_test + +import ( + "database/sql" + "os" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +func setup() (*sql.DB, *types.RequestContext, func()) { + randomDb := utils.RandomId() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + + context := &types.RequestContext{ + DBConn: testDb, + Args: &args.Arguments{}, + TemplateData: &(map[string]interface{}{}), + } + + return testDb, context, func() { + testDb.Close() + os.Remove(randomDb) + } +} + +/* +todo: test types key creation ++ api key attached to user ++ user session is unique ++ goLogin goes to page in cookie +*/ diff --git a/api/auth_test.go b/api/auth_test.go deleted file mode 100644 index 45ca12e..0000000 --- a/api/auth_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package api_test - -import ( - "database/sql" - "os" - "testing" - - "git.hatecomputers.club/hatecomputers/hatecomputers.club/api" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" -) - -func setup() (*sql.DB, *api.RequestContext, func()) { - randomDb := utils.RandomId() - - testDb := database.MakeConn(&randomDb) - database.Migrate(testDb) - - context := &api.RequestContext{ - DBConn: testDb, - Args: &args.Arguments{}, - TemplateData: &(map[string]interface{}{}), - } - - return testDb, context, func() { - testDb.Close() - os.Remove(randomDb) - } -} - -/* -todo: test api key creation -+ api key attached to user -+ user session is unique -+ goLogin goes to page in cookie -*/ diff --git a/api/dns.go b/api/dns.go deleted file mode 100644 index 7ade6e4..0000000 --- a/api/dns.go +++ /dev/null @@ -1,177 +0,0 @@ -package api - -import ( - "database/sql" - "fmt" - "log" - "net/http" - "strconv" - "strings" - - "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" -) - -const MAX_USER_RECORDS = 65 - -var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} - -func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool { - ownedByUser := (user.ID == record.UserID) - if !ownedByUser { - return false - } - - if !record.Internal { - for _, format := range ownedInternalDomainFormats { - domain := fmt.Sprintf(format, user.Username) - - isInSubDomain := strings.HasSuffix(record.Name, "."+domain) - if domain == record.Name || isInSubDomain { - return true - } - } - return false - } - - owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name) - if err != nil { - log.Println(err) - return false - } - - userIsOwnerOfDomain := owner == user.ID - return ownedByUser && userIsOwnerOfDomain -} - -func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - (*context.TemplateData)["DNSRecords"] = dnsRecords - return success(context, req, resp) - } -} - -func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - formErrors := FormError{ - Errors: []string{}, - } - - internal := req.FormValue("internal") == "on" - name := req.FormValue("name") - if internal && !strings.HasSuffix(name, ".") { - name += "." - } - - recordType := req.FormValue("type") - recordType = strings.ToUpper(recordType) - - recordContent := req.FormValue("content") - ttl := req.FormValue("ttl") - ttlNum, err := strconv.Atoi(ttl) - if err != nil { - formErrors.Errors = append(formErrors.Errors, "invalid ttl") - } - - dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - if dnsRecordCount >= MAX_USER_RECORDS { - formErrors.Errors = append(formErrors.Errors, "max records reached") - } - - dnsRecord := &database.DNSRecord{ - UserID: context.User.ID, - Name: name, - Type: recordType, - Content: recordContent, - TTL: ttlNum, - Internal: internal, - } - - if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, USER_OWNED_INTERNAL_FMT_DOMAINS) { - formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") - } - - if len(formErrors.Errors) == 0 { - if dnsRecord.Internal { - dnsRecord.ID = utils.RandomId() - } else { - dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord) - if err != nil { - log.Println(err) - formErrors.Errors = append(formErrors.Errors, err.Error()) - } - } - } - - if len(formErrors.Errors) == 0 { - _, err := database.SaveDNSRecord(context.DBConn, dnsRecord) - if err != nil { - log.Println(err) - formErrors.Errors = append(formErrors.Errors, "error saving record") - } - } - - if len(formErrors.Errors) == 0 { - http.Redirect(resp, req, "/dns", http.StatusFound) - return success(context, req, resp) - } - - (*context.TemplateData)["FormError"] = &formErrors - (*context.TemplateData)["RecordForm"] = dnsRecord - - resp.WriteHeader(http.StatusBadRequest) - return failure(context, req, resp) - } - } -} - -func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - recordId := req.FormValue("id") - record, err := database.GetDNSRecord(context.DBConn, recordId) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - if !userCanFuckWithDNSRecord(context.DBConn, context.User, record, USER_OWNED_INTERNAL_FMT_DOMAINS) { - resp.WriteHeader(http.StatusUnauthorized) - return failure(context, req, resp) - } - - if !record.Internal { - err = dnsAdapter.DeleteDNSRecord(recordId) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - } - - err = database.DeleteDNSRecord(context.DBConn, recordId) - if err != nil { - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - http.Redirect(resp, req, "/dns", http.StatusFound) - return success(context, req, resp) - } - } -} diff --git a/api/dns/dns.go b/api/dns/dns.go new file mode 100644 index 0000000..4805146 --- /dev/null +++ b/api/dns/dns.go @@ -0,0 +1,178 @@ +package dns + +import ( + "database/sql" + "fmt" + "log" + "net/http" + "strconv" + "strings" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +const MAX_USER_RECORDS = 65 + +var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} + +func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool { + ownedByUser := (user.ID == record.UserID) + if !ownedByUser { + return false + } + + if !record.Internal { + for _, format := range ownedInternalDomainFormats { + domain := fmt.Sprintf(format, user.Username) + + isInSubDomain := strings.HasSuffix(record.Name, "."+domain) + if domain == record.Name || isInSubDomain { + return true + } + } + return false + } + + owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name) + if err != nil { + log.Println(err) + return false + } + + userIsOwnerOfDomain := owner == user.ID + return ownedByUser && userIsOwnerOfDomain +} + +func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + (*context.TemplateData)["DNSRecords"] = dnsRecords + return success(context, req, resp) + } +} + +func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + formErrors := types.FormError{ + Errors: []string{}, + } + + internal := req.FormValue("internal") == "on" + name := req.FormValue("name") + if internal && !strings.HasSuffix(name, ".") { + name += "." + } + + recordType := req.FormValue("type") + recordType = strings.ToUpper(recordType) + + recordContent := req.FormValue("content") + ttl := req.FormValue("ttl") + ttlNum, err := strconv.Atoi(ttl) + if err != nil { + formErrors.Errors = append(formErrors.Errors, "invalid ttl") + } + + dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + if dnsRecordCount >= MAX_USER_RECORDS { + formErrors.Errors = append(formErrors.Errors, "max records reached") + } + + dnsRecord := &database.DNSRecord{ + UserID: context.User.ID, + Name: name, + Type: recordType, + Content: recordContent, + TTL: ttlNum, + Internal: internal, + } + + if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, USER_OWNED_INTERNAL_FMT_DOMAINS) { + formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") + } + + if len(formErrors.Errors) == 0 { + if dnsRecord.Internal { + dnsRecord.ID = utils.RandomId() + } else { + dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord) + if err != nil { + log.Println(err) + formErrors.Errors = append(formErrors.Errors, err.Error()) + } + } + } + + if len(formErrors.Errors) == 0 { + _, err := database.SaveDNSRecord(context.DBConn, dnsRecord) + if err != nil { + log.Println(err) + formErrors.Errors = append(formErrors.Errors, "error saving record") + } + } + + if len(formErrors.Errors) == 0 { + http.Redirect(resp, req, "/dns", http.StatusFound) + return success(context, req, resp) + } + + (*context.TemplateData)["FormError"] = &formErrors + (*context.TemplateData)["RecordForm"] = dnsRecord + + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + } +} + +func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + recordId := req.FormValue("id") + record, err := database.GetDNSRecord(context.DBConn, recordId) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + if !userCanFuckWithDNSRecord(context.DBConn, context.User, record, USER_OWNED_INTERNAL_FMT_DOMAINS) { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + if !record.Internal { + err = dnsAdapter.DeleteDNSRecord(recordId) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + } + + err = database.DeleteDNSRecord(context.DBConn, recordId) + if err != nil { + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + http.Redirect(resp, req, "/dns", http.StatusFound) + return success(context, req, resp) + } + } +} diff --git a/api/dns/dns_test.go b/api/dns/dns_test.go new file mode 100644 index 0000000..cc56120 --- /dev/null +++ b/api/dns/dns_test.go @@ -0,0 +1,63 @@ +package dns_test + +import ( + "database/sql" + "net/http" + "net/http/httptest" + "os" + "testing" + + // "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { + return success(context, req, resp) + } +} + +func setup() (*sql.DB, *types.RequestContext, func()) { + randomDb := utils.RandomId() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + + context := &types.RequestContext{ + DBConn: testDb, + Args: &args.Arguments{}, + TemplateData: &(map[string]interface{}{}), + } + + return testDb, context, func() { + testDb.Close() + os.Remove(randomDb) + } +} + +func TestThatOwnerCanPutRecordInDomain(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + _ = &database.User{ + ID: "test", + Username: "test", + } + + records, err := database.GetUserDNSRecords(db, context.User.ID) + if err != nil { + t.Fatal(err) + } + if len(records) > 0 { + t.Errorf("expected no records, got records") + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // dns.CreateDNSRecordContinuation(context, r, w)(IdContinuation, IdContinuation) + })) + defer ts.Close() + +} diff --git a/api/dns_test.go b/api/dns_test.go deleted file mode 100644 index 59dd85b..0000000 --- a/api/dns_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package api_test - -import ( - "database/sql" - "net/http" - "net/http/httptest" - "os" - "testing" - - "git.hatecomputers.club/hatecomputers/hatecomputers.club/api" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" -) - -func setup() (*sql.DB, *api.RequestContext, func()) { - randomDb := utils.RandomId() - - testDb := database.MakeConn(&randomDb) - database.Migrate(testDb) - - context := &api.RequestContext{ - DBConn: testDb, - Args: &args.Arguments{}, - TemplateData: &(map[string]interface{}{}), - } - - return testDb, context, func() { - testDb.Close() - os.Remove(randomDb) - } -} - -func TestThatOwnerCanPutRecordInDomain(t *testing.T) { - db, context, cleanup := setup() - defer cleanup() - - testUser := &database.User{ - ID: "test", - Username: "test", - } - - records, err := database.GetUserDNSRecords(db, context.User.ID) - if err != nil { - t.Fatal(err) - } - if len(records) > 0 { - t.Errorf("expected no records, got records") - } - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - api.PutDNSRecordContinuation(context, r, w)(api.IdContinuation, api.IdContinuation) - })) - defer ts.Close() - -} diff --git a/api/guestbook.go b/api/guestbook.go deleted file mode 100644 index ee3c79a..0000000 --- a/api/guestbook.go +++ /dev/null @@ -1,88 +0,0 @@ -package api - -import ( - "log" - "net/http" - "strings" - - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" -) - -type HcaptchaArgs struct { - SiteKey string -} - -func validateGuestbookEntry(entry *database.GuestbookEntry) []string { - errors := []string{} - - if entry.Name == "" { - errors = append(errors, "name is required") - } - - if entry.Message == "" { - errors = append(errors, "message is required") - } - - messageLength := len(entry.Message) - if messageLength > 500 { - errors = append(errors, "message cannot be longer than 500 characters") - } - - newLines := strings.Count(entry.Message, "\n") - if newLines > 10 { - errors = append(errors, "message cannot contain more than 10 new lines") - } - - return errors -} - -func SignGuestbookContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - name := req.FormValue("name") - message := req.FormValue("message") - - formErrors := FormError{ - Errors: []string{}, - } - - entry := &database.GuestbookEntry{ - ID: utils.RandomId(), - Name: name, - Message: message, - } - formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...) - - if len(formErrors.Errors) == 0 { - _, err := database.SaveGuestbookEntry(context.DBConn, entry) - if err != nil { - log.Println(err) - formErrors.Errors = append(formErrors.Errors, "failed to save entry") - } - } - - if len(formErrors.Errors) > 0 { - (*context.TemplateData)["FormError"] = formErrors - (*context.TemplateData)["EntryForm"] = entry - resp.WriteHeader(http.StatusBadRequest) - - return failure(context, req, resp) - } - - return success(context, req, resp) - } -} - -func ListGuestbookContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - entries, err := database.GetGuestbookEntries(context.DBConn) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - (*context.TemplateData)["GuestbookEntries"] = entries - return success(context, req, resp) - } -} diff --git a/api/guestbook/guestbook.go b/api/guestbook/guestbook.go new file mode 100644 index 0000000..60a7b4b --- /dev/null +++ b/api/guestbook/guestbook.go @@ -0,0 +1,85 @@ +package guestbook + +import ( + "log" + "net/http" + "strings" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +func validateGuestbookEntry(entry *database.GuestbookEntry) []string { + errors := []string{} + + if entry.Name == "" { + errors = append(errors, "name is required") + } + + if entry.Message == "" { + errors = append(errors, "message is required") + } + + messageLength := len(entry.Message) + if messageLength > 500 { + errors = append(errors, "message cannot be longer than 500 characters") + } + + newLines := strings.Count(entry.Message, "\n") + if newLines > 10 { + errors = append(errors, "message cannot contain more than 10 new lines") + } + + return errors +} + +func SignGuestbookContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + name := req.FormValue("name") + message := req.FormValue("message") + + formErrors := types.FormError{ + Errors: []string{}, + } + + entry := &database.GuestbookEntry{ + ID: utils.RandomId(), + Name: name, + Message: message, + } + formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...) + + if len(formErrors.Errors) == 0 { + _, err := database.SaveGuestbookEntry(context.DBConn, entry) + if err != nil { + log.Println(err) + formErrors.Errors = append(formErrors.Errors, "failed to save entry") + } + } + + if len(formErrors.Errors) > 0 { + (*context.TemplateData)["FormError"] = formErrors + (*context.TemplateData)["EntryForm"] = entry + resp.WriteHeader(http.StatusBadRequest) + + return failure(context, req, resp) + } + + return success(context, req, resp) + } +} + +func ListGuestbookContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + entries, err := database.GetGuestbookEntries(context.DBConn) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + (*context.TemplateData)["GuestbookEntries"] = entries + return success(context, req, resp) + } +} diff --git a/api/guestbook/guestbook_test.go b/api/guestbook/guestbook_test.go new file mode 100644 index 0000000..9fd6c62 --- /dev/null +++ b/api/guestbook/guestbook_test.go @@ -0,0 +1,136 @@ +package guestbook_test + +import ( + "database/sql" + "net/http" + "net/http/httptest" + "os" + "testing" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/guestbook" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { + return success(context, req, resp) + } +} + +func setup() (*sql.DB, *types.RequestContext, func()) { + randomDb := utils.RandomId() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + + context := &types.RequestContext{ + DBConn: testDb, + Args: &args.Arguments{}, + TemplateData: &(map[string]interface{}{}), + } + + return testDb, context, func() { + testDb.Close() + os.Remove(randomDb) + } +} + +func TestValidGuestbookPutsInDatabase(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + entries, err := database.GetGuestbookEntries(db) + if err != nil { + t.Fatal(err) + } + if len(entries) > 0 { + t.Errorf("expected no entries, got entries") + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + guestbook.SignGuestbookContinuation(context, r, w)(IdContinuation, IdContinuation) + })) + defer ts.Close() + + req := httptest.NewRequest("POST", ts.URL, nil) + req.Form = map[string][]string{ + "name": {"test"}, + "message": {"test"}, + } + + w := httptest.NewRecorder() + ts.Config.Handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status code 200, got %d", w.Code) + } + + entries, err = database.GetGuestbookEntries(db) + if err != nil { + t.Fatal(err) + } + + if len(entries) != 1 { + t.Errorf("expected 1 entry, got %d", len(entries)) + } + + if entries[0].Name != req.FormValue("name") { + t.Errorf("expected name %s, got %s", req.FormValue("name"), entries[0].Name) + } +} + +func TestInvalidGuestbookNotFoundInDatabase(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + entries, err := database.GetGuestbookEntries(db) + if err != nil { + t.Fatal(err) + } + if len(entries) > 0 { + t.Errorf("expected no entries, got entries") + } + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + guestbook.SignGuestbookContinuation(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + reallyLongStringThatWouldTakeTooMuchSpace := "a\na\na\na\na\na\na\na\na\na\na\n" + invalidRequests := []struct { + name string + message string + }{ + {"", "test"}, + {"test", ""}, + {"", ""}, + {"test", reallyLongStringThatWouldTakeTooMuchSpace}, + } + + for _, form := range invalidRequests { + req := httptest.NewRequest("POST", testServer.URL, nil) + req.Form = map[string][]string{ + "name": {form.name}, + "message": {form.message}, + } + + responseRecorder := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(responseRecorder, req) + + if responseRecorder.Code != http.StatusBadRequest { + t.Errorf("expected status code 400, got %d", responseRecorder.Code) + } + } + + entries, err = database.GetGuestbookEntries(db) + if err != nil { + t.Fatal(err) + } + + if len(entries) != 0 { + t.Errorf("expected 0 entries, got %d", len(entries)) + } +} diff --git a/api/guestbook_test.go b/api/guestbook_test.go deleted file mode 100644 index 5c1831f..0000000 --- a/api/guestbook_test.go +++ /dev/null @@ -1,129 +0,0 @@ -package api_test - -import ( - "database/sql" - "net/http" - "net/http/httptest" - "os" - "testing" - - "git.hatecomputers.club/hatecomputers/hatecomputers.club/api" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" -) - -func setup() (*sql.DB, *api.RequestContext, func()) { - randomDb := utils.RandomId() - - testDb := database.MakeConn(&randomDb) - database.Migrate(testDb) - - context := &api.RequestContext{ - DBConn: testDb, - Args: &args.Arguments{}, - TemplateData: &(map[string]interface{}{}), - } - - return testDb, context, func() { - testDb.Close() - os.Remove(randomDb) - } -} - -func TestValidGuestbookPutsInDatabase(t *testing.T) { - db, context, cleanup := setup() - defer cleanup() - - entries, err := database.GetGuestbookEntries(db) - if err != nil { - t.Fatal(err) - } - if len(entries) > 0 { - t.Errorf("expected no entries, got entries") - } - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - api.SignGuestbookContinuation(context, r, w)(api.IdContinuation, api.IdContinuation) - })) - defer ts.Close() - - req := httptest.NewRequest("POST", ts.URL, nil) - req.Form = map[string][]string{ - "name": {"test"}, - "message": {"test"}, - } - - w := httptest.NewRecorder() - ts.Config.Handler.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Errorf("expected status code 200, got %d", w.Code) - } - - entries, err = database.GetGuestbookEntries(db) - if err != nil { - t.Fatal(err) - } - - if len(entries) != 1 { - t.Errorf("expected 1 entry, got %d", len(entries)) - } - - if entries[0].Name != req.FormValue("name") { - t.Errorf("expected name %s, got %s", req.FormValue("name"), entries[0].Name) - } -} - -func TestInvalidGuestbookNotFoundInDatabase(t *testing.T) { - db, context, cleanup := setup() - defer cleanup() - - entries, err := database.GetGuestbookEntries(db) - if err != nil { - t.Fatal(err) - } - if len(entries) > 0 { - t.Errorf("expected no entries, got entries") - } - - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - api.SignGuestbookContinuation(context, r, w)(api.IdContinuation, api.IdContinuation) - })) - defer testServer.Close() - - reallyLongStringThatWouldTakeTooMuchSpace := "a\na\na\na\na\na\na\na\na\na\na\n" - invalidRequests := []struct { - name string - message string - }{ - {"", "test"}, - {"test", ""}, - {"", ""}, - {"test", reallyLongStringThatWouldTakeTooMuchSpace}, - } - - for _, form := range invalidRequests { - req := httptest.NewRequest("POST", testServer.URL, nil) - req.Form = map[string][]string{ - "name": {form.name}, - "message": {form.message}, - } - - responseRecorder := httptest.NewRecorder() - testServer.Config.Handler.ServeHTTP(responseRecorder, req) - - if responseRecorder.Code != http.StatusBadRequest { - t.Errorf("expected status code 400, got %d", responseRecorder.Code) - } - } - - entries, err = database.GetGuestbookEntries(db) - if err != nil { - t.Fatal(err) - } - - if len(entries) != 0 { - t.Errorf("expected 0 entries, got %d", len(entries)) - } -} diff --git a/api/hcaptcha.go b/api/hcaptcha.go deleted file mode 100644 index a310c01..0000000 --- a/api/hcaptcha.go +++ /dev/null @@ -1,69 +0,0 @@ -package api - -import ( - "encoding/json" - "fmt" - "net/http" - "strings" -) - -func verifyCaptcha(secret, response string) error { - verifyURL := "https://hcaptcha.com/siteverify" - body := strings.NewReader("secret=" + secret + "&response=" + response) - - req, err := http.NewRequest("POST", verifyURL, body) - if err != nil { - return err - } - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return err - } - - jsonResponse := struct { - Success bool `json:"success"` - }{} - err = json.NewDecoder(resp.Body).Decode(&jsonResponse) - if err != nil { - return err - } - - if !jsonResponse.Success { - return fmt.Errorf("hcaptcha verification failed") - } - - defer resp.Body.Close() - return nil -} - -func CaptchaArgsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - (*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{ - SiteKey: context.Args.HcaptchaSiteKey, - } - return success(context, req, resp) - } -} - -func CaptchaVerificationContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - hCaptchaResponse := req.FormValue("h-captcha-response") - secretKey := context.Args.HcaptchaSecret - - err := verifyCaptcha(secretKey, hCaptchaResponse) - if err != nil { - (*context.TemplateData)["FormError"] = FormError{ - Errors: []string{"hCaptcha verification failed"}, - } - resp.WriteHeader(http.StatusBadRequest) - - return failure(context, req, resp) - } - - return success(context, req, resp) - } -} diff --git a/api/hcaptcha/hcaptcha.go b/api/hcaptcha/hcaptcha.go new file mode 100644 index 0000000..007190d --- /dev/null +++ b/api/hcaptcha/hcaptcha.go @@ -0,0 +1,75 @@ +package hcaptcha + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" +) + +type HcaptchaArgs struct { + SiteKey string +} + +func verifyCaptcha(secret, response string) error { + verifyURL := "https://hcaptcha.com/siteverify" + body := strings.NewReader("secret=" + secret + "&response=" + response) + + req, err := http.NewRequest("POST", verifyURL, body) + if err != nil { + return err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err + } + + jsonResponse := struct { + Success bool `json:"success"` + }{} + err = json.NewDecoder(resp.Body).Decode(&jsonResponse) + if err != nil { + return err + } + + if !jsonResponse.Success { + return fmt.Errorf("hcaptcha verification failed") + } + + defer resp.Body.Close() + return nil +} + +func CaptchaArgsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + (*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{ + SiteKey: context.Args.HcaptchaSiteKey, + } + return success(context, req, resp) + } +} + +func CaptchaVerificationContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + hCaptchaResponse := req.FormValue("h-captcha-response") + secretKey := context.Args.HcaptchaSecret + + err := verifyCaptcha(secretKey, hCaptchaResponse) + if err != nil { + (*context.TemplateData)["FormError"] = types.FormError{ + Errors: []string{"hCaptcha verification failed"}, + } + resp.WriteHeader(http.StatusBadRequest) + + return failure(context, req, resp) + } + + return success(context, req, resp) + } +} diff --git a/api/keys/keys.go b/api/keys/keys.go new file mode 100644 index 0000000..ad380fc --- /dev/null +++ b/api/keys/keys.go @@ -0,0 +1,88 @@ +package keys + +import ( + "log" + "net/http" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +const MAX_USER_API_KEYS = 5 + +func ListAPIKeysContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + typesKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + (*context.TemplateData)["APIKeys"] = typesKeys + return success(context, req, resp) + } +} + +func CreateAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + formErrors := types.FormError{ + Errors: []string{}, + } + + numKeys, err := database.CountUserAPIKeys(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + if numKeys >= MAX_USER_API_KEYS { + formErrors.Errors = append(formErrors.Errors, "max types keys reached") + } + + if len(formErrors.Errors) > 0 { + (*context.TemplateData)["FormError"] = formErrors + return failure(context, req, resp) + } + + _, err = database.SaveAPIKey(context.DBConn, &database.UserApiKey{ + UserID: context.User.ID, + Key: utils.RandomId(), + }) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + return success(context, req, resp) + } +} + +func DeleteAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + key := req.FormValue("key") + + typesKey, err := database.GetAPIKey(context.DBConn, key) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + if (typesKey == nil) || (typesKey.UserID != context.User.ID) { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + err = database.DeleteAPIKey(context.DBConn, key) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + http.Redirect(resp, req, "/keys", http.StatusFound) + return success(context, req, resp) + } +} diff --git a/api/serve.go b/api/serve.go index 1536f65..6d8c59c 100644 --- a/api/serve.go +++ b/api/serve.go @@ -8,31 +8,19 @@ import ( "time" "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/guestbook" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/hcaptcha" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/keys" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/template" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) -type RequestContext struct { - DBConn *sql.DB - Args *args.Arguments - - Id string - Start time.Time - - TemplateData *map[string]interface{} - User *database.User -} - -type FormError struct { - Errors []string -} - -type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain -type ContinuationChain func(Continuation, Continuation) ContinuationChain - -func LogRequestContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, _failure Continuation) ContinuationChain { +func LogRequestContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { context.Start = time.Now() context.Id = utils.RandomId() @@ -41,8 +29,8 @@ func LogRequestContinuation(context *RequestContext, req *http.Request, resp htt } } -func LogExecutionTimeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, _failure Continuation) ContinuationChain { +func LogExecutionTimeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { end := time.Now() log.Println(context.Id, "took", end.Sub(context.Start)) @@ -51,22 +39,22 @@ func LogExecutionTimeContinuation(context *RequestContext, req *http.Request, re } } -func HealthCheckContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, _failure Continuation) ContinuationChain { +func HealthCheckContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { resp.WriteHeader(200) resp.Write([]byte("healthy")) return success(context, req, resp) } } -func FailurePassingContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(_success Continuation, failure Continuation) ContinuationChain { +func FailurePassingContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(_success types.Continuation, failure types.Continuation) types.ContinuationChain { return failure(context, req, resp) } } -func IdContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, _failure Continuation) ContinuationChain { +func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { return success(context, req, resp) } } @@ -90,8 +78,8 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { ZoneId: argv.CloudflareZone, } - makeRequestContext := func() *RequestContext { - return &RequestContext{ + makeRequestContext := func() *types.RequestContext { + return &types.RequestContext{ DBConn: dbConn, Args: argv, TemplateData: &map[string]interface{}{}, @@ -100,7 +88,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(template.TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) { @@ -110,63 +98,63 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { mux.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(InterceptOauthCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.InterceptOauthCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /dns", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation(cloudflareAdapter), FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(dns.CreateDNSRecordContinuation(cloudflareAdapter), FailurePassingContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation(cloudflareAdapter), GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.DeleteDNSRecordContinuation(cloudflareAdapter), auth.GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListAPIKeysContinuation, GoLoginContinuation)(TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.ListAPIKeysContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("POST /keys", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CreateAPIKeyContinuation, GoLoginContinuation)(ListAPIKeysContinuation, ListAPIKeysContinuation)(TemplateContinuation("api_keys.html", true), TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.CreateAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteAPIKeyContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CaptchaArgsContinuation, CaptchaArgsContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(hcaptcha.CaptchaArgsContinuation, hcaptcha.CaptchaArgsContinuation)(guestbook.ListGuestbookContinuation, guestbook.ListGuestbookContinuation)(template.TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("POST /guestbook", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CaptchaVerificationContinuation, CaptchaVerificationContinuation)(SignGuestbookContinuation, FailurePassingContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(CaptchaArgsContinuation, CaptchaArgsContinuation)(TemplateContinuation("guestbook.html", true), TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(hcaptcha.CaptchaVerificationContinuation, hcaptcha.CaptchaVerificationContinuation)(guestbook.SignGuestbookContinuation, FailurePassingContinuation)(guestbook.ListGuestbookContinuation, guestbook.ListGuestbookContinuation)(hcaptcha.CaptchaArgsContinuation, hcaptcha.CaptchaArgsContinuation)(template.TemplateContinuation("guestbook.html", true), template.TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() name := r.PathValue("name") - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation(name+".html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(template.TemplateContinuation(name+".html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) return &http.Server{ diff --git a/api/template.go b/api/template.go deleted file mode 100644 index d637c64..0000000 --- a/api/template.go +++ /dev/null @@ -1,74 +0,0 @@ -package api - -import ( - "bytes" - "errors" - "html/template" - "log" - "net/http" - "os" -) - -func renderTemplate(context *RequestContext, templateName string, showBaseHtml bool) (bytes.Buffer, error) { - templatePath := context.Args.TemplatePath - basePath := templatePath + "/base_empty.html" - if showBaseHtml { - basePath = templatePath + "/base.html" - } - - templateLocation := templatePath + "/" + templateName - tmpl, err := template.New("").ParseFiles(templateLocation, basePath) - if err != nil { - return bytes.Buffer{}, err - } - - dataPtr := context.TemplateData - if dataPtr == nil { - dataPtr = &map[string]interface{}{} - } - - data := *dataPtr - if data["User"] == nil { - data["User"] = context.User - } - - var buffer bytes.Buffer - err = tmpl.ExecuteTemplate(&buffer, "base", data) - - if err != nil { - return bytes.Buffer{}, err - } - return buffer, nil -} - -func TemplateContinuation(path string, showBase bool) Continuation { - return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - html, err := renderTemplate(context, path, true) - if errors.Is(err, os.ErrNotExist) { - resp.WriteHeader(404) - html, err = renderTemplate(context, "404.html", true) - if err != nil { - log.Println("error rendering 404 template", err) - resp.WriteHeader(500) - return failure(context, req, resp) - } - - resp.Header().Set("Content-Type", "text/html") - resp.Write(html.Bytes()) - return failure(context, req, resp) - } - - if err != nil { - log.Println("error rendering template", err) - resp.WriteHeader(500) - resp.Write([]byte("error rendering template")) - return failure(context, req, resp) - } - - resp.Header().Set("Content-Type", "text/html") - resp.Write(html.Bytes()) - return success(context, req, resp) - } - } -} diff --git a/api/template/template.go b/api/template/template.go new file mode 100644 index 0000000..2875649 --- /dev/null +++ b/api/template/template.go @@ -0,0 +1,76 @@ +package template + +import ( + "bytes" + "errors" + "html/template" + "log" + "net/http" + "os" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" +) + +func renderTemplate(context *types.RequestContext, templateName string, showBaseHtml bool) (bytes.Buffer, error) { + templatePath := context.Args.TemplatePath + basePath := templatePath + "/base_empty.html" + if showBaseHtml { + basePath = templatePath + "/base.html" + } + + templateLocation := templatePath + "/" + templateName + tmpl, err := template.New("").ParseFiles(templateLocation, basePath) + if err != nil { + return bytes.Buffer{}, err + } + + dataPtr := context.TemplateData + if dataPtr == nil { + dataPtr = &map[string]interface{}{} + } + + data := *dataPtr + if data["User"] == nil { + data["User"] = context.User + } + + var buffer bytes.Buffer + err = tmpl.ExecuteTemplate(&buffer, "base", data) + + if err != nil { + return bytes.Buffer{}, err + } + return buffer, nil +} + +func TemplateContinuation(path string, showBase bool) types.Continuation { + return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + html, err := renderTemplate(context, path, true) + if errors.Is(err, os.ErrNotExist) { + resp.WriteHeader(404) + html, err = renderTemplate(context, "404.html", true) + if err != nil { + log.Println("error rendering 404 template", err) + resp.WriteHeader(500) + return failure(context, req, resp) + } + + resp.Header().Set("Content-Type", "text/html") + resp.Write(html.Bytes()) + return failure(context, req, resp) + } + + if err != nil { + log.Println("error rendering template", err) + resp.WriteHeader(500) + resp.Write([]byte("error rendering template")) + return failure(context, req, resp) + } + + resp.Header().Set("Content-Type", "text/html") + resp.Write(html.Bytes()) + return success(context, req, resp) + } + } +} diff --git a/api/types/types.go b/api/types/types.go new file mode 100644 index 0000000..bbc25ea --- /dev/null +++ b/api/types/types.go @@ -0,0 +1,28 @@ +package types + +import ( + "database/sql" + "net/http" + "time" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" +) + +type RequestContext struct { + DBConn *sql.DB + Args *args.Arguments + + Id string + Start time.Time + + TemplateData *map[string]interface{} + User *database.User +} + +type FormError struct { + Errors []string +} + +type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain +type ContinuationChain func(Continuation, Continuation) ContinuationChain -- cgit v1.2.3-70-g09d2