From f38e8719c2a8537fe9b64ed8ceca45858a58e498 Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Wed, 3 Apr 2024 17:53:50 -0600 Subject: make it compile --- api/auth/auth_test.go | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 api/auth/auth_test.go (limited to 'api/auth/auth_test.go') diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go new file mode 100644 index 0000000..a6c2a45 --- /dev/null +++ b/api/auth/auth_test.go @@ -0,0 +1,36 @@ +package auth_test + +import ( + "database/sql" + "os" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +func setup() (*sql.DB, *types.RequestContext, func()) { + randomDb := utils.RandomId() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + + context := &types.RequestContext{ + DBConn: testDb, + Args: &args.Arguments{}, + TemplateData: &(map[string]interface{}{}), + } + + return testDb, context, func() { + testDb.Close() + os.Remove(randomDb) + } +} + +/* +todo: test types key creation ++ api key attached to user ++ user session is unique ++ goLogin goes to page in cookie +*/ -- cgit v1.2.3-70-g09d2 From 94984aa4b01e96773b71325b5b27e6f64d9bd102 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Thu, 4 Apr 2024 16:03:34 -0600 Subject: auth test scaffolding --- api/auth/auth.go | 115 +++++++++++++++++++++++++++----------------------- api/auth/auth_test.go | 74 +++++++++++++++++++++++++++++--- 2 files changed, 131 insertions(+), 58 deletions(-) (limited to 'api/auth/auth_test.go') diff --git a/api/auth/auth.go b/api/auth/auth.go index dc348b2..3c633cd 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -35,7 +35,7 @@ func StartSessionContinuation(context *types.RequestContext, req *http.Request, Path: "/", Secure: true, SameSite: http.SameSiteLaxMode, - MaxAge: 60, + MaxAge: 200, }) http.SetCookie(resp, &http.Cookie{ Name: "state", @@ -43,7 +43,7 @@ func StartSessionContinuation(context *types.RequestContext, req *http.Request, Path: "/", Secure: true, SameSite: http.SameSiteLaxMode, - MaxAge: 60, + MaxAge: 200, }) http.Redirect(resp, req, url, http.StatusFound) @@ -102,6 +102,16 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req SameSite: http.SameSiteLaxMode, Secure: true, }) + http.SetCookie(resp, &http.Cookie{ + Name: "verifier", + Value: "", + MaxAge: 0, + }) + http.SetCookie(resp, &http.Cookie{ + Name: "state", + Value: "", + MaxAge: 0, + }) redirect := "/" redirectCookie, err := req.Cookie("redirect") @@ -110,6 +120,7 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req http.SetCookie(resp, &http.Cookie{ Name: "redirect", MaxAge: 0, + Value: "", }) } @@ -118,52 +129,6 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req } } -func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { - if bearerToken == "" { - return nil, nil - } - - parts := strings.Split(bearerToken, " ") - if len(parts) != 2 || parts[0] != "Bearer" { - return nil, nil - } - - typesKey, err := database.GetAPIKey(dbConn, parts[1]) - if err != nil { - return nil, err - } - if typesKey == nil { - return nil, nil - } - - user, err := database.GetUser(dbConn, typesKey.UserID) - if err != nil { - return nil, err - } - - return user, nil -} - -func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { - session, err := database.GetSession(dbConn, sessionId) - if err != nil { - return nil, err - } - - if session.ExpireAt.Before(time.Now()) { - session = nil - database.DeleteSession(dbConn, sessionId) - return nil, fmt.Errorf("session expired") - } - - user, err := database.GetUser(dbConn, session.UserID) - if err != nil { - return nil, err - } - - return user, nil -} - func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { authHeader := req.Header.Get("Authorization") @@ -179,6 +144,7 @@ func VerifySessionContinuation(context *types.RequestContext, req *http.Request, http.SetCookie(resp, &http.Cookie{ Name: "session", + Value: "", MaxAge: 0, // reset session cookie in case }) @@ -210,13 +176,11 @@ func RefreshSessionContinuation(context *types.RequestContext, req *http.Request return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { sessionCookie, err := req.Cookie("session") if err != nil { - resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } _, err = database.RefreshSession(context.DBConn, sessionCookie.Value) if err != nil { - resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } @@ -235,6 +199,7 @@ func LogoutContinuation(context *types.RequestContext, req *http.Request, resp h http.SetCookie(resp, &http.Cookie{ Name: "session", MaxAge: 0, + Value: "", }) return success(context, req, resp) } @@ -246,7 +211,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us return nil, err } - userStruct, err := createUserFromResponse(userResponse) + userStruct, err := createUserFromOauthResponse(userResponse) if err != nil { return nil, err } @@ -259,7 +224,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us return user, nil } -func createUserFromResponse(response *http.Response) (*database.User, error) { +func createUserFromOauthResponse(response *http.Response) (*database.User, error) { user := &database.User{ CreatedAt: time.Now(), } @@ -286,3 +251,49 @@ func verifyState(req *http.Request, stateCookieName string, expectedState string return true } + +func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { + if bearerToken == "" { + return nil, nil + } + + parts := strings.Split(bearerToken, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + return nil, nil + } + + key, err := database.GetAPIKey(dbConn, parts[1]) + if err != nil { + return nil, err + } + if key == nil { + return nil, nil + } + + user, err := database.GetUser(dbConn, key.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + +func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { + session, err := database.GetSession(dbConn, sessionId) + if err != nil { + return nil, err + } + + if session.ExpireAt.Before(time.Now()) { + session = nil + database.DeleteSession(dbConn, sessionId) + return nil, fmt.Errorf("session expired") + } + + user, err := database.GetUser(dbConn, session.UserID) + if err != nil { + return nil, err + } + + return user, nil +} diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index a6c2a45..caaedf1 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -2,14 +2,24 @@ package auth_test import ( "database/sql" + "net/http" + "net/http/httptest" "os" + "testing" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) +func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { + return success(context, req, resp) + } +} + func setup() (*sql.DB, *types.RequestContext, func()) { randomDb := utils.RandomId() @@ -28,9 +38,61 @@ func setup() (*sql.DB, *types.RequestContext, func()) { } } -/* -todo: test types key creation -+ api key attached to user -+ user session is unique -+ goLogin goes to page in cookie -*/ +func TestLoginSendsYouToRedirect(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + user := &database.User{ + ID: "test", + Username: "test", + } + database.FindOrSaveUser(db, user) + + session, _ := database.MakeUserSessionFor(db, user) + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + protectedPath := testServer.URL + "/protected-path" + req := httptest.NewRequest("GET", protectedPath, nil) + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + + location := resp.Header().Get("Location") + if resp.Code != http.StatusFound && location != "/login" { + t.Errorf("expected redirect code, got %d, to login, got %s", resp.Code, location) + } + + req.AddCookie(&http.Cookie{ + Name: "session", + Value: session.ID, + MaxAge: 60, + }) + resp = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + if resp.Code != http.StatusOK { +} + +func TestOauthFormatsUsername(t *testing.T) { + +} + +func TestSessionIsUnique(t *testing.T) {} + +func TestLogoutClearsCookie(t *testing.T) { + +} + +func TestRefreshUpdatesExpiration(t *testing.T) { + +} + +func TestVerifySessionEnsuresNonExpired(t *testing.T) { + +} + +func TestAPITokensAreEquivalentToSessions(t *testing.T) { + +} -- cgit v1.2.3-70-g09d2 From ae640a253edb5935380975fb07430e910a83b340 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Fri, 5 Apr 2024 15:43:03 -0600 Subject: add some auth test cases --- api/auth/auth.go | 9 +- api/auth/auth_test.go | 234 ++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 212 insertions(+), 31 deletions(-) (limited to 'api/auth/auth_test.go') diff --git a/api/auth/auth.go b/api/auth/auth.go index 3c633cd..becce24 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -74,7 +74,6 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req reqContext := req.Context() token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value)) if err != nil { - log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } @@ -195,12 +194,13 @@ func LogoutContinuation(context *types.RequestContext, req *http.Request, resp h _ = database.DeleteSession(context.DBConn, sessionCookie.Value) } - http.Redirect(resp, req, "/", http.StatusFound) http.SetCookie(resp, &http.Cookie{ Name: "session", MaxAge: 0, Value: "", }) + http.Redirect(resp, req, "/", http.StatusFound) + return success(context, req, resp) } } @@ -225,10 +225,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us } func createUserFromOauthResponse(response *http.Response) (*database.User, error) { - user := &database.User{ - CreatedAt: time.Now(), - } - + user := &database.User{} err := json.NewDecoder(response.Body).Decode(user) defer response.Body.Close() diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index caaedf1..1e54099 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -2,9 +2,11 @@ package auth_test import ( "database/sql" + "log" "net/http" "net/http/httptest" "os" + "strings" "testing" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" @@ -12,6 +14,7 @@ import ( "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" + "golang.org/x/oauth2" ) func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { @@ -38,51 +41,232 @@ func setup() (*sql.DB, *types.RequestContext, func()) { } } -func TestLoginSendsYouToRedirect(t *testing.T) { +func FakedOauthServer() *httptest.Server { + oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/auth" { + code := utils.RandomId() + + state := r.URL.Query().Get("state") + redirectPath := r.URL.Query().Get("redirect_uri") + redirectPath += "?code=" + code + "&state=" + state + + http.Redirect(w, r, redirectPath, http.StatusFound) + } + if r.URL.Path == "/token" { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`)) + } + if r.URL.Path == "/user" { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`)) + } + })) + + return oauthServer +} + +func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + resp.Write([]byte(context.User.Username)) + return success(context, req, resp) + } +} + +func MockUserEndpointServer(context *types.RequestContext) *httptest.Server { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/protected-path" { + auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/login" { + log.Println("login") + auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/callback" { + log.Println("callback") + auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/me" { + auth.VerifySessionContinuation(context, r, w)(EchoUsernameContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/logout" { + auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation) + } + })) + return testServer +} + +func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) { + return &oauth2.Config{ + ClientID: "test", + ClientSecret: "test", + Scopes: []string{"test"}, + Endpoint: oauth2.Endpoint{ + AuthURL: oauthServerURL + "/auth", + TokenURL: oauthServerURL + "/token", + }, + RedirectURL: testServerURL + "/callback", + }, oauthServerURL + "/user" +} + +func FollowAuthentication( + oauthServer *httptest.Server, + testServer *httptest.Server, + cookies map[string]*http.Cookie, + location string, +) (map[string]*http.Cookie, string) { + resp := httptest.NewRecorder() + resp.Code = 0 + + for resp.Code == 0 || resp.Code == http.StatusFound { + req := httptest.NewRequest("GET", location, nil) + resp = httptest.NewRecorder() + + for _, cookie := range cookies { + req.AddCookie(cookie) + } + if strings.HasPrefix(location, oauthServer.URL) { + oauthServer.Config.Handler.ServeHTTP(resp, req) + } else { + testServer.Config.Handler.ServeHTTP(resp, req) + } + for _, cookie := range resp.Result().Cookies() { + cookies[cookie.Name] = cookie + } + + if resp.Code == http.StatusFound { + location = resp.Header().Get("Location") + } + } + + return cookies, location +} + +func TestOauthCreatesUserWithCorrectUsername(t *testing.T) { db, context, cleanup := setup() defer cleanup() - user := &database.User{ - ID: "test", - Username: "test", + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + defer oauthServer.Close() + defer testServer.Close() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + user, _ := database.GetUser(db, "test") + if user != nil { + t.Errorf("expected no user, got user") } - database.FindOrSaveUser(db, user) - session, _ := database.MakeUserSessionFor(db, user) + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) - })) + user, _ = database.GetUser(db, "test") + if user == nil { + t.Errorf("expected a user to be created, could not find user") + } + if user.Username != "test" { + t.Errorf("expected username to be test, got %s", user.Username) + } +} + +func TestOauthRedirectsToPreviousLockedPage(t *testing.T) { + _, context, cleanup := setup() + defer cleanup() + + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + defer oauthServer.Close() defer testServer.Close() - protectedPath := testServer.URL + "/protected-path" - req := httptest.NewRequest("GET", protectedPath, nil) + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + req := httptest.NewRequest("GET", "/protected-path", nil) resp := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(resp, req) - location := resp.Header().Get("Location") - if resp.Code != http.StatusFound && location != "/login" { - t.Errorf("expected redirect code, got %d, to login, got %s", resp.Code, location) + if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") { + t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location")) } - req.AddCookie(&http.Cookie{ - Name: "session", - Value: session.ID, - MaxAge: 60, - }) - resp = httptest.NewRecorder() - testServer.Config.Handler.ServeHTTP(resp, req) - if resp.Code != http.StatusOK { + cookies := make(map[string]*http.Cookie) + cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page") + + if !(strings.HasSuffix(location, "/protected-page")) { + t.Errorf("expected to redirect back to /protected-page after login, got %s", location) + } } -func TestOauthFormatsUsername(t *testing.T) { +func TestOauthSetsUniqueSession(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + defer oauthServer.Close() + defer testServer.Close() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") + + cookiesAgain := make(map[string]*http.Cookie) + cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me") + + sessionOne := cookies["session"].Value + sessionTwo := cookiesAgain["session"].Value + if sessionOne == sessionTwo { + t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo) + } + session, _ := database.GetSession(db, sessionOne) + if session.UserID != "test" { + t.Errorf("expected session to be associated with user test, got %s", session.UserID) + } } -func TestSessionIsUnique(t *testing.T) {} +func TestLogoutClearsSession(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + defer oauthServer.Close() + defer testServer.Close() -func TestLogoutClearsCookie(t *testing.T) { + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") + + req := httptest.NewRequest("GET", "/logout", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + for _, cookie := range resp.Result().Cookies() { + cookies[cookie.Name] = cookie + } + + req = httptest.NewRequest("GET", "/me", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") { + t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location")) + } + + session, _ := database.GetSession(db, cookies["session"].Value) + if session != nil { + t.Errorf("expected session to be deleted, got session") + } } func TestRefreshUpdatesExpiration(t *testing.T) { -- cgit v1.2.3-70-g09d2 From 5177735b835289c8437799536d3654e5ab142fa3 Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Sat, 6 Apr 2024 13:36:13 -0600 Subject: finish auth tests --- api/auth/auth_test.go | 116 ++++++++++++++++++++++++++++++-------------------- api/keys/keys.go | 9 ++-- api/serve.go | 2 +- database/users.go | 12 ++++++ 4 files changed, 88 insertions(+), 51 deletions(-) (limited to 'api/auth/auth_test.go') diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index 1e54099..a3d5b16 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -8,6 +8,7 @@ import ( "os" "strings" "testing" + "time" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" @@ -23,24 +24,6 @@ func IdContinuation(context *types.RequestContext, req *http.Request, resp http. } } -func setup() (*sql.DB, *types.RequestContext, func()) { - randomDb := utils.RandomId() - - testDb := database.MakeConn(&randomDb) - database.Migrate(testDb) - - context := &types.RequestContext{ - DBConn: testDb, - Args: &args.Arguments{}, - TemplateData: &(map[string]interface{}{}), - } - - return testDb, context, func() { - testDb.Close() - os.Remove(randomDb) - } -} - func FakedOauthServer() *httptest.Server { oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/auth" { @@ -89,7 +72,7 @@ func MockUserEndpointServer(context *types.RequestContext) *httptest.Server { } if r.URL.Path == "/me" { - auth.VerifySessionContinuation(context, r, w)(EchoUsernameContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) + auth.VerifySessionContinuation(context, r, w)(auth.RefreshSessionContinuation, auth.GoLoginContinuation)(EchoUsernameContinuation, IdContinuation)(IdContinuation, IdContinuation) } if r.URL.Path == "/logout" { @@ -99,6 +82,30 @@ func MockUserEndpointServer(context *types.RequestContext) *httptest.Server { return testServer } +func setup() (*sql.DB, *types.RequestContext, *httptest.Server, *httptest.Server, func()) { + randomDb := utils.RandomId() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + + context := &types.RequestContext{ + DBConn: testDb, + Args: &args.Arguments{}, + TemplateData: &(map[string]interface{}{}), + } + + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + + return testDb, context, oauthServer, testServer, func() { + oauthServer.Close() + testServer.Close() + + testDb.Close() + os.Remove(randomDb) + } +} + func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) { return &oauth2.Config{ ClientID: "test", @@ -146,14 +153,9 @@ func FollowAuthentication( } func TestOauthCreatesUserWithCorrectUsername(t *testing.T) { - db, context, cleanup := setup() + db, context, oauthServer, testServer, cleanup := setup() defer cleanup() - oauthServer := FakedOauthServer() - testServer := MockUserEndpointServer(context) - defer oauthServer.Close() - defer testServer.Close() - context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) user, _ := database.GetUser(db, "test") @@ -174,14 +176,9 @@ func TestOauthCreatesUserWithCorrectUsername(t *testing.T) { } func TestOauthRedirectsToPreviousLockedPage(t *testing.T) { - _, context, cleanup := setup() + _, context, oauthServer, testServer, cleanup := setup() defer cleanup() - oauthServer := FakedOauthServer() - testServer := MockUserEndpointServer(context) - defer oauthServer.Close() - defer testServer.Close() - context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) req := httptest.NewRequest("GET", "/protected-path", nil) @@ -201,14 +198,9 @@ func TestOauthRedirectsToPreviousLockedPage(t *testing.T) { } func TestOauthSetsUniqueSession(t *testing.T) { - db, context, cleanup := setup() + db, context, oauthServer, testServer, cleanup := setup() defer cleanup() - oauthServer := FakedOauthServer() - testServer := MockUserEndpointServer(context) - defer oauthServer.Close() - defer testServer.Close() - context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) cookies := make(map[string]*http.Cookie) @@ -230,14 +222,9 @@ func TestOauthSetsUniqueSession(t *testing.T) { } func TestLogoutClearsSession(t *testing.T) { - db, context, cleanup := setup() + db, context, oauthServer, testServer, cleanup := setup() defer cleanup() - oauthServer := FakedOauthServer() - testServer := MockUserEndpointServer(context) - defer oauthServer.Close() - defer testServer.Close() - context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) cookies := make(map[string]*http.Cookie) @@ -270,13 +257,52 @@ func TestLogoutClearsSession(t *testing.T) { } func TestRefreshUpdatesExpiration(t *testing.T) { + db, context, oauthServer, testServer, cleanup := setup() + defer cleanup() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path") + + session, _ := database.GetSession(db, cookies["session"].Value) + + req := httptest.NewRequest("GET", "/me", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + updatedSession, _ := database.GetSession(db, cookies["session"].Value) + + // if session expiration is greater than or equal to updated session expiration + if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) { + t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt) + } } func TestVerifySessionEnsuresNonExpired(t *testing.T) { + db, context, oauthServer, testServer, cleanup := setup() + defer cleanup() -} + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) -func TestAPITokensAreEquivalentToSessions(t *testing.T) { + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path") + session, _ := database.GetSession(db, cookies["session"].Value) + session.ExpireAt = time.Now().Add(-time.Hour) + database.SaveSession(db, session) + + req := httptest.NewRequest("GET", "/me", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + + if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") { + t.Errorf("expected redirect to /login after session expiration, got %d and %s", resp.Code, resp.Header().Get("Location")) + } } diff --git a/api/keys/keys.go b/api/keys/keys.go index ad380fc..cef3f3c 100644 --- a/api/keys/keys.go +++ b/api/keys/keys.go @@ -62,27 +62,26 @@ func CreateAPIKeyContinuation(context *types.RequestContext, req *http.Request, func DeleteAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { - key := req.FormValue("key") + apiKey := req.FormValue("key") - typesKey, err := database.GetAPIKey(context.DBConn, key) + key, err := database.GetAPIKey(context.DBConn, apiKey) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } - if (typesKey == nil) || (typesKey.UserID != context.User.ID) { + if (key == nil) || (key.UserID != context.User.ID) { resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } - err = database.DeleteAPIKey(context.DBConn, key) + err = database.DeleteAPIKey(context.DBConn, apiKey) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } - http.Redirect(resp, req, "/keys", http.StatusFound) return success(context, req, resp) } } diff --git a/api/serve.go b/api/serve.go index 2b0eba4..c8775d8 100644 --- a/api/serve.go +++ b/api/serve.go @@ -140,7 +140,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) { diff --git a/database/users.go b/database/users.go index 5cebb8f..6f9456e 100644 --- a/database/users.go +++ b/database/users.go @@ -111,6 +111,18 @@ func DeleteSession(dbConn *sql.DB, sessionId string) error { return nil } +func SaveSession(dbConn *sql.DB, session *UserSession) (*UserSession, error) { + log.Println("saving session", session.ID) + + _, err := dbConn.Exec(`INSERT OR REPLACE INTO user_sessions (id, user_id, expire_at) VALUES (?, ?, ?);`, session.ID, session.UserID, session.ExpireAt) + if err != nil { + log.Println(err) + return nil, err + } + + return session, nil +} + func RefreshSession(dbConn *sql.DB, sessionId string) (*UserSession, error) { newExpireAt := time.Now().Add(ExpiryDuration) -- cgit v1.2.3-70-g09d2 From cad8e2c4ed5e3bab61ff243f8677f8a46eaeafb0 Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Sat, 6 Apr 2024 13:40:46 -0600 Subject: nits --- api/auth/auth.go | 2 +- api/auth/auth_test.go | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) (limited to 'api/auth/auth_test.go') diff --git a/api/auth/auth.go b/api/auth/auth.go index becce24..0ffbf9c 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -144,7 +144,7 @@ func VerifySessionContinuation(context *types.RequestContext, req *http.Request, http.SetCookie(resp, &http.Cookie{ Name: "session", Value: "", - MaxAge: 0, // reset session cookie in case + MaxAge: 0, }) context.User = nil diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index a3d5b16..5e67c6d 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -276,7 +276,6 @@ func TestRefreshUpdatesExpiration(t *testing.T) { updatedSession, _ := database.GetSession(db, cookies["session"].Value) - // if session expiration is greater than or equal to updated session expiration if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) { t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt) } -- cgit v1.2.3-70-g09d2