summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsimponic <simponic@hatecomputers.club>2024-04-06 15:43:18 -0400
committersimponic <simponic@hatecomputers.club>2024-04-06 15:43:18 -0400
commit83cc6267fd5ce2f61200314424c5f400f65ff2ba (patch)
treeeafb35310236a15572cbb6e16ff8d6f181bfe240
parent569d2788ebfb90774faf361f62bfe7968e091465 (diff)
parentcad8e2c4ed5e3bab61ff243f8677f8a46eaeafb0 (diff)
downloadhatecomputers.club-83cc6267fd5ce2f61200314424c5f400f65ff2ba.tar.gz
hatecomputers.club-83cc6267fd5ce2f61200314424c5f400f65ff2ba.zip
Merge pull request 'testing | dont be recursive for external domains | finalize oauth' (#5) from dont-be-authoritative into main
Reviewed-on: https://git.hatecomputers.club/hatecomputers/hatecomputers.club/pulls/5
-rw-r--r--.dockerignore1
-rw-r--r--.drone.yml31
-rw-r--r--Dockerfile2
-rw-r--r--adapters/cloudflare/cloudflare.go17
-rw-r--r--adapters/external_dns.go8
-rw-r--r--api/auth/auth.go (renamed from api/auth.go)155
-rw-r--r--api/auth/auth_test.go307
-rw-r--r--api/dns.go179
-rw-r--r--api/dns/dns.go174
-rw-r--r--api/dns/dns_test.go442
-rw-r--r--api/guestbook.go141
-rw-r--r--api/guestbook/guestbook.go85
-rw-r--r--api/guestbook/guestbook_test.go136
-rw-r--r--api/hcaptcha/hcaptcha.go75
-rw-r--r--api/keys/keys.go (renamed from api/api_keys.go)32
-rw-r--r--api/serve.go90
-rw-r--r--api/template/template.go (renamed from api/template.go)13
-rw-r--r--api/types/types.go28
-rw-r--r--args/args.go7
-rw-r--r--database/dns.go23
-rw-r--r--database/users.go12
-rw-r--r--hcdns/server.go (renamed from dns/server.go)62
-rw-r--r--hcdns/server_test.go254
-rw-r--r--main.go4
-rw-r--r--static/css/styles.css16
-rw-r--r--static/img/cursor-1.pngbin0 -> 570 bytes
-rw-r--r--static/img/cursor-2.pngbin0 -> 563 bytes
27 files changed, 1779 insertions, 515 deletions
diff --git a/.dockerignore b/.dockerignore
index 52be0d9..6045466 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -2,3 +2,4 @@
hatecomputers.club
Dockerfile
*.db
+.drone.yml
diff --git a/.drone.yml b/.drone.yml
index 8f459a1..d056e69 100644
--- a/.drone.yml
+++ b/.drone.yml
@@ -1,9 +1,30 @@
---
kind: pipeline
type: docker
-name: build, publish docker image, deploy
+name: build
steps:
+ - name: run tests
+ image: golang
+ commands:
+ - go build
+ - go test -p 1 -v ./...
+
+trigger:
+ event:
+ - pull_request
+
+---
+kind: pipeline
+type: docker
+name: deploy
+
+steps:
+ - name: run tests
+ image: golang
+ commands:
+ - go build
+ - go test -p 1 -v ./...
- name: docker
image: plugins/docker
settings:
@@ -13,9 +34,6 @@ steps:
from_secret: gitea_packpub_password
registry: git.hatecomputers.club
repo: git.hatecomputers.club/hatecomputers/hatecomputers.club
- tags:
- - latest
- - main
- name: ssh
image: appleboy/drone-ssh
settings:
@@ -27,6 +45,9 @@ steps:
command_timeout: 2m
script:
- systemctl restart docker-compose@hatecomputers-club
+
trigger:
branch:
- - main
+ - main
+ event:
+ - push
diff --git a/Dockerfile b/Dockerfile
index a46f6c4..591423f 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers
EXPOSE 8080
-CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-recursion", "1.1.1.1:53,1.0.0.1:53"]
+CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"]
diff --git a/adapters/cloudflare/cloudflare.go b/adapters/cloudflare/cloudflare.go
index 40b04a5..c302037 100644
--- a/adapters/cloudflare/cloudflare.go
+++ b/adapters/cloudflare/cloudflare.go
@@ -14,15 +14,20 @@ type CloudflareDNSResponse struct {
Result database.DNSRecord `json:"result"`
}
-func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord) (string, error) {
- url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", zoneId)
+type CloudflareExternalDNSAdapter struct {
+ ZoneId string
+ APIToken string
+}
+
+func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
+ url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", adapter.ZoneId)
reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL)
payload := strings.NewReader(reqBody)
req, _ := http.NewRequest("POST", url, payload)
- req.Header.Add("Authorization", "Bearer "+apiToken)
+ req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
req.Header.Add("Content-Type", "application/json")
res, err := http.DefaultClient.Do(req)
@@ -48,12 +53,12 @@ func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord)
return result.ID, nil
}
-func DeleteDNSRecord(zoneId string, apiToken string, id string) error {
- url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", zoneId, id)
+func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error {
+ url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, id)
req, _ := http.NewRequest("DELETE", url, nil)
- req.Header.Add("Authorization", "Bearer "+apiToken)
+ req.Header.Add("Authorization", "Bearer "+adapter.APIToken)
res, err := http.DefaultClient.Do(req)
if err != nil {
diff --git a/adapters/external_dns.go b/adapters/external_dns.go
new file mode 100644
index 0000000..c861283
--- /dev/null
+++ b/adapters/external_dns.go
@@ -0,0 +1,8 @@
+package external_dns
+
+import "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+
+type ExternalDNSAdapter interface {
+ CreateDNSRecord(record *database.DNSRecord) (string, error)
+ DeleteDNSRecord(id string) error
+}
diff --git a/api/auth.go b/api/auth/auth.go
index 0294edd..0ffbf9c 100644
--- a/api/auth.go
+++ b/api/auth/auth.go
@@ -1,4 +1,4 @@
-package api
+package auth
import (
"crypto/sha256"
@@ -12,13 +12,14 @@ import (
"strings"
"time"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
"golang.org/x/oauth2"
)
-func StartSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, failure Continuation) ContinuationChain {
+func StartSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
verifier := utils.RandomId() + utils.RandomId()
sha2 := sha256.New()
@@ -34,7 +35,7 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h
Path: "/",
Secure: true,
SameSite: http.SameSiteLaxMode,
- MaxAge: 60,
+ MaxAge: 200,
})
http.SetCookie(resp, &http.Cookie{
Name: "state",
@@ -42,7 +43,7 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h
Path: "/",
Secure: true,
SameSite: http.SameSiteLaxMode,
- MaxAge: 60,
+ MaxAge: 200,
})
http.Redirect(resp, req, url, http.StatusFound)
@@ -50,8 +51,8 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h
}
}
-func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, failure Continuation) ContinuationChain {
+func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
state := req.URL.Query().Get("state")
code := req.URL.Query().Get("code")
@@ -73,7 +74,6 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
reqContext := req.Context()
token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value))
if err != nil {
- log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
@@ -101,6 +101,16 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
SameSite: http.SameSiteLaxMode,
Secure: true,
})
+ http.SetCookie(resp, &http.Cookie{
+ Name: "verifier",
+ Value: "",
+ MaxAge: 0,
+ })
+ http.SetCookie(resp, &http.Cookie{
+ Name: "state",
+ Value: "",
+ MaxAge: 0,
+ })
redirect := "/"
redirectCookie, err := req.Cookie("redirect")
@@ -109,6 +119,7 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
http.SetCookie(resp, &http.Cookie{
Name: "redirect",
MaxAge: 0,
+ Value: "",
})
}
@@ -117,54 +128,8 @@ func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp
}
}
-func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) {
- if bearerToken == "" {
- return nil, nil
- }
-
- parts := strings.Split(bearerToken, " ")
- if len(parts) != 2 || parts[0] != "Bearer" {
- return nil, nil
- }
-
- apiKey, err := database.GetAPIKey(dbConn, parts[1])
- if err != nil {
- return nil, err
- }
- if apiKey == nil {
- return nil, nil
- }
-
- user, err := database.GetUser(dbConn, apiKey.UserID)
- if err != nil {
- return nil, err
- }
-
- return user, nil
-}
-
-func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) {
- session, err := database.GetSession(dbConn, sessionId)
- if err != nil {
- return nil, err
- }
-
- if session.ExpireAt.Before(time.Now()) {
- session = nil
- database.DeleteSession(dbConn, sessionId)
- return nil, fmt.Errorf("session expired")
- }
-
- user, err := database.GetUser(dbConn, session.UserID)
- if err != nil {
- return nil, err
- }
-
- return user, nil
-}
-
-func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, failure Continuation) ContinuationChain {
+func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
authHeader := req.Header.Get("Authorization")
user, userErr := getUserFromAuthHeader(context.DBConn, authHeader)
@@ -178,7 +143,8 @@ func VerifySessionContinuation(context *RequestContext, req *http.Request, resp
http.SetCookie(resp, &http.Cookie{
Name: "session",
- MaxAge: 0, // reset session cookie in case
+ Value: "",
+ MaxAge: 0,
})
context.User = nil
@@ -190,8 +156,8 @@ func VerifySessionContinuation(context *RequestContext, req *http.Request, resp
}
}
-func GoLoginContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, failure Continuation) ContinuationChain {
+func GoLoginContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
http.SetCookie(resp, &http.Cookie{
Name: "redirect",
Value: req.URL.Path,
@@ -205,17 +171,15 @@ func GoLoginContinuation(context *RequestContext, req *http.Request, resp http.R
}
}
-func RefreshSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, failure Continuation) ContinuationChain {
+func RefreshSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
sessionCookie, err := req.Cookie("session")
if err != nil {
- resp.WriteHeader(http.StatusUnauthorized)
return failure(context, req, resp)
}
_, err = database.RefreshSession(context.DBConn, sessionCookie.Value)
if err != nil {
- resp.WriteHeader(http.StatusUnauthorized)
return failure(context, req, resp)
}
@@ -223,18 +187,20 @@ func RefreshSessionContinuation(context *RequestContext, req *http.Request, resp
}
}
-func LogoutContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, failure Continuation) ContinuationChain {
+func LogoutContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
sessionCookie, err := req.Cookie("session")
if err == nil && sessionCookie.Value != "" {
_ = database.DeleteSession(context.DBConn, sessionCookie.Value)
}
- http.Redirect(resp, req, "/", http.StatusFound)
http.SetCookie(resp, &http.Cookie{
Name: "session",
MaxAge: 0,
+ Value: "",
})
+ http.Redirect(resp, req, "/", http.StatusFound)
+
return success(context, req, resp)
}
}
@@ -245,7 +211,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us
return nil, err
}
- userStruct, err := createUserFromResponse(userResponse)
+ userStruct, err := createUserFromOauthResponse(userResponse)
if err != nil {
return nil, err
}
@@ -258,12 +224,11 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us
return user, nil
}
-func createUserFromResponse(response *http.Response) (*database.User, error) {
- defer response.Body.Close()
- user := &database.User{
- CreatedAt: time.Now(),
- }
+func createUserFromOauthResponse(response *http.Response) (*database.User, error) {
+ user := &database.User{}
err := json.NewDecoder(response.Body).Decode(user)
+ defer response.Body.Close()
+
if err != nil {
log.Println(err)
return nil, err
@@ -283,3 +248,49 @@ func verifyState(req *http.Request, stateCookieName string, expectedState string
return true
}
+
+func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) {
+ if bearerToken == "" {
+ return nil, nil
+ }
+
+ parts := strings.Split(bearerToken, " ")
+ if len(parts) != 2 || parts[0] != "Bearer" {
+ return nil, nil
+ }
+
+ key, err := database.GetAPIKey(dbConn, parts[1])
+ if err != nil {
+ return nil, err
+ }
+ if key == nil {
+ return nil, nil
+ }
+
+ user, err := database.GetUser(dbConn, key.UserID)
+ if err != nil {
+ return nil, err
+ }
+
+ return user, nil
+}
+
+func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) {
+ session, err := database.GetSession(dbConn, sessionId)
+ if err != nil {
+ return nil, err
+ }
+
+ if session.ExpireAt.Before(time.Now()) {
+ session = nil
+ database.DeleteSession(dbConn, sessionId)
+ return nil, fmt.Errorf("session expired")
+ }
+
+ user, err := database.GetUser(dbConn, session.UserID)
+ if err != nil {
+ return nil, err
+ }
+
+ return user, nil
+}
diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go
new file mode 100644
index 0000000..5e67c6d
--- /dev/null
+++ b/api/auth/auth_test.go
@@ -0,0 +1,307 @@
+package auth_test
+
+import (
+ "database/sql"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
+ "golang.org/x/oauth2"
+)
+
+func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
+ return success(context, req, resp)
+ }
+}
+
+func FakedOauthServer() *httptest.Server {
+ oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/auth" {
+ code := utils.RandomId()
+
+ state := r.URL.Query().Get("state")
+ redirectPath := r.URL.Query().Get("redirect_uri")
+ redirectPath += "?code=" + code + "&state=" + state
+
+ http.Redirect(w, r, redirectPath, http.StatusFound)
+ }
+ if r.URL.Path == "/token" {
+ w.Header().Set("Content-Type", "application/json")
+ w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`))
+ }
+ if r.URL.Path == "/user" {
+ w.Header().Set("Content-Type", "application/json")
+ w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`))
+ }
+ }))
+
+ return oauthServer
+}
+
+func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ resp.Write([]byte(context.User.Username))
+ return success(context, req, resp)
+ }
+}
+
+func MockUserEndpointServer(context *types.RequestContext) *httptest.Server {
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/protected-path" {
+ auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation)
+ }
+
+ if r.URL.Path == "/login" {
+ log.Println("login")
+ auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation)
+ }
+
+ if r.URL.Path == "/callback" {
+ log.Println("callback")
+ auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation)
+ }
+
+ if r.URL.Path == "/me" {
+ auth.VerifySessionContinuation(context, r, w)(auth.RefreshSessionContinuation, auth.GoLoginContinuation)(EchoUsernameContinuation, IdContinuation)(IdContinuation, IdContinuation)
+ }
+
+ if r.URL.Path == "/logout" {
+ auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation)
+ }
+ }))
+ return testServer
+}
+
+func setup() (*sql.DB, *types.RequestContext, *httptest.Server, *httptest.Server, func()) {
+ randomDb := utils.RandomId()
+
+ testDb := database.MakeConn(&randomDb)
+ database.Migrate(testDb)
+
+ context := &types.RequestContext{
+ DBConn: testDb,
+ Args: &args.Arguments{},
+ TemplateData: &(map[string]interface{}{}),
+ }
+
+ oauthServer := FakedOauthServer()
+ testServer := MockUserEndpointServer(context)
+
+ return testDb, context, oauthServer, testServer, func() {
+ oauthServer.Close()
+ testServer.Close()
+
+ testDb.Close()
+ os.Remove(randomDb)
+ }
+}
+
+func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) {
+ return &oauth2.Config{
+ ClientID: "test",
+ ClientSecret: "test",
+ Scopes: []string{"test"},
+ Endpoint: oauth2.Endpoint{
+ AuthURL: oauthServerURL + "/auth",
+ TokenURL: oauthServerURL + "/token",
+ },
+ RedirectURL: testServerURL + "/callback",
+ }, oauthServerURL + "/user"
+}
+
+func FollowAuthentication(
+ oauthServer *httptest.Server,
+ testServer *httptest.Server,
+ cookies map[string]*http.Cookie,
+ location string,
+) (map[string]*http.Cookie, string) {
+ resp := httptest.NewRecorder()
+ resp.Code = 0
+
+ for resp.Code == 0 || resp.Code == http.StatusFound {
+ req := httptest.NewRequest("GET", location, nil)
+ resp = httptest.NewRecorder()
+
+ for _, cookie := range cookies {
+ req.AddCookie(cookie)
+ }
+ if strings.HasPrefix(location, oauthServer.URL) {
+ oauthServer.Config.Handler.ServeHTTP(resp, req)
+ } else {
+ testServer.Config.Handler.ServeHTTP(resp, req)
+ }
+ for _, cookie := range resp.Result().Cookies() {
+ cookies[cookie.Name] = cookie
+ }
+
+ if resp.Code == http.StatusFound {
+ location = resp.Header().Get("Location")
+ }
+ }
+
+ return cookies, location
+}
+
+func TestOauthCreatesUserWithCorrectUsername(t *testing.T) {
+ db, context, oauthServer, testServer, cleanup := setup()
+ defer cleanup()
+
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+
+ user, _ := database.GetUser(db, "test")
+ if user != nil {
+ t.Errorf("expected no user, got user")
+ }
+
+ cookies := make(map[string]*http.Cookie)
+ cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
+
+ user, _ = database.GetUser(db, "test")
+ if user == nil {
+ t.Errorf("expected a user to be created, could not find user")
+ }
+ if user.Username != "test" {
+ t.Errorf("expected username to be test, got %s", user.Username)
+ }
+}
+
+func TestOauthRedirectsToPreviousLockedPage(t *testing.T) {
+ _, context, oauthServer, testServer, cleanup := setup()
+ defer cleanup()
+
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+
+ req := httptest.NewRequest("GET", "/protected-path", nil)
+ resp := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(resp, req)
+ location := resp.Header().Get("Location")
+ if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") {
+ t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location"))
+ }
+
+ cookies := make(map[string]*http.Cookie)
+ cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page")
+
+ if !(strings.HasSuffix(location, "/protected-page")) {
+ t.Errorf("expected to redirect back to /protected-page after login, got %s", location)
+ }
+}
+
+func TestOauthSetsUniqueSession(t *testing.T) {
+ db, context, oauthServer, testServer, cleanup := setup()
+ defer cleanup()
+
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+
+ cookies := make(map[string]*http.Cookie)
+ cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
+
+ cookiesAgain := make(map[string]*http.Cookie)
+ cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me")
+
+ sessionOne := cookies["session"].Value
+ sessionTwo := cookiesAgain["session"].Value
+ if sessionOne == sessionTwo {
+ t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo)
+ }
+
+ session, _ := database.GetSession(db, sessionOne)
+ if session.UserID != "test" {
+ t.Errorf("expected session to be associated with user test, got %s", session.UserID)
+ }
+}
+
+func TestLogoutClearsSession(t *testing.T) {
+ db, context, oauthServer, testServer, cleanup := setup()
+ defer cleanup()
+
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+
+ cookies := make(map[string]*http.Cookie)
+ cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me")
+
+ req := httptest.NewRequest("GET", "/logout", nil)
+ for _, cookie := range cookies {
+ req.AddCookie(cookie)
+ }
+ resp := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(resp, req)
+ for _, cookie := range resp.Result().Cookies() {
+ cookies[cookie.Name] = cookie
+ }
+
+ req = httptest.NewRequest("GET", "/me", nil)
+ for _, cookie := range cookies {
+ req.AddCookie(cookie)
+ }
+ resp = httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(resp, req)
+ if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
+ t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location"))
+ }
+
+ session, _ := database.GetSession(db, cookies["session"].Value)
+ if session != nil {
+ t.Errorf("expected session to be deleted, got session")
+ }
+}
+
+func TestRefreshUpdatesExpiration(t *testing.T) {
+ db, context, oauthServer, testServer, cleanup := setup()
+ defer cleanup()
+
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+
+ cookies := make(map[string]*http.Cookie)
+ cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
+
+ session, _ := database.GetSession(db, cookies["session"].Value)
+
+ req := httptest.NewRequest("GET", "/me", nil)
+ for _, cookie := range cookies {
+ req.AddCookie(cookie)
+ }
+ resp := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(resp, req)
+
+ updatedSession, _ := database.GetSession(db, cookies["session"].Value)
+
+ if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) {
+ t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt)
+ }
+}
+
+func TestVerifySessionEnsuresNonExpired(t *testing.T) {
+ db, context, oauthServer, testServer, cleanup := setup()
+ defer cleanup()
+
+ context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL)
+
+ cookies := make(map[string]*http.Cookie)
+ cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path")
+
+ session, _ := database.GetSession(db, cookies["session"].Value)
+ session.ExpireAt = time.Now().Add(-time.Hour)
+ database.SaveSession(db, session)
+
+ req := httptest.NewRequest("GET", "/me", nil)
+ for _, cookie := range cookies {
+ req.AddCookie(cookie)
+ }
+ resp := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") {
+ t.Errorf("expected redirect to /login after session expiration, got %d and %s", resp.Code, resp.Header().Get("Location"))
+ }
+}
diff --git a/api/dns.go b/api/dns.go
deleted file mode 100644
index ad41103..0000000
--- a/api/dns.go
+++ /dev/null
@@ -1,179 +0,0 @@
-package api
-
-import (
- "database/sql"
- "fmt"
- "log"
- "net/http"
- "strconv"
- "strings"
-
- "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare"
- "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
- "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
-)
-
-const MAX_USER_RECORDS = 65
-
-type FormError struct {
- Errors []string
-}
-
-func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord) bool {
- ownedByUser := (user.ID == record.UserID)
- if !ownedByUser {
- return false
- }
-
- if !record.Internal {
- userOwnedDomains := []string{
- fmt.Sprintf("%s", user.Username),
- fmt.Sprintf("%s.endpoints", user.Username),
- }
-
- for _, domain := range userOwnedDomains {
- isInSubDomain := strings.HasSuffix(record.Name, "."+domain)
- if domain == record.Name || isInSubDomain {
- return true
- }
- }
- return false
- }
-
- owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name)
- if err != nil {
- log.Println(err)
- return false
- }
-
- userIsOwnerOfDomain := owner == user.ID
- return ownedByUser && userIsOwnerOfDomain
-}
-
-func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, failure Continuation) ContinuationChain {
- dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID)
- if err != nil {
- log.Println(err)
- resp.WriteHeader(http.StatusInternalServerError)
- return failure(context, req, resp)
- }
-
- (*context.TemplateData)["DNSRecords"] = dnsRecords
- return success(context, req, resp)
- }
-}
-
-func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, failure Continuation) ContinuationChain {
- formErrors := FormError{
- Errors: []string{},
- }
-
- internal := req.FormValue("internal") == "on"
- name := req.FormValue("name")
- if internal && !strings.HasSuffix(name, ".") {
- name += "."
- }
-
- recordType := req.FormValue("type")
- recordType = strings.ToUpper(recordType)
-
- recordContent := req.FormValue("content")
- ttl := req.FormValue("ttl")
- ttlNum, err := strconv.Atoi(ttl)
- if err != nil {
- formErrors.Errors = append(formErrors.Errors, "invalid ttl")
- }
-
- dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
- if err != nil {
- log.Println(err)
- resp.WriteHeader(http.StatusInternalServerError)
- return failure(context, req, resp)
- }
- if dnsRecordCount >= MAX_USER_RECORDS {
- formErrors.Errors = append(formErrors.Errors, "max records reached")
- }
-
- dnsRecord := &database.DNSRecord{
- UserID: context.User.ID,
- Name: name,
- Type: recordType,
- Content: recordContent,
- TTL: ttlNum,
- Internal: internal,
- }
- if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord) {
- formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
- }
-
- if len(formErrors.Errors) == 0 {
- if dnsRecord.Internal {
- dnsRecord.ID = utils.RandomId()
- } else {
- cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord)
- if err != nil {
- log.Println(err)
- formErrors.Errors = append(formErrors.Errors, err.Error())
- }
-
- dnsRecord.ID = cloudflareRecordId
- }
- }
-
- if len(formErrors.Errors) == 0 {
- _, err := database.SaveDNSRecord(context.DBConn, dnsRecord)
- if err != nil {
- log.Println(err)
- formErrors.Errors = append(formErrors.Errors, "error saving record")
- }
- }
-
- if len(formErrors.Errors) == 0 {
- http.Redirect(resp, req, "/dns", http.StatusFound)
- return success(context, req, resp)
- }
-
- (*context.TemplateData)["FormError"] = &formErrors
- (*context.TemplateData)["RecordForm"] = dnsRecord
-
- resp.WriteHeader(http.StatusBadRequest)
- return failure(context, req, resp)
- }
-}
-
-func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, failure Continuation) ContinuationChain {
- recordId := req.FormValue("id")
- record, err := database.GetDNSRecord(context.DBConn, recordId)
- if err != nil {
- log.Println(err)
- resp.WriteHeader(http.StatusInternalServerError)
- return failure(context, req, resp)
- }
-
- if !userCanFuckWithDNSRecord(context.DBConn, context.User, record) {
- resp.WriteHeader(http.StatusUnauthorized)
- return failure(context, req, resp)
- }
-
- if !record.Internal {
- err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId)
- if err != nil {
- log.Println(err)
- resp.WriteHeader(http.StatusInternalServerError)
- return failure(context, req, resp)
- }
- }
-
- err = database.DeleteDNSRecord(context.DBConn, recordId)
- if err != nil {
- resp.WriteHeader(http.StatusInternalServerError)
- return failure(context, req, resp)
- }
-
- http.Redirect(resp, req, "/dns", http.StatusFound)
- return success(context, req, resp)
- }
-}
diff --git a/api/dns/dns.go b/api/dns/dns.go
new file mode 100644
index 0000000..aa2f356
--- /dev/null
+++ b/api/dns/dns.go
@@ -0,0 +1,174 @@
+package dns
+
+import (
+ "database/sql"
+ "fmt"
+ "log"
+ "net/http"
+ "strconv"
+ "strings"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
+)
+
+func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool {
+ ownedByUser := (user.ID == record.UserID)
+ if !ownedByUser {
+ return false
+ }
+
+ if !record.Internal {
+ for _, format := range ownedInternalDomainFormats {
+ domain := fmt.Sprintf(format, user.Username)
+
+ isInSubDomain := strings.HasSuffix(record.Name, "."+domain)
+ if domain == record.Name || isInSubDomain {
+ return true
+ }
+ }
+ return false
+ }
+
+ owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name)
+ if err != nil {
+ log.Println(err)
+ return false
+ }
+
+ userIsOwnerOfDomain := owner == user.ID
+ return ownedByUser && userIsOwnerOfDomain
+}
+
+func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+
+ (*context.TemplateData)["DNSRecords"] = dnsRecords
+ return success(context, req, resp)
+ }
+}
+
+func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ formErrors := types.FormError{
+ Errors: []string{},
+ }
+
+ internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true"
+ name := req.FormValue("name")
+ if internal && !strings.HasSuffix(name, ".") {
+ name += "."
+ }
+
+ recordType := req.FormValue("type")
+ recordType = strings.ToUpper(recordType)
+
+ recordContent := req.FormValue("content")
+ ttl := req.FormValue("ttl")
+ ttlNum, err := strconv.Atoi(ttl)
+ if err != nil {
+ resp.WriteHeader(http.StatusBadRequest)
+ formErrors.Errors = append(formErrors.Errors, "invalid ttl")
+ }
+
+ dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+ if dnsRecordCount >= maxUserRecords {
+ resp.WriteHeader(http.StatusTooManyRequests)
+ formErrors.Errors = append(formErrors.Errors, "max records reached")
+ }
+
+ dnsRecord := &database.DNSRecord{
+ UserID: context.User.ID,
+ Name: name,
+ Type: recordType,
+ Content: recordContent,
+ TTL: ttlNum,
+ Internal: internal,
+ }
+
+ if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) {
+ resp.WriteHeader(http.StatusUnauthorized)
+ formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains")
+ }
+
+ if len(formErrors.Errors) == 0 {
+ if dnsRecord.Internal {
+ dnsRecord.ID = utils.RandomId()
+ } else {
+ dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ formErrors.Errors = append(formErrors.Errors, err.Error())
+ }
+ }
+ }
+
+ if len(formErrors.Errors) == 0 {
+ _, err := database.SaveDNSRecord(context.DBConn, dnsRecord)
+ if err != nil {
+ log.Println(err)
+ formErrors.Errors = append(formErrors.Errors, "error saving record")
+ }
+ }
+
+ if len(formErrors.Errors) == 0 {
+ return success(context, req, resp)
+ }
+
+ (*context.TemplateData)["FormError"] = &formErrors
+ (*context.TemplateData)["RecordForm"] = dnsRecord
+ return failure(context, req, resp)
+ }
+ }
+}
+
+func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ recordId := req.FormValue("id")
+ record, err := database.GetDNSRecord(context.DBConn, recordId)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+
+ if !(record.UserID == context.User.ID) {
+ resp.WriteHeader(http.StatusUnauthorized)
+ return failure(context, req, resp)
+ }
+
+ if !record.Internal {
+ err = dnsAdapter.DeleteDNSRecord(recordId)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+ }
+
+ err = database.DeleteDNSRecord(context.DBConn, recordId)
+ if err != nil {
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+
+ return success(context, req, resp)
+ }
+ }
+}
diff --git a/api/dns/dns_test.go b/api/dns/dns_test.go
new file mode 100644
index 0000000..43dc680
--- /dev/null
+++ b/api/dns/dns_test.go
@@ -0,0 +1,442 @@
+package dns_test
+
+import (
+ "database/sql"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strconv"
+ "testing"
+ "time"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
+)
+
+const MAX_USER_RECORDS = 10
+
+var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
+
+func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
+ return success(context, req, resp)
+ }
+}
+
+func setup() (*sql.DB, *types.RequestContext, func()) {
+ randomDb := utils.RandomId()
+
+ testDb := database.MakeConn(&randomDb)
+ database.Migrate(testDb)
+
+ user := &database.User{
+ ID: "test",
+ Username: "test",
+ Mail: "test@test.com",
+ DisplayName: "test",
+ }
+ database.FindOrSaveUser(testDb, user)
+
+ context := &types.RequestContext{
+ DBConn: testDb,
+ Args: &args.Arguments{},
+ TemplateData: &(map[string]interface{}{}),
+ User: user,
+ }
+
+ return testDb, context, func() {
+ testDb.Close()
+ os.Remove(randomDb)
+ }
+}
+
+type SignallingExternalDnsAdapter struct {
+ AddChannel chan *database.DNSRecord
+ RmChannel chan string
+}
+
+func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
+ id := utils.RandomId()
+ go func() { adapter.AddChannel <- record }()
+
+ return id, nil
+}
+
+func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error {
+ go func() { adapter.RmChannel <- id }()
+
+ return nil
+}
+
+func TestThatOwnerCanPutRecordInDomain(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ domainOwner := &database.DomainOwner{
+ UserID: context.User.ID,
+ Domain: "test.domain.",
+ }
+ domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
+
+ records, err := database.GetUserDNSRecords(db, context.User.ID)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(records) > 0 {
+ t.Errorf("expected no records, got records")
+ }
+
+ addChannel := make(chan *database.DNSRecord)
+ signallingDnsAdapter := &SignallingExternalDnsAdapter{
+ AddChannel: addChannel,
+ }
+
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer testServer.Close()
+
+ validOwner := httptest.NewRequest("POST", testServer.URL, nil)
+ validOwner.Form = map[string][]string{
+ "internal": {"on"},
+ "name": {"new.test.domain."},
+ "type": {"CNAME"},
+ "ttl": {"43000"},
+ "content": {"test.domain."},
+ }
+
+ validOwnerRecorder := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
+ if validOwnerRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
+ }
+
+ validOwnerNonInternalRecorder := httptest.NewRecorder()
+ validOwner.Form["internal"] = []string{"off"}
+ testServer.Config.Handler.ServeHTTP(validOwnerNonInternalRecorder, validOwner)
+ if validOwnerNonInternalRecorder.Code != http.StatusUnauthorized {
+ t.Errorf("expected invalid return, got %d", validOwnerNonInternalRecorder.Code)
+ }
+
+ invalidOwnerRecorder := httptest.NewRecorder()
+ invalidOwner := validOwner
+ invalidOwner.Form["internal"] = []string{"on"}
+ invalidOwner.Form["name"] = []string{"new.invalid.domain."}
+ testServer.Config.Handler.ServeHTTP(invalidOwnerRecorder, invalidOwner)
+ if invalidOwnerRecorder.Code != http.StatusUnauthorized {
+ t.Errorf("expected invalid return, got %d", invalidOwnerRecorder.Code)
+ }
+}
+
+func TestThatUserCanAddToPublicEndpoints(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ addChannel := make(chan *database.DNSRecord)
+ signallingDnsAdapter := &SignallingExternalDnsAdapter{
+ AddChannel: addChannel,
+ }
+
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer testServer.Close()
+
+ responseRecorder := httptest.NewRecorder()
+ req := httptest.NewRequest("POST", testServer.URL, nil)
+ fmts := USER_OWNED_INTERNAL_FMT_DOMAINS
+ for _, format := range fmts {
+ name := fmt.Sprintf(format, context.User.Username)
+
+ req.Form = map[string][]string{
+ "internal": {"off"},
+ "name": {name},
+ "type": {"CNAME"},
+ "ttl": {"43000"},
+ "content": {"test.domain."},
+ }
+
+ testServer.Config.Handler.ServeHTTP(responseRecorder, req)
+ if responseRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", responseRecorder.Code)
+ }
+
+ namedRecords, _ := database.FindDNSRecords(db, name, "CNAME")
+ if len(namedRecords) == 0 {
+ t.Errorf("saved record not found")
+ }
+ }
+}
+
+func TestThatExternalDnsSaves(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ addChannel := make(chan *database.DNSRecord)
+ signallingDnsAdapter := &SignallingExternalDnsAdapter{
+ AddChannel: addChannel,
+ }
+
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer testServer.Close()
+
+ responseRecorder := httptest.NewRecorder()
+ externalRequest := httptest.NewRequest("POST", testServer.URL, nil)
+
+ name := "test." + context.User.Username
+ externalRequest.Form = map[string][]string{
+ "internal": {"off"},
+ "name": {name},
+ "type": {"CNAME"},
+ "ttl": {"43000"},
+ "content": {"test.domain."},
+ }
+
+ testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
+ if responseRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", responseRecorder.Code)
+ }
+ select {
+ case res := <-addChannel:
+ if res.Name != name || res.Type != "CNAME" || res.Content != "test.domain." {
+ t.Errorf("received the wrong external record")
+ }
+ case <-time.After(100 * time.Millisecond):
+ t.Errorf("timed out in waiting for external addition")
+ }
+
+ domainOwner := &database.DomainOwner{
+ UserID: context.User.ID,
+ Domain: "test.domain.",
+ }
+ domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
+ internalRequest := externalRequest
+ internalRequest.Form["internal"] = []string{"on"}
+ internalRequest.Form["name"] = []string{"test.domain."}
+
+ testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
+ if responseRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", responseRecorder.Code)
+ }
+ select {
+ case _ = <-addChannel:
+ t.Errorf("expected nothing in the add channel")
+ case <-time.After(100 * time.Millisecond):
+ }
+}
+
+func TestThatUserMustOwnRecordToRemove(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ rmChannel := make(chan string)
+ signallingDnsAdapter := &SignallingExternalDnsAdapter{
+ RmChannel: rmChannel,
+ }
+
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer testServer.Close()
+
+ nonOwnerUser := &database.User{ID: "n/a", Username: "testuser"}
+ _, err := database.FindOrSaveUser(db, nonOwnerUser)
+ if err != nil {
+ t.Error(err)
+ }
+
+ record := &database.DNSRecord{
+ ID: "1",
+ Internal: false,
+ Name: "test",
+ Type: "CNAME",
+ Content: "asdf",
+ TTL: 1000,
+ UserID: nonOwnerUser.ID,
+ }
+ _, err = database.SaveDNSRecord(db, record)
+ if err != nil {
+ t.Error(err)
+ }
+
+ nonOwnerRecorder := httptest.NewRecorder()
+ nonOwner := httptest.NewRequest("POST", testServer.URL, nil)
+ nonOwner.Form = map[string][]string{
+ "id": {record.ID},
+ }
+
+ testServer.Config.Handler.ServeHTTP(nonOwnerRecorder, nonOwner)
+ if nonOwnerRecorder.Code != http.StatusUnauthorized {
+ t.Errorf("expected unauthorized return, got %d", nonOwnerRecorder.Code)
+ }
+
+ record.UserID = context.User.ID
+ record.ID = "2"
+ database.SaveDNSRecord(db, record)
+
+ owner := nonOwner
+ owner.Form["id"] = []string{"2"}
+ ownerRecorder := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(ownerRecorder, owner)
+ if ownerRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", ownerRecorder.Code)
+ }
+}
+
+func TestThatExternalDnsRemoves(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ record := &database.DNSRecord{
+ ID: "1",
+ Internal: false,
+ Name: "test",
+ Type: "CNAME",
+ Content: "asdf",
+ TTL: 1000,
+ UserID: context.User.ID,
+ }
+ database.SaveDNSRecord(db, record)
+
+ rmChannel := make(chan string)
+ signallingDnsAdapter := &SignallingExternalDnsAdapter{
+ RmChannel: rmChannel,
+ }
+
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer testServer.Close()
+
+ externalResponseRecorder := httptest.NewRecorder()
+ deleteRequest := httptest.NewRequest("POST", testServer.URL, nil)
+
+ deleteRequest.Form = map[string][]string{
+ "id": {record.ID},
+ }
+
+ testServer.Config.Handler.ServeHTTP(externalResponseRecorder, deleteRequest)
+ if externalResponseRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", externalResponseRecorder.Code)
+ }
+ select {
+ case res := <-rmChannel:
+ if res != record.ID {
+ t.Errorf("received the wrong external record")
+ }
+ case <-time.After(100 * time.Millisecond):
+ t.Errorf("timed out in waiting for external addition")
+ }
+
+ record.Internal = true
+ record.Name = "test.domain."
+ database.SaveDNSRecord(db, record)
+ domainOwner := &database.DomainOwner{
+ UserID: context.User.ID,
+ Domain: "test.domain.",
+ }
+ database.SaveDomainOwner(db, domainOwner)
+
+ internalResponseRecorder := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(internalResponseRecorder, deleteRequest)
+ if internalResponseRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", internalResponseRecorder.Code)
+ }
+ select {
+ case _ = <-rmChannel:
+ t.Errorf("expected nothing in the rmchannel")
+ case <-time.After(100 * time.Millisecond):
+ }
+}
+
+func TestRecordCountCannotExceed(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ record := &database.DNSRecord{
+ Internal: false,
+ Name: context.User.Username,
+ Type: "CNAME",
+ Content: "asdf",
+ TTL: 1000,
+ UserID: context.User.ID,
+ }
+
+ for i := 1; i <= MAX_USER_RECORDS; i++ {
+ record.ID = strconv.Itoa(i)
+ record.Name = record.ID + "." + record.Name
+ database.SaveDNSRecord(db, record)
+ }
+
+ addChannel := make(chan *database.DNSRecord)
+ signallingDnsAdapter := &SignallingExternalDnsAdapter{
+ AddChannel: addChannel,
+ }
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer testServer.Close()
+
+ req := httptest.NewRequest("POST", testServer.URL, nil)
+ req.Form = map[string][]string{
+ "internal": {"off"},
+ "name": {record.Name},
+ "type": {record.Type},
+ "ttl": {"43000"},
+ "content": {record.Content},
+ }
+
+ recorder := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(recorder, req)
+ if recorder.Code != http.StatusTooManyRequests {
+ t.Errorf("expected too many requests code return, got %d", recorder.Code)
+ }
+}
+
+func TestInternalRecordAppendsTopLevelDot(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ domainOwner := &database.DomainOwner{
+ UserID: context.User.ID,
+ Domain: "test.internal.",
+ }
+ database.SaveDomainOwner(db, domainOwner)
+
+ addChannel := make(chan *database.DNSRecord)
+ signallingDnsAdapter := &SignallingExternalDnsAdapter{
+ AddChannel: addChannel,
+ }
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer testServer.Close()
+
+ validOwner := httptest.NewRequest("POST", testServer.URL, nil)
+ validOwner.Form = map[string][]string{
+ "internal": {"on"},
+ "name": {"test.internal"},
+ "type": {"CNAME"},
+ "ttl": {"43000"},
+ "content": {"asdf.internal"},
+ }
+
+ validOwnerRecorder := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
+ if validOwnerRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
+ }
+
+ recordsAppendedDot, _ := database.FindDNSRecords(db, "test.internal.", "CNAME")
+ recordsWithoutDot, _ := database.FindDNSRecords(db, "test.internal", "CNAME")
+
+ if len(recordsAppendedDot) != 1 && len(recordsWithoutDot) != 0 {
+ t.Errorf("expected dot appended")
+ }
+}
diff --git a/api/guestbook.go b/api/guestbook.go
deleted file mode 100644
index 7b84f45..0000000
--- a/api/guestbook.go
+++ /dev/null
@@ -1,141 +0,0 @@
-package api
-
-import (
- "encoding/json"
- "fmt"
- "log"
- "net/http"
- "strings"
-
- "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
- "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
-)
-
-type HcaptchaArgs struct {
- SiteKey string
-}
-
-func validateGuestbookEntry(entry *database.GuestbookEntry) []string {
- errors := []string{}
-
- if entry.Name == "" {
- errors = append(errors, "name is required")
- }
-
- if entry.Message == "" {
- errors = append(errors, "message is required")
- }
-
- messageLength := len(entry.Message)
- if messageLength > 500 {
- errors = append(errors, "message cannot be longer than 500 characters")
- }
-
- newLines := strings.Count(entry.Message, "\n")
- if newLines > 10 {
- errors = append(errors, "message cannot contain more than 10 new lines")
- }
-
- return errors
-}
-
-func SignGuestbookContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, failure Continuation) ContinuationChain {
- name := req.FormValue("name")
- message := req.FormValue("message")
- hCaptchaResponse := req.FormValue("h-captcha-response")
-
- formErrors := FormError{
- Errors: []string{},
- }
-
- if hCaptchaResponse == "" {
- formErrors.Errors = append(formErrors.Errors, "hCaptcha is required")
- }
-
- entry := &database.GuestbookEntry{
- ID: utils.RandomId(),
- Name: name,
- Message: message,
- }
- formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...)
-
- err := verifyHCaptcha(context.Args.HcaptchaSecret, hCaptchaResponse)
- if err != nil {
- log.Println(err)
-
- formErrors.Errors = append(formErrors.Errors, "hCaptcha verification failed")
- }
- if len(formErrors.Errors) > 0 {
- (*context.TemplateData)["FormError"] = formErrors
- (*context.TemplateData)["EntryForm"] = entry
- return failure(context, req, resp)
- }
-
- _, err = database.SaveGuestbookEntry(context.DBConn, entry)
- if err != nil {
- log.Println(err)
- resp.WriteHeader(http.StatusInternalServerError)
- return failure(context, req, resp)
- }
-
- return success(context, req, resp)
- }
-}
-
-func ListGuestbookContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, failure Continuation) ContinuationChain {
- entries, err := database.GetGuestbookEntries(context.DBConn)
- if err != nil {
- log.Println(err)
- resp.WriteHeader(http.StatusInternalServerError)
- return failure(context, req, resp)
- }
-
- (*context.TemplateData)["GuestbookEntries"] = entries
- return success(context, req, resp)
- }
-}
-
-func HcaptchaArgsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, failure Continuation) ContinuationChain {
- (*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{
- SiteKey: context.Args.HcaptchaSiteKey,
- }
- log.Println(context.Args.HcaptchaSiteKey)
- return success(context, req, resp)
- }
-}
-
-func verifyHCaptcha(secret, response string) error {
- verifyURL := "https://hcaptcha.com/siteverify"
- body := strings.NewReader("secret=" + secret + "&response=" + response)
-
- req, err := http.NewRequest("POST", verifyURL, body)
- if err != nil {
- return err
- }
-
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
-
- client := &http.Client{}
- resp, err := client.Do(req)
- if err != nil {
- return err
- }
-
- jsonResponse := struct {
- Success bool `json:"success"`
- }{}
- err = json.NewDecoder(resp.Body).Decode(&jsonResponse)
- if err != nil {
- return err
- }
-
- if !jsonResponse.Success {
- return fmt.Errorf("hcaptcha verification failed")
- }
-
- defer resp.Body.Close()
- return nil
-}
diff --git a/api/guestbook/guestbook.go b/api/guestbook/guestbook.go
new file mode 100644
index 0000000..60a7b4b
--- /dev/null
+++ b/api/guestbook/guestbook.go
@@ -0,0 +1,85 @@
+package guestbook
+
+import (
+ "log"
+ "net/http"
+ "strings"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
+)
+
+func validateGuestbookEntry(entry *database.GuestbookEntry) []string {
+ errors := []string{}
+
+ if entry.Name == "" {
+ errors = append(errors, "name is required")
+ }
+
+ if entry.Message == "" {
+ errors = append(errors, "message is required")
+ }
+
+ messageLength := len(entry.Message)
+ if messageLength > 500 {
+ errors = append(errors, "message cannot be longer than 500 characters")
+ }
+
+ newLines := strings.Count(entry.Message, "\n")
+ if newLines > 10 {
+ errors = append(errors, "message cannot contain more than 10 new lines")
+ }
+
+ return errors
+}
+
+func SignGuestbookContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ name := req.FormValue("name")
+ message := req.FormValue("message")
+
+ formErrors := types.FormError{
+ Errors: []string{},
+ }
+
+ entry := &database.GuestbookEntry{
+ ID: utils.RandomId(),
+ Name: name,
+ Message: message,
+ }
+ formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...)
+
+ if len(formErrors.Errors) == 0 {
+ _, err := database.SaveGuestbookEntry(context.DBConn, entry)
+ if err != nil {
+ log.Println(err)
+ formErrors.Errors = append(formErrors.Errors, "failed to save entry")
+ }
+ }
+
+ if len(formErrors.Errors) > 0 {
+ (*context.TemplateData)["FormError"] = formErrors
+ (*context.TemplateData)["EntryForm"] = entry
+ resp.WriteHeader(http.StatusBadRequest)
+
+ return failure(context, req, resp)
+ }
+
+ return success(context, req, resp)
+ }
+}
+
+func ListGuestbookContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ entries, err := database.GetGuestbookEntries(context.DBConn)
+ if err != nil {
+ log.Println(err)
+ resp.WriteHeader(http.StatusInternalServerError)
+ return failure(context, req, resp)
+ }
+
+ (*context.TemplateData)["GuestbookEntries"] = entries
+ return success(context, req, resp)
+ }
+}
diff --git a/api/guestbook/guestbook_test.go b/api/guestbook/guestbook_test.go
new file mode 100644
index 0000000..9fd6c62
--- /dev/null
+++ b/api/guestbook/guestbook_test.go
@@ -0,0 +1,136 @@
+package guestbook_test
+
+import (
+ "database/sql"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "testing"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/guestbook"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
+)
+
+func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
+ return success(context, req, resp)
+ }
+}
+
+func setup() (*sql.DB, *types.RequestContext, func()) {
+ randomDb := utils.RandomId()
+
+ testDb := database.MakeConn(&randomDb)
+ database.Migrate(testDb)
+
+ context := &types.RequestContext{
+ DBConn: testDb,
+ Args: &args.Arguments{},
+ TemplateData: &(map[string]interface{}{}),
+ }
+
+ return testDb, context, func() {
+ testDb.Close()
+ os.Remove(randomDb)
+ }
+}
+
+func TestValidGuestbookPutsInDatabase(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ entries, err := database.GetGuestbookEntries(db)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(entries) > 0 {
+ t.Errorf("expected no entries, got entries")
+ }
+
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ guestbook.SignGuestbookContinuation(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer ts.Close()
+
+ req := httptest.NewRequest("POST", ts.URL, nil)
+ req.Form = map[string][]string{
+ "name": {"test"},
+ "message": {"test"},
+ }
+
+ w := httptest.NewRecorder()
+ ts.Config.Handler.ServeHTTP(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("expected status code 200, got %d", w.Code)
+ }
+
+ entries, err = database.GetGuestbookEntries(db)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(entries) != 1 {
+ t.Errorf("expected 1 entry, got %d", len(entries))
+ }
+
+ if entries[0].Name != req.FormValue("name") {
+ t.Errorf("expected name %s, got %s", req.FormValue("name"), entries[0].Name)
+ }
+}
+
+func TestInvalidGuestbookNotFoundInDatabase(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ entries, err := database.GetGuestbookEntries(db)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(entries) > 0 {
+ t.Errorf("expected no entries, got entries")
+ }
+
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ guestbook.SignGuestbookContinuation(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer testServer.Close()
+
+ reallyLongStringThatWouldTakeTooMuchSpace := "a\na\na\na\na\na\na\na\na\na\na\n"
+ invalidRequests := []struct {
+ name string
+ message string
+ }{
+ {"", "test"},
+ {"test", ""},
+ {"", ""},
+ {"test", reallyLongStringThatWouldTakeTooMuchSpace},
+ }
+
+ for _, form := range invalidRequests {
+ req := httptest.NewRequest("POST", testServer.URL, nil)
+ req.Form = map[string][]string{
+ "name": {form.name},
+ "message": {form.message},
+ }
+
+ responseRecorder := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(responseRecorder, req)
+
+ if responseRecorder.Code != http.StatusBadRequest {
+ t.Errorf("expected status code 400, got %d", responseRecorder.Code)
+ }
+ }
+
+ entries, err = database.GetGuestbookEntries(db)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(entries) != 0 {
+ t.Errorf("expected 0 entries, got %d", len(entries))
+ }
+}
diff --git a/api/hcaptcha/hcaptcha.go b/api/hcaptcha/hcaptcha.go
new file mode 100644
index 0000000..007190d
--- /dev/null
+++ b/api/hcaptcha/hcaptcha.go
@@ -0,0 +1,75 @@
+package hcaptcha
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
+)
+
+type HcaptchaArgs struct {
+ SiteKey string
+}
+
+func verifyCaptcha(secret, response string) error {
+ verifyURL := "https://hcaptcha.com/siteverify"
+ body := strings.NewReader("secret=" + secret + "&response=" + response)
+
+ req, err := http.NewRequest("POST", verifyURL, body)
+ if err != nil {
+ return err
+ }
+
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ return err
+ }
+
+ jsonResponse := struct {
+ Success bool `json:"success"`
+ }{}
+ err = json.NewDecoder(resp.Body).Decode(&jsonResponse)
+ if err != nil {
+ return err
+ }
+
+ if !jsonResponse.Success {
+ return fmt.Errorf("hcaptcha verification failed")
+ }
+
+ defer resp.Body.Close()
+ return nil
+}
+
+func CaptchaArgsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ (*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{
+ SiteKey: context.Args.HcaptchaSiteKey,
+ }
+ return success(context, req, resp)
+ }
+}
+
+func CaptchaVerificationContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ hCaptchaResponse := req.FormValue("h-captcha-response")
+ secretKey := context.Args.HcaptchaSecret
+
+ err := verifyCaptcha(secretKey, hCaptchaResponse)
+ if err != nil {
+ (*context.TemplateData)["FormError"] = types.FormError{
+ Errors: []string{"hCaptcha verification failed"},
+ }
+ resp.WriteHeader(http.StatusBadRequest)
+
+ return failure(context, req, resp)
+ }
+
+ return success(context, req, resp)
+ }
+}
diff --git a/api/api_keys.go b/api/keys/keys.go
index d636044..cef3f3c 100644
--- a/api/api_keys.go
+++ b/api/keys/keys.go
@@ -1,32 +1,33 @@
-package api
+package keys
import (
"log"
"net/http"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
)
const MAX_USER_API_KEYS = 5
-func ListAPIKeysContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, failure Continuation) ContinuationChain {
- apiKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID)
+func ListAPIKeysContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ typesKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
- (*context.TemplateData)["APIKeys"] = apiKeys
+ (*context.TemplateData)["APIKeys"] = typesKeys
return success(context, req, resp)
}
}
-func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, failure Continuation) ContinuationChain {
- formErrors := FormError{
+func CreateAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ formErrors := types.FormError{
Errors: []string{},
}
@@ -38,7 +39,7 @@ func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp h
}
if numKeys >= MAX_USER_API_KEYS {
- formErrors.Errors = append(formErrors.Errors, "max api keys reached")
+ formErrors.Errors = append(formErrors.Errors, "max types keys reached")
}
if len(formErrors.Errors) > 0 {
@@ -59,29 +60,28 @@ func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp h
}
}
-func DeleteAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, failure Continuation) ContinuationChain {
- key := req.FormValue("key")
+func DeleteAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
+ apiKey := req.FormValue("key")
- apiKey, err := database.GetAPIKey(context.DBConn, key)
+ key, err := database.GetAPIKey(context.DBConn, apiKey)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
- if (apiKey == nil) || (apiKey.UserID != context.User.ID) {
+ if (key == nil) || (key.UserID != context.User.ID) {
resp.WriteHeader(http.StatusUnauthorized)
return failure(context, req, resp)
}
- err = database.DeleteAPIKey(context.DBConn, key)
+ err = database.DeleteAPIKey(context.DBConn, apiKey)
if err != nil {
log.Println(err)
resp.WriteHeader(http.StatusInternalServerError)
return failure(context, req, resp)
}
- http.Redirect(resp, req, "/keys", http.StatusFound)
return success(context, req, resp)
}
}
diff --git a/api/serve.go b/api/serve.go
index f71001d..c8775d8 100644
--- a/api/serve.go
+++ b/api/serve.go
@@ -7,27 +7,20 @@ import (
"net/http"
"time"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/guestbook"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/hcaptcha"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/keys"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/template"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
- "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
)
-type RequestContext struct {
- DBConn *sql.DB
- Args *args.Arguments
-
- Id string
- Start time.Time
-
- TemplateData *map[string]interface{}
- User *database.User
-}
-
-type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain
-type ContinuationChain func(Continuation, Continuation) ContinuationChain
-
-func LogRequestContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, _failure Continuation) ContinuationChain {
+func LogRequestContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
context.Start = time.Now()
context.Id = utils.RandomId()
@@ -36,8 +29,8 @@ func LogRequestContinuation(context *RequestContext, req *http.Request, resp htt
}
}
-func LogExecutionTimeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, _failure Continuation) ContinuationChain {
+func LogExecutionTimeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
end := time.Now()
log.Println(context.Id, "took", end.Sub(context.Start))
@@ -46,22 +39,22 @@ func LogExecutionTimeContinuation(context *RequestContext, req *http.Request, re
}
}
-func HealthCheckContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, _failure Continuation) ContinuationChain {
+func HealthCheckContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
resp.WriteHeader(200)
resp.Write([]byte("healthy"))
return success(context, req, resp)
}
}
-func FailurePassingContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(_success Continuation, failure Continuation) ContinuationChain {
+func FailurePassingContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(_success types.Continuation, failure types.Continuation) types.ContinuationChain {
return failure(context, req, resp)
}
}
-func IdContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, _failure Continuation) ContinuationChain {
+func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
return success(context, req, resp)
}
}
@@ -80,89 +73,90 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server {
fileServer := http.FileServer(http.Dir(argv.StaticPath))
mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600)))
- makeRequestContext := func() *RequestContext {
- return &RequestContext{
- DBConn: dbConn,
- Args: argv,
+ cloudflareAdapter := &cloudflare.CloudflareExternalDNSAdapter{
+ APIToken: argv.CloudflareToken,
+ ZoneId: argv.CloudflareZone,
+ }
+ makeRequestContext := func() *types.RequestContext {
+ return &types.RequestContext{
+ DBConn: dbConn,
+ Args: argv,
TemplateData: &map[string]interface{}{},
}
}
mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
- LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(template.TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
- mux.HandleFunc("GET /api/health", func(w http.ResponseWriter, r *http.Request) {
+ mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
LogRequestContinuation(requestContext, r, w)(HealthCheckContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
mux.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
- LogRequestContinuation(requestContext, r, w)(StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ LogRequestContinuation(requestContext, r, w)(auth.StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
- LogRequestContinuation(requestContext, r, w)(InterceptCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
- })
-
- mux.HandleFunc("GET /me", func(w http.ResponseWriter, r *http.Request) {
- requestContext := makeRequestContext()
- LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(RefreshSessionContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ LogRequestContinuation(requestContext, r, w)(auth.InterceptOauthCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
- LogRequestContinuation(requestContext, r, w)(LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ LogRequestContinuation(requestContext, r, w)(auth.LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
mux.HandleFunc("GET /dns", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
- LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
+ const MAX_USER_RECORDS = 100
+ var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
- LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation, FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(dns.CreateDNSRecordContinuation(cloudflareAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS), FailurePassingContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
- LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.DeleteDNSRecordContinuation(cloudflareAdapter), auth.GoLoginContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
- LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListAPIKeysContinuation, GoLoginContinuation)(TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.ListAPIKeysContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
mux.HandleFunc("POST /keys", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
- LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CreateAPIKeyContinuation, GoLoginContinuation)(ListAPIKeysContinuation, ListAPIKeysContinuation)(TemplateContinuation("api_keys.html", true), TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.CreateAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
- LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteAPIKeyContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
- LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(hcaptcha.CaptchaArgsContinuation, hcaptcha.CaptchaArgsContinuation)(guestbook.ListGuestbookContinuation, guestbook.ListGuestbookContinuation)(template.TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
mux.HandleFunc("POST /guestbook", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
- LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(SignGuestbookContinuation, FailurePassingContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(hcaptcha.CaptchaVerificationContinuation, hcaptcha.CaptchaVerificationContinuation)(guestbook.SignGuestbookContinuation, FailurePassingContinuation)(guestbook.ListGuestbookContinuation, guestbook.ListGuestbookContinuation)(hcaptcha.CaptchaArgsContinuation, hcaptcha.CaptchaArgsContinuation)(template.TemplateContinuation("guestbook.html", true), template.TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) {
requestContext := makeRequestContext()
name := r.PathValue("name")
- LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation(name+".html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
+ LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(template.TemplateContinuation(name+".html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation)
})
return &http.Server{
diff --git a/api/template.go b/api/template/template.go
index eeaeb51..2875649 100644
--- a/api/template.go
+++ b/api/template/template.go
@@ -1,4 +1,4 @@
-package api
+package template
import (
"bytes"
@@ -7,9 +7,11 @@ import (
"log"
"net/http"
"os"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
)
-func renderTemplate(context *RequestContext, templateName string, showBaseHtml bool) (bytes.Buffer, error) {
+func renderTemplate(context *types.RequestContext, templateName string, showBaseHtml bool) (bytes.Buffer, error) {
templatePath := context.Args.TemplatePath
basePath := templatePath + "/base_empty.html"
if showBaseHtml {
@@ -41,9 +43,9 @@ func renderTemplate(context *RequestContext, templateName string, showBaseHtml b
return buffer, nil
}
-func TemplateContinuation(path string, showBase bool) Continuation {
- return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain {
- return func(success Continuation, failure Continuation) ContinuationChain {
+func TemplateContinuation(path string, showBase bool) types.Continuation {
+ return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, failure types.Continuation) types.ContinuationChain {
html, err := renderTemplate(context, path, true)
if errors.Is(err, os.ErrNotExist) {
resp.WriteHeader(404)
@@ -66,7 +68,6 @@ func TemplateContinuation(path string, showBase bool) Continuation {
return failure(context, req, resp)
}
- resp.WriteHeader(200)
resp.Header().Set("Content-Type", "text/html")
resp.Write(html.Bytes())
return success(context, req, resp)
diff --git a/api/types/types.go b/api/types/types.go
new file mode 100644
index 0000000..bbc25ea
--- /dev/null
+++ b/api/types/types.go
@@ -0,0 +1,28 @@
+package types
+
+import (
+ "database/sql"
+ "net/http"
+ "time"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+)
+
+type RequestContext struct {
+ DBConn *sql.DB
+ Args *args.Arguments
+
+ Id string
+ Start time.Time
+
+ TemplateData *map[string]interface{}
+ User *database.User
+}
+
+type FormError struct {
+ Errors []string
+}
+
+type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain
+type ContinuationChain func(Continuation, Continuation) ContinuationChain
diff --git a/args/args.go b/args/args.go
index 40dd1af..f71e8e3 100644
--- a/args/args.go
+++ b/args/args.go
@@ -22,9 +22,8 @@ type Arguments struct {
OauthConfig *oauth2.Config
OauthUserInfoURI string
- Dns bool
- DnsRecursion []string
- DnsPort int
+ Dns bool
+ DnsPort int
CloudflareToken string
CloudflareZone string
@@ -45,7 +44,6 @@ func GetArgs() (*Arguments, error) {
server := flag.Bool("server", false, "Run the server")
dns := flag.Bool("dns", false, "Run DNS resolver")
- dnsRecursion := flag.String("dns-recursion", "1.1.1.1:53,1.0.0.1:53", "Comma separated list of DNS resolvers")
dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver")
flag.Parse()
@@ -104,7 +102,6 @@ func GetArgs() (*Arguments, error) {
Migrate: *migrate,
Scheduler: *scheduler,
Dns: *dns,
- DnsRecursion: strings.Split(*dnsRecursion, ","),
DnsPort: *dnsPort,
OauthConfig: oauthConfig,
diff --git a/database/dns.go b/database/dns.go
index fc01347..7851ab4 100644
--- a/database/dns.go
+++ b/database/dns.go
@@ -9,6 +9,12 @@ import (
"time"
)
+type DomainOwner struct {
+ UserID string `json:"user_id"`
+ Domain string `json:"domain"`
+ CreatedAt time.Time `json:"created_at"`
+}
+
type DNSRecord struct {
ID string `json:"id"`
UserID string `json:"user_id"`
@@ -57,7 +63,10 @@ func GetUserDNSRecords(db *sql.DB, userID string) ([]DNSRecord, error) {
func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) {
log.Println("saving dns record", record.ID)
- record.CreatedAt = time.Now()
+ if (record.CreatedAt == time.Time{}) {
+ record.CreatedAt = time.Now()
+ }
+
_, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, internal, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.Internal, record.CreatedAt)
if err != nil {
@@ -137,3 +146,15 @@ func FindDNSRecords(dbConn *sql.DB, name string, qtype string) ([]DNSRecord, err
return records, nil
}
+
+func SaveDomainOwner(db *sql.DB, domainOwner *DomainOwner) (*DomainOwner, error) {
+ log.Println("saving domain owner", domainOwner.Domain)
+
+ domainOwner.CreatedAt = time.Now()
+ _, err := db.Exec("INSERT OR REPLACE INTO domain_owners (user_id, domain, created_at) VALUES (?, ?, ?)", domainOwner.UserID, domainOwner.Domain, domainOwner.CreatedAt)
+
+ if err != nil {
+ return nil, err
+ }
+ return domainOwner, nil
+}
diff --git a/database/users.go b/database/users.go
index 5cebb8f..6f9456e 100644
--- a/database/users.go
+++ b/database/users.go
@@ -111,6 +111,18 @@ func DeleteSession(dbConn *sql.DB, sessionId string) error {
return nil
}
+func SaveSession(dbConn *sql.DB, session *UserSession) (*UserSession, error) {
+ log.Println("saving session", session.ID)
+
+ _, err := dbConn.Exec(`INSERT OR REPLACE INTO user_sessions (id, user_id, expire_at) VALUES (?, ?, ?);`, session.ID, session.UserID, session.ExpireAt)
+ if err != nil {
+ log.Println(err)
+ return nil, err
+ }
+
+ return session, nil
+}
+
func RefreshSession(dbConn *sql.DB, sessionId string) (*UserSession, error) {
newExpireAt := time.Now().Add(ExpiryDuration)
diff --git a/dns/server.go b/hcdns/server.go
index f5365e8..ce7894b 100644
--- a/dns/server.go
+++ b/hcdns/server.go
@@ -1,4 +1,4 @@
-package dns
+package hcdns
import (
"database/sql"
@@ -9,27 +9,28 @@ import (
"log"
)
-const MAX_RECURSION = 10
-
-func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
- if maxDepth == 0 {
- return nil, fmt.Errorf("too much recursion")
- }
+const MAX_RECURSION = 15
+func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME")
if err != nil {
return nil, err
}
- answers := []dns.RR{}
+ var answers []dns.RR
for _, record := range internalCnames {
cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content))
if err != nil {
+ log.Println(err)
return nil, err
}
answers = append(answers, cname)
- cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1)
+ cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1)
+ if err != nil {
+ log.Println(err)
+ return nil, err
+ }
answers = append(answers, cnameRecursive...)
}
@@ -43,36 +44,26 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp
return nil, err
}
for _, record := range typeDnsRecords {
- answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content))
+ answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content))
if err != nil {
return nil, err
}
answers = append(answers, answer)
}
- if len(answers) > 0 {
- // base case; we found the answer
- return answers, nil
- }
-
- message := new(dns.Msg)
- message.SetQuestion(dns.Fqdn(domain), qtype)
- message.RecursionDesired = true
+ return answers, nil
+}
- client := new(dns.Client)
+func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) {
+ if maxDepth == 0 {
+ return nil, fmt.Errorf("too much recursion")
+ }
- i := 0
- in, _, err := client.Exchange(message, dnsResolvers[i])
- for err != nil {
- i += 1
- if i == len(dnsResolvers) {
- log.Println(err)
- return nil, err
- }
- in, _, err = client.Exchange(message, dnsResolvers[i])
+ answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth)
+ if err != nil {
+ return nil, err
}
- answers = append(answers, in.Answer...)
return answers, nil
}
@@ -87,22 +78,27 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
msg.Authoritative = true
for _, question := range r.Question {
- answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION)
+ answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION)
if err != nil {
fmt.Println(err)
- continue
+ msg.SetRcode(r, dns.RcodeServerFailure)
+ w.WriteMsg(msg)
+ return
}
msg.Answer = append(msg.Answer, answers...)
}
+ if len(msg.Answer) == 0 {
+ msg.SetRcode(r, dns.RcodeNameError)
+ }
+
log.Println(msg.Answer)
w.WriteMsg(msg)
}
func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server {
handler := &DnsHandler{
- DnsResolvers: argv.DnsRecursion,
- DbConn: dbConn,
+ DbConn: dbConn,
}
addr := fmt.Sprintf(":%d", argv.DnsPort)
diff --git a/hcdns/server_test.go b/hcdns/server_test.go
new file mode 100644
index 0000000..177def4
--- /dev/null
+++ b/hcdns/server_test.go
@@ -0,0 +1,254 @@
+package hcdns_test
+
+import (
+ "database/sql"
+ "fmt"
+ "math/rand"
+ "os"
+ "sync"
+ "testing"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
+ "github.com/miekg/dns"
+)
+
+func randomPort() int {
+ return rand.Intn(3000) + 5192
+}
+
+func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
+ randomDb := utils.RandomId()
+ dnsPort := randomPort()
+
+ testDb := database.MakeConn(&randomDb)
+ database.Migrate(testDb)
+ testUser := &database.User{
+ ID: "test",
+ }
+ database.FindOrSaveUser(testDb, testUser)
+
+ waitLock := &sync.Mutex{}
+ server := hcdns.MakeServer(&args.Arguments{
+ DnsPort: dnsPort,
+ }, testDb)
+ server.NotifyStartedFunc = func() {
+ waitLock.Unlock()
+ }
+ waitLock.Lock()
+
+ go func() {
+ server.ListenAndServe()
+ }()
+ waitLock.Lock()
+
+ address := fmt.Sprintf("127.0.0.1:%d", dnsPort)
+ return testDb, server, &address, waitLock, func() {
+ server.Shutdown()
+
+ testDb.Close()
+ os.Remove(randomDb)
+ }
+}
+
+func TestWhenCNAMEIsResolved(t *testing.T) {
+ t.Log("TestWhenCNAMEIsResolved")
+
+ testDb, _, addr, lock, cleanup := setup()
+ defer cleanup()
+ defer lock.Unlock()
+
+ records := []*database.DNSRecord{
+ {
+ ID: "0",
+ UserID: "test",
+ Name: "cname.internal.example.com.",
+ Type: "CNAME",
+ Content: "next.internal.example.com.",
+ TTL: 300,
+ Internal: true,
+ }, {
+ ID: "1",
+ UserID: "test",
+ Name: "next.internal.example.com.",
+ Type: "CNAME",
+ Content: "res.example.com.",
+ TTL: 300,
+ Internal: true,
+ },
+ {
+ ID: "2",
+ UserID: "test",
+ Name: "res.example.com.",
+ Type: "A",
+ Content: "1.2.3.2",
+ TTL: 300,
+ Internal: true,
+ },
+ }
+
+ for _, record := range records {
+ database.SaveDNSRecord(testDb, record)
+ }
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn("cname.internal.example.com.")
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, *addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) != 3 {
+ t.Fatalf("expected 3 answers, got %d", len(in.Answer))
+ }
+
+ for i, record := range records {
+ if in.Answer[i].Header().Name != record.Name {
+ t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name)
+ }
+
+ if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] {
+ t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype)
+ }
+
+ if int(in.Answer[i].Header().Ttl) != record.TTL {
+ t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl)
+ }
+
+ if !in.Authoritative {
+ t.Fatalf("expected authoritative response")
+ }
+ }
+
+ if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" {
+ t.Fatalf("expected final record to be the A record with correct IP")
+ }
+}
+
+func TestWhenNoRecordNxDomain(t *testing.T) {
+ t.Log("TestWhenNoRecordNxDomain")
+
+ _, _, addr, lock, cleanup := setup()
+ defer cleanup()
+ defer lock.Unlock()
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn("nonexistant.example.com.")
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, *addr)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) != 0 {
+ t.Fatalf("expected 0 answers, got %d", len(in.Answer))
+ }
+
+ if in.Rcode != dns.RcodeNameError {
+ t.Fatalf("expected NXDOMAIN, got %d", in.Rcode)
+ }
+}
+
+func TestWhenUnresolvingCNAME(t *testing.T) {
+ t.Log("TestWhenUnresolvingCNAME")
+
+ testDb, _, addr, lock, cleanup := setup()
+ defer cleanup()
+ defer lock.Unlock()
+
+ cname := &database.DNSRecord{
+ ID: "1",
+ UserID: "test",
+ Name: "cname.internal.example.com.",
+ Type: "CNAME",
+ Content: "nonexistant.example.com.",
+ TTL: 300,
+ Internal: true,
+ }
+ database.SaveDNSRecord(testDb, cname)
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn(cname.Name)
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, *addr)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) != 1 {
+ t.Fatalf("expected 1 answer, got %d", len(in.Answer))
+ }
+
+ if !in.Authoritative {
+ t.Fatalf("expected authoritative response")
+ }
+
+ if in.Answer[0].Header().Name != cname.Name {
+ t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name)
+ }
+
+ if in.Answer[0].Header().Rrtype != dns.TypeCNAME {
+ t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype)
+ }
+
+ if in.Answer[0].(*dns.CNAME).Target != cname.Content {
+ t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target)
+ }
+
+ if in.Rcode == dns.RcodeNameError {
+ t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode)
+ }
+}
+
+func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
+ t.Log("TestWhenUnresolvingCNAMEWithMaxDepth")
+
+ testDb, _, addr, lock, cleanup := setup()
+ defer cleanup()
+ defer lock.Unlock()
+
+ cname := &database.DNSRecord{
+ ID: "1",
+ UserID: "test",
+ Name: "cname.internal.example.com.",
+ Type: "CNAME",
+ Content: "cname.internal.example.com.",
+ TTL: 300,
+ Internal: true,
+ }
+ database.SaveDNSRecord(testDb, cname)
+
+ qtype := dns.TypeA
+ domain := dns.Fqdn(cname.Name)
+ client := &dns.Client{}
+ message := &dns.Msg{}
+ message.SetQuestion(domain, qtype)
+
+ in, _, err := client.Exchange(message, *addr)
+
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(in.Answer) > 0 {
+ t.Fatalf("expected 0 answers, got %d", len(in.Answer))
+ }
+
+ if in.Rcode != dns.RcodeServerFailure {
+ t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
+ }
+}
diff --git a/main.go b/main.go
index 2991821..e0f3e55 100644
--- a/main.go
+++ b/main.go
@@ -6,7 +6,7 @@ import (
"git.hatecomputers.club/hatecomputers/hatecomputers.club/api"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
- "git.hatecomputers.club/hatecomputers/hatecomputers.club/dns"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns"
"git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler"
"github.com/joho/godotenv"
)
@@ -52,7 +52,7 @@ func main() {
}
if argv.Dns {
- server := dns.MakeServer(argv, dbConn)
+ server := hcdns.MakeServer(argv, dbConn)
log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort)
go func() {
err = server.ListenAndServe()
diff --git a/static/css/styles.css b/static/css/styles.css
index 7486016..ba58018 100644
--- a/static/css/styles.css
+++ b/static/css/styles.css
@@ -15,6 +15,22 @@
padding: 0;
color: var(--text-color);
font-family: "ComicSans", sans-serif;
+
+ cursor: url("/static/img/cursor-1.png"), auto;
+ -webkit-animation: cursor 400ms infinite;
+ animation: cursor 400ms infinite;
+}
+
+@-webkit-keyframes cursor {
+ 0% {cursor: url("/static/img/cursor-2.png"), auto;}
+ 50% {cursor: url("/static/img/cursor-1.png"), auto;}
+ 100% {cursor: url("/static/img/cursor-2.png"), auto;}
+}
+
+@keyframes cursor {
+ 0% {cursor: url("/static/img/cursor-2.png"), auto;}
+ 50% {cursor: url("/static/img/cursor-1.png"), auto;}
+ 100% {cursor: url("/static/img/cursor-2.png"), auto;}
}
body {
diff --git a/static/img/cursor-1.png b/static/img/cursor-1.png
new file mode 100644
index 0000000..68fbe5c
--- /dev/null
+++ b/static/img/cursor-1.png
Binary files differ
diff --git a/static/img/cursor-2.png b/static/img/cursor-2.png
new file mode 100644
index 0000000..9851648
--- /dev/null
+++ b/static/img/cursor-2.png
Binary files differ