diff options
| author | simponic <simponic@hatecomputers.club> | 2024-03-28 12:57:35 -0400 |
|---|---|---|
| committer | simponic <simponic@hatecomputers.club> | 2024-03-28 12:57:35 -0400 |
| commit | b2fc689bdcff28bf75c0128db19ba4730d726b4f (patch) | |
| tree | 37c16d95183242516ba667aa5f441539d152c279 /api/auth.go | |
| parent | 75ba836d6072235fc7a71659f8630ab3c1b210ad (diff) | |
| download | hatecomputers.club-b2fc689bdcff28bf75c0128db19ba4730d726b4f.tar.gz hatecomputers.club-b2fc689bdcff28bf75c0128db19ba4730d726b4f.zip | |
dns api (#1)
Co-authored-by: Elizabeth Hunt <elizabeth.hunt@simponic.xyz>
Reviewed-on: https://git.hatecomputers.club/hatecomputers/hatecomputers.club/pulls/1
Diffstat (limited to 'api/auth.go')
| -rw-r--r-- | api/auth.go | 74 |
1 files changed, 56 insertions, 18 deletions
diff --git a/api/auth.go b/api/auth.go index 4733971..dcddf5a 100644 --- a/api/auth.go +++ b/api/auth.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/base64" "encoding/json" + "fmt" "io" "log" "net/http" @@ -116,32 +117,69 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp } } +func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { + if bearerToken == "" { + return nil, nil + } + + parts := strings.Split(bearerToken, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + return nil, nil + } + + apiKey, err := database.GetAPIKey(dbConn, parts[1]) + if err != nil { + return nil, err + } + if apiKey == nil { + return nil, nil + } + + user, err := database.GetUser(dbConn, apiKey.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + +func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { + session, err := database.GetSession(dbConn, sessionId) + if err != nil { + return nil, err + } + + if session.ExpireAt.Before(time.Now()) { + session = nil + database.DeleteSession(dbConn, sessionId) + return nil, fmt.Errorf("session expired") + } + + user, err := database.GetUser(dbConn, session.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain { + authHeader := req.Header.Get("Authorization") + user, userErr := getUserFromAuthHeader(context.DBConn, authHeader) + sessionCookie, err := req.Cookie("session") - if err != nil { - resp.WriteHeader(http.StatusUnauthorized) - return failure(context, req, resp) + if err == nil { + user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value) } - session, err := database.GetSession(context.DBConn, sessionCookie.Value) - if err == nil && session.ExpireAt.Before(time.Now()) { - session = nil - database.DeleteSession(context.DBConn, sessionCookie.Value) - } - if err != nil || session == nil { + if userErr != nil || user == nil { + log.Println(userErr, user) + http.SetCookie(resp, &http.Cookie{ Name: "session", - MaxAge: 0, + MaxAge: 0, // reset session cookie in case }) - - return failure(context, req, resp) - } - - user, err := database.GetUser(context.DBConn, session.UserID) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } |
