diff options
| author | simponic <simponic@hatecomputers.club> | 2024-04-06 15:43:18 -0400 |
|---|---|---|
| committer | simponic <simponic@hatecomputers.club> | 2024-04-06 15:43:18 -0400 |
| commit | 83cc6267fd5ce2f61200314424c5f400f65ff2ba (patch) | |
| tree | eafb35310236a15572cbb6e16ff8d6f181bfe240 /api/auth/auth_test.go | |
| parent | 569d2788ebfb90774faf361f62bfe7968e091465 (diff) | |
| parent | cad8e2c4ed5e3bab61ff243f8677f8a46eaeafb0 (diff) | |
| download | hatecomputers.club-83cc6267fd5ce2f61200314424c5f400f65ff2ba.tar.gz hatecomputers.club-83cc6267fd5ce2f61200314424c5f400f65ff2ba.zip | |
Merge pull request 'testing | dont be recursive for external domains | finalize oauth' (#5) from dont-be-authoritative into main
Reviewed-on: https://git.hatecomputers.club/hatecomputers/hatecomputers.club/pulls/5
Diffstat (limited to 'api/auth/auth_test.go')
| -rw-r--r-- | api/auth/auth_test.go | 307 |
1 files changed, 307 insertions, 0 deletions
diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go new file mode 100644 index 0000000..5e67c6d --- /dev/null +++ b/api/auth/auth_test.go @@ -0,0 +1,307 @@ +package auth_test + +import ( + "database/sql" + "log" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" + "golang.org/x/oauth2" +) + +func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { + return success(context, req, resp) + } +} + +func FakedOauthServer() *httptest.Server { + oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/auth" { + code := utils.RandomId() + + state := r.URL.Query().Get("state") + redirectPath := r.URL.Query().Get("redirect_uri") + redirectPath += "?code=" + code + "&state=" + state + + http.Redirect(w, r, redirectPath, http.StatusFound) + } + if r.URL.Path == "/token" { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`)) + } + if r.URL.Path == "/user" { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`)) + } + })) + + return oauthServer +} + +func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + resp.Write([]byte(context.User.Username)) + return success(context, req, resp) + } +} + +func MockUserEndpointServer(context *types.RequestContext) *httptest.Server { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/protected-path" { + auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/login" { + log.Println("login") + auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/callback" { + log.Println("callback") + auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/me" { + auth.VerifySessionContinuation(context, r, w)(auth.RefreshSessionContinuation, auth.GoLoginContinuation)(EchoUsernameContinuation, IdContinuation)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/logout" { + auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation) + } + })) + return testServer +} + +func setup() (*sql.DB, *types.RequestContext, *httptest.Server, *httptest.Server, func()) { + randomDb := utils.RandomId() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + + context := &types.RequestContext{ + DBConn: testDb, + Args: &args.Arguments{}, + TemplateData: &(map[string]interface{}{}), + } + + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + + return testDb, context, oauthServer, testServer, func() { + oauthServer.Close() + testServer.Close() + + testDb.Close() + os.Remove(randomDb) + } +} + +func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) { + return &oauth2.Config{ + ClientID: "test", + ClientSecret: "test", + Scopes: []string{"test"}, + Endpoint: oauth2.Endpoint{ + AuthURL: oauthServerURL + "/auth", + TokenURL: oauthServerURL + "/token", + }, + RedirectURL: testServerURL + "/callback", + }, oauthServerURL + "/user" +} + +func FollowAuthentication( + oauthServer *httptest.Server, + testServer *httptest.Server, + cookies map[string]*http.Cookie, + location string, +) (map[string]*http.Cookie, string) { + resp := httptest.NewRecorder() + resp.Code = 0 + + for resp.Code == 0 || resp.Code == http.StatusFound { + req := httptest.NewRequest("GET", location, nil) + resp = httptest.NewRecorder() + + for _, cookie := range cookies { + req.AddCookie(cookie) + } + if strings.HasPrefix(location, oauthServer.URL) { + oauthServer.Config.Handler.ServeHTTP(resp, req) + } else { + testServer.Config.Handler.ServeHTTP(resp, req) + } + for _, cookie := range resp.Result().Cookies() { + cookies[cookie.Name] = cookie + } + + if resp.Code == http.StatusFound { + location = resp.Header().Get("Location") + } + } + + return cookies, location +} + +func TestOauthCreatesUserWithCorrectUsername(t *testing.T) { + db, context, oauthServer, testServer, cleanup := setup() + defer cleanup() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + user, _ := database.GetUser(db, "test") + if user != nil { + t.Errorf("expected no user, got user") + } + + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") + + user, _ = database.GetUser(db, "test") + if user == nil { + t.Errorf("expected a user to be created, could not find user") + } + if user.Username != "test" { + t.Errorf("expected username to be test, got %s", user.Username) + } +} + +func TestOauthRedirectsToPreviousLockedPage(t *testing.T) { + _, context, oauthServer, testServer, cleanup := setup() + defer cleanup() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + req := httptest.NewRequest("GET", "/protected-path", nil) + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + location := resp.Header().Get("Location") + if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") { + t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location")) + } + + cookies := make(map[string]*http.Cookie) + cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page") + + if !(strings.HasSuffix(location, "/protected-page")) { + t.Errorf("expected to redirect back to /protected-page after login, got %s", location) + } +} + +func TestOauthSetsUniqueSession(t *testing.T) { + db, context, oauthServer, testServer, cleanup := setup() + defer cleanup() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") + + cookiesAgain := make(map[string]*http.Cookie) + cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me") + + sessionOne := cookies["session"].Value + sessionTwo := cookiesAgain["session"].Value + if sessionOne == sessionTwo { + t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo) + } + + session, _ := database.GetSession(db, sessionOne) + if session.UserID != "test" { + t.Errorf("expected session to be associated with user test, got %s", session.UserID) + } +} + +func TestLogoutClearsSession(t *testing.T) { + db, context, oauthServer, testServer, cleanup := setup() + defer cleanup() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") + + req := httptest.NewRequest("GET", "/logout", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + for _, cookie := range resp.Result().Cookies() { + cookies[cookie.Name] = cookie + } + + req = httptest.NewRequest("GET", "/me", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") { + t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location")) + } + + session, _ := database.GetSession(db, cookies["session"].Value) + if session != nil { + t.Errorf("expected session to be deleted, got session") + } +} + +func TestRefreshUpdatesExpiration(t *testing.T) { + db, context, oauthServer, testServer, cleanup := setup() + defer cleanup() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path") + + session, _ := database.GetSession(db, cookies["session"].Value) + + req := httptest.NewRequest("GET", "/me", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + + updatedSession, _ := database.GetSession(db, cookies["session"].Value) + + if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) { + t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt) + } +} + +func TestVerifySessionEnsuresNonExpired(t *testing.T) { + db, context, oauthServer, testServer, cleanup := setup() + defer cleanup() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path") + + session, _ := database.GetSession(db, cookies["session"].Value) + session.ExpireAt = time.Now().Add(-time.Hour) + database.SaveSession(db, session) + + req := httptest.NewRequest("GET", "/me", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + + if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") { + t.Errorf("expected redirect to /login after session expiration, got %d and %s", resp.Code, resp.Header().Get("Location")) + } +} |
