diff options
| author | Elizabeth Hunt <elizabeth@simponic.xyz> | 2024-04-06 13:36:13 -0600 |
|---|---|---|
| committer | Elizabeth Hunt <elizabeth@simponic.xyz> | 2024-04-06 13:36:13 -0600 |
| commit | 5177735b835289c8437799536d3654e5ab142fa3 (patch) | |
| tree | aa5a0c9588ca7b1b97ccd6c4f0cf245d2c3ac162 /api | |
| parent | ae640a253edb5935380975fb07430e910a83b340 (diff) | |
| download | hatecomputers.club-5177735b835289c8437799536d3654e5ab142fa3.tar.gz hatecomputers.club-5177735b835289c8437799536d3654e5ab142fa3.zip | |
finish auth tests
Diffstat (limited to 'api')
| -rw-r--r-- | api/auth/auth_test.go | 116 | ||||
| -rw-r--r-- | api/keys/keys.go | 9 | ||||
| -rw-r--r-- | api/serve.go | 2 |
3 files changed, 76 insertions, 51 deletions
diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index 1e54099..a3d5b16 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -8,6 +8,7 @@ import ( "os" "strings" "testing" + "time" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" @@ -23,24 +24,6 @@ func IdContinuation(context *types.RequestContext, req *http.Request, resp http. } } -func setup() (*sql.DB, *types.RequestContext, func()) { - randomDb := utils.RandomId() - - testDb := database.MakeConn(&randomDb) - database.Migrate(testDb) - - context := &types.RequestContext{ - DBConn: testDb, - Args: &args.Arguments{}, - TemplateData: &(map[string]interface{}{}), - } - - return testDb, context, func() { - testDb.Close() - os.Remove(randomDb) - } -} - func FakedOauthServer() *httptest.Server { oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/auth" { @@ -89,7 +72,7 @@ func MockUserEndpointServer(context *types.RequestContext) *httptest.Server { } if r.URL.Path == "/me" { - auth.VerifySessionContinuation(context, r, w)(EchoUsernameContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) + auth.VerifySessionContinuation(context, r, w)(auth.RefreshSessionContinuation, auth.GoLoginContinuation)(EchoUsernameContinuation, IdContinuation)(IdContinuation, IdContinuation) } if r.URL.Path == "/logout" { @@ -99,6 +82,30 @@ func MockUserEndpointServer(context *types.RequestContext) *httptest.Server { return testServer } +func setup() (*sql.DB, *types.RequestContext, *httptest.Server, *httptest.Server, func()) { + randomDb := utils.RandomId() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + + context := &types.RequestContext{ + DBConn: testDb, + Args: &args.Arguments{}, + TemplateData: &(map[string]interface{}{}), + } + + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + + return testDb, context, oauthServer, testServer, func() { + oauthServer.Close() + testServer.Close() + + testDb.Close() + os.Remove(randomDb) + } +} + func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) { return &oauth2.Config{ ClientID: "test", @@ -146,14 +153,9 @@ func FollowAuthentication( } func TestOauthCreatesUserWithCorrectUsername(t *testing.T) { - db, context, cleanup := setup() + db, context, oauthServer, testServer, cleanup := setup() defer cleanup() - oauthServer := FakedOauthServer() - testServer := MockUserEndpointServer(context) - defer oauthServer.Close() - defer testServer.Close() - context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) user, _ := database.GetUser(db, "test") @@ -174,14 +176,9 @@ func TestOauthCreatesUserWithCorrectUsername(t *testing.T) { } func TestOauthRedirectsToPreviousLockedPage(t *testing.T) { - _, context, cleanup := setup() + _, context, oauthServer, testServer, cleanup := setup() defer cleanup() - oauthServer := FakedOauthServer() - testServer := MockUserEndpointServer(context) - defer oauthServer.Close() - defer testServer.Close() - context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) req := httptest.NewRequest("GET", "/protected-path", nil) @@ -201,14 +198,9 @@ func TestOauthRedirectsToPreviousLockedPage(t *testing.T) { } func TestOauthSetsUniqueSession(t *testing.T) { - db, context, cleanup := setup() + db, context, oauthServer, testServer, cleanup := setup() defer cleanup() - oauthServer := FakedOauthServer() - testServer := MockUserEndpointServer(context) - defer oauthServer.Close() - defer testServer.Close() - context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) cookies := make(map[string]*http.Cookie) @@ -230,14 +222,9 @@ func TestOauthSetsUniqueSession(t *testing.T) { } func TestLogoutClearsSession(t *testing.T) { - db, context, cleanup := setup() + db, context, oauthServer, testServer, cleanup := setup() defer cleanup() - oauthServer := FakedOauthServer() - testServer := MockUserEndpointServer(context) - defer oauthServer.Close() - defer testServer.Close() - context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) cookies := make(map[string]*http.Cookie) @@ -270,13 +257,52 @@ func TestLogoutClearsSession(t *testing.T) { } func TestRefreshUpdatesExpiration(t *testing.T) { + db, context, oauthServer, testServer, cleanup := setup() + defer cleanup() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path") + + session, _ := database.GetSession(db, cookies["session"].Value) + + req := httptest.NewRequest("GET", "/me", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + updatedSession, _ := database.GetSession(db, cookies["session"].Value) + + // if session expiration is greater than or equal to updated session expiration + if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) { + t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt) + } } func TestVerifySessionEnsuresNonExpired(t *testing.T) { + db, context, oauthServer, testServer, cleanup := setup() + defer cleanup() -} + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) -func TestAPITokensAreEquivalentToSessions(t *testing.T) { + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path") + session, _ := database.GetSession(db, cookies["session"].Value) + session.ExpireAt = time.Now().Add(-time.Hour) + database.SaveSession(db, session) + + req := httptest.NewRequest("GET", "/me", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + + if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") { + t.Errorf("expected redirect to /login after session expiration, got %d and %s", resp.Code, resp.Header().Get("Location")) + } } diff --git a/api/keys/keys.go b/api/keys/keys.go index ad380fc..cef3f3c 100644 --- a/api/keys/keys.go +++ b/api/keys/keys.go @@ -62,27 +62,26 @@ func CreateAPIKeyContinuation(context *types.RequestContext, req *http.Request, func DeleteAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { - key := req.FormValue("key") + apiKey := req.FormValue("key") - typesKey, err := database.GetAPIKey(context.DBConn, key) + key, err := database.GetAPIKey(context.DBConn, apiKey) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } - if (typesKey == nil) || (typesKey.UserID != context.User.ID) { + if (key == nil) || (key.UserID != context.User.ID) { resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } - err = database.DeleteAPIKey(context.DBConn, key) + err = database.DeleteAPIKey(context.DBConn, apiKey) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } - http.Redirect(resp, req, "/keys", http.StatusFound) return success(context, req, resp) } } diff --git a/api/serve.go b/api/serve.go index 2b0eba4..c8775d8 100644 --- a/api/serve.go +++ b/api/serve.go @@ -140,7 +140,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) { |
