summaryrefslogtreecommitdiff
path: root/api
diff options
context:
space:
mode:
Diffstat (limited to 'api')
-rw-r--r--api/auth.go245
-rw-r--r--api/serve.go42
-rw-r--r--api/template.go7
3 files changed, 278 insertions, 16 deletions
diff --git a/api/auth.go b/api/auth.go
new file mode 100644
index 0000000..4733971
--- /dev/null
+++ b/api/auth.go
@@ -0,0 +1,245 @@
+package api
+
+import (
+ "crypto/sha256"
+ "database/sql"
+ "encoding/base64"
+ "encoding/json"
+ "io"
+ "log"
+ "net/http"
+ "strings"
+ "time"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
+ "golang.org/x/oauth2"
+)
+
+func StartSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
+ return func(success Continuation, failure Continuation) ContinuationChain {
+ verifier := utils.RandomId() + utils.RandomId()
+
+ sha2 := sha256.New()
+ io.WriteString(sha2, verifier)
+ codeChallenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil))
+
+ state := utils.RandomId()
+ url := context.Args.OauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge", codeChallenge))
+
+ http.SetCookie(resp, &http.Cookie{
+ Name: "verifier",
+ Value: verifier,
+ Path: "/",
+ Secure: true,
+ SameSite: http.SameSiteLaxMode,
+ MaxAge: 60,
+ })
+ http.SetCookie(resp, &http.Cookie{
+ Name: "state",
+ Value: state,
+ Path: "/",
+ Secure: true,
+ SameSite: http.SameSiteLaxMode,
+ MaxAge: 60,
+ })
+
+ http.Redirect(resp, req, url, http.StatusFound)
+ return success(context, req, resp)
+ }
+}
+
+func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
+ return func(success Continuation, failure Continuation) ContinuationChain {
+ state := req.URL.Query().Get("state")
+ code := req.URL.Query().Get("code")
+
+ if code == "" || state == "" {
+ resp.WriteHeader(http.StatusBadRequest)
+ return failure(context, req, resp)
+ }
+
+ if !verifyState(req, "state", state) {
+ resp.WriteHeader(http.StatusBadRequest)
+ return failure(context, req, resp)
+ }
+ verifierCookie, err := req.Cookie("verifier")
+ if err != nil {
+ resp.WriteHeader(http.StatusBadRequest)
+ return failure(context, req, resp)
+ }
+
+ reqContext := req.Context()
+ token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value))
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+
+ client := context.Args.OauthConfig.Client(reqContext, token)
+ user, err := getOauthUser(context.DBConn, client, context.Args.OauthUserInfoURI)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+
+ return failure(context, req, resp)
+ }
+
+ session, err := database.MakeUserSessionFor(context.DBConn, user)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+
+ http.SetCookie(resp, &http.Cookie{
+ Name: "session",
+ Value: session.ID,
+ Path: "/",
+ SameSite: http.SameSiteLaxMode,
+ Secure: true,
+ })
+
+ redirect := "/"
+ redirectCookie, err := req.Cookie("redirect")
+ if err == nil && redirectCookie.Value != "" {
+ redirect = redirectCookie.Value
+ http.SetCookie(resp, &http.Cookie{
+ Name: "redirect",
+ MaxAge: 0,
+ })
+ }
+
+ http.Redirect(resp, req, redirect, http.StatusFound)
+ return success(context, req, resp)
+ }
+}
+
+func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
+ return func(success Continuation, failure Continuation) ContinuationChain {
+ sessionCookie, err := req.Cookie("session")
+ if err != nil {
+ resp.WriteHeader(http.StatusUnauthorized)
+ return failure(context, req, resp)
+ }
+
+ session, err := database.GetSession(context.DBConn, sessionCookie.Value)
+ if err == nil && session.ExpireAt.Before(time.Now()) {
+ session = nil
+ database.DeleteSession(context.DBConn, sessionCookie.Value)
+ }
+ if err != nil || session == nil {
+ http.SetCookie(resp, &http.Cookie{
+ Name: "session",
+ MaxAge: 0,
+ })
+
+ return failure(context, req, resp)
+ }
+
+ user, err := database.GetUser(context.DBConn, session.UserID)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusUnauthorized)
+ return failure(context, req, resp)
+ }
+
+ context.User = user
+ return success(context, req, resp)
+ }
+}
+
+func GoLoginContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
+ return func(success Continuation, failure Continuation) ContinuationChain {
+ http.SetCookie(resp, &http.Cookie{
+ Name: "redirect",
+ Value: req.URL.Path,
+ Path: "/",
+ Secure: true,
+ SameSite: http.SameSiteLaxMode,
+ })
+
+ http.Redirect(resp, req, "/login", http.StatusFound)
+ return failure(context, req, resp)
+ }
+}
+
+func RefreshSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
+ return func(success Continuation, failure Continuation) ContinuationChain {
+ sessionCookie, err := req.Cookie("session")
+ if err != nil {
+ resp.WriteHeader(http.StatusUnauthorized)
+ return failure(context, req, resp)
+ }
+
+ _, err = database.RefreshSession(context.DBConn, sessionCookie.Value)
+ if err != nil {
+ resp.WriteHeader(http.StatusUnauthorized)
+ return failure(context, req, resp)
+ }
+
+ return success(context, req, resp)
+ }
+}
+
+func LogoutContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
+ return func(success Continuation, failure Continuation) ContinuationChain {
+ sessionCookie, err := req.Cookie("session")
+ if err == nil && sessionCookie.Value != "" {
+ _ = database.DeleteSession(context.DBConn, sessionCookie.Value)
+ }
+
+ http.Redirect(resp, req, "/", http.StatusFound)
+ http.SetCookie(resp, &http.Cookie{
+ Name: "session",
+ MaxAge: 0,
+ })
+ return success(context, req, resp)
+ }
+}
+
+func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) {
+ userResponse, err := client.Get(uri)
+ if err != nil {
+ return nil, err
+ }
+
+ userStruct, err := createUserFromResponse(userResponse)
+ if err != nil {
+ return nil, err
+ }
+
+ user, err := database.FindOrSaveUser(dbConn, userStruct)
+ if err != nil {
+ return nil, err
+ }
+
+ return user, nil
+}
+
+func createUserFromResponse(response *http.Response) (*database.User, error) {
+ defer response.Body.Close()
+ user := &database.User{
+ CreatedAt: time.Now(),
+ }
+ err := json.NewDecoder(response.Body).Decode(user)
+ if err != nil {
+ log.Println(err)
+ return nil, err
+ }
+
+ user.Username = strings.ToLower(user.Username)
+ user.Username = strings.Split(user.Username, "@")[0]
+
+ return user, nil
+}
+
+func verifyState(req *http.Request, stateCookieName string, expectedState string) bool {
+ cookie, err := req.Cookie(stateCookieName)
+ if err != nil || cookie.Value != expectedState {
+ return false
+ }
+
+ return true
+}
diff --git a/api/serve.go b/api/serve.go
index 2b95297..df30e76 100644
--- a/api/serve.go
+++ b/api/serve.go
@@ -1,7 +1,6 @@
package api
import (
- "crypto/rand"
"database/sql"
"fmt"
"log"
@@ -9,6 +8,8 @@ import (
"time"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
)
type RequestContext struct {
@@ -17,28 +18,17 @@ type RequestContext struct {
Id string
Start time.Time
+
+ User *database.User
}
type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain
type ContinuationChain func(Continuation, Continuation) ContinuationChain
-func randomId() string {
- uuid := make([]byte, 16)
- _, err := rand.Read(uuid)
- if err != nil {
- panic(err)
- }
-
- uuid[8] = uuid[8]&^0xc0 | 0x80
- uuid[6] = uuid[6]&^0xf0 | 0x40
-
- return fmt.Sprintf("%x-%x-%x-%x-%x", uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:])
-}
-
func LogRequestContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, _failure Continuation) ContinuationChain {
context.Start = time.Now()
- context.Id = randomId()
+ context.Id = utils.RandomId()
log.Println(req.Method, req.URL.Path, req.RemoteAddr, context.Id)
return success(context, req, resp)
@@ -90,7 +80,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
- LogRequestContinuation(requestContext, r, w)(TemplateContinuation("home.html", nil, true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", nil, true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
mux.HandleFunc("GET /api/health", func(w http.ResponseWriter, r *http.Request) {
@@ -98,6 +88,26 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
LogRequestContinuation(requestContext, r, w)(HealthCheckContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
+ mux.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
+ requestContext := makeRequestContext()
+ LogRequestContinuation(requestContext, r, w)(StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ })
+
+ mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) {
+ requestContext := makeRequestContext()
+ LogRequestContinuation(requestContext, r, w)(InterceptCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ })
+
+ mux.HandleFunc("GET /me", func(w http.ResponseWriter, r *http.Request) {
+ requestContext := makeRequestContext()
+ LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(RefreshSessionContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ })
+
+ mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) {
+ requestContext := makeRequestContext()
+ LogRequestContinuation(requestContext, r, w)(LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ })
+
mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
name := r.PathValue("name")
diff --git a/api/template.go b/api/template.go
index c666029..a4ccfa8 100644
--- a/api/template.go
+++ b/api/template.go
@@ -22,6 +22,13 @@ func renderTemplate(context *RequestContext, templateName string, showBaseHtml b
return bytes.Buffer{}, err
}
+ if data == nil {
+ data = map[string]interface{}{}
+ }
+ if context.User != nil {
+ data.(map[string]interface{})["User"] = context.User
+ }
+
var buffer bytes.Buffer
err = tmpl.ExecuteTemplate(&buffer, "base", data)