diff options
Diffstat (limited to 'api/auth.go')
| -rw-r--r-- | api/auth.go | 74 |
1 files changed, 56 insertions, 18 deletions
diff --git a/api/auth.go b/api/auth.go index 4733971..dcddf5a 100644 --- a/api/auth.go +++ b/api/auth.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/base64" "encoding/json" + "fmt" "io" "log" "net/http" @@ -116,32 +117,69 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp } } +func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { + if bearerToken == "" { + return nil, nil + } + + parts := strings.Split(bearerToken, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + return nil, nil + } + + apiKey, err := database.GetAPIKey(dbConn, parts[1]) + if err != nil { + return nil, err + } + if apiKey == nil { + return nil, nil + } + + user, err := database.GetUser(dbConn, apiKey.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + +func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { + session, err := database.GetSession(dbConn, sessionId) + if err != nil { + return nil, err + } + + if session.ExpireAt.Before(time.Now()) { + session = nil + database.DeleteSession(dbConn, sessionId) + return nil, fmt.Errorf("session expired") + } + + user, err := database.GetUser(dbConn, session.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain { + authHeader := req.Header.Get("Authorization") + user, userErr := getUserFromAuthHeader(context.DBConn, authHeader) + sessionCookie, err := req.Cookie("session") - if err != nil { - resp.WriteHeader(http.StatusUnauthorized) - return failure(context, req, resp) + if err == nil { + user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value) } - session, err := database.GetSession(context.DBConn, sessionCookie.Value) - if err == nil && session.ExpireAt.Before(time.Now()) { - session = nil - database.DeleteSession(context.DBConn, sessionCookie.Value) - } - if err != nil || session == nil { + if userErr != nil || user == nil { + log.Println(userErr, user) + http.SetCookie(resp, &http.Cookie{ Name: "session", - MaxAge: 0, + MaxAge: 0, // reset session cookie in case }) - - return failure(context, req, resp) - } - - user, err := database.GetUser(context.DBConn, session.UserID) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } |
