summaryrefslogtreecommitdiff
path: root/api/auth
diff options
context:
space:
mode:
authorsimponic <simponic@hatecomputers.club>2024-04-06 15:43:18 -0400
committersimponic <simponic@hatecomputers.club>2024-04-06 15:43:18 -0400
commit83cc6267fd5ce2f61200314424c5f400f65ff2ba (patch)
treeeafb35310236a15572cbb6e16ff8d6f181bfe240 /api/auth
parent569d2788ebfb90774faf361f62bfe7968e091465 (diff)
parentcad8e2c4ed5e3bab61ff243f8677f8a46eaeafb0 (diff)
downloadhatecomputers.club-83cc6267fd5ce2f61200314424c5f400f65ff2ba.tar.gz
hatecomputers.club-83cc6267fd5ce2f61200314424c5f400f65ff2ba.zip
Merge pull request 'testing | dont be recursive for external domains | finalize oauth' (#5) from dont-be-authoritative into main
Reviewed-on: https://git.hatecomputers.club/hatecomputers/hatecomputers.club/pulls/5
Diffstat (limited to 'api/auth')
-rw-r--r--api/auth/auth.go296
-rw-r--r--api/auth/auth_test.go307
2 files changed, 603 insertions, 0 deletions
diff --git a/api/auth/auth.go b/api/auth/auth.go
new file mode 100644
index 0000000..0ffbf9c
--- /dev/null
+++ b/api/auth/auth.go
@@ -0,0 +1,296 @@
+package auth
+
+import (
+ "crypto/sha256"
+ "database/sql"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "strings"
+ "time"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
+ "golang.org/x/oauth2"
+)
+
+func StartSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ verifier := utils.RandomId() + utils.RandomId()
+
+ sha2 := sha256.New()
+ io.WriteString(sha2, verifier)
+ codeChallenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil))
+
+ state := utils.RandomId()
+ url := context.Args.OauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge", codeChallenge))
+
+ http.SetCookie(resp, &http.Cookie{
+ Name: "verifier",
+ Value: verifier,
+ Path: "/",
+ Secure: true,
+ SameSite: http.SameSiteLaxMode,
+ MaxAge: 200,
+ })
+ http.SetCookie(resp, &http.Cookie{
+ Name: "state",
+ Value: state,
+ Path: "/",
+ Secure: true,
+ SameSite: http.SameSiteLaxMode,
+ MaxAge: 200,
+ })
+
+ http.Redirect(resp, req, url, http.StatusFound)
+ return success(context, req, resp)
+ }
+}
+
+func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ state := req.URL.Query().Get("state")
+ code := req.URL.Query().Get("code")
+
+ if code == "" || state == "" {
+ resp.WriteHeader(http.StatusBadRequest)
+ return failure(context, req, resp)
+ }
+
+ if !verifyState(req, "state", state) {
+ resp.WriteHeader(http.StatusBadRequest)
+ return failure(context, req, resp)
+ }
+ verifierCookie, err := req.Cookie("verifier")
+ if err != nil {
+ resp.WriteHeader(http.StatusBadRequest)
+ return failure(context, req, resp)
+ }
+
+ reqContext := req.Context()
+ token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value))
+ if err != nil {
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+
+ client := context.Args.OauthConfig.Client(reqContext, token)
+ user, err := getOauthUser(context.DBConn, client, context.Args.OauthUserInfoURI)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+
+ return failure(context, req, resp)
+ }
+
+ session, err := database.MakeUserSessionFor(context.DBConn, user)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+
+ http.SetCookie(resp, &http.Cookie{
+ Name: "session",
+ Value: session.ID,
+ Path: "/",
+ SameSite: http.SameSiteLaxMode,
+ Secure: true,
+ })
+ http.SetCookie(resp, &http.Cookie{
+ Name: "verifier",
+ Value: "",
+ MaxAge: 0,
+ })
+ http.SetCookie(resp, &http.Cookie{
+ Name: "state",
+ Value: "",
+ MaxAge: 0,
+ })
+
+ redirect := "/"
+ redirectCookie, err := req.Cookie("redirect")
+ if err == nil && redirectCookie.Value != "" {
+ redirect = redirectCookie.Value
+ http.SetCookie(resp, &http.Cookie{
+ Name: "redirect",
+ MaxAge: 0,
+ Value: "",
+ })
+ }
+
+ http.Redirect(resp, req, redirect, http.StatusFound)
+ return success(context, req, resp)
+ }
+}
+
+func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ authHeader := req.Header.Get("Authorization")
+ user, userErr := getUserFromAuthHeader(context.DBConn, authHeader)
+
+ sessionCookie, err := req.Cookie("session")
+ if err == nil && sessionCookie.Value != "" {
+ user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value)
+ }
+
+ if userErr != nil || user == nil {
+ log.Println(userErr, user)
+
+ http.SetCookie(resp, &http.Cookie{
+ Name: "session",
+ Value: "",
+ MaxAge: 0,
+ })
+
+ context.User = nil
+ return failure(context, req, resp)
+ }
+
+ context.User = user
+ return success(context, req, resp)
+ }
+}
+
+func GoLoginContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ http.SetCookie(resp, &http.Cookie{
+ Name: "redirect",
+ Value: req.URL.Path,
+ Path: "/",
+ Secure: true,
+ SameSite: http.SameSiteLaxMode,
+ })
+
+ http.Redirect(resp, req, "/login", http.StatusFound)
+ return failure(context, req, resp)
+ }
+}
+
+func RefreshSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ sessionCookie, err := req.Cookie("session")
+ if err != nil {
+ return failure(context, req, resp)
+ }
+
+ _, err = database.RefreshSession(context.DBConn, sessionCookie.Value)
+ if err != nil {
+ return failure(context, req, resp)
+ }
+
+ return success(context, req, resp)
+ }
+}
+
+func LogoutContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ sessionCookie, err := req.Cookie("session")
+ if err == nil && sessionCookie.Value != "" {
+ _ = database.DeleteSession(context.DBConn, sessionCookie.Value)
+ }
+
+ http.SetCookie(resp, &http.Cookie{
+ Name: "session",
+ MaxAge: 0,
+ Value: "",
+ })
+ http.Redirect(resp, req, "/", http.StatusFound)
+
+ return success(context, req, resp)
+ }
+}
+
+func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) {
+ userResponse, err := client.Get(uri)
+ if err != nil {
+ return nil, err
+ }
+
+ userStruct, err := createUserFromOauthResponse(userResponse)
+ if err != nil {
+ return nil, err
+ }
+
+ user, err := database.FindOrSaveUser(dbConn, userStruct)
+ if err != nil {
+ return nil, err
+ }
+
+ return user, nil
+}
+
+func createUserFromOauthResponse(response *http.Response) (*database.User, error) {
+ user := &database.User{}
+ err := json.NewDecoder(response.Body).Decode(user)
+ defer response.Body.Close()
+
+ if err != nil {
+ log.Println(err)
+ return nil, err
+ }
+
+ user.Username = strings.ToLower(user.Username)
+ user.Username = strings.Split(user.Username, "@")[0]
+
+ return user, nil
+}
+
+func verifyState(req *http.Request, stateCookieName string, expectedState string) bool {
+ cookie, err := req.Cookie(stateCookieName)
+ if err != nil || cookie.Value != expectedState {
+ return false
+ }
+
+ return true
+}
+
+func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) {
+ if bearerToken == "" {
+ return nil, nil
+ }
+
+ parts := strings.Split(bearerToken, " ")
+ if len(parts) != 2 || parts[0] != "Bearer" {
+ return nil, nil
+ }
+
+ key, err := database.GetAPIKey(dbConn, parts[1])
+ if err != nil {
+ return nil, err
+ }
+ if key == nil {
+ return nil, nil
+ }
+
+ user, err := database.GetUser(dbConn, key.UserID)
+ if err != nil {
+ return nil, err
+ }
+
+ return user, nil
+}
+
+func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) {
+ session, err := database.GetSession(dbConn, sessionId)
+ if err != nil {
+ return nil, err
+ }
+
+ if session.ExpireAt.Before(time.Now()) {
+ session = nil
+ database.DeleteSession(dbConn, sessionId)
+ return nil, fmt.Errorf("session expired")
+ }
+
+ user, err := database.GetUser(dbConn, session.UserID)
+ if err != nil {
+ return nil, err
+ }
+
+ return user, nil
+}
diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go
new file mode 100644
index 0000000..5e67c6d
--- /dev/null
+++ b/api/auth/auth_test.go
@@ -0,0 +1,307 @@
+package auth_test
+
+import (
+ "database/sql"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
+ "golang.org/x/oauth2"
+)
+
+func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
+ return success(context, req, resp)
+ }
+}
+
+func FakedOauthServer() *httptest.Server {
+ oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/auth" {
+ code := utils.RandomId()
+
+ state := r.URL.Query().Get("state")
+ redirectPath := r.URL.Query().Get("redirect_uri")
+ redirectPath += "?code=" + code + "&state=" + state
+
+ http.Redirect(w, r, redirectPath, http.StatusFound)
+ }
+ if r.URL.Path == "/token" {
+ w.Header().Set("Content-Type", "application/json")
+ w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`))
+ }
+ if r.URL.Path == "/user" {
+ w.Header().Set("Content-Type", "application/json")
+ w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`))
+ }
+ }))
+
+ return oauthServer
+}
+
+func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ resp.Write([]byte(context.User.Username))
+ return success(context, req, resp)
+ }
+}
+
+func MockUserEndpointServer(context *types.RequestContext) *httptest.Server {
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/protected-path" {
+ auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation)
+ }
+
+ if r.URL.Path == "/login" {
+ log.Println("login")
+ auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation)
+ }
+
+ if r.URL.Path == "/callback" {
+ log.Println("callback")
+ auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation)
+ }
+
+ if r.URL.Path == "/me" {
+ auth.VerifySessionContinuation(context, r, w)(auth.RefreshSessionContinuation, auth.GoLoginContinuation)(EchoUsernameContinuation, IdContinuation)(IdContinuation, IdContinuation)
+ }
+
+ if r.URL.Path == "/logout" {
+ auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation)
+ }
+ }))
+ return testServer
+}
+
+func setup() (*sql.DB, *types.RequestContext, *httptest.Server, *httptest.Server, func()) {
+ randomDb := utils.RandomId()
+
+ testDb := database.MakeConn(&randomDb)
+ database.Migrate(testDb)
+
+ context := &types.RequestContext{
+ DBConn: testDb,
+ Args: &args.Arguments{},
+ TemplateData: &(map[string]interface{}{}),
+ }
+
+ oauthServer := FakedOauthServer()
+ testServer := MockUserEndpointServer(context)
+
+ return testDb, context, oauthServer, testServer, func() {
+ oauthServer.Close()
+ testServer.Close()
+
+ testDb.Close()
+ os.Remove(randomDb)
+ }
+}
+
+func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) {
+ return &oauth2.Config{
+ ClientID: "test",
+ ClientSecret: "test",
+ Scopes: []string{"test"},
+ Endpoint: oauth2.Endpoint{
+ AuthURL: oauthServerURL + "/auth",
+ TokenURL: oauthServerURL + "/token",
+ },
+ RedirectURL: testServerURL + "/callback",
+ }, oauthServerURL + "/user"
+}
+
+func FollowAuthentication(
+ oauthServer *httptest.Server,
+ testServer *httptest.Server,
+ cookies map[string]*http.Cookie,
+ location string,
+) (map[string]*http.Cookie, string) {
+ resp := httptest.NewRecorder()
+ resp.Code = 0
+
+ for resp.Code == 0 || resp.Code == http.StatusFound {
+ req := httptest.NewRequest("GET", location, nil)
+ resp = httptest.NewRecorder()
+
+ for _, cookie := range cookies {
+ req.AddCookie(cookie)
+ }
+ if strings.HasPrefix(location, oauthServer.URL) {
+ oauthServer.Config.Handler.ServeHTTP(resp, req)
+ } else {
+ testServer.Config.Handler.ServeHTTP(resp, req)
+ }
+ for _, cookie := range resp.Result().Cookies() {
+ cookies[cookie.Name] = cookie
+ }
+
+ if resp.Code == http.StatusFound {
+ location = resp.Header().Get("Location")
+ }
+ }
+
+ return cookies, location
+}
+
+func TestOauthCreatesUserWithCorrectUsername(t *testing.T) {
+ db, context, oauthServer, testServer, cleanup := setup()
+ defer cleanup()
+
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+
+ user, _ := database.GetUser(db, "test")
+ if user != nil {
+ t.Errorf("expected no user, got user")
+ }
+
+ cookies := make(map[string]*http.Cookie)
+ cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
+
+ user, _ = database.GetUser(db, "test")
+ if user == nil {
+ t.Errorf("expected a user to be created, could not find user")
+ }
+ if user.Username != "test" {
+ t.Errorf("expected username to be test, got %s", user.Username)
+ }
+}
+
+func TestOauthRedirectsToPreviousLockedPage(t *testing.T) {
+ _, context, oauthServer, testServer, cleanup := setup()
+ defer cleanup()
+
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+
+ req := httptest.NewRequest("GET", "/protected-path", nil)
+ resp := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(resp, req)
+ location := resp.Header().Get("Location")
+ if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") {
+ t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location"))
+ }
+
+ cookies := make(map[string]*http.Cookie)
+ cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page")
+
+ if !(strings.HasSuffix(location, "/protected-page")) {
+ t.Errorf("expected to redirect back to /protected-page after login, got %s", location)
+ }
+}
+
+func TestOauthSetsUniqueSession(t *testing.T) {
+ db, context, oauthServer, testServer, cleanup := setup()
+ defer cleanup()
+
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+
+ cookies := make(map[string]*http.Cookie)
+ cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
+
+ cookiesAgain := make(map[string]*http.Cookie)
+ cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me")
+
+ sessionOne := cookies["session"].Value
+ sessionTwo := cookiesAgain["session"].Value
+ if sessionOne == sessionTwo {
+ t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo)
+ }
+
+ session, _ := database.GetSession(db, sessionOne)
+ if session.UserID != "test" {
+ t.Errorf("expected session to be associated with user test, got %s", session.UserID)
+ }
+}
+
+func TestLogoutClearsSession(t *testing.T) {
+ db, context, oauthServer, testServer, cleanup := setup()
+ defer cleanup()
+
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+
+ cookies := make(map[string]*http.Cookie)
+ cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
+
+ req := httptest.NewRequest("GET", "/logout", nil)
+ for _, cookie := range cookies {
+ req.AddCookie(cookie)
+ }
+ resp := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(resp, req)
+ for _, cookie := range resp.Result().Cookies() {
+ cookies[cookie.Name] = cookie
+ }
+
+ req = httptest.NewRequest("GET", "/me", nil)
+ for _, cookie := range cookies {
+ req.AddCookie(cookie)
+ }
+ resp = httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(resp, req)
+ if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
+ t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location"))
+ }
+
+ session, _ := database.GetSession(db, cookies["session"].Value)
+ if session != nil {
+ t.Errorf("expected session to be deleted, got session")
+ }
+}
+
+func TestRefreshUpdatesExpiration(t *testing.T) {
+ db, context, oauthServer, testServer, cleanup := setup()
+ defer cleanup()
+
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+
+ cookies := make(map[string]*http.Cookie)
+ cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
+
+ session, _ := database.GetSession(db, cookies["session"].Value)
+
+ req := httptest.NewRequest("GET", "/me", nil)
+ for _, cookie := range cookies {
+ req.AddCookie(cookie)
+ }
+ resp := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(resp, req)
+
+ updatedSession, _ := database.GetSession(db, cookies["session"].Value)
+
+ if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) {
+ t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt)
+ }
+}
+
+func TestVerifySessionEnsuresNonExpired(t *testing.T) {
+ db, context, oauthServer, testServer, cleanup := setup()
+ defer cleanup()
+
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+
+ cookies := make(map[string]*http.Cookie)
+ cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
+
+ session, _ := database.GetSession(db, cookies["session"].Value)
+ session.ExpireAt = time.Now().Add(-time.Hour)
+ database.SaveSession(db, session)
+
+ req := httptest.NewRequest("GET", "/me", nil)
+ for _, cookie := range cookies {
+ req.AddCookie(cookie)
+ }
+ resp := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
+ t.Errorf("expected redirect to /login after session expiration, got %d and %s", resp.Code, resp.Header().Get("Location"))
+ }
+}