diff options
| author | Elizabeth Hunt <elizabeth@simponic.xyz> | 2024-04-03 17:53:50 -0600 |
|---|---|---|
| committer | Elizabeth Hunt <elizabeth@simponic.xyz> | 2024-04-03 17:53:50 -0600 |
| commit | f38e8719c2a8537fe9b64ed8ceca45858a58e498 (patch) | |
| tree | 5cf2c7c7f6396f75bdb841db00638e4eef8e81e8 /api/auth | |
| parent | e398cf05402c010d594cea4e2dea307ca1a36dbe (diff) | |
| download | hatecomputers.club-f38e8719c2a8537fe9b64ed8ceca45858a58e498.tar.gz hatecomputers.club-f38e8719c2a8537fe9b64ed8ceca45858a58e498.zip | |
make it compile
Diffstat (limited to 'api/auth')
| -rw-r--r-- | api/auth/auth.go | 288 | ||||
| -rw-r--r-- | api/auth/auth_test.go | 36 |
2 files changed, 324 insertions, 0 deletions
diff --git a/api/auth/auth.go b/api/auth/auth.go new file mode 100644 index 0000000..dc348b2 --- /dev/null +++ b/api/auth/auth.go @@ -0,0 +1,288 @@ +package auth + +import ( + "crypto/sha256" + "database/sql" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "strings" + "time" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" + "golang.org/x/oauth2" +) + +func StartSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + verifier := utils.RandomId() + utils.RandomId() + + sha2 := sha256.New() + io.WriteString(sha2, verifier) + codeChallenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil)) + + state := utils.RandomId() + url := context.Args.OauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge", codeChallenge)) + + http.SetCookie(resp, &http.Cookie{ + Name: "verifier", + Value: verifier, + Path: "/", + Secure: true, + SameSite: http.SameSiteLaxMode, + MaxAge: 60, + }) + http.SetCookie(resp, &http.Cookie{ + Name: "state", + Value: state, + Path: "/", + Secure: true, + SameSite: http.SameSiteLaxMode, + MaxAge: 60, + }) + + http.Redirect(resp, req, url, http.StatusFound) + return success(context, req, resp) + } +} + +func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + state := req.URL.Query().Get("state") + code := req.URL.Query().Get("code") + + if code == "" || state == "" { + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + + if !verifyState(req, "state", state) { + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + verifierCookie, err := req.Cookie("verifier") + if err != nil { + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + + reqContext := req.Context() + token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value)) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + client := context.Args.OauthConfig.Client(reqContext, token) + user, err := getOauthUser(context.DBConn, client, context.Args.OauthUserInfoURI) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + + return failure(context, req, resp) + } + + session, err := database.MakeUserSessionFor(context.DBConn, user) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + http.SetCookie(resp, &http.Cookie{ + Name: "session", + Value: session.ID, + Path: "/", + SameSite: http.SameSiteLaxMode, + Secure: true, + }) + + redirect := "/" + redirectCookie, err := req.Cookie("redirect") + if err == nil && redirectCookie.Value != "" { + redirect = redirectCookie.Value + http.SetCookie(resp, &http.Cookie{ + Name: "redirect", + MaxAge: 0, + }) + } + + http.Redirect(resp, req, redirect, http.StatusFound) + return success(context, req, resp) + } +} + +func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { + if bearerToken == "" { + return nil, nil + } + + parts := strings.Split(bearerToken, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + return nil, nil + } + + typesKey, err := database.GetAPIKey(dbConn, parts[1]) + if err != nil { + return nil, err + } + if typesKey == nil { + return nil, nil + } + + user, err := database.GetUser(dbConn, typesKey.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + +func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { + session, err := database.GetSession(dbConn, sessionId) + if err != nil { + return nil, err + } + + if session.ExpireAt.Before(time.Now()) { + session = nil + database.DeleteSession(dbConn, sessionId) + return nil, fmt.Errorf("session expired") + } + + user, err := database.GetUser(dbConn, session.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + +func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + authHeader := req.Header.Get("Authorization") + user, userErr := getUserFromAuthHeader(context.DBConn, authHeader) + + sessionCookie, err := req.Cookie("session") + if err == nil && sessionCookie.Value != "" { + user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value) + } + + if userErr != nil || user == nil { + log.Println(userErr, user) + + http.SetCookie(resp, &http.Cookie{ + Name: "session", + MaxAge: 0, // reset session cookie in case + }) + + context.User = nil + return failure(context, req, resp) + } + + context.User = user + return success(context, req, resp) + } +} + +func GoLoginContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + http.SetCookie(resp, &http.Cookie{ + Name: "redirect", + Value: req.URL.Path, + Path: "/", + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + + http.Redirect(resp, req, "/login", http.StatusFound) + return failure(context, req, resp) + } +} + +func RefreshSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + sessionCookie, err := req.Cookie("session") + if err != nil { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + _, err = database.RefreshSession(context.DBConn, sessionCookie.Value) + if err != nil { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + return success(context, req, resp) + } +} + +func LogoutContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + sessionCookie, err := req.Cookie("session") + if err == nil && sessionCookie.Value != "" { + _ = database.DeleteSession(context.DBConn, sessionCookie.Value) + } + + http.Redirect(resp, req, "/", http.StatusFound) + http.SetCookie(resp, &http.Cookie{ + Name: "session", + MaxAge: 0, + }) + return success(context, req, resp) + } +} + +func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) { + userResponse, err := client.Get(uri) + if err != nil { + return nil, err + } + + userStruct, err := createUserFromResponse(userResponse) + if err != nil { + return nil, err + } + + user, err := database.FindOrSaveUser(dbConn, userStruct) + if err != nil { + return nil, err + } + + return user, nil +} + +func createUserFromResponse(response *http.Response) (*database.User, error) { + user := &database.User{ + CreatedAt: time.Now(), + } + + err := json.NewDecoder(response.Body).Decode(user) + defer response.Body.Close() + + if err != nil { + log.Println(err) + return nil, err + } + + user.Username = strings.ToLower(user.Username) + user.Username = strings.Split(user.Username, "@")[0] + + return user, nil +} + +func verifyState(req *http.Request, stateCookieName string, expectedState string) bool { + cookie, err := req.Cookie(stateCookieName) + if err != nil || cookie.Value != expectedState { + return false + } + + return true +} diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go new file mode 100644 index 0000000..a6c2a45 --- /dev/null +++ b/api/auth/auth_test.go @@ -0,0 +1,36 @@ +package auth_test + +import ( + "database/sql" + "os" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +func setup() (*sql.DB, *types.RequestContext, func()) { + randomDb := utils.RandomId() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + + context := &types.RequestContext{ + DBConn: testDb, + Args: &args.Arguments{}, + TemplateData: &(map[string]interface{}{}), + } + + return testDb, context, func() { + testDb.Close() + os.Remove(randomDb) + } +} + +/* +todo: test types key creation ++ api key attached to user ++ user session is unique ++ goLogin goes to page in cookie +*/ |
