summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorElizabeth Hunt <elizabeth@simponic.xyz>2024-04-06 13:36:13 -0600
committerElizabeth Hunt <elizabeth@simponic.xyz>2024-04-06 13:36:13 -0600
commit5177735b835289c8437799536d3654e5ab142fa3 (patch)
treeaa5a0c9588ca7b1b97ccd6c4f0cf245d2c3ac162
parentae640a253edb5935380975fb07430e910a83b340 (diff)
downloadhatecomputers.club-5177735b835289c8437799536d3654e5ab142fa3.tar.gz
hatecomputers.club-5177735b835289c8437799536d3654e5ab142fa3.zip
finish auth tests
-rw-r--r--api/auth/auth_test.go116
-rw-r--r--api/keys/keys.go9
-rw-r--r--api/serve.go2
-rw-r--r--database/users.go12
4 files changed, 88 insertions, 51 deletions
diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go
index 1e54099..a3d5b16 100644
--- a/api/auth/auth_test.go
+++ b/api/auth/auth_test.go
@@ -8,6 +8,7 @@ import (
"os"
"strings"
"testing"
+ "time"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
@@ -23,24 +24,6 @@ func IdContinuation(context *types.RequestContext, req *http.Request, resp http.
}
}
-func setup() (*sql.DB, *types.RequestContext, func()) {
- randomDb := utils.RandomId()
-
- testDb := database.MakeConn(&randomDb)
- database.Migrate(testDb)
-
- context := &types.RequestContext{
- DBConn: testDb,
- Args: &args.Arguments{},
- TemplateData: &(map[string]interface{}{}),
- }
-
- return testDb, context, func() {
- testDb.Close()
- os.Remove(randomDb)
- }
-}
-
func FakedOauthServer() *httptest.Server {
oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/auth" {
@@ -89,7 +72,7 @@ func MockUserEndpointServer(context *types.RequestContext) *httptest.Server {
}
if r.URL.Path == "/me" {
- auth.VerifySessionContinuation(context, r, w)(EchoUsernameContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation)
+ auth.VerifySessionContinuation(context, r, w)(auth.RefreshSessionContinuation, auth.GoLoginContinuation)(EchoUsernameContinuation, IdContinuation)(IdContinuation, IdContinuation)
}
if r.URL.Path == "/logout" {
@@ -99,6 +82,30 @@ func MockUserEndpointServer(context *types.RequestContext) *httptest.Server {
return testServer
}
+func setup() (*sql.DB, *types.RequestContext, *httptest.Server, *httptest.Server, func()) {
+ randomDb := utils.RandomId()
+
+ testDb := database.MakeConn(&randomDb)
+ database.Migrate(testDb)
+
+ context := &types.RequestContext{
+ DBConn: testDb,
+ Args: &args.Arguments{},
+ TemplateData: &(map[string]interface{}{}),
+ }
+
+ oauthServer := FakedOauthServer()
+ testServer := MockUserEndpointServer(context)
+
+ return testDb, context, oauthServer, testServer, func() {
+ oauthServer.Close()
+ testServer.Close()
+
+ testDb.Close()
+ os.Remove(randomDb)
+ }
+}
+
func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) {
return &oauth2.Config{
ClientID: "test",
@@ -146,14 +153,9 @@ func FollowAuthentication(
}
func TestOauthCreatesUserWithCorrectUsername(t *testing.T) {
- db, context, cleanup := setup()
+ db, context, oauthServer, testServer, cleanup := setup()
defer cleanup()
- oauthServer := FakedOauthServer()
- testServer := MockUserEndpointServer(context)
- defer oauthServer.Close()
- defer testServer.Close()
-
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
user, _ := database.GetUser(db, "test")
@@ -174,14 +176,9 @@ func TestOauthCreatesUserWithCorrectUsername(t *testing.T) {
}
func TestOauthRedirectsToPreviousLockedPage(t *testing.T) {
- _, context, cleanup := setup()
+ _, context, oauthServer, testServer, cleanup := setup()
defer cleanup()
- oauthServer := FakedOauthServer()
- testServer := MockUserEndpointServer(context)
- defer oauthServer.Close()
- defer testServer.Close()
-
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
req := httptest.NewRequest("GET", "/protected-path", nil)
@@ -201,14 +198,9 @@ func TestOauthRedirectsToPreviousLockedPage(t *testing.T) {
}
func TestOauthSetsUniqueSession(t *testing.T) {
- db, context, cleanup := setup()
+ db, context, oauthServer, testServer, cleanup := setup()
defer cleanup()
- oauthServer := FakedOauthServer()
- testServer := MockUserEndpointServer(context)
- defer oauthServer.Close()
- defer testServer.Close()
-
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
cookies := make(map[string]*http.Cookie)
@@ -230,14 +222,9 @@ func TestOauthSetsUniqueSession(t *testing.T) {
}
func TestLogoutClearsSession(t *testing.T) {
- db, context, cleanup := setup()
+ db, context, oauthServer, testServer, cleanup := setup()
defer cleanup()
- oauthServer := FakedOauthServer()
- testServer := MockUserEndpointServer(context)
- defer oauthServer.Close()
- defer testServer.Close()
-
context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
cookies := make(map[string]*http.Cookie)
@@ -270,13 +257,52 @@ func TestLogoutClearsSession(t *testing.T) {
}
func TestRefreshUpdatesExpiration(t *testing.T) {
+ db, context, oauthServer, testServer, cleanup := setup()
+ defer cleanup()
+
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+
+ cookies := make(map[string]*http.Cookie)
+ cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
+
+ session, _ := database.GetSession(db, cookies["session"].Value)
+
+ req := httptest.NewRequest("GET", "/me", nil)
+ for _, cookie := range cookies {
+ req.AddCookie(cookie)
+ }
+ resp := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(resp, req)
+ updatedSession, _ := database.GetSession(db, cookies["session"].Value)
+
+ // if session expiration is greater than or equal to updated session expiration
+ if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) {
+ t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt)
+ }
}
func TestVerifySessionEnsuresNonExpired(t *testing.T) {
+ db, context, oauthServer, testServer, cleanup := setup()
+ defer cleanup()
-}
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
-func TestAPITokensAreEquivalentToSessions(t *testing.T) {
+ cookies := make(map[string]*http.Cookie)
+ cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
+ session, _ := database.GetSession(db, cookies["session"].Value)
+ session.ExpireAt = time.Now().Add(-time.Hour)
+ database.SaveSession(db, session)
+
+ req := httptest.NewRequest("GET", "/me", nil)
+ for _, cookie := range cookies {
+ req.AddCookie(cookie)
+ }
+ resp := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
+ t.Errorf("expected redirect to /login after session expiration, got %d and %s", resp.Code, resp.Header().Get("Location"))
+ }
}
diff --git a/api/keys/keys.go b/api/keys/keys.go
index ad380fc..cef3f3c 100644
--- a/api/keys/keys.go
+++ b/api/keys/keys.go
@@ -62,27 +62,26 @@ func CreateAPIKeyContinuation(context *types.RequestContext, req *http.Request,
func DeleteAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
- key := req.FormValue("key")
+ apiKey := req.FormValue("key")
- typesKey, err := database.GetAPIKey(context.DBConn, key)
+ key, err := database.GetAPIKey(context.DBConn, apiKey)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
- if (typesKey == nil) || (typesKey.UserID != context.User.ID) {
+ if (key == nil) || (key.UserID != context.User.ID) {
resp.WriteHeader(http.StatusUnauthorized)
return failure(context, req, resp)
}
- err = database.DeleteAPIKey(context.DBConn, key)
+ err = database.DeleteAPIKey(context.DBConn, apiKey)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
- http.Redirect(resp, req, "/keys", http.StatusFound)
return success(context, req, resp)
}
}
diff --git a/api/serve.go b/api/serve.go
index 2b0eba4..c8775d8 100644
--- a/api/serve.go
+++ b/api/serve.go
@@ -140,7 +140,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
- LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) {
diff --git a/database/users.go b/database/users.go
index 5cebb8f..6f9456e 100644
--- a/database/users.go
+++ b/database/users.go
@@ -111,6 +111,18 @@ func DeleteSession(dbConn *sql.DB, sessionId string) error {
return nil
}
+func SaveSession(dbConn *sql.DB, session *UserSession) (*UserSession, error) {
+ log.Println("saving session", session.ID)
+
+ _, err := dbConn.Exec(`INSERT OR REPLACE INTO user_sessions (id, user_id, expire_at) VALUES (?, ?, ?);`, session.ID, session.UserID, session.ExpireAt)
+ if err != nil {
+ log.Println(err)
+ return nil, err
+ }
+
+ return session, nil
+}
+
func RefreshSession(dbConn *sql.DB, sessionId string) (*UserSession, error) {
newExpireAt := time.Now().Add(ExpiryDuration)