summaryrefslogtreecommitdiff
path: root/api/dns/dns_test.go
diff options
context:
space:
mode:
authorsimponic <simponic@hatecomputers.club>2024-04-06 15:43:18 -0400
committersimponic <simponic@hatecomputers.club>2024-04-06 15:43:18 -0400
commit83cc6267fd5ce2f61200314424c5f400f65ff2ba (patch)
treeeafb35310236a15572cbb6e16ff8d6f181bfe240 /api/dns/dns_test.go
parent569d2788ebfb90774faf361f62bfe7968e091465 (diff)
parentcad8e2c4ed5e3bab61ff243f8677f8a46eaeafb0 (diff)
downloadhatecomputers.club-83cc6267fd5ce2f61200314424c5f400f65ff2ba.tar.gz
hatecomputers.club-83cc6267fd5ce2f61200314424c5f400f65ff2ba.zip
Merge pull request 'testing | dont be recursive for external domains | finalize oauth' (#5) from dont-be-authoritative into main
Reviewed-on: https://git.hatecomputers.club/hatecomputers/hatecomputers.club/pulls/5
Diffstat (limited to 'api/dns/dns_test.go')
-rw-r--r--api/dns/dns_test.go442
1 files changed, 442 insertions, 0 deletions
diff --git a/api/dns/dns_test.go b/api/dns/dns_test.go
new file mode 100644
index 0000000..43dc680
--- /dev/null
+++ b/api/dns/dns_test.go
@@ -0,0 +1,442 @@
+package dns_test
+
+import (
+ "database/sql"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strconv"
+ "testing"
+ "time"
+
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/args"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/database"
+ "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils"
+)
+
+const MAX_USER_RECORDS = 10
+
+var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"}
+
+func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain {
+ return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain {
+ return success(context, req, resp)
+ }
+}
+
+func setup() (*sql.DB, *types.RequestContext, func()) {
+ randomDb := utils.RandomId()
+
+ testDb := database.MakeConn(&randomDb)
+ database.Migrate(testDb)
+
+ user := &database.User{
+ ID: "test",
+ Username: "test",
+ Mail: "test@test.com",
+ DisplayName: "test",
+ }
+ database.FindOrSaveUser(testDb, user)
+
+ context := &types.RequestContext{
+ DBConn: testDb,
+ Args: &args.Arguments{},
+ TemplateData: &(map[string]interface{}{}),
+ User: user,
+ }
+
+ return testDb, context, func() {
+ testDb.Close()
+ os.Remove(randomDb)
+ }
+}
+
+type SignallingExternalDnsAdapter struct {
+ AddChannel chan *database.DNSRecord
+ RmChannel chan string
+}
+
+func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) {
+ id := utils.RandomId()
+ go func() { adapter.AddChannel <- record }()
+
+ return id, nil
+}
+
+func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error {
+ go func() { adapter.RmChannel <- id }()
+
+ return nil
+}
+
+func TestThatOwnerCanPutRecordInDomain(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ domainOwner := &database.DomainOwner{
+ UserID: context.User.ID,
+ Domain: "test.domain.",
+ }
+ domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
+
+ records, err := database.GetUserDNSRecords(db, context.User.ID)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(records) > 0 {
+ t.Errorf("expected no records, got records")
+ }
+
+ addChannel := make(chan *database.DNSRecord)
+ signallingDnsAdapter := &SignallingExternalDnsAdapter{
+ AddChannel: addChannel,
+ }
+
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer testServer.Close()
+
+ validOwner := httptest.NewRequest("POST", testServer.URL, nil)
+ validOwner.Form = map[string][]string{
+ "internal": {"on"},
+ "name": {"new.test.domain."},
+ "type": {"CNAME"},
+ "ttl": {"43000"},
+ "content": {"test.domain."},
+ }
+
+ validOwnerRecorder := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
+ if validOwnerRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
+ }
+
+ validOwnerNonInternalRecorder := httptest.NewRecorder()
+ validOwner.Form["internal"] = []string{"off"}
+ testServer.Config.Handler.ServeHTTP(validOwnerNonInternalRecorder, validOwner)
+ if validOwnerNonInternalRecorder.Code != http.StatusUnauthorized {
+ t.Errorf("expected invalid return, got %d", validOwnerNonInternalRecorder.Code)
+ }
+
+ invalidOwnerRecorder := httptest.NewRecorder()
+ invalidOwner := validOwner
+ invalidOwner.Form["internal"] = []string{"on"}
+ invalidOwner.Form["name"] = []string{"new.invalid.domain."}
+ testServer.Config.Handler.ServeHTTP(invalidOwnerRecorder, invalidOwner)
+ if invalidOwnerRecorder.Code != http.StatusUnauthorized {
+ t.Errorf("expected invalid return, got %d", invalidOwnerRecorder.Code)
+ }
+}
+
+func TestThatUserCanAddToPublicEndpoints(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ addChannel := make(chan *database.DNSRecord)
+ signallingDnsAdapter := &SignallingExternalDnsAdapter{
+ AddChannel: addChannel,
+ }
+
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer testServer.Close()
+
+ responseRecorder := httptest.NewRecorder()
+ req := httptest.NewRequest("POST", testServer.URL, nil)
+ fmts := USER_OWNED_INTERNAL_FMT_DOMAINS
+ for _, format := range fmts {
+ name := fmt.Sprintf(format, context.User.Username)
+
+ req.Form = map[string][]string{
+ "internal": {"off"},
+ "name": {name},
+ "type": {"CNAME"},
+ "ttl": {"43000"},
+ "content": {"test.domain."},
+ }
+
+ testServer.Config.Handler.ServeHTTP(responseRecorder, req)
+ if responseRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", responseRecorder.Code)
+ }
+
+ namedRecords, _ := database.FindDNSRecords(db, name, "CNAME")
+ if len(namedRecords) == 0 {
+ t.Errorf("saved record not found")
+ }
+ }
+}
+
+func TestThatExternalDnsSaves(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ addChannel := make(chan *database.DNSRecord)
+ signallingDnsAdapter := &SignallingExternalDnsAdapter{
+ AddChannel: addChannel,
+ }
+
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer testServer.Close()
+
+ responseRecorder := httptest.NewRecorder()
+ externalRequest := httptest.NewRequest("POST", testServer.URL, nil)
+
+ name := "test." + context.User.Username
+ externalRequest.Form = map[string][]string{
+ "internal": {"off"},
+ "name": {name},
+ "type": {"CNAME"},
+ "ttl": {"43000"},
+ "content": {"test.domain."},
+ }
+
+ testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
+ if responseRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", responseRecorder.Code)
+ }
+ select {
+ case res := <-addChannel:
+ if res.Name != name || res.Type != "CNAME" || res.Content != "test.domain." {
+ t.Errorf("received the wrong external record")
+ }
+ case <-time.After(100 * time.Millisecond):
+ t.Errorf("timed out in waiting for external addition")
+ }
+
+ domainOwner := &database.DomainOwner{
+ UserID: context.User.ID,
+ Domain: "test.domain.",
+ }
+ domainOwner, _ = database.SaveDomainOwner(db, domainOwner)
+ internalRequest := externalRequest
+ internalRequest.Form["internal"] = []string{"on"}
+ internalRequest.Form["name"] = []string{"test.domain."}
+
+ testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest)
+ if responseRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", responseRecorder.Code)
+ }
+ select {
+ case _ = <-addChannel:
+ t.Errorf("expected nothing in the add channel")
+ case <-time.After(100 * time.Millisecond):
+ }
+}
+
+func TestThatUserMustOwnRecordToRemove(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ rmChannel := make(chan string)
+ signallingDnsAdapter := &SignallingExternalDnsAdapter{
+ RmChannel: rmChannel,
+ }
+
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer testServer.Close()
+
+ nonOwnerUser := &database.User{ID: "n/a", Username: "testuser"}
+ _, err := database.FindOrSaveUser(db, nonOwnerUser)
+ if err != nil {
+ t.Error(err)
+ }
+
+ record := &database.DNSRecord{
+ ID: "1",
+ Internal: false,
+ Name: "test",
+ Type: "CNAME",
+ Content: "asdf",
+ TTL: 1000,
+ UserID: nonOwnerUser.ID,
+ }
+ _, err = database.SaveDNSRecord(db, record)
+ if err != nil {
+ t.Error(err)
+ }
+
+ nonOwnerRecorder := httptest.NewRecorder()
+ nonOwner := httptest.NewRequest("POST", testServer.URL, nil)
+ nonOwner.Form = map[string][]string{
+ "id": {record.ID},
+ }
+
+ testServer.Config.Handler.ServeHTTP(nonOwnerRecorder, nonOwner)
+ if nonOwnerRecorder.Code != http.StatusUnauthorized {
+ t.Errorf("expected unauthorized return, got %d", nonOwnerRecorder.Code)
+ }
+
+ record.UserID = context.User.ID
+ record.ID = "2"
+ database.SaveDNSRecord(db, record)
+
+ owner := nonOwner
+ owner.Form["id"] = []string{"2"}
+ ownerRecorder := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(ownerRecorder, owner)
+ if ownerRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", ownerRecorder.Code)
+ }
+}
+
+func TestThatExternalDnsRemoves(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ record := &database.DNSRecord{
+ ID: "1",
+ Internal: false,
+ Name: "test",
+ Type: "CNAME",
+ Content: "asdf",
+ TTL: 1000,
+ UserID: context.User.ID,
+ }
+ database.SaveDNSRecord(db, record)
+
+ rmChannel := make(chan string)
+ signallingDnsAdapter := &SignallingExternalDnsAdapter{
+ RmChannel: rmChannel,
+ }
+
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer testServer.Close()
+
+ externalResponseRecorder := httptest.NewRecorder()
+ deleteRequest := httptest.NewRequest("POST", testServer.URL, nil)
+
+ deleteRequest.Form = map[string][]string{
+ "id": {record.ID},
+ }
+
+ testServer.Config.Handler.ServeHTTP(externalResponseRecorder, deleteRequest)
+ if externalResponseRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", externalResponseRecorder.Code)
+ }
+ select {
+ case res := <-rmChannel:
+ if res != record.ID {
+ t.Errorf("received the wrong external record")
+ }
+ case <-time.After(100 * time.Millisecond):
+ t.Errorf("timed out in waiting for external addition")
+ }
+
+ record.Internal = true
+ record.Name = "test.domain."
+ database.SaveDNSRecord(db, record)
+ domainOwner := &database.DomainOwner{
+ UserID: context.User.ID,
+ Domain: "test.domain.",
+ }
+ database.SaveDomainOwner(db, domainOwner)
+
+ internalResponseRecorder := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(internalResponseRecorder, deleteRequest)
+ if internalResponseRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", internalResponseRecorder.Code)
+ }
+ select {
+ case _ = <-rmChannel:
+ t.Errorf("expected nothing in the rmchannel")
+ case <-time.After(100 * time.Millisecond):
+ }
+}
+
+func TestRecordCountCannotExceed(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ record := &database.DNSRecord{
+ Internal: false,
+ Name: context.User.Username,
+ Type: "CNAME",
+ Content: "asdf",
+ TTL: 1000,
+ UserID: context.User.ID,
+ }
+
+ for i := 1; i <= MAX_USER_RECORDS; i++ {
+ record.ID = strconv.Itoa(i)
+ record.Name = record.ID + "." + record.Name
+ database.SaveDNSRecord(db, record)
+ }
+
+ addChannel := make(chan *database.DNSRecord)
+ signallingDnsAdapter := &SignallingExternalDnsAdapter{
+ AddChannel: addChannel,
+ }
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer testServer.Close()
+
+ req := httptest.NewRequest("POST", testServer.URL, nil)
+ req.Form = map[string][]string{
+ "internal": {"off"},
+ "name": {record.Name},
+ "type": {record.Type},
+ "ttl": {"43000"},
+ "content": {record.Content},
+ }
+
+ recorder := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(recorder, req)
+ if recorder.Code != http.StatusTooManyRequests {
+ t.Errorf("expected too many requests code return, got %d", recorder.Code)
+ }
+}
+
+func TestInternalRecordAppendsTopLevelDot(t *testing.T) {
+ db, context, cleanup := setup()
+ defer cleanup()
+
+ domainOwner := &database.DomainOwner{
+ UserID: context.User.ID,
+ Domain: "test.internal.",
+ }
+ database.SaveDomainOwner(db, domainOwner)
+
+ addChannel := make(chan *database.DNSRecord)
+ signallingDnsAdapter := &SignallingExternalDnsAdapter{
+ AddChannel: addChannel,
+ }
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation)
+ }))
+ defer testServer.Close()
+
+ validOwner := httptest.NewRequest("POST", testServer.URL, nil)
+ validOwner.Form = map[string][]string{
+ "internal": {"on"},
+ "name": {"test.internal"},
+ "type": {"CNAME"},
+ "ttl": {"43000"},
+ "content": {"asdf.internal"},
+ }
+
+ validOwnerRecorder := httptest.NewRecorder()
+ testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner)
+ if validOwnerRecorder.Code != http.StatusOK {
+ t.Errorf("expected valid return, got %d", validOwnerRecorder.Code)
+ }
+
+ recordsAppendedDot, _ := database.FindDNSRecords(db, "test.internal.", "CNAME")
+ recordsWithoutDot, _ := database.FindDNSRecords(db, "test.internal", "CNAME")
+
+ if len(recordsAppendedDot) != 1 && len(recordsWithoutDot) != 0 {
+ t.Errorf("expected dot appended")
+ }
+}