From f38e8719c2a8537fe9b64ed8ceca45858a58e498 Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Wed, 3 Apr 2024 17:53:50 -0600 Subject: make it compile --- api/auth/auth.go | 288 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 288 insertions(+) create mode 100644 api/auth/auth.go (limited to 'api/auth/auth.go') diff --git a/api/auth/auth.go b/api/auth/auth.go new file mode 100644 index 0000000..dc348b2 --- /dev/null +++ b/api/auth/auth.go @@ -0,0 +1,288 @@ +package auth + +import ( + "crypto/sha256" + "database/sql" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "strings" + "time" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" + "golang.org/x/oauth2" +) + +func StartSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + verifier := utils.RandomId() + utils.RandomId() + + sha2 := sha256.New() + io.WriteString(sha2, verifier) + codeChallenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil)) + + state := utils.RandomId() + url := context.Args.OauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge", codeChallenge)) + + http.SetCookie(resp, &http.Cookie{ + Name: "verifier", + Value: verifier, + Path: "/", + Secure: true, + SameSite: http.SameSiteLaxMode, + MaxAge: 60, + }) + http.SetCookie(resp, &http.Cookie{ + Name: "state", + Value: state, + Path: "/", + Secure: true, + SameSite: http.SameSiteLaxMode, + MaxAge: 60, + }) + + http.Redirect(resp, req, url, http.StatusFound) + return success(context, req, resp) + } +} + +func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + state := req.URL.Query().Get("state") + code := req.URL.Query().Get("code") + + if code == "" || state == "" { + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + + if !verifyState(req, "state", state) { + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + verifierCookie, err := req.Cookie("verifier") + if err != nil { + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + + reqContext := req.Context() + token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value)) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + client := context.Args.OauthConfig.Client(reqContext, token) + user, err := getOauthUser(context.DBConn, client, context.Args.OauthUserInfoURI) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + + return failure(context, req, resp) + } + + session, err := database.MakeUserSessionFor(context.DBConn, user) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + http.SetCookie(resp, &http.Cookie{ + Name: "session", + Value: session.ID, + Path: "/", + SameSite: http.SameSiteLaxMode, + Secure: true, + }) + + redirect := "/" + redirectCookie, err := req.Cookie("redirect") + if err == nil && redirectCookie.Value != "" { + redirect = redirectCookie.Value + http.SetCookie(resp, &http.Cookie{ + Name: "redirect", + MaxAge: 0, + }) + } + + http.Redirect(resp, req, redirect, http.StatusFound) + return success(context, req, resp) + } +} + +func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { + if bearerToken == "" { + return nil, nil + } + + parts := strings.Split(bearerToken, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + return nil, nil + } + + typesKey, err := database.GetAPIKey(dbConn, parts[1]) + if err != nil { + return nil, err + } + if typesKey == nil { + return nil, nil + } + + user, err := database.GetUser(dbConn, typesKey.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + +func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { + session, err := database.GetSession(dbConn, sessionId) + if err != nil { + return nil, err + } + + if session.ExpireAt.Before(time.Now()) { + session = nil + database.DeleteSession(dbConn, sessionId) + return nil, fmt.Errorf("session expired") + } + + user, err := database.GetUser(dbConn, session.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + +func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + authHeader := req.Header.Get("Authorization") + user, userErr := getUserFromAuthHeader(context.DBConn, authHeader) + + sessionCookie, err := req.Cookie("session") + if err == nil && sessionCookie.Value != "" { + user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value) + } + + if userErr != nil || user == nil { + log.Println(userErr, user) + + http.SetCookie(resp, &http.Cookie{ + Name: "session", + MaxAge: 0, // reset session cookie in case + }) + + context.User = nil + return failure(context, req, resp) + } + + context.User = user + return success(context, req, resp) + } +} + +func GoLoginContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + http.SetCookie(resp, &http.Cookie{ + Name: "redirect", + Value: req.URL.Path, + Path: "/", + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + + http.Redirect(resp, req, "/login", http.StatusFound) + return failure(context, req, resp) + } +} + +func RefreshSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + sessionCookie, err := req.Cookie("session") + if err != nil { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + _, err = database.RefreshSession(context.DBConn, sessionCookie.Value) + if err != nil { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + return success(context, req, resp) + } +} + +func LogoutContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + sessionCookie, err := req.Cookie("session") + if err == nil && sessionCookie.Value != "" { + _ = database.DeleteSession(context.DBConn, sessionCookie.Value) + } + + http.Redirect(resp, req, "/", http.StatusFound) + http.SetCookie(resp, &http.Cookie{ + Name: "session", + MaxAge: 0, + }) + return success(context, req, resp) + } +} + +func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) { + userResponse, err := client.Get(uri) + if err != nil { + return nil, err + } + + userStruct, err := createUserFromResponse(userResponse) + if err != nil { + return nil, err + } + + user, err := database.FindOrSaveUser(dbConn, userStruct) + if err != nil { + return nil, err + } + + return user, nil +} + +func createUserFromResponse(response *http.Response) (*database.User, error) { + user := &database.User{ + CreatedAt: time.Now(), + } + + err := json.NewDecoder(response.Body).Decode(user) + defer response.Body.Close() + + if err != nil { + log.Println(err) + return nil, err + } + + user.Username = strings.ToLower(user.Username) + user.Username = strings.Split(user.Username, "@")[0] + + return user, nil +} + +func verifyState(req *http.Request, stateCookieName string, expectedState string) bool { + cookie, err := req.Cookie(stateCookieName) + if err != nil || cookie.Value != expectedState { + return false + } + + return true +} -- cgit v1.2.3-70-g09d2 From 94984aa4b01e96773b71325b5b27e6f64d9bd102 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Thu, 4 Apr 2024 16:03:34 -0600 Subject: auth test scaffolding --- api/auth/auth.go | 115 +++++++++++++++++++++++++++----------------------- api/auth/auth_test.go | 74 +++++++++++++++++++++++++++++--- 2 files changed, 131 insertions(+), 58 deletions(-) (limited to 'api/auth/auth.go') diff --git a/api/auth/auth.go b/api/auth/auth.go index dc348b2..3c633cd 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -35,7 +35,7 @@ func StartSessionContinuation(context *types.RequestContext, req *http.Request, Path: "/", Secure: true, SameSite: http.SameSiteLaxMode, - MaxAge: 60, + MaxAge: 200, }) http.SetCookie(resp, &http.Cookie{ Name: "state", @@ -43,7 +43,7 @@ func StartSessionContinuation(context *types.RequestContext, req *http.Request, Path: "/", Secure: true, SameSite: http.SameSiteLaxMode, - MaxAge: 60, + MaxAge: 200, }) http.Redirect(resp, req, url, http.StatusFound) @@ -102,6 +102,16 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req SameSite: http.SameSiteLaxMode, Secure: true, }) + http.SetCookie(resp, &http.Cookie{ + Name: "verifier", + Value: "", + MaxAge: 0, + }) + http.SetCookie(resp, &http.Cookie{ + Name: "state", + Value: "", + MaxAge: 0, + }) redirect := "/" redirectCookie, err := req.Cookie("redirect") @@ -110,6 +120,7 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req http.SetCookie(resp, &http.Cookie{ Name: "redirect", MaxAge: 0, + Value: "", }) } @@ -118,52 +129,6 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req } } -func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { - if bearerToken == "" { - return nil, nil - } - - parts := strings.Split(bearerToken, " ") - if len(parts) != 2 || parts[0] != "Bearer" { - return nil, nil - } - - typesKey, err := database.GetAPIKey(dbConn, parts[1]) - if err != nil { - return nil, err - } - if typesKey == nil { - return nil, nil - } - - user, err := database.GetUser(dbConn, typesKey.UserID) - if err != nil { - return nil, err - } - - return user, nil -} - -func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { - session, err := database.GetSession(dbConn, sessionId) - if err != nil { - return nil, err - } - - if session.ExpireAt.Before(time.Now()) { - session = nil - database.DeleteSession(dbConn, sessionId) - return nil, fmt.Errorf("session expired") - } - - user, err := database.GetUser(dbConn, session.UserID) - if err != nil { - return nil, err - } - - return user, nil -} - func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { authHeader := req.Header.Get("Authorization") @@ -179,6 +144,7 @@ func VerifySessionContinuation(context *types.RequestContext, req *http.Request, http.SetCookie(resp, &http.Cookie{ Name: "session", + Value: "", MaxAge: 0, // reset session cookie in case }) @@ -210,13 +176,11 @@ func RefreshSessionContinuation(context *types.RequestContext, req *http.Request return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { sessionCookie, err := req.Cookie("session") if err != nil { - resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } _, err = database.RefreshSession(context.DBConn, sessionCookie.Value) if err != nil { - resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } @@ -235,6 +199,7 @@ func LogoutContinuation(context *types.RequestContext, req *http.Request, resp h http.SetCookie(resp, &http.Cookie{ Name: "session", MaxAge: 0, + Value: "", }) return success(context, req, resp) } @@ -246,7 +211,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us return nil, err } - userStruct, err := createUserFromResponse(userResponse) + userStruct, err := createUserFromOauthResponse(userResponse) if err != nil { return nil, err } @@ -259,7 +224,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us return user, nil } -func createUserFromResponse(response *http.Response) (*database.User, error) { +func createUserFromOauthResponse(response *http.Response) (*database.User, error) { user := &database.User{ CreatedAt: time.Now(), } @@ -286,3 +251,49 @@ func verifyState(req *http.Request, stateCookieName string, expectedState string return true } + +func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { + if bearerToken == "" { + return nil, nil + } + + parts := strings.Split(bearerToken, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + return nil, nil + } + + key, err := database.GetAPIKey(dbConn, parts[1]) + if err != nil { + return nil, err + } + if key == nil { + return nil, nil + } + + user, err := database.GetUser(dbConn, key.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + +func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { + session, err := database.GetSession(dbConn, sessionId) + if err != nil { + return nil, err + } + + if session.ExpireAt.Before(time.Now()) { + session = nil + database.DeleteSession(dbConn, sessionId) + return nil, fmt.Errorf("session expired") + } + + user, err := database.GetUser(dbConn, session.UserID) + if err != nil { + return nil, err + } + + return user, nil +} diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index a6c2a45..caaedf1 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -2,14 +2,24 @@ package auth_test import ( "database/sql" + "net/http" + "net/http/httptest" "os" + "testing" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) +func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { + return success(context, req, resp) + } +} + func setup() (*sql.DB, *types.RequestContext, func()) { randomDb := utils.RandomId() @@ -28,9 +38,61 @@ func setup() (*sql.DB, *types.RequestContext, func()) { } } -/* -todo: test types key creation -+ api key attached to user -+ user session is unique -+ goLogin goes to page in cookie -*/ +func TestLoginSendsYouToRedirect(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + user := &database.User{ + ID: "test", + Username: "test", + } + database.FindOrSaveUser(db, user) + + session, _ := database.MakeUserSessionFor(db, user) + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + protectedPath := testServer.URL + "/protected-path" + req := httptest.NewRequest("GET", protectedPath, nil) + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + + location := resp.Header().Get("Location") + if resp.Code != http.StatusFound && location != "/login" { + t.Errorf("expected redirect code, got %d, to login, got %s", resp.Code, location) + } + + req.AddCookie(&http.Cookie{ + Name: "session", + Value: session.ID, + MaxAge: 60, + }) + resp = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + if resp.Code != http.StatusOK { +} + +func TestOauthFormatsUsername(t *testing.T) { + +} + +func TestSessionIsUnique(t *testing.T) {} + +func TestLogoutClearsCookie(t *testing.T) { + +} + +func TestRefreshUpdatesExpiration(t *testing.T) { + +} + +func TestVerifySessionEnsuresNonExpired(t *testing.T) { + +} + +func TestAPITokensAreEquivalentToSessions(t *testing.T) { + +} -- cgit v1.2.3-70-g09d2 From ae640a253edb5935380975fb07430e910a83b340 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Fri, 5 Apr 2024 15:43:03 -0600 Subject: add some auth test cases --- api/auth/auth.go | 9 +- api/auth/auth_test.go | 234 ++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 212 insertions(+), 31 deletions(-) (limited to 'api/auth/auth.go') diff --git a/api/auth/auth.go b/api/auth/auth.go index 3c633cd..becce24 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -74,7 +74,6 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req reqContext := req.Context() token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value)) if err != nil { - log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } @@ -195,12 +194,13 @@ func LogoutContinuation(context *types.RequestContext, req *http.Request, resp h _ = database.DeleteSession(context.DBConn, sessionCookie.Value) } - http.Redirect(resp, req, "/", http.StatusFound) http.SetCookie(resp, &http.Cookie{ Name: "session", MaxAge: 0, Value: "", }) + http.Redirect(resp, req, "/", http.StatusFound) + return success(context, req, resp) } } @@ -225,10 +225,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us } func createUserFromOauthResponse(response *http.Response) (*database.User, error) { - user := &database.User{ - CreatedAt: time.Now(), - } - + user := &database.User{} err := json.NewDecoder(response.Body).Decode(user) defer response.Body.Close() diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index caaedf1..1e54099 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -2,9 +2,11 @@ package auth_test import ( "database/sql" + "log" "net/http" "net/http/httptest" "os" + "strings" "testing" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" @@ -12,6 +14,7 @@ import ( "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" + "golang.org/x/oauth2" ) func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { @@ -38,51 +41,232 @@ func setup() (*sql.DB, *types.RequestContext, func()) { } } -func TestLoginSendsYouToRedirect(t *testing.T) { +func FakedOauthServer() *httptest.Server { + oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/auth" { + code := utils.RandomId() + + state := r.URL.Query().Get("state") + redirectPath := r.URL.Query().Get("redirect_uri") + redirectPath += "?code=" + code + "&state=" + state + + http.Redirect(w, r, redirectPath, http.StatusFound) + } + if r.URL.Path == "/token" { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`)) + } + if r.URL.Path == "/user" { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`)) + } + })) + + return oauthServer +} + +func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + resp.Write([]byte(context.User.Username)) + return success(context, req, resp) + } +} + +func MockUserEndpointServer(context *types.RequestContext) *httptest.Server { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/protected-path" { + auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/login" { + log.Println("login") + auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/callback" { + log.Println("callback") + auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/me" { + auth.VerifySessionContinuation(context, r, w)(EchoUsernameContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/logout" { + auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation) + } + })) + return testServer +} + +func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) { + return &oauth2.Config{ + ClientID: "test", + ClientSecret: "test", + Scopes: []string{"test"}, + Endpoint: oauth2.Endpoint{ + AuthURL: oauthServerURL + "/auth", + TokenURL: oauthServerURL + "/token", + }, + RedirectURL: testServerURL + "/callback", + }, oauthServerURL + "/user" +} + +func FollowAuthentication( + oauthServer *httptest.Server, + testServer *httptest.Server, + cookies map[string]*http.Cookie, + location string, +) (map[string]*http.Cookie, string) { + resp := httptest.NewRecorder() + resp.Code = 0 + + for resp.Code == 0 || resp.Code == http.StatusFound { + req := httptest.NewRequest("GET", location, nil) + resp = httptest.NewRecorder() + + for _, cookie := range cookies { + req.AddCookie(cookie) + } + if strings.HasPrefix(location, oauthServer.URL) { + oauthServer.Config.Handler.ServeHTTP(resp, req) + } else { + testServer.Config.Handler.ServeHTTP(resp, req) + } + for _, cookie := range resp.Result().Cookies() { + cookies[cookie.Name] = cookie + } + + if resp.Code == http.StatusFound { + location = resp.Header().Get("Location") + } + } + + return cookies, location +} + +func TestOauthCreatesUserWithCorrectUsername(t *testing.T) { db, context, cleanup := setup() defer cleanup() - user := &database.User{ - ID: "test", - Username: "test", + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + defer oauthServer.Close() + defer testServer.Close() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + user, _ := database.GetUser(db, "test") + if user != nil { + t.Errorf("expected no user, got user") } - database.FindOrSaveUser(db, user) - session, _ := database.MakeUserSessionFor(db, user) + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) - })) + user, _ = database.GetUser(db, "test") + if user == nil { + t.Errorf("expected a user to be created, could not find user") + } + if user.Username != "test" { + t.Errorf("expected username to be test, got %s", user.Username) + } +} + +func TestOauthRedirectsToPreviousLockedPage(t *testing.T) { + _, context, cleanup := setup() + defer cleanup() + + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + defer oauthServer.Close() defer testServer.Close() - protectedPath := testServer.URL + "/protected-path" - req := httptest.NewRequest("GET", protectedPath, nil) + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + req := httptest.NewRequest("GET", "/protected-path", nil) resp := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(resp, req) - location := resp.Header().Get("Location") - if resp.Code != http.StatusFound && location != "/login" { - t.Errorf("expected redirect code, got %d, to login, got %s", resp.Code, location) + if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") { + t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location")) } - req.AddCookie(&http.Cookie{ - Name: "session", - Value: session.ID, - MaxAge: 60, - }) - resp = httptest.NewRecorder() - testServer.Config.Handler.ServeHTTP(resp, req) - if resp.Code != http.StatusOK { + cookies := make(map[string]*http.Cookie) + cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page") + + if !(strings.HasSuffix(location, "/protected-page")) { + t.Errorf("expected to redirect back to /protected-page after login, got %s", location) + } } -func TestOauthFormatsUsername(t *testing.T) { +func TestOauthSetsUniqueSession(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + defer oauthServer.Close() + defer testServer.Close() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") + + cookiesAgain := make(map[string]*http.Cookie) + cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me") + + sessionOne := cookies["session"].Value + sessionTwo := cookiesAgain["session"].Value + if sessionOne == sessionTwo { + t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo) + } + session, _ := database.GetSession(db, sessionOne) + if session.UserID != "test" { + t.Errorf("expected session to be associated with user test, got %s", session.UserID) + } } -func TestSessionIsUnique(t *testing.T) {} +func TestLogoutClearsSession(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + defer oauthServer.Close() + defer testServer.Close() -func TestLogoutClearsCookie(t *testing.T) { + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") + + req := httptest.NewRequest("GET", "/logout", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + for _, cookie := range resp.Result().Cookies() { + cookies[cookie.Name] = cookie + } + + req = httptest.NewRequest("GET", "/me", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") { + t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location")) + } + + session, _ := database.GetSession(db, cookies["session"].Value) + if session != nil { + t.Errorf("expected session to be deleted, got session") + } } func TestRefreshUpdatesExpiration(t *testing.T) { -- cgit v1.2.3-70-g09d2 From cad8e2c4ed5e3bab61ff243f8677f8a46eaeafb0 Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Sat, 6 Apr 2024 13:40:46 -0600 Subject: nits --- api/auth/auth.go | 2 +- api/auth/auth_test.go | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) (limited to 'api/auth/auth.go') diff --git a/api/auth/auth.go b/api/auth/auth.go index becce24..0ffbf9c 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -144,7 +144,7 @@ func VerifySessionContinuation(context *types.RequestContext, req *http.Request, http.SetCookie(resp, &http.Cookie{ Name: "session", Value: "", - MaxAge: 0, // reset session cookie in case + MaxAge: 0, }) context.User = nil diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index a3d5b16..5e67c6d 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -276,7 +276,6 @@ func TestRefreshUpdatesExpiration(t *testing.T) { updatedSession, _ := database.GetSession(db, cookies["session"].Value) - // if session expiration is greater than or equal to updated session expiration if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) { t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt) } -- cgit v1.2.3-70-g09d2