summaryrefslogtreecommitdiff
path: root/api/auth.go
diff options
context:
space:
mode:
Diffstat (limited to 'api/auth.go')
-rw-r--r--api/auth.go74
1 files changed, 56 insertions, 18 deletions
diff --git a/api/auth.go b/api/auth.go
index 4733971..dcddf5a 100644
--- a/api/auth.go
+++ b/api/auth.go
@@ -5,6 +5,7 @@ import (
"database/sql"
"encoding/base64"
"encoding/json"
+ "fmt"
"io"
"log"
"net/http"
@@ -116,32 +117,69 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
}
}
+func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) {
+ if bearerToken == "" {
+ return nil, nil
+ }
+
+ parts := strings.Split(bearerToken, " ")
+ if len(parts) != 2 || parts[0] != "Bearer" {
+ return nil, nil
+ }
+
+ apiKey, err := database.GetAPIKey(dbConn, parts[1])
+ if err != nil {
+ return nil, err
+ }
+ if apiKey == nil {
+ return nil, nil
+ }
+
+ user, err := database.GetUser(dbConn, apiKey.UserID)
+ if err != nil {
+ return nil, err
+ }
+
+ return user, nil
+}
+
+func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) {
+ session, err := database.GetSession(dbConn, sessionId)
+ if err != nil {
+ return nil, err
+ }
+
+ if session.ExpireAt.Before(time.Now()) {
+ session = nil
+ database.DeleteSession(dbConn, sessionId)
+ return nil, fmt.Errorf("session expired")
+ }
+
+ user, err := database.GetUser(dbConn, session.UserID)
+ if err != nil {
+ return nil, err
+ }
+
+ return user, nil
+}
+
func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
return func(success Continuation, failure Continuation) ContinuationChain {
+ authHeader := req.Header.Get("Authorization")
+ user, userErr := getUserFromAuthHeader(context.DBConn, authHeader)
+
sessionCookie, err := req.Cookie("session")
- if err != nil {
- resp.WriteHeader(http.StatusUnauthorized)
- return failure(context, req, resp)
+ if err == nil {
+ user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value)
}
- session, err := database.GetSession(context.DBConn, sessionCookie.Value)
- if err == nil && session.ExpireAt.Before(time.Now()) {
- session = nil
- database.DeleteSession(context.DBConn, sessionCookie.Value)
- }
- if err != nil || session == nil {
+ if userErr != nil || user == nil {
+ log.Println(userErr, user)
+
http.SetCookie(resp, &http.Cookie{
Name: "session",
- MaxAge: 0,
+ MaxAge: 0, // reset session cookie in case
})
-
- return failure(context, req, resp)
- }
-
- user, err := database.GetUser(context.DBConn, session.UserID)
- if err != nil {
- log.Println(err)
- resp.WriteHeader(http.StatusUnauthorized)
return failure(context, req, resp)
}