diff options
| author | Lizzy Hunt <lizzy.hunt@usu.edu> | 2024-03-27 15:02:31 -0600 |
|---|---|---|
| committer | Lizzy Hunt <lizzy.hunt@usu.edu> | 2024-03-27 15:02:31 -0600 |
| commit | 0dc2679005e70c50024bc49e750f3998a0c4c24b (patch) | |
| tree | 73153522195608ee2ed3bbb4c2ed3cbc621b6b07 /api | |
| parent | 8d65f4e23026dce5d04e9a4afaf216f0732482a6 (diff) | |
| download | hatecomputers.club-0dc2679005e70c50024bc49e750f3998a0c4c24b.tar.gz hatecomputers.club-0dc2679005e70c50024bc49e750f3998a0c4c24b.zip | |
authentication! oauth2!
Diffstat (limited to 'api')
| -rw-r--r-- | api/auth.go | 245 | ||||
| -rw-r--r-- | api/serve.go | 42 | ||||
| -rw-r--r-- | api/template.go | 7 |
3 files changed, 278 insertions, 16 deletions
diff --git a/api/auth.go b/api/auth.go new file mode 100644 index 0000000..4733971 --- /dev/null +++ b/api/auth.go @@ -0,0 +1,245 @@ +package api + +import ( + "crypto/sha256" + "database/sql" + "encoding/base64" + "encoding/json" + "io" + "log" + "net/http" + "strings" + "time" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" + "golang.org/x/oauth2" +) + +func StartSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + verifier := utils.RandomId() + utils.RandomId() + + sha2 := sha256.New() + io.WriteString(sha2, verifier) + codeChallenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil)) + + state := utils.RandomId() + url := context.Args.OauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge", codeChallenge)) + + http.SetCookie(resp, &http.Cookie{ + Name: "verifier", + Value: verifier, + Path: "/", + Secure: true, + SameSite: http.SameSiteLaxMode, + MaxAge: 60, + }) + http.SetCookie(resp, &http.Cookie{ + Name: "state", + Value: state, + Path: "/", + Secure: true, + SameSite: http.SameSiteLaxMode, + MaxAge: 60, + }) + + http.Redirect(resp, req, url, http.StatusFound) + return success(context, req, resp) + } +} + +func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + state := req.URL.Query().Get("state") + code := req.URL.Query().Get("code") + + if code == "" || state == "" { + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + + if !verifyState(req, "state", state) { + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + verifierCookie, err := req.Cookie("verifier") + if err != nil { + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + + reqContext := req.Context() + token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value)) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + client := context.Args.OauthConfig.Client(reqContext, token) + user, err := getOauthUser(context.DBConn, client, context.Args.OauthUserInfoURI) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + + return failure(context, req, resp) + } + + session, err := database.MakeUserSessionFor(context.DBConn, user) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + http.SetCookie(resp, &http.Cookie{ + Name: "session", + Value: session.ID, + Path: "/", + SameSite: http.SameSiteLaxMode, + Secure: true, + }) + + redirect := "/" + redirectCookie, err := req.Cookie("redirect") + if err == nil && redirectCookie.Value != "" { + redirect = redirectCookie.Value + http.SetCookie(resp, &http.Cookie{ + Name: "redirect", + MaxAge: 0, + }) + } + + http.Redirect(resp, req, redirect, http.StatusFound) + return success(context, req, resp) + } +} + +func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + sessionCookie, err := req.Cookie("session") + if err != nil { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + session, err := database.GetSession(context.DBConn, sessionCookie.Value) + if err == nil && session.ExpireAt.Before(time.Now()) { + session = nil + database.DeleteSession(context.DBConn, sessionCookie.Value) + } + if err != nil || session == nil { + http.SetCookie(resp, &http.Cookie{ + Name: "session", + MaxAge: 0, + }) + + return failure(context, req, resp) + } + + user, err := database.GetUser(context.DBConn, session.UserID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + context.User = user + return success(context, req, resp) + } +} + +func GoLoginContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + http.SetCookie(resp, &http.Cookie{ + Name: "redirect", + Value: req.URL.Path, + Path: "/", + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + + http.Redirect(resp, req, "/login", http.StatusFound) + return failure(context, req, resp) + } +} + +func RefreshSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + sessionCookie, err := req.Cookie("session") + if err != nil { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + _, err = database.RefreshSession(context.DBConn, sessionCookie.Value) + if err != nil { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + return success(context, req, resp) + } +} + +func LogoutContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + sessionCookie, err := req.Cookie("session") + if err == nil && sessionCookie.Value != "" { + _ = database.DeleteSession(context.DBConn, sessionCookie.Value) + } + + http.Redirect(resp, req, "/", http.StatusFound) + http.SetCookie(resp, &http.Cookie{ + Name: "session", + MaxAge: 0, + }) + return success(context, req, resp) + } +} + +func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) { + userResponse, err := client.Get(uri) + if err != nil { + return nil, err + } + + userStruct, err := createUserFromResponse(userResponse) + if err != nil { + return nil, err + } + + user, err := database.FindOrSaveUser(dbConn, userStruct) + if err != nil { + return nil, err + } + + return user, nil +} + +func createUserFromResponse(response *http.Response) (*database.User, error) { + defer response.Body.Close() + user := &database.User{ + CreatedAt: time.Now(), + } + err := json.NewDecoder(response.Body).Decode(user) + if err != nil { + log.Println(err) + return nil, err + } + + user.Username = strings.ToLower(user.Username) + user.Username = strings.Split(user.Username, "@")[0] + + return user, nil +} + +func verifyState(req *http.Request, stateCookieName string, expectedState string) bool { + cookie, err := req.Cookie(stateCookieName) + if err != nil || cookie.Value != expectedState { + return false + } + + return true +} diff --git a/api/serve.go b/api/serve.go index 2b95297..df30e76 100644 --- a/api/serve.go +++ b/api/serve.go @@ -1,7 +1,6 @@ package api import ( - "crypto/rand" "database/sql" "fmt" "log" @@ -9,6 +8,8 @@ import ( "time" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) type RequestContext struct { @@ -17,28 +18,17 @@ type RequestContext struct { Id string Start time.Time + + User *database.User } type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain type ContinuationChain func(Continuation, Continuation) ContinuationChain -func randomId() string { - uuid := make([]byte, 16) - _, err := rand.Read(uuid) - if err != nil { - panic(err) - } - - uuid[8] = uuid[8]&^0xc0 | 0x80 - uuid[6] = uuid[6]&^0xf0 | 0x40 - - return fmt.Sprintf("%x-%x-%x-%x-%x", uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:]) -} - func LogRequestContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { return func(success Continuation, _failure Continuation) ContinuationChain { context.Start = time.Now() - context.Id = randomId() + context.Id = utils.RandomId() log.Println(req.Method, req.URL.Path, req.RemoteAddr, context.Id) return success(context, req, resp) @@ -90,7 +80,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(TemplateContinuation("home.html", nil, true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", nil, true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /api/health", func(w http.ResponseWriter, r *http.Request) { @@ -98,6 +88,26 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { LogRequestContinuation(requestContext, r, w)(HealthCheckContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) + mux.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) { + requestContext := makeRequestContext() + LogRequestContinuation(requestContext, r, w)(StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + }) + + mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) { + requestContext := makeRequestContext() + LogRequestContinuation(requestContext, r, w)(InterceptCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + }) + + mux.HandleFunc("GET /me", func(w http.ResponseWriter, r *http.Request) { + requestContext := makeRequestContext() + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(RefreshSessionContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + }) + + mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) { + requestContext := makeRequestContext() + LogRequestContinuation(requestContext, r, w)(LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + }) + mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() name := r.PathValue("name") diff --git a/api/template.go b/api/template.go index c666029..a4ccfa8 100644 --- a/api/template.go +++ b/api/template.go @@ -22,6 +22,13 @@ func renderTemplate(context *RequestContext, templateName string, showBaseHtml b return bytes.Buffer{}, err } + if data == nil { + data = map[string]interface{}{} + } + if context.User != nil { + data.(map[string]interface{})["User"] = context.User + } + var buffer bytes.Buffer err = tmpl.ExecuteTemplate(&buffer, "base", data) |
