summaryrefslogtreecommitdiff
path: root/api/auth/auth_test.go
diff options
context:
space:
mode:
authorElizabeth <elizabeth@simponic.xyz>2024-04-05 15:43:03 -0600
committerElizabeth <elizabeth@simponic.xyz>2024-04-05 15:43:03 -0600
commitae640a253edb5935380975fb07430e910a83b340 (patch)
tree62290bea86e150da13ccc46fb324670818df4a97 /api/auth/auth_test.go
parent94984aa4b01e96773b71325b5b27e6f64d9bd102 (diff)
downloadhatecomputers.club-ae640a253edb5935380975fb07430e910a83b340.tar.gz
hatecomputers.club-ae640a253edb5935380975fb07430e910a83b340.zip
add some auth test cases
Diffstat (limited to 'api/auth/auth_test.go')
-rw-r--r--api/auth/auth_test.go234
1 files changed, 209 insertions, 25 deletions
diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go
index caaedf1..1e54099 100644
--- a/api/auth/auth_test.go
+++ b/api/auth/auth_test.go
@@ -2,9 +2,11 @@ package auth_test
import (
"database/sql"
+ "log"
"net/http"
"net/http/httptest"
"os"
+ "strings"
"testing"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
@@ -12,6 +14,7 @@ import (
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
+ "golang.org/x/oauth2"
)
func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
@@ -38,51 +41,232 @@ func setup() (*sql.DB, *types.RequestContext, func()) {
}
}
-func TestLoginSendsYouToRedirect(t *testing.T) {
+func FakedOauthServer() *httptest.Server {
+ oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/auth" {
+ code := utils.RandomId()
+
+ state := r.URL.Query().Get("state")
+ redirectPath := r.URL.Query().Get("redirect_uri")
+ redirectPath += "?code=" + code + "&state=" + state
+
+ http.Redirect(w, r, redirectPath, http.StatusFound)
+ }
+ if r.URL.Path == "/token" {
+ w.Header().Set("Content-Type", "application/json")
+ w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`))
+ }
+ if r.URL.Path == "/user" {
+ w.Header().Set("Content-Type", "application/json")
+ w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`))
+ }
+ }))
+
+ return oauthServer
+}
+
+func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ resp.Write([]byte(context.User.Username))
+ return success(context, req, resp)
+ }
+}
+
+func MockUserEndpointServer(context *types.RequestContext) *httptest.Server {
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/protected-path" {
+ auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation)
+ }
+
+ if r.URL.Path == "/login" {
+ log.Println("login")
+ auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation)
+ }
+
+ if r.URL.Path == "/callback" {
+ log.Println("callback")
+ auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation)
+ }
+
+ if r.URL.Path == "/me" {
+ auth.VerifySessionContinuation(context, r, w)(EchoUsernameContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation)
+ }
+
+ if r.URL.Path == "/logout" {
+ auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation)
+ }
+ }))
+ return testServer
+}
+
+func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) {
+ return &oauth2.Config{
+ ClientID: "test",
+ ClientSecret: "test",
+ Scopes: []string{"test"},
+ Endpoint: oauth2.Endpoint{
+ AuthURL: oauthServerURL + "/auth",
+ TokenURL: oauthServerURL + "/token",
+ },
+ RedirectURL: testServerURL + "/callback",
+ }, oauthServerURL + "/user"
+}
+
+func FollowAuthentication(
+ oauthServer *httptest.Server,
+ testServer *httptest.Server,
+ cookies map[string]*http.Cookie,
+ location string,
+) (map[string]*http.Cookie, string) {
+ resp := httptest.NewRecorder()
+ resp.Code = 0
+
+ for resp.Code == 0 || resp.Code == http.StatusFound {
+ req := httptest.NewRequest("GET", location, nil)
+ resp = httptest.NewRecorder()
+
+ for _, cookie := range cookies {
+ req.AddCookie(cookie)
+ }
+ if strings.HasPrefix(location, oauthServer.URL) {
+ oauthServer.Config.Handler.ServeHTTP(resp, req)
+ } else {
+ testServer.Config.Handler.ServeHTTP(resp, req)
+ }
+ for _, cookie := range resp.Result().Cookies() {
+ cookies[cookie.Name] = cookie
+ }
+
+ if resp.Code == http.StatusFound {
+ location = resp.Header().Get("Location")
+ }
+ }
+
+ return cookies, location
+}
+
+func TestOauthCreatesUserWithCorrectUsername(t *testing.T) {
db, context, cleanup := setup()
defer cleanup()
- user := &database.User{
- ID: "test",
- Username: "test",
+ oauthServer := FakedOauthServer()
+ testServer := MockUserEndpointServer(context)
+ defer oauthServer.Close()
+ defer testServer.Close()
+
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+
+ user, _ := database.GetUser(db, "test")
+ if user != nil {
+ t.Errorf("expected no user, got user")
}
- database.FindOrSaveUser(db, user)
- session, _ := database.MakeUserSessionFor(db, user)
+ cookies := make(map[string]*http.Cookie)
+ cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
- testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation)
- }))
+ user, _ = database.GetUser(db, "test")
+ if user == nil {
+ t.Errorf("expected a user to be created, could not find user")
+ }
+ if user.Username != "test" {
+ t.Errorf("expected username to be test, got %s", user.Username)
+ }
+}
+
+func TestOauthRedirectsToPreviousLockedPage(t *testing.T) {
+ _, context, cleanup := setup()
+ defer cleanup()
+
+ oauthServer := FakedOauthServer()
+ testServer := MockUserEndpointServer(context)
+ defer oauthServer.Close()
defer testServer.Close()
- protectedPath := testServer.URL + "/protected-path"
- req := httptest.NewRequest("GET", protectedPath, nil)
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+
+ req := httptest.NewRequest("GET", "/protected-path", nil)
resp := httptest.NewRecorder()
testServer.Config.Handler.ServeHTTP(resp, req)
-
location := resp.Header().Get("Location")
- if resp.Code != http.StatusFound && location != "/login" {
- t.Errorf("expected redirect code, got %d, to login, got %s", resp.Code, location)
+ if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") {
+ t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location"))
}
- req.AddCookie(&http.Cookie{
- Name: "session",
- Value: session.ID,
- MaxAge: 60,
- })
- resp = httptest.NewRecorder()
- testServer.Config.Handler.ServeHTTP(resp, req)
- if resp.Code != http.StatusOK {
+ cookies := make(map[string]*http.Cookie)
+ cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page")
+
+ if !(strings.HasSuffix(location, "/protected-page")) {
+ t.Errorf("expected to redirect back to /protected-page after login, got %s", location)
+ }
}
-func TestOauthFormatsUsername(t *testing.T) {
+func TestOauthSetsUniqueSession(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ oauthServer := FakedOauthServer()
+ testServer := MockUserEndpointServer(context)
+ defer oauthServer.Close()
+ defer testServer.Close()
+
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+
+ cookies := make(map[string]*http.Cookie)
+ cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
+
+ cookiesAgain := make(map[string]*http.Cookie)
+ cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me")
+
+ sessionOne := cookies["session"].Value
+ sessionTwo := cookiesAgain["session"].Value
+ if sessionOne == sessionTwo {
+ t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo)
+ }
+ session, _ := database.GetSession(db, sessionOne)
+ if session.UserID != "test" {
+ t.Errorf("expected session to be associated with user test, got %s", session.UserID)
+ }
}
-func TestSessionIsUnique(t *testing.T) {}
+func TestLogoutClearsSession(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ oauthServer := FakedOauthServer()
+ testServer := MockUserEndpointServer(context)
+ defer oauthServer.Close()
+ defer testServer.Close()
-func TestLogoutClearsCookie(t *testing.T) {
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+ cookies := make(map[string]*http.Cookie)
+ cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
+
+ req := httptest.NewRequest("GET", "/logout", nil)
+ for _, cookie := range cookies {
+ req.AddCookie(cookie)
+ }
+ resp := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(resp, req)
+ for _, cookie := range resp.Result().Cookies() {
+ cookies[cookie.Name] = cookie
+ }
+
+ req = httptest.NewRequest("GET", "/me", nil)
+ for _, cookie := range cookies {
+ req.AddCookie(cookie)
+ }
+ resp = httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(resp, req)
+ if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
+ t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location"))
+ }
+
+ session, _ := database.GetSession(db, cookies["session"].Value)
+ if session != nil {
+ t.Errorf("expected session to be deleted, got session")
+ }
}
func TestRefreshUpdatesExpiration(t *testing.T) {