From d7843d18d01a0b74319b66e5fbfbc680faa1951d Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Tue, 2 Apr 2024 14:49:18 -0600 Subject: stop being authoritative for stuff not in internal dns --- Dockerfile | 2 +- args/args.go | 7 ++----- dns/server.go | 53 +++++++++++++++++++++++++---------------------------- 3 files changed, 28 insertions(+), 34 deletions(-) diff --git a/Dockerfile b/Dockerfile index a46f6c4..591423f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,4 +11,4 @@ RUN go build -o /app/hatecomputers EXPOSE 8080 -CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053", "--dns-recursion", "1.1.1.1:53,1.0.0.1:53"] +CMD ["/app/hatecomputers", "--server", "--migrate", "--port", "8080", "--template-path", "/app/templates", "--database-path", "/app/db/hatecomputers.db", "--static-path", "/app/static", "--scheduler", "--dns", "--dns-port", "8053"] diff --git a/args/args.go b/args/args.go index 40dd1af..f71e8e3 100644 --- a/args/args.go +++ b/args/args.go @@ -22,9 +22,8 @@ type Arguments struct { OauthConfig *oauth2.Config OauthUserInfoURI string - Dns bool - DnsRecursion []string - DnsPort int + Dns bool + DnsPort int CloudflareToken string CloudflareZone string @@ -45,7 +44,6 @@ func GetArgs() (*Arguments, error) { server := flag.Bool("server", false, "Run the server") dns := flag.Bool("dns", false, "Run DNS resolver") - dnsRecursion := flag.String("dns-recursion", "1.1.1.1:53,1.0.0.1:53", "Comma separated list of DNS resolvers") dnsPort := flag.Int("dns-port", 8053, "Port to listen on for DNS resolver") flag.Parse() @@ -104,7 +102,6 @@ func GetArgs() (*Arguments, error) { Migrate: *migrate, Scheduler: *scheduler, Dns: *dns, - DnsRecursion: strings.Split(*dnsRecursion, ","), DnsPort: *dnsPort, OauthConfig: oauthConfig, diff --git a/dns/server.go b/dns/server.go index f5365e8..9b3e5e9 100644 --- a/dns/server.go +++ b/dns/server.go @@ -11,17 +11,13 @@ import ( const MAX_RECURSION = 10 -func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { - if maxDepth == 0 { - return nil, fmt.Errorf("too much recursion") - } - +func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") if err != nil { return nil, err } - answers := []dns.RR{} + var answers []dns.RR for _, record := range internalCnames { cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) if err != nil { @@ -29,7 +25,10 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp } answers = append(answers, cname) - cnameRecursive, _ := resolveRecursive(dbConn, dnsResolvers, record.Content, qtype, maxDepth-1) + cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) + if err != nil { + return nil, err + } answers = append(answers, cnameRecursive...) } @@ -43,37 +42,31 @@ func resolveRecursive(dbConn *sql.DB, dnsResolvers []string, domain string, qtyp return nil, err } for _, record := range typeDnsRecords { - answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, record.Type, record.Content)) + answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) if err != nil { return nil, err } answers = append(answers, answer) } - if len(answers) > 0 { - // base case; we found the answer - return answers, nil - } + return answers, nil +} - message := new(dns.Msg) - message.SetQuestion(dns.Fqdn(domain), qtype) - message.RecursionDesired = true +func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { + if maxDepth == 0 { + return nil, fmt.Errorf("too much recursion") + } - client := new(dns.Client) + answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth) + if err != nil { + return nil, err + } - i := 0 - in, _, err := client.Exchange(message, dnsResolvers[i]) - for err != nil { - i += 1 - if i == len(dnsResolvers) { - log.Println(err) - return nil, err - } - in, _, err = client.Exchange(message, dnsResolvers[i]) + if len(answers) > 0 { + return answers, nil } - answers = append(answers, in.Answer...) - return answers, nil + return nil, fmt.Errorf("no records found for %s", domain) } type DnsHandler struct { @@ -87,7 +80,7 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { msg.Authoritative = true for _, question := range r.Question { - answers, err := resolveRecursive(h.DbConn, h.DnsResolvers, question.Name, question.Qtype, MAX_RECURSION) + answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) if err != nil { fmt.Println(err) continue @@ -95,6 +88,10 @@ func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { msg.Answer = append(msg.Answer, answers...) } + if len(msg.Answer) == 0 { + msg.SetRcode(r, dns.RcodeNameError) + } + log.Println(msg.Answer) w.WriteMsg(msg) } -- cgit v1.2.3-70-g09d2 From 657be669482462ada3b88672ff7497b652848176 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Tue, 2 Apr 2024 14:53:50 -0600 Subject: defer body close after encoding json --- api/auth.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/api/auth.go b/api/auth.go index 0294edd..14e6924 100644 --- a/api/auth.go +++ b/api/auth.go @@ -259,11 +259,13 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us } func createUserFromResponse(response *http.Response) (*database.User, error) { - defer response.Body.Close() user := &database.User{ CreatedAt: time.Now(), } + err := json.NewDecoder(response.Body).Decode(user) + defer response.Body.Close() + if err != nil { log.Println(err) return nil, err -- cgit v1.2.3-70-g09d2 From bcdcc508ef4a0ae646937c91d0994a90bde719e1 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Tue, 2 Apr 2024 16:26:39 -0600 Subject: add integration tests for dns server --- dns/server.go | 113 -------------------------- hcdns/server.go | 112 +++++++++++++++++++++++++ main.go | 4 +- test/dns_test.go | 244 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 358 insertions(+), 115 deletions(-) delete mode 100644 dns/server.go create mode 100644 hcdns/server.go create mode 100644 test/dns_test.go diff --git a/dns/server.go b/dns/server.go deleted file mode 100644 index 9b3e5e9..0000000 --- a/dns/server.go +++ /dev/null @@ -1,113 +0,0 @@ -package dns - -import ( - "database/sql" - "fmt" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "github.com/miekg/dns" - "log" -) - -const MAX_RECURSION = 10 - -func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { - internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") - if err != nil { - return nil, err - } - - var answers []dns.RR - for _, record := range internalCnames { - cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) - if err != nil { - return nil, err - } - answers = append(answers, cname) - - cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) - if err != nil { - return nil, err - } - answers = append(answers, cnameRecursive...) - } - - qtypeName := dns.TypeToString[qtype] - if qtypeName == "" { - return nil, fmt.Errorf("invalid query type %d", qtype) - } - - typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName) - if err != nil { - return nil, err - } - for _, record := range typeDnsRecords { - answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) - if err != nil { - return nil, err - } - answers = append(answers, answer) - } - - return answers, nil -} - -func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { - if maxDepth == 0 { - return nil, fmt.Errorf("too much recursion") - } - - answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth) - if err != nil { - return nil, err - } - - if len(answers) > 0 { - return answers, nil - } - - return nil, fmt.Errorf("no records found for %s", domain) -} - -type DnsHandler struct { - DnsResolvers []string - DbConn *sql.DB -} - -func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - msg := new(dns.Msg) - msg.SetReply(r) - msg.Authoritative = true - - for _, question := range r.Question { - answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) - if err != nil { - fmt.Println(err) - continue - } - msg.Answer = append(msg.Answer, answers...) - } - - if len(msg.Answer) == 0 { - msg.SetRcode(r, dns.RcodeNameError) - } - - log.Println(msg.Answer) - w.WriteMsg(msg) -} - -func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { - handler := &DnsHandler{ - DnsResolvers: argv.DnsRecursion, - DbConn: dbConn, - } - addr := fmt.Sprintf(":%d", argv.DnsPort) - - return &dns.Server{ - Addr: addr, - Net: "udp", - Handler: handler, - UDPSize: 65535, - ReusePort: true, - } -} diff --git a/hcdns/server.go b/hcdns/server.go new file mode 100644 index 0000000..ce7894b --- /dev/null +++ b/hcdns/server.go @@ -0,0 +1,112 @@ +package hcdns + +import ( + "database/sql" + "fmt" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "github.com/miekg/dns" + "log" +) + +const MAX_RECURSION = 15 + +func resolveInternalCNAMEs(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { + internalCnames, err := database.FindDNSRecords(dbConn, domain, "CNAME") + if err != nil { + return nil, err + } + + var answers []dns.RR + for _, record := range internalCnames { + cname, err := dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", record.Name, record.TTL, record.Content)) + if err != nil { + log.Println(err) + return nil, err + } + answers = append(answers, cname) + + cnameRecursive, err := resolveDNS(dbConn, record.Content, qtype, maxDepth-1) + if err != nil { + log.Println(err) + return nil, err + } + answers = append(answers, cnameRecursive...) + } + + qtypeName := dns.TypeToString[qtype] + if qtypeName == "" { + return nil, fmt.Errorf("invalid query type %d", qtype) + } + + typeDnsRecords, err := database.FindDNSRecords(dbConn, domain, qtypeName) + if err != nil { + return nil, err + } + for _, record := range typeDnsRecords { + answer, err := dns.NewRR(fmt.Sprintf("%s %d IN %s %s", record.Name, record.TTL, qtypeName, record.Content)) + if err != nil { + return nil, err + } + answers = append(answers, answer) + } + + return answers, nil +} + +func resolveDNS(dbConn *sql.DB, domain string, qtype uint16, maxDepth int) ([]dns.RR, error) { + if maxDepth == 0 { + return nil, fmt.Errorf("too much recursion") + } + + answers, err := resolveInternalCNAMEs(dbConn, domain, qtype, maxDepth) + if err != nil { + return nil, err + } + + return answers, nil +} + +type DnsHandler struct { + DnsResolvers []string + DbConn *sql.DB +} + +func (h *DnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + msg := new(dns.Msg) + msg.SetReply(r) + msg.Authoritative = true + + for _, question := range r.Question { + answers, err := resolveDNS(h.DbConn, question.Name, question.Qtype, MAX_RECURSION) + if err != nil { + fmt.Println(err) + msg.SetRcode(r, dns.RcodeServerFailure) + w.WriteMsg(msg) + return + } + msg.Answer = append(msg.Answer, answers...) + } + + if len(msg.Answer) == 0 { + msg.SetRcode(r, dns.RcodeNameError) + } + + log.Println(msg.Answer) + w.WriteMsg(msg) +} + +func MakeServer(argv *args.Arguments, dbConn *sql.DB) *dns.Server { + handler := &DnsHandler{ + DbConn: dbConn, + } + addr := fmt.Sprintf(":%d", argv.DnsPort) + + return &dns.Server{ + Addr: addr, + Net: "udp", + Handler: handler, + UDPSize: 65535, + ReusePort: true, + } +} diff --git a/main.go b/main.go index 2991821..e0f3e55 100644 --- a/main.go +++ b/main.go @@ -6,7 +6,7 @@ import ( "git.hatecomputers.club/hatecomputers/hatecomputers.club/api" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/dns" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" "git.hatecomputers.club/hatecomputers/hatecomputers.club/scheduler" "github.com/joho/godotenv" ) @@ -52,7 +52,7 @@ func main() { } if argv.Dns { - server := dns.MakeServer(argv, dbConn) + server := hcdns.MakeServer(argv, dbConn) log.Println("🚀🚀 DNS resolver listening on port", argv.DnsPort) go func() { err = server.ListenAndServe() diff --git a/test/dns_test.go b/test/dns_test.go new file mode 100644 index 0000000..ce6deb5 --- /dev/null +++ b/test/dns_test.go @@ -0,0 +1,244 @@ +package hcdns + +import ( + "database/sql" + "os" + "sync" + "testing" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" + "github.com/miekg/dns" +) + +const ( + testDBPath = "test.db" + address = "127.0.0.1:8353" + dnsPort = 8353 +) + +func setup(dbPath string) (*sql.DB, *dns.Server, *sync.WaitGroup) { + testDb := database.MakeConn(&dbPath) + database.Migrate(testDb) + testUser := &database.User{ + ID: "test", + } + database.FindOrSaveUser(testDb, testUser) + + server := hcdns.MakeServer(&args.Arguments{ + DnsPort: dnsPort, + }, testDb) + + waitGroup := sync.WaitGroup{} + waitGroup.Add(1) + go func() { + server.ListenAndServe() + waitGroup.Done() + }() + + return testDb, server, &waitGroup +} + +func destroy(conn *sql.DB, path string) { + conn.Close() + os.Remove(path) +} + +func TestWhenCNAMEIsResolved(t *testing.T) { + t.Log("TestWhenCNAMEIsResolved") + + testDb, server, _ := setup(testDBPath) + defer destroy(testDb, testDBPath) + defer server.Shutdown() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "res.example.com.", + TTL: 300, + Internal: true, + } + a := &database.DNSRecord{ + ID: "2", + UserID: "test", + Name: "res.example.com.", + Type: "A", + Content: "127.0.0.1", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + database.SaveDNSRecord(testDb, a) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := new(dns.Client) + message := new(dns.Msg) + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, address) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 2 { + t.Fatalf("expected 2 answers, got %d", len(in.Answer)) + } + + if in.Answer[0].Header().Name != cname.Name { + t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) + } + + if in.Answer[1].Header().Name != a.Name { + t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name) + } + + if in.Answer[0].(*dns.CNAME).Target != a.Name { + t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) + } + + if in.Answer[1].(*dns.A).A.String() != a.Content { + t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String()) + } + + if in.Answer[0].Header().Rrtype != dns.TypeCNAME { + t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) + } + + if in.Answer[1].Header().Rrtype != dns.TypeA { + t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype) + } + + if int(in.Answer[0].Header().Ttl) != cname.TTL { + t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } +} + +func TestWhenNoRecordNxDomain(t *testing.T) { + t.Log("TestWhenNoRecordNxDomain") + + testDb, server, _ := setup(testDBPath) + defer destroy(testDb, testDBPath) + defer server.Shutdown() + + qtype := dns.TypeA + domain := dns.Fqdn("nonexistant.example.com.") + client := new(dns.Client) + message := new(dns.Msg) + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, address) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 0 { + t.Fatalf("expected 0 answers, got %d", len(in.Answer)) + } + + if in.Rcode != dns.RcodeNameError { + t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) + } +} + +func TestWhenUnresolvingCNAME(t *testing.T) { + t.Log("TestWhenUnresolvingCNAME") + + testDb, server, _ := setup(testDBPath) + defer destroy(testDb, testDBPath) + defer server.Shutdown() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "nonexistant.example.com.", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := new(dns.Client) + message := new(dns.Msg) + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, address) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 1 { + t.Fatalf("expected 1 answer, got %d", len(in.Answer)) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } + + if in.Answer[0].Header().Name != cname.Name { + t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) + } + + if in.Answer[0].Header().Rrtype != dns.TypeCNAME { + t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) + } + + if in.Answer[0].(*dns.CNAME).Target != cname.Content { + t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) + } + + if in.Rcode == dns.RcodeNameError { + t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) + } +} + +func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { + t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") + + testDb, server, _ := setup(testDBPath) + defer destroy(testDb, testDBPath) + defer server.Shutdown() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "cname.internal.example.com.", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := new(dns.Client) + message := new(dns.Msg) + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, address) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) > 0 { + t.Fatalf("expected 0 answers, got %d", len(in.Answer)) + } + if in.Rcode != dns.RcodeServerFailure { + t.Fatalf("expected SERVFAIL, got %d", in.Rcode) + } +} -- cgit v1.2.3-70-g09d2 From 07c272b8098b6cefa17280f65b4079e99a18a6c2 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Tue, 2 Apr 2024 16:32:03 -0600 Subject: add test step to ci --- .drone.yml | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/.drone.yml b/.drone.yml index 8f459a1..38ed21d 100644 --- a/.drone.yml +++ b/.drone.yml @@ -1,9 +1,15 @@ --- kind: pipeline type: docker -name: build, publish docker image, deploy +name: deployment steps: + - name: run tests + image: golang + commands: + - go build + - go test -v ./... + - name: docker image: plugins/docker settings: @@ -13,9 +19,10 @@ steps: from_secret: gitea_packpub_password registry: git.hatecomputers.club repo: git.hatecomputers.club/hatecomputers/hatecomputers.club - tags: - - latest - - main + when: + branch: + - main + - name: ssh image: appleboy/drone-ssh settings: @@ -27,6 +34,10 @@ steps: command_timeout: 2m script: - systemctl restart docker-compose@hatecomputers-club + when: + branch: + - main + trigger: branch: - - main + - main -- cgit v1.2.3-70-g09d2 From 14b450b9c87f391909a80383041844a48b5d77ea Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Tue, 2 Apr 2024 16:36:54 -0600 Subject: update parallelism in ci --- .drone.yml | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/.drone.yml b/.drone.yml index 38ed21d..52f3ffa 100644 --- a/.drone.yml +++ b/.drone.yml @@ -1,15 +1,26 @@ --- kind: pipeline type: docker -name: deployment +name: build steps: - name: run tests image: golang commands: - go build - - go test -v ./... + - go test -p 1 -v ./... +trigger: + event: + - pull_request + - push + +--- +kind: pipeline +type: docker +name: build + +steps: - name: docker image: plugins/docker settings: @@ -19,10 +30,6 @@ steps: from_secret: gitea_packpub_password registry: git.hatecomputers.club repo: git.hatecomputers.club/hatecomputers/hatecomputers.club - when: - branch: - - main - - name: ssh image: appleboy/drone-ssh settings: @@ -34,10 +41,9 @@ steps: command_timeout: 2m script: - systemctl restart docker-compose@hatecomputers-club - when: - branch: - - main trigger: branch: - main + event: + - push -- cgit v1.2.3-70-g09d2 From 52d061e7cc309bdb1a73b3c2c0eda4c089770e9f Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Tue, 2 Apr 2024 16:38:48 -0600 Subject: rename target in ci/cd --- .drone.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.drone.yml b/.drone.yml index 52f3ffa..94716d5 100644 --- a/.drone.yml +++ b/.drone.yml @@ -18,7 +18,7 @@ trigger: --- kind: pipeline type: docker -name: build +name: deploy steps: - name: docker -- cgit v1.2.3-70-g09d2 From 35a5e9a263eec2edf375eae119a31f03fb87caa0 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Tue, 2 Apr 2024 16:51:13 -0600 Subject: use random ports and test db paths --- .dockerignore | 1 + test/dns_test.go | 63 +++++++++++++++++++++++++++++++------------------------- 2 files changed, 36 insertions(+), 28 deletions(-) diff --git a/.dockerignore b/.dockerignore index 52be0d9..6045466 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,3 +2,4 @@ hatecomputers.club Dockerfile *.db +.drone.yml diff --git a/test/dns_test.go b/test/dns_test.go index ce6deb5..55bb060 100644 --- a/test/dns_test.go +++ b/test/dns_test.go @@ -2,6 +2,8 @@ package hcdns import ( "database/sql" + "fmt" + "math/rand" "os" "sync" "testing" @@ -9,17 +11,24 @@ import ( "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" "github.com/miekg/dns" ) -const ( - testDBPath = "test.db" - address = "127.0.0.1:8353" - dnsPort = 8353 -) +func destroy(conn *sql.DB, path string) { + conn.Close() + os.Remove(path) +} + +func randomPort() int { + return rand.Intn(3000) + 10000 +} -func setup(dbPath string) (*sql.DB, *dns.Server, *sync.WaitGroup) { - testDb := database.MakeConn(&dbPath) +func setup() (*sql.DB, *dns.Server, int, *string, func()) { + randomDb := utils.RandomId() + dnsPort := randomPort() + + testDb := database.MakeConn(&randomDb) database.Migrate(testDb) testUser := &database.User{ ID: "test", @@ -37,20 +46,21 @@ func setup(dbPath string) (*sql.DB, *dns.Server, *sync.WaitGroup) { waitGroup.Done() }() - return testDb, server, &waitGroup -} + address := fmt.Sprintf("127.0.0.1:%d", dnsPort) + return testDb, server, dnsPort, &address, func() { + testDb.Close() + os.Remove(randomDb) -func destroy(conn *sql.DB, path string) { - conn.Close() - os.Remove(path) + server.Shutdown() + waitGroup.Wait() + } } func TestWhenCNAMEIsResolved(t *testing.T) { t.Log("TestWhenCNAMEIsResolved") - testDb, server, _ := setup(testDBPath) - defer destroy(testDb, testDBPath) - defer server.Shutdown() + testDb, _, _, addr, cleanup := setup() + defer cleanup() cname := &database.DNSRecord{ ID: "1", @@ -79,7 +89,7 @@ func TestWhenCNAMEIsResolved(t *testing.T) { message := new(dns.Msg) message.SetQuestion(domain, qtype) - in, _, err := client.Exchange(message, address) + in, _, err := client.Exchange(message, *addr) if err != nil { t.Fatal(err) @@ -125,9 +135,8 @@ func TestWhenCNAMEIsResolved(t *testing.T) { func TestWhenNoRecordNxDomain(t *testing.T) { t.Log("TestWhenNoRecordNxDomain") - testDb, server, _ := setup(testDBPath) - defer destroy(testDb, testDBPath) - defer server.Shutdown() + _, _, _, addr, cleanup := setup() + defer cleanup() qtype := dns.TypeA domain := dns.Fqdn("nonexistant.example.com.") @@ -135,7 +144,7 @@ func TestWhenNoRecordNxDomain(t *testing.T) { message := new(dns.Msg) message.SetQuestion(domain, qtype) - in, _, err := client.Exchange(message, address) + in, _, err := client.Exchange(message, *addr) if err != nil { t.Fatal(err) @@ -153,9 +162,8 @@ func TestWhenNoRecordNxDomain(t *testing.T) { func TestWhenUnresolvingCNAME(t *testing.T) { t.Log("TestWhenUnresolvingCNAME") - testDb, server, _ := setup(testDBPath) - defer destroy(testDb, testDBPath) - defer server.Shutdown() + testDb, _, _, addr, cleanup := setup() + defer cleanup() cname := &database.DNSRecord{ ID: "1", @@ -174,7 +182,7 @@ func TestWhenUnresolvingCNAME(t *testing.T) { message := new(dns.Msg) message.SetQuestion(domain, qtype) - in, _, err := client.Exchange(message, address) + in, _, err := client.Exchange(message, *addr) if err != nil { t.Fatal(err) @@ -208,9 +216,8 @@ func TestWhenUnresolvingCNAME(t *testing.T) { func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") - testDb, server, _ := setup(testDBPath) - defer destroy(testDb, testDBPath) - defer server.Shutdown() + testDb, _, _, addr, cleanup := setup() + defer cleanup() cname := &database.DNSRecord{ ID: "1", @@ -229,7 +236,7 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { message := new(dns.Msg) message.SetQuestion(domain, qtype) - in, _, err := client.Exchange(message, address) + in, _, err := client.Exchange(message, *addr) if err != nil { t.Fatal(err) -- cgit v1.2.3-70-g09d2 From ce393a5ac1dedaa04a885b5400d66bcbbf794855 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Tue, 2 Apr 2024 16:52:07 -0600 Subject: only run tests on push --- .drone.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.drone.yml b/.drone.yml index 94716d5..b96d25e 100644 --- a/.drone.yml +++ b/.drone.yml @@ -12,7 +12,6 @@ steps: trigger: event: - - pull_request - push --- -- cgit v1.2.3-70-g09d2 From 385d4a84eb813ce6f777b6ab10642ad447f93321 Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Tue, 2 Apr 2024 20:26:24 -0600 Subject: fix dns race condition --- .drone.yml | 7 ++++++- test/dns_test.go | 49 ++++++++++++++++++++++++++++--------------------- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/.drone.yml b/.drone.yml index b96d25e..d056e69 100644 --- a/.drone.yml +++ b/.drone.yml @@ -12,7 +12,7 @@ steps: trigger: event: - - push + - pull_request --- kind: pipeline @@ -20,6 +20,11 @@ type: docker name: deploy steps: + - name: run tests + image: golang + commands: + - go build + - go test -p 1 -v ./... - name: docker image: plugins/docker settings: diff --git a/test/dns_test.go b/test/dns_test.go index 55bb060..2caabe4 100644 --- a/test/dns_test.go +++ b/test/dns_test.go @@ -21,10 +21,10 @@ func destroy(conn *sql.DB, path string) { } func randomPort() int { - return rand.Intn(3000) + 10000 + return rand.Intn(3000) + 1024 } -func setup() (*sql.DB, *dns.Server, int, *string, func()) { +func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { randomDb := utils.RandomId() dnsPort := randomPort() @@ -35,32 +35,35 @@ func setup() (*sql.DB, *dns.Server, int, *string, func()) { } database.FindOrSaveUser(testDb, testUser) + waitLock := &sync.Mutex{} server := hcdns.MakeServer(&args.Arguments{ DnsPort: dnsPort, }, testDb) + server.NotifyStartedFunc = func() { + waitLock.Unlock() + } + waitLock.Lock() - waitGroup := sync.WaitGroup{} - waitGroup.Add(1) go func() { server.ListenAndServe() - waitGroup.Done() }() + waitLock.Lock() address := fmt.Sprintf("127.0.0.1:%d", dnsPort) - return testDb, server, dnsPort, &address, func() { + return testDb, server, &address, waitLock, func() { + server.Shutdown() + testDb.Close() os.Remove(randomDb) - - server.Shutdown() - waitGroup.Wait() } } func TestWhenCNAMEIsResolved(t *testing.T) { t.Log("TestWhenCNAMEIsResolved") - testDb, _, _, addr, cleanup := setup() + testDb, _, addr, lock, cleanup := setup() defer cleanup() + defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", @@ -85,8 +88,8 @@ func TestWhenCNAMEIsResolved(t *testing.T) { qtype := dns.TypeA domain := dns.Fqdn(cname.Name) - client := new(dns.Client) - message := new(dns.Msg) + client := &dns.Client{} + message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) @@ -135,13 +138,14 @@ func TestWhenCNAMEIsResolved(t *testing.T) { func TestWhenNoRecordNxDomain(t *testing.T) { t.Log("TestWhenNoRecordNxDomain") - _, _, _, addr, cleanup := setup() + _, _, addr, lock, cleanup := setup() defer cleanup() + defer lock.Unlock() qtype := dns.TypeA domain := dns.Fqdn("nonexistant.example.com.") - client := new(dns.Client) - message := new(dns.Msg) + client := &dns.Client{} + message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) @@ -162,8 +166,9 @@ func TestWhenNoRecordNxDomain(t *testing.T) { func TestWhenUnresolvingCNAME(t *testing.T) { t.Log("TestWhenUnresolvingCNAME") - testDb, _, _, addr, cleanup := setup() + testDb, _, addr, lock, cleanup := setup() defer cleanup() + defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", @@ -178,8 +183,8 @@ func TestWhenUnresolvingCNAME(t *testing.T) { qtype := dns.TypeA domain := dns.Fqdn(cname.Name) - client := new(dns.Client) - message := new(dns.Msg) + client := &dns.Client{} + message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) @@ -216,8 +221,9 @@ func TestWhenUnresolvingCNAME(t *testing.T) { func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") - testDb, _, _, addr, cleanup := setup() + testDb, _, addr, lock, cleanup := setup() defer cleanup() + defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", @@ -232,8 +238,8 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { qtype := dns.TypeA domain := dns.Fqdn(cname.Name) - client := new(dns.Client) - message := new(dns.Msg) + client := &dns.Client{} + message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) @@ -245,6 +251,7 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { if len(in.Answer) > 0 { t.Fatalf("expected 0 answers, got %d", len(in.Answer)) } + if in.Rcode != dns.RcodeServerFailure { t.Fatalf("expected SERVFAIL, got %d", in.Rcode) } -- cgit v1.2.3-70-g09d2 From d0fa8924938c17a10028f7a13dfac346635ba61c Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Tue, 2 Apr 2024 20:27:01 -0600 Subject: remove unused destroy method --- test/dns_test.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/dns_test.go b/test/dns_test.go index 2caabe4..d875f3f 100644 --- a/test/dns_test.go +++ b/test/dns_test.go @@ -15,11 +15,6 @@ import ( "github.com/miekg/dns" ) -func destroy(conn *sql.DB, path string) { - conn.Close() - os.Remove(path) -} - func randomPort() int { return rand.Intn(3000) + 1024 } -- cgit v1.2.3-70-g09d2 From c32ca84e8a1a87994d5233a2d51c640a368b88bc Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Tue, 2 Apr 2024 23:28:55 -0600 Subject: cute cat cursor --- static/css/styles.css | 16 ++++++++++++++++ static/img/cursor-1.png | Bin 0 -> 570 bytes static/img/cursor-2.png | Bin 0 -> 563 bytes 3 files changed, 16 insertions(+) create mode 100644 static/img/cursor-1.png create mode 100644 static/img/cursor-2.png diff --git a/static/css/styles.css b/static/css/styles.css index 7486016..ba58018 100644 --- a/static/css/styles.css +++ b/static/css/styles.css @@ -15,6 +15,22 @@ padding: 0; color: var(--text-color); font-family: "ComicSans", sans-serif; + + cursor: url("/static/img/cursor-1.png"), auto; + -webkit-animation: cursor 400ms infinite; + animation: cursor 400ms infinite; +} + +@-webkit-keyframes cursor { + 0% {cursor: url("/static/img/cursor-2.png"), auto;} + 50% {cursor: url("/static/img/cursor-1.png"), auto;} + 100% {cursor: url("/static/img/cursor-2.png"), auto;} +} + +@keyframes cursor { + 0% {cursor: url("/static/img/cursor-2.png"), auto;} + 50% {cursor: url("/static/img/cursor-1.png"), auto;} + 100% {cursor: url("/static/img/cursor-2.png"), auto;} } body { diff --git a/static/img/cursor-1.png b/static/img/cursor-1.png new file mode 100644 index 0000000..68fbe5c Binary files /dev/null and b/static/img/cursor-1.png differ diff --git a/static/img/cursor-2.png b/static/img/cursor-2.png new file mode 100644 index 0000000..9851648 Binary files /dev/null and b/static/img/cursor-2.png differ -- cgit v1.2.3-70-g09d2 From cc33a90bfd455f36169b01b0cca064cd35e2524f Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Wed, 3 Apr 2024 14:27:55 -0600 Subject: abstract dns adapter --- adapters/cloudflare/cloudflare.go | 17 ++-- adapters/external_dns.go | 8 ++ api/dns.go | 177 +++++++++++++++++++------------------- api/serve.go | 10 ++- 4 files changed, 117 insertions(+), 95 deletions(-) create mode 100644 adapters/external_dns.go diff --git a/adapters/cloudflare/cloudflare.go b/adapters/cloudflare/cloudflare.go index 40b04a5..c302037 100644 --- a/adapters/cloudflare/cloudflare.go +++ b/adapters/cloudflare/cloudflare.go @@ -14,15 +14,20 @@ type CloudflareDNSResponse struct { Result database.DNSRecord `json:"result"` } -func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord) (string, error) { - url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", zoneId) +type CloudflareExternalDNSAdapter struct { + ZoneId string + APIToken string +} + +func (adapter *CloudflareExternalDNSAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) { + url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", adapter.ZoneId) reqBody := fmt.Sprintf(`{"type":"%s","name":"%s","content":"%s","ttl":%d,"proxied":false}`, record.Type, record.Name, record.Content, record.TTL) payload := strings.NewReader(reqBody) req, _ := http.NewRequest("POST", url, payload) - req.Header.Add("Authorization", "Bearer "+apiToken) + req.Header.Add("Authorization", "Bearer "+adapter.APIToken) req.Header.Add("Content-Type", "application/json") res, err := http.DefaultClient.Do(req) @@ -48,12 +53,12 @@ func CreateDNSRecord(zoneId string, apiToken string, record *database.DNSRecord) return result.ID, nil } -func DeleteDNSRecord(zoneId string, apiToken string, id string) error { - url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", zoneId, id) +func (adapter *CloudflareExternalDNSAdapter) DeleteDNSRecord(id string) error { + url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", adapter.ZoneId, id) req, _ := http.NewRequest("DELETE", url, nil) - req.Header.Add("Authorization", "Bearer "+apiToken) + req.Header.Add("Authorization", "Bearer "+adapter.APIToken) res, err := http.DefaultClient.Do(req) if err != nil { diff --git a/adapters/external_dns.go b/adapters/external_dns.go new file mode 100644 index 0000000..c861283 --- /dev/null +++ b/adapters/external_dns.go @@ -0,0 +1,8 @@ +package external_dns + +import "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + +type ExternalDNSAdapter interface { + CreateDNSRecord(record *database.DNSRecord) (string, error) + DeleteDNSRecord(id string) error +} diff --git a/api/dns.go b/api/dns.go index ad41103..6f0e1fd 100644 --- a/api/dns.go +++ b/api/dns.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) @@ -64,116 +64,119 @@ func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp } } -func CreateDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - formErrors := FormError{ - Errors: []string{}, - } +func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + formErrors := FormError{ + Errors: []string{}, + } - internal := req.FormValue("internal") == "on" - name := req.FormValue("name") - if internal && !strings.HasSuffix(name, ".") { - name += "." - } + internal := req.FormValue("internal") == "on" + name := req.FormValue("name") + if internal && !strings.HasSuffix(name, ".") { + name += "." + } - recordType := req.FormValue("type") - recordType = strings.ToUpper(recordType) + recordType := req.FormValue("type") + recordType = strings.ToUpper(recordType) - recordContent := req.FormValue("content") - ttl := req.FormValue("ttl") - ttlNum, err := strconv.Atoi(ttl) - if err != nil { - formErrors.Errors = append(formErrors.Errors, "invalid ttl") - } + recordContent := req.FormValue("content") + ttl := req.FormValue("ttl") + ttlNum, err := strconv.Atoi(ttl) + if err != nil { + formErrors.Errors = append(formErrors.Errors, "invalid ttl") + } - dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - if dnsRecordCount >= MAX_USER_RECORDS { - formErrors.Errors = append(formErrors.Errors, "max records reached") - } + dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + if dnsRecordCount >= MAX_USER_RECORDS { + formErrors.Errors = append(formErrors.Errors, "max records reached") + } - dnsRecord := &database.DNSRecord{ - UserID: context.User.ID, - Name: name, - Type: recordType, - Content: recordContent, - TTL: ttlNum, - Internal: internal, - } - if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord) { - formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") - } + dnsRecord := &database.DNSRecord{ + UserID: context.User.ID, + Name: name, + Type: recordType, + Content: recordContent, + TTL: ttlNum, + Internal: internal, + } + + if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord) { + formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") + } - if len(formErrors.Errors) == 0 { - if dnsRecord.Internal { - dnsRecord.ID = utils.RandomId() - } else { - cloudflareRecordId, err := cloudflare.CreateDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, dnsRecord) + if len(formErrors.Errors) == 0 { + if dnsRecord.Internal { + dnsRecord.ID = utils.RandomId() + } else { + dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord) + if err != nil { + log.Println(err) + formErrors.Errors = append(formErrors.Errors, err.Error()) + } + } + } + + if len(formErrors.Errors) == 0 { + _, err := database.SaveDNSRecord(context.DBConn, dnsRecord) if err != nil { log.Println(err) - formErrors.Errors = append(formErrors.Errors, err.Error()) + formErrors.Errors = append(formErrors.Errors, "error saving record") } - - dnsRecord.ID = cloudflareRecordId } - } - if len(formErrors.Errors) == 0 { - _, err := database.SaveDNSRecord(context.DBConn, dnsRecord) - if err != nil { - log.Println(err) - formErrors.Errors = append(formErrors.Errors, "error saving record") + if len(formErrors.Errors) == 0 { + http.Redirect(resp, req, "/dns", http.StatusFound) + return success(context, req, resp) } - } - if len(formErrors.Errors) == 0 { - http.Redirect(resp, req, "/dns", http.StatusFound) - return success(context, req, resp) - } - - (*context.TemplateData)["FormError"] = &formErrors - (*context.TemplateData)["RecordForm"] = dnsRecord + (*context.TemplateData)["FormError"] = &formErrors + (*context.TemplateData)["RecordForm"] = dnsRecord - resp.WriteHeader(http.StatusBadRequest) - return failure(context, req, resp) + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } } } -func DeleteDNSRecordContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - recordId := req.FormValue("id") - record, err := database.GetDNSRecord(context.DBConn, recordId) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } +func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + recordId := req.FormValue("id") + record, err := database.GetDNSRecord(context.DBConn, recordId) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } - if !userCanFuckWithDNSRecord(context.DBConn, context.User, record) { - resp.WriteHeader(http.StatusUnauthorized) - return failure(context, req, resp) - } + if !userCanFuckWithDNSRecord(context.DBConn, context.User, record) { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } - if !record.Internal { - err = cloudflare.DeleteDNSRecord(context.Args.CloudflareZone, context.Args.CloudflareToken, recordId) + if !record.Internal { + err = dnsAdapter.DeleteDNSRecord(recordId) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + } + + err = database.DeleteDNSRecord(context.DBConn, recordId) if err != nil { - log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } - } - err = database.DeleteDNSRecord(context.DBConn, recordId) - if err != nil { - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) + http.Redirect(resp, req, "/dns", http.StatusFound) + return success(context, req, resp) } - - http.Redirect(resp, req, "/dns", http.StatusFound) - return success(context, req, resp) } } diff --git a/api/serve.go b/api/serve.go index f71001d..1b632a1 100644 --- a/api/serve.go +++ b/api/serve.go @@ -7,6 +7,7 @@ import ( "net/http" "time" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" @@ -80,6 +81,11 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { fileServer := http.FileServer(http.Dir(argv.StaticPath)) mux.Handle("GET /static/", http.StripPrefix("/static/", CacheControlMiddleware(fileServer, 3600))) + cloudflareAdapter := &cloudflare.CloudflareExternalDNSAdapter{ + APIToken: argv.CloudflareToken, + ZoneId: argv.CloudflareZone, + } + makeRequestContext := func() *RequestContext { return &RequestContext{ DBConn: dbConn, @@ -126,12 +132,12 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation, FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation(cloudflareAdapter), FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation(cloudflareAdapter), GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) { -- cgit v1.2.3-70-g09d2 From da6b6011fc8a73af7d0feb32f116e6b10de11b44 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Wed, 3 Apr 2024 15:33:02 -0600 Subject: refactor dns server test a bit --- hcdns/server_test.go | 254 +++++++++++++++++++++++++++++++++++++++++++++++++++ test/dns_test.go | 253 -------------------------------------------------- 2 files changed, 254 insertions(+), 253 deletions(-) create mode 100644 hcdns/server_test.go delete mode 100644 test/dns_test.go diff --git a/hcdns/server_test.go b/hcdns/server_test.go new file mode 100644 index 0000000..177def4 --- /dev/null +++ b/hcdns/server_test.go @@ -0,0 +1,254 @@ +package hcdns_test + +import ( + "database/sql" + "fmt" + "math/rand" + "os" + "sync" + "testing" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" + "github.com/miekg/dns" +) + +func randomPort() int { + return rand.Intn(3000) + 5192 +} + +func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { + randomDb := utils.RandomId() + dnsPort := randomPort() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + testUser := &database.User{ + ID: "test", + } + database.FindOrSaveUser(testDb, testUser) + + waitLock := &sync.Mutex{} + server := hcdns.MakeServer(&args.Arguments{ + DnsPort: dnsPort, + }, testDb) + server.NotifyStartedFunc = func() { + waitLock.Unlock() + } + waitLock.Lock() + + go func() { + server.ListenAndServe() + }() + waitLock.Lock() + + address := fmt.Sprintf("127.0.0.1:%d", dnsPort) + return testDb, server, &address, waitLock, func() { + server.Shutdown() + + testDb.Close() + os.Remove(randomDb) + } +} + +func TestWhenCNAMEIsResolved(t *testing.T) { + t.Log("TestWhenCNAMEIsResolved") + + testDb, _, addr, lock, cleanup := setup() + defer cleanup() + defer lock.Unlock() + + records := []*database.DNSRecord{ + { + ID: "0", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "next.internal.example.com.", + TTL: 300, + Internal: true, + }, { + ID: "1", + UserID: "test", + Name: "next.internal.example.com.", + Type: "CNAME", + Content: "res.example.com.", + TTL: 300, + Internal: true, + }, + { + ID: "2", + UserID: "test", + Name: "res.example.com.", + Type: "A", + Content: "1.2.3.2", + TTL: 300, + Internal: true, + }, + } + + for _, record := range records { + database.SaveDNSRecord(testDb, record) + } + + qtype := dns.TypeA + domain := dns.Fqdn("cname.internal.example.com.") + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, *addr) + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 3 { + t.Fatalf("expected 3 answers, got %d", len(in.Answer)) + } + + for i, record := range records { + if in.Answer[i].Header().Name != record.Name { + t.Fatalf("expected %s, got %s", record.Name, in.Answer[i].Header().Name) + } + + if in.Answer[i].Header().Rrtype != dns.StringToType[record.Type] { + t.Fatalf("expected %s, got %d", record.Type, in.Answer[i].Header().Rrtype) + } + + if int(in.Answer[i].Header().Ttl) != record.TTL { + t.Fatalf("expected %d, got %d", record.TTL, in.Answer[i].Header().Ttl) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } + } + + if in.Answer[2].(*dns.A).A.String() != "1.2.3.2" { + t.Fatalf("expected final record to be the A record with correct IP") + } +} + +func TestWhenNoRecordNxDomain(t *testing.T) { + t.Log("TestWhenNoRecordNxDomain") + + _, _, addr, lock, cleanup := setup() + defer cleanup() + defer lock.Unlock() + + qtype := dns.TypeA + domain := dns.Fqdn("nonexistant.example.com.") + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, *addr) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 0 { + t.Fatalf("expected 0 answers, got %d", len(in.Answer)) + } + + if in.Rcode != dns.RcodeNameError { + t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) + } +} + +func TestWhenUnresolvingCNAME(t *testing.T) { + t.Log("TestWhenUnresolvingCNAME") + + testDb, _, addr, lock, cleanup := setup() + defer cleanup() + defer lock.Unlock() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "nonexistant.example.com.", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, *addr) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) != 1 { + t.Fatalf("expected 1 answer, got %d", len(in.Answer)) + } + + if !in.Authoritative { + t.Fatalf("expected authoritative response") + } + + if in.Answer[0].Header().Name != cname.Name { + t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) + } + + if in.Answer[0].Header().Rrtype != dns.TypeCNAME { + t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) + } + + if in.Answer[0].(*dns.CNAME).Target != cname.Content { + t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) + } + + if in.Rcode == dns.RcodeNameError { + t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) + } +} + +func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { + t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") + + testDb, _, addr, lock, cleanup := setup() + defer cleanup() + defer lock.Unlock() + + cname := &database.DNSRecord{ + ID: "1", + UserID: "test", + Name: "cname.internal.example.com.", + Type: "CNAME", + Content: "cname.internal.example.com.", + TTL: 300, + Internal: true, + } + database.SaveDNSRecord(testDb, cname) + + qtype := dns.TypeA + domain := dns.Fqdn(cname.Name) + client := &dns.Client{} + message := &dns.Msg{} + message.SetQuestion(domain, qtype) + + in, _, err := client.Exchange(message, *addr) + + if err != nil { + t.Fatal(err) + } + + if len(in.Answer) > 0 { + t.Fatalf("expected 0 answers, got %d", len(in.Answer)) + } + + if in.Rcode != dns.RcodeServerFailure { + t.Fatalf("expected SERVFAIL, got %d", in.Rcode) + } +} diff --git a/test/dns_test.go b/test/dns_test.go deleted file mode 100644 index d875f3f..0000000 --- a/test/dns_test.go +++ /dev/null @@ -1,253 +0,0 @@ -package hcdns - -import ( - "database/sql" - "fmt" - "math/rand" - "os" - "sync" - "testing" - - "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/hcdns" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" - "github.com/miekg/dns" -) - -func randomPort() int { - return rand.Intn(3000) + 1024 -} - -func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { - randomDb := utils.RandomId() - dnsPort := randomPort() - - testDb := database.MakeConn(&randomDb) - database.Migrate(testDb) - testUser := &database.User{ - ID: "test", - } - database.FindOrSaveUser(testDb, testUser) - - waitLock := &sync.Mutex{} - server := hcdns.MakeServer(&args.Arguments{ - DnsPort: dnsPort, - }, testDb) - server.NotifyStartedFunc = func() { - waitLock.Unlock() - } - waitLock.Lock() - - go func() { - server.ListenAndServe() - }() - waitLock.Lock() - - address := fmt.Sprintf("127.0.0.1:%d", dnsPort) - return testDb, server, &address, waitLock, func() { - server.Shutdown() - - testDb.Close() - os.Remove(randomDb) - } -} - -func TestWhenCNAMEIsResolved(t *testing.T) { - t.Log("TestWhenCNAMEIsResolved") - - testDb, _, addr, lock, cleanup := setup() - defer cleanup() - defer lock.Unlock() - - cname := &database.DNSRecord{ - ID: "1", - UserID: "test", - Name: "cname.internal.example.com.", - Type: "CNAME", - Content: "res.example.com.", - TTL: 300, - Internal: true, - } - a := &database.DNSRecord{ - ID: "2", - UserID: "test", - Name: "res.example.com.", - Type: "A", - Content: "127.0.0.1", - TTL: 300, - Internal: true, - } - database.SaveDNSRecord(testDb, cname) - database.SaveDNSRecord(testDb, a) - - qtype := dns.TypeA - domain := dns.Fqdn(cname.Name) - client := &dns.Client{} - message := &dns.Msg{} - message.SetQuestion(domain, qtype) - - in, _, err := client.Exchange(message, *addr) - - if err != nil { - t.Fatal(err) - } - - if len(in.Answer) != 2 { - t.Fatalf("expected 2 answers, got %d", len(in.Answer)) - } - - if in.Answer[0].Header().Name != cname.Name { - t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) - } - - if in.Answer[1].Header().Name != a.Name { - t.Fatalf("expected res.example.com., got %s", in.Answer[1].Header().Name) - } - - if in.Answer[0].(*dns.CNAME).Target != a.Name { - t.Fatalf("expected res.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) - } - - if in.Answer[1].(*dns.A).A.String() != a.Content { - t.Fatalf("expected %s, got %s", a.Content, in.Answer[1].(*dns.A).A.String()) - } - - if in.Answer[0].Header().Rrtype != dns.TypeCNAME { - t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) - } - - if in.Answer[1].Header().Rrtype != dns.TypeA { - t.Fatalf("expected A, got %d", in.Answer[1].Header().Rrtype) - } - - if int(in.Answer[0].Header().Ttl) != cname.TTL { - t.Fatalf("expected %d, got %d", cname.TTL, in.Answer[0].Header().Ttl) - } - - if !in.Authoritative { - t.Fatalf("expected authoritative response") - } -} - -func TestWhenNoRecordNxDomain(t *testing.T) { - t.Log("TestWhenNoRecordNxDomain") - - _, _, addr, lock, cleanup := setup() - defer cleanup() - defer lock.Unlock() - - qtype := dns.TypeA - domain := dns.Fqdn("nonexistant.example.com.") - client := &dns.Client{} - message := &dns.Msg{} - message.SetQuestion(domain, qtype) - - in, _, err := client.Exchange(message, *addr) - - if err != nil { - t.Fatal(err) - } - - if len(in.Answer) != 0 { - t.Fatalf("expected 0 answers, got %d", len(in.Answer)) - } - - if in.Rcode != dns.RcodeNameError { - t.Fatalf("expected NXDOMAIN, got %d", in.Rcode) - } -} - -func TestWhenUnresolvingCNAME(t *testing.T) { - t.Log("TestWhenUnresolvingCNAME") - - testDb, _, addr, lock, cleanup := setup() - defer cleanup() - defer lock.Unlock() - - cname := &database.DNSRecord{ - ID: "1", - UserID: "test", - Name: "cname.internal.example.com.", - Type: "CNAME", - Content: "nonexistant.example.com.", - TTL: 300, - Internal: true, - } - database.SaveDNSRecord(testDb, cname) - - qtype := dns.TypeA - domain := dns.Fqdn(cname.Name) - client := &dns.Client{} - message := &dns.Msg{} - message.SetQuestion(domain, qtype) - - in, _, err := client.Exchange(message, *addr) - - if err != nil { - t.Fatal(err) - } - - if len(in.Answer) != 1 { - t.Fatalf("expected 1 answer, got %d", len(in.Answer)) - } - - if !in.Authoritative { - t.Fatalf("expected authoritative response") - } - - if in.Answer[0].Header().Name != cname.Name { - t.Fatalf("expected cname.internal.example.com., got %s", in.Answer[0].Header().Name) - } - - if in.Answer[0].Header().Rrtype != dns.TypeCNAME { - t.Fatalf("expected CNAME, got %d", in.Answer[0].Header().Rrtype) - } - - if in.Answer[0].(*dns.CNAME).Target != cname.Content { - t.Fatalf("expected nonexistant.example.com., got %s", in.Answer[0].(*dns.CNAME).Target) - } - - if in.Rcode == dns.RcodeNameError { - t.Fatalf("expected no NXDOMAIN, got %d", in.Rcode) - } -} - -func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { - t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") - - testDb, _, addr, lock, cleanup := setup() - defer cleanup() - defer lock.Unlock() - - cname := &database.DNSRecord{ - ID: "1", - UserID: "test", - Name: "cname.internal.example.com.", - Type: "CNAME", - Content: "cname.internal.example.com.", - TTL: 300, - Internal: true, - } - database.SaveDNSRecord(testDb, cname) - - qtype := dns.TypeA - domain := dns.Fqdn(cname.Name) - client := &dns.Client{} - message := &dns.Msg{} - message.SetQuestion(domain, qtype) - - in, _, err := client.Exchange(message, *addr) - - if err != nil { - t.Fatal(err) - } - - if len(in.Answer) > 0 { - t.Fatalf("expected 0 answers, got %d", len(in.Answer)) - } - - if in.Rcode != dns.RcodeServerFailure { - t.Fatalf("expected SERVFAIL, got %d", in.Rcode) - } -} -- cgit v1.2.3-70-g09d2 From 47cc8feefa34f35722719a82456ccd6257903d35 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Wed, 3 Apr 2024 15:58:44 -0600 Subject: rename auth redirect login name --- api/auth.go | 2 +- api/guestbook.go | 69 +++++++------------------------------------------------- api/hcaptcha.go | 69 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 62 deletions(-) create mode 100644 api/hcaptcha.go diff --git a/api/auth.go b/api/auth.go index 14e6924..0e4c1ed 100644 --- a/api/auth.go +++ b/api/auth.go @@ -50,7 +50,7 @@ func StartSessionContinuation(context *RequestContext, req *http.Request, resp h } } -func InterceptCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { +func InterceptOauthCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { return func(success Continuation, failure Continuation) ContinuationChain { state := req.URL.Query().Get("state") code := req.URL.Query().Get("code") diff --git a/api/guestbook.go b/api/guestbook.go index 7b84f45..ee3c79a 100644 --- a/api/guestbook.go +++ b/api/guestbook.go @@ -1,8 +1,6 @@ package api import ( - "encoding/json" - "fmt" "log" "net/http" "strings" @@ -43,16 +41,11 @@ func SignGuestbookContinuation(context *RequestContext, req *http.Request, resp return func(success Continuation, failure Continuation) ContinuationChain { name := req.FormValue("name") message := req.FormValue("message") - hCaptchaResponse := req.FormValue("h-captcha-response") formErrors := FormError{ Errors: []string{}, } - if hCaptchaResponse == "" { - formErrors.Errors = append(formErrors.Errors, "hCaptcha is required") - } - entry := &database.GuestbookEntry{ ID: utils.RandomId(), Name: name, @@ -60,22 +53,19 @@ func SignGuestbookContinuation(context *RequestContext, req *http.Request, resp } formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...) - err := verifyHCaptcha(context.Args.HcaptchaSecret, hCaptchaResponse) - if err != nil { - log.Println(err) - - formErrors.Errors = append(formErrors.Errors, "hCaptcha verification failed") + if len(formErrors.Errors) == 0 { + _, err := database.SaveGuestbookEntry(context.DBConn, entry) + if err != nil { + log.Println(err) + formErrors.Errors = append(formErrors.Errors, "failed to save entry") + } } + if len(formErrors.Errors) > 0 { (*context.TemplateData)["FormError"] = formErrors (*context.TemplateData)["EntryForm"] = entry - return failure(context, req, resp) - } + resp.WriteHeader(http.StatusBadRequest) - _, err = database.SaveGuestbookEntry(context.DBConn, entry) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } @@ -96,46 +86,3 @@ func ListGuestbookContinuation(context *RequestContext, req *http.Request, resp return success(context, req, resp) } } - -func HcaptchaArgsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - (*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{ - SiteKey: context.Args.HcaptchaSiteKey, - } - log.Println(context.Args.HcaptchaSiteKey) - return success(context, req, resp) - } -} - -func verifyHCaptcha(secret, response string) error { - verifyURL := "https://hcaptcha.com/siteverify" - body := strings.NewReader("secret=" + secret + "&response=" + response) - - req, err := http.NewRequest("POST", verifyURL, body) - if err != nil { - return err - } - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return err - } - - jsonResponse := struct { - Success bool `json:"success"` - }{} - err = json.NewDecoder(resp.Body).Decode(&jsonResponse) - if err != nil { - return err - } - - if !jsonResponse.Success { - return fmt.Errorf("hcaptcha verification failed") - } - - defer resp.Body.Close() - return nil -} diff --git a/api/hcaptcha.go b/api/hcaptcha.go new file mode 100644 index 0000000..a310c01 --- /dev/null +++ b/api/hcaptcha.go @@ -0,0 +1,69 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" +) + +func verifyCaptcha(secret, response string) error { + verifyURL := "https://hcaptcha.com/siteverify" + body := strings.NewReader("secret=" + secret + "&response=" + response) + + req, err := http.NewRequest("POST", verifyURL, body) + if err != nil { + return err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err + } + + jsonResponse := struct { + Success bool `json:"success"` + }{} + err = json.NewDecoder(resp.Body).Decode(&jsonResponse) + if err != nil { + return err + } + + if !jsonResponse.Success { + return fmt.Errorf("hcaptcha verification failed") + } + + defer resp.Body.Close() + return nil +} + +func CaptchaArgsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + (*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{ + SiteKey: context.Args.HcaptchaSiteKey, + } + return success(context, req, resp) + } +} + +func CaptchaVerificationContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { + return func(success Continuation, failure Continuation) ContinuationChain { + hCaptchaResponse := req.FormValue("h-captcha-response") + secretKey := context.Args.HcaptchaSecret + + err := verifyCaptcha(secretKey, hCaptchaResponse) + if err != nil { + (*context.TemplateData)["FormError"] = FormError{ + Errors: []string{"hCaptcha verification failed"}, + } + resp.WriteHeader(http.StatusBadRequest) + + return failure(context, req, resp) + } + + return success(context, req, resp) + } +} -- cgit v1.2.3-70-g09d2 From 8c7d9b376249807e1595f440fa72c77cafbdaf6f Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Wed, 3 Apr 2024 15:59:12 -0600 Subject: dont always 200 on template render --- api/template.go | 1 - 1 file changed, 1 deletion(-) diff --git a/api/template.go b/api/template.go index eeaeb51..d637c64 100644 --- a/api/template.go +++ b/api/template.go @@ -66,7 +66,6 @@ func TemplateContinuation(path string, showBase bool) Continuation { return failure(context, req, resp) } - resp.WriteHeader(200) resp.Header().Set("Content-Type", "text/html") resp.Write(html.Bytes()) return success(context, req, resp) -- cgit v1.2.3-70-g09d2 From b74a955dcb8cc1d5d2599a1b096510a60e55e7d7 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Wed, 3 Apr 2024 15:59:19 -0600 Subject: add guestbook tests --- api/guestbook_test.go | 129 ++++++++++++++++++++++++++++++++++++++++++++++++++ api/serve.go | 18 +++---- 2 files changed, 135 insertions(+), 12 deletions(-) create mode 100644 api/guestbook_test.go diff --git a/api/guestbook_test.go b/api/guestbook_test.go new file mode 100644 index 0000000..5c1831f --- /dev/null +++ b/api/guestbook_test.go @@ -0,0 +1,129 @@ +package api_test + +import ( + "database/sql" + "net/http" + "net/http/httptest" + "os" + "testing" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +func setup() (*sql.DB, *api.RequestContext, func()) { + randomDb := utils.RandomId() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + + context := &api.RequestContext{ + DBConn: testDb, + Args: &args.Arguments{}, + TemplateData: &(map[string]interface{}{}), + } + + return testDb, context, func() { + testDb.Close() + os.Remove(randomDb) + } +} + +func TestValidGuestbookPutsInDatabase(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + entries, err := database.GetGuestbookEntries(db) + if err != nil { + t.Fatal(err) + } + if len(entries) > 0 { + t.Errorf("expected no entries, got entries") + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + api.SignGuestbookContinuation(context, r, w)(api.IdContinuation, api.IdContinuation) + })) + defer ts.Close() + + req := httptest.NewRequest("POST", ts.URL, nil) + req.Form = map[string][]string{ + "name": {"test"}, + "message": {"test"}, + } + + w := httptest.NewRecorder() + ts.Config.Handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status code 200, got %d", w.Code) + } + + entries, err = database.GetGuestbookEntries(db) + if err != nil { + t.Fatal(err) + } + + if len(entries) != 1 { + t.Errorf("expected 1 entry, got %d", len(entries)) + } + + if entries[0].Name != req.FormValue("name") { + t.Errorf("expected name %s, got %s", req.FormValue("name"), entries[0].Name) + } +} + +func TestInvalidGuestbookNotFoundInDatabase(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + entries, err := database.GetGuestbookEntries(db) + if err != nil { + t.Fatal(err) + } + if len(entries) > 0 { + t.Errorf("expected no entries, got entries") + } + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + api.SignGuestbookContinuation(context, r, w)(api.IdContinuation, api.IdContinuation) + })) + defer testServer.Close() + + reallyLongStringThatWouldTakeTooMuchSpace := "a\na\na\na\na\na\na\na\na\na\na\n" + invalidRequests := []struct { + name string + message string + }{ + {"", "test"}, + {"test", ""}, + {"", ""}, + {"test", reallyLongStringThatWouldTakeTooMuchSpace}, + } + + for _, form := range invalidRequests { + req := httptest.NewRequest("POST", testServer.URL, nil) + req.Form = map[string][]string{ + "name": {form.name}, + "message": {form.message}, + } + + responseRecorder := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(responseRecorder, req) + + if responseRecorder.Code != http.StatusBadRequest { + t.Errorf("expected status code 400, got %d", responseRecorder.Code) + } + } + + entries, err = database.GetGuestbookEntries(db) + if err != nil { + t.Fatal(err) + } + + if len(entries) != 0 { + t.Errorf("expected 0 entries, got %d", len(entries)) + } +} diff --git a/api/serve.go b/api/serve.go index 1b632a1..9547ee0 100644 --- a/api/serve.go +++ b/api/serve.go @@ -88,9 +88,8 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { makeRequestContext := func() *RequestContext { return &RequestContext{ - DBConn: dbConn, - Args: argv, - + DBConn: dbConn, + Args: argv, TemplateData: &map[string]interface{}{}, } } @@ -100,7 +99,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) - mux.HandleFunc("GET /api/health", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() LogRequestContinuation(requestContext, r, w)(HealthCheckContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) @@ -112,12 +111,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(InterceptCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) - }) - - mux.HandleFunc("GET /me", func(w http.ResponseWriter, r *http.Request) { - requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(RefreshSessionContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(InterceptOauthCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) { @@ -157,12 +151,12 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CaptchaArgsContinuation, CaptchaArgsContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("POST /guestbook", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(HcaptchaArgsContinuation, HcaptchaArgsContinuation)(SignGuestbookContinuation, FailurePassingContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CaptchaVerificationContinuation, CaptchaVerificationContinuation)(SignGuestbookContinuation, FailurePassingContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(CaptchaArgsContinuation, CaptchaArgsContinuation)(TemplateContinuation("guestbook.html", true), TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) { -- cgit v1.2.3-70-g09d2 From e398cf05402c010d594cea4e2dea307ca1a36dbe Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Wed, 3 Apr 2024 16:22:19 -0600 Subject: checkpoint to save work; had to get on the bus --- api/auth_test.go | 37 +++++++++++++++++++++++++++++++++++++ api/dns.go | 17 ++++++----------- api/dns_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ api/serve.go | 4 ++++ 4 files changed, 103 insertions(+), 11 deletions(-) create mode 100644 api/auth_test.go create mode 100644 api/dns_test.go diff --git a/api/auth_test.go b/api/auth_test.go new file mode 100644 index 0000000..45ca12e --- /dev/null +++ b/api/auth_test.go @@ -0,0 +1,37 @@ +package api_test + +import ( + "database/sql" + "os" + "testing" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +func setup() (*sql.DB, *api.RequestContext, func()) { + randomDb := utils.RandomId() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + + context := &api.RequestContext{ + DBConn: testDb, + Args: &args.Arguments{}, + TemplateData: &(map[string]interface{}{}), + } + + return testDb, context, func() { + testDb.Close() + os.Remove(randomDb) + } +} + +/* +todo: test api key creation ++ api key attached to user ++ user session is unique ++ goLogin goes to page in cookie +*/ diff --git a/api/dns.go b/api/dns.go index 6f0e1fd..7ade6e4 100644 --- a/api/dns.go +++ b/api/dns.go @@ -15,23 +15,18 @@ import ( const MAX_USER_RECORDS = 65 -type FormError struct { - Errors []string -} +var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} -func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord) bool { +func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool { ownedByUser := (user.ID == record.UserID) if !ownedByUser { return false } if !record.Internal { - userOwnedDomains := []string{ - fmt.Sprintf("%s", user.Username), - fmt.Sprintf("%s.endpoints", user.Username), - } + for _, format := range ownedInternalDomainFormats { + domain := fmt.Sprintf(format, user.Username) - for _, domain := range userOwnedDomains { isInSubDomain := strings.HasSuffix(record.Name, "."+domain) if domain == record.Name || isInSubDomain { return true @@ -106,7 +101,7 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun Internal: internal, } - if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord) { + if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, USER_OWNED_INTERNAL_FMT_DOMAINS) { formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") } @@ -155,7 +150,7 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun return failure(context, req, resp) } - if !userCanFuckWithDNSRecord(context.DBConn, context.User, record) { + if !userCanFuckWithDNSRecord(context.DBConn, context.User, record, USER_OWNED_INTERNAL_FMT_DOMAINS) { resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } diff --git a/api/dns_test.go b/api/dns_test.go new file mode 100644 index 0000000..59dd85b --- /dev/null +++ b/api/dns_test.go @@ -0,0 +1,56 @@ +package api_test + +import ( + "database/sql" + "net/http" + "net/http/httptest" + "os" + "testing" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +func setup() (*sql.DB, *api.RequestContext, func()) { + randomDb := utils.RandomId() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + + context := &api.RequestContext{ + DBConn: testDb, + Args: &args.Arguments{}, + TemplateData: &(map[string]interface{}{}), + } + + return testDb, context, func() { + testDb.Close() + os.Remove(randomDb) + } +} + +func TestThatOwnerCanPutRecordInDomain(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + testUser := &database.User{ + ID: "test", + Username: "test", + } + + records, err := database.GetUserDNSRecords(db, context.User.ID) + if err != nil { + t.Fatal(err) + } + if len(records) > 0 { + t.Errorf("expected no records, got records") + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + api.PutDNSRecordContinuation(context, r, w)(api.IdContinuation, api.IdContinuation) + })) + defer ts.Close() + +} diff --git a/api/serve.go b/api/serve.go index 9547ee0..1536f65 100644 --- a/api/serve.go +++ b/api/serve.go @@ -24,6 +24,10 @@ type RequestContext struct { User *database.User } +type FormError struct { + Errors []string +} + type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain type ContinuationChain func(Continuation, Continuation) ContinuationChain -- cgit v1.2.3-70-g09d2 From f38e8719c2a8537fe9b64ed8ceca45858a58e498 Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Wed, 3 Apr 2024 17:53:50 -0600 Subject: make it compile --- api/api_keys.go | 87 ------------ api/auth.go | 287 --------------------------------------- api/auth/auth.go | 288 ++++++++++++++++++++++++++++++++++++++++ api/auth/auth_test.go | 36 +++++ api/auth_test.go | 37 ------ api/dns.go | 177 ------------------------ api/dns/dns.go | 178 +++++++++++++++++++++++++ api/dns/dns_test.go | 63 +++++++++ api/dns_test.go | 56 -------- api/guestbook.go | 88 ------------ api/guestbook/guestbook.go | 85 ++++++++++++ api/guestbook/guestbook_test.go | 136 +++++++++++++++++++ api/guestbook_test.go | 129 ------------------ api/hcaptcha.go | 69 ---------- api/hcaptcha/hcaptcha.go | 75 +++++++++++ api/keys/keys.go | 88 ++++++++++++ api/serve.go | 76 +++++------ api/template.go | 74 ----------- api/template/template.go | 76 +++++++++++ api/types/types.go | 28 ++++ 20 files changed, 1085 insertions(+), 1048 deletions(-) delete mode 100644 api/api_keys.go delete mode 100644 api/auth.go create mode 100644 api/auth/auth.go create mode 100644 api/auth/auth_test.go delete mode 100644 api/auth_test.go delete mode 100644 api/dns.go create mode 100644 api/dns/dns.go create mode 100644 api/dns/dns_test.go delete mode 100644 api/dns_test.go delete mode 100644 api/guestbook.go create mode 100644 api/guestbook/guestbook.go create mode 100644 api/guestbook/guestbook_test.go delete mode 100644 api/guestbook_test.go delete mode 100644 api/hcaptcha.go create mode 100644 api/hcaptcha/hcaptcha.go create mode 100644 api/keys/keys.go delete mode 100644 api/template.go create mode 100644 api/template/template.go create mode 100644 api/types/types.go diff --git a/api/api_keys.go b/api/api_keys.go deleted file mode 100644 index d636044..0000000 --- a/api/api_keys.go +++ /dev/null @@ -1,87 +0,0 @@ -package api - -import ( - "log" - "net/http" - - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" -) - -const MAX_USER_API_KEYS = 5 - -func ListAPIKeysContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - apiKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - (*context.TemplateData)["APIKeys"] = apiKeys - return success(context, req, resp) - } -} - -func CreateAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - formErrors := FormError{ - Errors: []string{}, - } - - numKeys, err := database.CountUserAPIKeys(context.DBConn, context.User.ID) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - if numKeys >= MAX_USER_API_KEYS { - formErrors.Errors = append(formErrors.Errors, "max api keys reached") - } - - if len(formErrors.Errors) > 0 { - (*context.TemplateData)["FormError"] = formErrors - return failure(context, req, resp) - } - - _, err = database.SaveAPIKey(context.DBConn, &database.UserApiKey{ - UserID: context.User.ID, - Key: utils.RandomId(), - }) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - return success(context, req, resp) - } -} - -func DeleteAPIKeyContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - key := req.FormValue("key") - - apiKey, err := database.GetAPIKey(context.DBConn, key) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - if (apiKey == nil) || (apiKey.UserID != context.User.ID) { - resp.WriteHeader(http.StatusUnauthorized) - return failure(context, req, resp) - } - - err = database.DeleteAPIKey(context.DBConn, key) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - http.Redirect(resp, req, "/keys", http.StatusFound) - return success(context, req, resp) - } -} diff --git a/api/auth.go b/api/auth.go deleted file mode 100644 index 0e4c1ed..0000000 --- a/api/auth.go +++ /dev/null @@ -1,287 +0,0 @@ -package api - -import ( - "crypto/sha256" - "database/sql" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "log" - "net/http" - "strings" - "time" - - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" - "golang.org/x/oauth2" -) - -func StartSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - verifier := utils.RandomId() + utils.RandomId() - - sha2 := sha256.New() - io.WriteString(sha2, verifier) - codeChallenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil)) - - state := utils.RandomId() - url := context.Args.OauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge", codeChallenge)) - - http.SetCookie(resp, &http.Cookie{ - Name: "verifier", - Value: verifier, - Path: "/", - Secure: true, - SameSite: http.SameSiteLaxMode, - MaxAge: 60, - }) - http.SetCookie(resp, &http.Cookie{ - Name: "state", - Value: state, - Path: "/", - Secure: true, - SameSite: http.SameSiteLaxMode, - MaxAge: 60, - }) - - http.Redirect(resp, req, url, http.StatusFound) - return success(context, req, resp) - } -} - -func InterceptOauthCodeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - state := req.URL.Query().Get("state") - code := req.URL.Query().Get("code") - - if code == "" || state == "" { - resp.WriteHeader(http.StatusBadRequest) - return failure(context, req, resp) - } - - if !verifyState(req, "state", state) { - resp.WriteHeader(http.StatusBadRequest) - return failure(context, req, resp) - } - verifierCookie, err := req.Cookie("verifier") - if err != nil { - resp.WriteHeader(http.StatusBadRequest) - return failure(context, req, resp) - } - - reqContext := req.Context() - token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value)) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - client := context.Args.OauthConfig.Client(reqContext, token) - user, err := getOauthUser(context.DBConn, client, context.Args.OauthUserInfoURI) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - - return failure(context, req, resp) - } - - session, err := database.MakeUserSessionFor(context.DBConn, user) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - http.SetCookie(resp, &http.Cookie{ - Name: "session", - Value: session.ID, - Path: "/", - SameSite: http.SameSiteLaxMode, - Secure: true, - }) - - redirect := "/" - redirectCookie, err := req.Cookie("redirect") - if err == nil && redirectCookie.Value != "" { - redirect = redirectCookie.Value - http.SetCookie(resp, &http.Cookie{ - Name: "redirect", - MaxAge: 0, - }) - } - - http.Redirect(resp, req, redirect, http.StatusFound) - return success(context, req, resp) - } -} - -func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { - if bearerToken == "" { - return nil, nil - } - - parts := strings.Split(bearerToken, " ") - if len(parts) != 2 || parts[0] != "Bearer" { - return nil, nil - } - - apiKey, err := database.GetAPIKey(dbConn, parts[1]) - if err != nil { - return nil, err - } - if apiKey == nil { - return nil, nil - } - - user, err := database.GetUser(dbConn, apiKey.UserID) - if err != nil { - return nil, err - } - - return user, nil -} - -func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { - session, err := database.GetSession(dbConn, sessionId) - if err != nil { - return nil, err - } - - if session.ExpireAt.Before(time.Now()) { - session = nil - database.DeleteSession(dbConn, sessionId) - return nil, fmt.Errorf("session expired") - } - - user, err := database.GetUser(dbConn, session.UserID) - if err != nil { - return nil, err - } - - return user, nil -} - -func VerifySessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - authHeader := req.Header.Get("Authorization") - user, userErr := getUserFromAuthHeader(context.DBConn, authHeader) - - sessionCookie, err := req.Cookie("session") - if err == nil && sessionCookie.Value != "" { - user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value) - } - - if userErr != nil || user == nil { - log.Println(userErr, user) - - http.SetCookie(resp, &http.Cookie{ - Name: "session", - MaxAge: 0, // reset session cookie in case - }) - - context.User = nil - return failure(context, req, resp) - } - - context.User = user - return success(context, req, resp) - } -} - -func GoLoginContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - http.SetCookie(resp, &http.Cookie{ - Name: "redirect", - Value: req.URL.Path, - Path: "/", - Secure: true, - SameSite: http.SameSiteLaxMode, - }) - - http.Redirect(resp, req, "/login", http.StatusFound) - return failure(context, req, resp) - } -} - -func RefreshSessionContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - sessionCookie, err := req.Cookie("session") - if err != nil { - resp.WriteHeader(http.StatusUnauthorized) - return failure(context, req, resp) - } - - _, err = database.RefreshSession(context.DBConn, sessionCookie.Value) - if err != nil { - resp.WriteHeader(http.StatusUnauthorized) - return failure(context, req, resp) - } - - return success(context, req, resp) - } -} - -func LogoutContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - sessionCookie, err := req.Cookie("session") - if err == nil && sessionCookie.Value != "" { - _ = database.DeleteSession(context.DBConn, sessionCookie.Value) - } - - http.Redirect(resp, req, "/", http.StatusFound) - http.SetCookie(resp, &http.Cookie{ - Name: "session", - MaxAge: 0, - }) - return success(context, req, resp) - } -} - -func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) { - userResponse, err := client.Get(uri) - if err != nil { - return nil, err - } - - userStruct, err := createUserFromResponse(userResponse) - if err != nil { - return nil, err - } - - user, err := database.FindOrSaveUser(dbConn, userStruct) - if err != nil { - return nil, err - } - - return user, nil -} - -func createUserFromResponse(response *http.Response) (*database.User, error) { - user := &database.User{ - CreatedAt: time.Now(), - } - - err := json.NewDecoder(response.Body).Decode(user) - defer response.Body.Close() - - if err != nil { - log.Println(err) - return nil, err - } - - user.Username = strings.ToLower(user.Username) - user.Username = strings.Split(user.Username, "@")[0] - - return user, nil -} - -func verifyState(req *http.Request, stateCookieName string, expectedState string) bool { - cookie, err := req.Cookie(stateCookieName) - if err != nil || cookie.Value != expectedState { - return false - } - - return true -} diff --git a/api/auth/auth.go b/api/auth/auth.go new file mode 100644 index 0000000..dc348b2 --- /dev/null +++ b/api/auth/auth.go @@ -0,0 +1,288 @@ +package auth + +import ( + "crypto/sha256" + "database/sql" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "strings" + "time" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" + "golang.org/x/oauth2" +) + +func StartSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + verifier := utils.RandomId() + utils.RandomId() + + sha2 := sha256.New() + io.WriteString(sha2, verifier) + codeChallenge := base64.RawURLEncoding.EncodeToString(sha2.Sum(nil)) + + state := utils.RandomId() + url := context.Args.OauthConfig.AuthCodeURL(state, oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge", codeChallenge)) + + http.SetCookie(resp, &http.Cookie{ + Name: "verifier", + Value: verifier, + Path: "/", + Secure: true, + SameSite: http.SameSiteLaxMode, + MaxAge: 60, + }) + http.SetCookie(resp, &http.Cookie{ + Name: "state", + Value: state, + Path: "/", + Secure: true, + SameSite: http.SameSiteLaxMode, + MaxAge: 60, + }) + + http.Redirect(resp, req, url, http.StatusFound) + return success(context, req, resp) + } +} + +func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + state := req.URL.Query().Get("state") + code := req.URL.Query().Get("code") + + if code == "" || state == "" { + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + + if !verifyState(req, "state", state) { + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + verifierCookie, err := req.Cookie("verifier") + if err != nil { + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + + reqContext := req.Context() + token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value)) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + client := context.Args.OauthConfig.Client(reqContext, token) + user, err := getOauthUser(context.DBConn, client, context.Args.OauthUserInfoURI) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + + return failure(context, req, resp) + } + + session, err := database.MakeUserSessionFor(context.DBConn, user) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + http.SetCookie(resp, &http.Cookie{ + Name: "session", + Value: session.ID, + Path: "/", + SameSite: http.SameSiteLaxMode, + Secure: true, + }) + + redirect := "/" + redirectCookie, err := req.Cookie("redirect") + if err == nil && redirectCookie.Value != "" { + redirect = redirectCookie.Value + http.SetCookie(resp, &http.Cookie{ + Name: "redirect", + MaxAge: 0, + }) + } + + http.Redirect(resp, req, redirect, http.StatusFound) + return success(context, req, resp) + } +} + +func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { + if bearerToken == "" { + return nil, nil + } + + parts := strings.Split(bearerToken, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + return nil, nil + } + + typesKey, err := database.GetAPIKey(dbConn, parts[1]) + if err != nil { + return nil, err + } + if typesKey == nil { + return nil, nil + } + + user, err := database.GetUser(dbConn, typesKey.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + +func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { + session, err := database.GetSession(dbConn, sessionId) + if err != nil { + return nil, err + } + + if session.ExpireAt.Before(time.Now()) { + session = nil + database.DeleteSession(dbConn, sessionId) + return nil, fmt.Errorf("session expired") + } + + user, err := database.GetUser(dbConn, session.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + +func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + authHeader := req.Header.Get("Authorization") + user, userErr := getUserFromAuthHeader(context.DBConn, authHeader) + + sessionCookie, err := req.Cookie("session") + if err == nil && sessionCookie.Value != "" { + user, userErr = getUserFromSession(context.DBConn, sessionCookie.Value) + } + + if userErr != nil || user == nil { + log.Println(userErr, user) + + http.SetCookie(resp, &http.Cookie{ + Name: "session", + MaxAge: 0, // reset session cookie in case + }) + + context.User = nil + return failure(context, req, resp) + } + + context.User = user + return success(context, req, resp) + } +} + +func GoLoginContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + http.SetCookie(resp, &http.Cookie{ + Name: "redirect", + Value: req.URL.Path, + Path: "/", + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + + http.Redirect(resp, req, "/login", http.StatusFound) + return failure(context, req, resp) + } +} + +func RefreshSessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + sessionCookie, err := req.Cookie("session") + if err != nil { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + _, err = database.RefreshSession(context.DBConn, sessionCookie.Value) + if err != nil { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + return success(context, req, resp) + } +} + +func LogoutContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + sessionCookie, err := req.Cookie("session") + if err == nil && sessionCookie.Value != "" { + _ = database.DeleteSession(context.DBConn, sessionCookie.Value) + } + + http.Redirect(resp, req, "/", http.StatusFound) + http.SetCookie(resp, &http.Cookie{ + Name: "session", + MaxAge: 0, + }) + return success(context, req, resp) + } +} + +func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.User, error) { + userResponse, err := client.Get(uri) + if err != nil { + return nil, err + } + + userStruct, err := createUserFromResponse(userResponse) + if err != nil { + return nil, err + } + + user, err := database.FindOrSaveUser(dbConn, userStruct) + if err != nil { + return nil, err + } + + return user, nil +} + +func createUserFromResponse(response *http.Response) (*database.User, error) { + user := &database.User{ + CreatedAt: time.Now(), + } + + err := json.NewDecoder(response.Body).Decode(user) + defer response.Body.Close() + + if err != nil { + log.Println(err) + return nil, err + } + + user.Username = strings.ToLower(user.Username) + user.Username = strings.Split(user.Username, "@")[0] + + return user, nil +} + +func verifyState(req *http.Request, stateCookieName string, expectedState string) bool { + cookie, err := req.Cookie(stateCookieName) + if err != nil || cookie.Value != expectedState { + return false + } + + return true +} diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go new file mode 100644 index 0000000..a6c2a45 --- /dev/null +++ b/api/auth/auth_test.go @@ -0,0 +1,36 @@ +package auth_test + +import ( + "database/sql" + "os" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +func setup() (*sql.DB, *types.RequestContext, func()) { + randomDb := utils.RandomId() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + + context := &types.RequestContext{ + DBConn: testDb, + Args: &args.Arguments{}, + TemplateData: &(map[string]interface{}{}), + } + + return testDb, context, func() { + testDb.Close() + os.Remove(randomDb) + } +} + +/* +todo: test types key creation ++ api key attached to user ++ user session is unique ++ goLogin goes to page in cookie +*/ diff --git a/api/auth_test.go b/api/auth_test.go deleted file mode 100644 index 45ca12e..0000000 --- a/api/auth_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package api_test - -import ( - "database/sql" - "os" - "testing" - - "git.hatecomputers.club/hatecomputers/hatecomputers.club/api" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" -) - -func setup() (*sql.DB, *api.RequestContext, func()) { - randomDb := utils.RandomId() - - testDb := database.MakeConn(&randomDb) - database.Migrate(testDb) - - context := &api.RequestContext{ - DBConn: testDb, - Args: &args.Arguments{}, - TemplateData: &(map[string]interface{}{}), - } - - return testDb, context, func() { - testDb.Close() - os.Remove(randomDb) - } -} - -/* -todo: test api key creation -+ api key attached to user -+ user session is unique -+ goLogin goes to page in cookie -*/ diff --git a/api/dns.go b/api/dns.go deleted file mode 100644 index 7ade6e4..0000000 --- a/api/dns.go +++ /dev/null @@ -1,177 +0,0 @@ -package api - -import ( - "database/sql" - "fmt" - "log" - "net/http" - "strconv" - "strings" - - "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" -) - -const MAX_USER_RECORDS = 65 - -var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} - -func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool { - ownedByUser := (user.ID == record.UserID) - if !ownedByUser { - return false - } - - if !record.Internal { - for _, format := range ownedInternalDomainFormats { - domain := fmt.Sprintf(format, user.Username) - - isInSubDomain := strings.HasSuffix(record.Name, "."+domain) - if domain == record.Name || isInSubDomain { - return true - } - } - return false - } - - owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name) - if err != nil { - log.Println(err) - return false - } - - userIsOwnerOfDomain := owner == user.ID - return ownedByUser && userIsOwnerOfDomain -} - -func ListDNSRecordsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - (*context.TemplateData)["DNSRecords"] = dnsRecords - return success(context, req, resp) - } -} - -func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - formErrors := FormError{ - Errors: []string{}, - } - - internal := req.FormValue("internal") == "on" - name := req.FormValue("name") - if internal && !strings.HasSuffix(name, ".") { - name += "." - } - - recordType := req.FormValue("type") - recordType = strings.ToUpper(recordType) - - recordContent := req.FormValue("content") - ttl := req.FormValue("ttl") - ttlNum, err := strconv.Atoi(ttl) - if err != nil { - formErrors.Errors = append(formErrors.Errors, "invalid ttl") - } - - dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - if dnsRecordCount >= MAX_USER_RECORDS { - formErrors.Errors = append(formErrors.Errors, "max records reached") - } - - dnsRecord := &database.DNSRecord{ - UserID: context.User.ID, - Name: name, - Type: recordType, - Content: recordContent, - TTL: ttlNum, - Internal: internal, - } - - if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, USER_OWNED_INTERNAL_FMT_DOMAINS) { - formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") - } - - if len(formErrors.Errors) == 0 { - if dnsRecord.Internal { - dnsRecord.ID = utils.RandomId() - } else { - dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord) - if err != nil { - log.Println(err) - formErrors.Errors = append(formErrors.Errors, err.Error()) - } - } - } - - if len(formErrors.Errors) == 0 { - _, err := database.SaveDNSRecord(context.DBConn, dnsRecord) - if err != nil { - log.Println(err) - formErrors.Errors = append(formErrors.Errors, "error saving record") - } - } - - if len(formErrors.Errors) == 0 { - http.Redirect(resp, req, "/dns", http.StatusFound) - return success(context, req, resp) - } - - (*context.TemplateData)["FormError"] = &formErrors - (*context.TemplateData)["RecordForm"] = dnsRecord - - resp.WriteHeader(http.StatusBadRequest) - return failure(context, req, resp) - } - } -} - -func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - recordId := req.FormValue("id") - record, err := database.GetDNSRecord(context.DBConn, recordId) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - if !userCanFuckWithDNSRecord(context.DBConn, context.User, record, USER_OWNED_INTERNAL_FMT_DOMAINS) { - resp.WriteHeader(http.StatusUnauthorized) - return failure(context, req, resp) - } - - if !record.Internal { - err = dnsAdapter.DeleteDNSRecord(recordId) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - } - - err = database.DeleteDNSRecord(context.DBConn, recordId) - if err != nil { - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - http.Redirect(resp, req, "/dns", http.StatusFound) - return success(context, req, resp) - } - } -} diff --git a/api/dns/dns.go b/api/dns/dns.go new file mode 100644 index 0000000..4805146 --- /dev/null +++ b/api/dns/dns.go @@ -0,0 +1,178 @@ +package dns + +import ( + "database/sql" + "fmt" + "log" + "net/http" + "strconv" + "strings" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +const MAX_USER_RECORDS = 65 + +var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} + +func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool { + ownedByUser := (user.ID == record.UserID) + if !ownedByUser { + return false + } + + if !record.Internal { + for _, format := range ownedInternalDomainFormats { + domain := fmt.Sprintf(format, user.Username) + + isInSubDomain := strings.HasSuffix(record.Name, "."+domain) + if domain == record.Name || isInSubDomain { + return true + } + } + return false + } + + owner, err := database.FindFirstDomainOwnerId(dbConn, record.Name) + if err != nil { + log.Println(err) + return false + } + + userIsOwnerOfDomain := owner == user.ID + return ownedByUser && userIsOwnerOfDomain +} + +func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + dnsRecords, err := database.GetUserDNSRecords(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + (*context.TemplateData)["DNSRecords"] = dnsRecords + return success(context, req, resp) + } +} + +func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + formErrors := types.FormError{ + Errors: []string{}, + } + + internal := req.FormValue("internal") == "on" + name := req.FormValue("name") + if internal && !strings.HasSuffix(name, ".") { + name += "." + } + + recordType := req.FormValue("type") + recordType = strings.ToUpper(recordType) + + recordContent := req.FormValue("content") + ttl := req.FormValue("ttl") + ttlNum, err := strconv.Atoi(ttl) + if err != nil { + formErrors.Errors = append(formErrors.Errors, "invalid ttl") + } + + dnsRecordCount, err := database.CountUserDNSRecords(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + if dnsRecordCount >= MAX_USER_RECORDS { + formErrors.Errors = append(formErrors.Errors, "max records reached") + } + + dnsRecord := &database.DNSRecord{ + UserID: context.User.ID, + Name: name, + Type: recordType, + Content: recordContent, + TTL: ttlNum, + Internal: internal, + } + + if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, USER_OWNED_INTERNAL_FMT_DOMAINS) { + formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") + } + + if len(formErrors.Errors) == 0 { + if dnsRecord.Internal { + dnsRecord.ID = utils.RandomId() + } else { + dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord) + if err != nil { + log.Println(err) + formErrors.Errors = append(formErrors.Errors, err.Error()) + } + } + } + + if len(formErrors.Errors) == 0 { + _, err := database.SaveDNSRecord(context.DBConn, dnsRecord) + if err != nil { + log.Println(err) + formErrors.Errors = append(formErrors.Errors, "error saving record") + } + } + + if len(formErrors.Errors) == 0 { + http.Redirect(resp, req, "/dns", http.StatusFound) + return success(context, req, resp) + } + + (*context.TemplateData)["FormError"] = &formErrors + (*context.TemplateData)["RecordForm"] = dnsRecord + + resp.WriteHeader(http.StatusBadRequest) + return failure(context, req, resp) + } + } +} + +func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + recordId := req.FormValue("id") + record, err := database.GetDNSRecord(context.DBConn, recordId) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + if !userCanFuckWithDNSRecord(context.DBConn, context.User, record, USER_OWNED_INTERNAL_FMT_DOMAINS) { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + if !record.Internal { + err = dnsAdapter.DeleteDNSRecord(recordId) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + } + + err = database.DeleteDNSRecord(context.DBConn, recordId) + if err != nil { + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + http.Redirect(resp, req, "/dns", http.StatusFound) + return success(context, req, resp) + } + } +} diff --git a/api/dns/dns_test.go b/api/dns/dns_test.go new file mode 100644 index 0000000..cc56120 --- /dev/null +++ b/api/dns/dns_test.go @@ -0,0 +1,63 @@ +package dns_test + +import ( + "database/sql" + "net/http" + "net/http/httptest" + "os" + "testing" + + // "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { + return success(context, req, resp) + } +} + +func setup() (*sql.DB, *types.RequestContext, func()) { + randomDb := utils.RandomId() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + + context := &types.RequestContext{ + DBConn: testDb, + Args: &args.Arguments{}, + TemplateData: &(map[string]interface{}{}), + } + + return testDb, context, func() { + testDb.Close() + os.Remove(randomDb) + } +} + +func TestThatOwnerCanPutRecordInDomain(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + _ = &database.User{ + ID: "test", + Username: "test", + } + + records, err := database.GetUserDNSRecords(db, context.User.ID) + if err != nil { + t.Fatal(err) + } + if len(records) > 0 { + t.Errorf("expected no records, got records") + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // dns.CreateDNSRecordContinuation(context, r, w)(IdContinuation, IdContinuation) + })) + defer ts.Close() + +} diff --git a/api/dns_test.go b/api/dns_test.go deleted file mode 100644 index 59dd85b..0000000 --- a/api/dns_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package api_test - -import ( - "database/sql" - "net/http" - "net/http/httptest" - "os" - "testing" - - "git.hatecomputers.club/hatecomputers/hatecomputers.club/api" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" -) - -func setup() (*sql.DB, *api.RequestContext, func()) { - randomDb := utils.RandomId() - - testDb := database.MakeConn(&randomDb) - database.Migrate(testDb) - - context := &api.RequestContext{ - DBConn: testDb, - Args: &args.Arguments{}, - TemplateData: &(map[string]interface{}{}), - } - - return testDb, context, func() { - testDb.Close() - os.Remove(randomDb) - } -} - -func TestThatOwnerCanPutRecordInDomain(t *testing.T) { - db, context, cleanup := setup() - defer cleanup() - - testUser := &database.User{ - ID: "test", - Username: "test", - } - - records, err := database.GetUserDNSRecords(db, context.User.ID) - if err != nil { - t.Fatal(err) - } - if len(records) > 0 { - t.Errorf("expected no records, got records") - } - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - api.PutDNSRecordContinuation(context, r, w)(api.IdContinuation, api.IdContinuation) - })) - defer ts.Close() - -} diff --git a/api/guestbook.go b/api/guestbook.go deleted file mode 100644 index ee3c79a..0000000 --- a/api/guestbook.go +++ /dev/null @@ -1,88 +0,0 @@ -package api - -import ( - "log" - "net/http" - "strings" - - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" -) - -type HcaptchaArgs struct { - SiteKey string -} - -func validateGuestbookEntry(entry *database.GuestbookEntry) []string { - errors := []string{} - - if entry.Name == "" { - errors = append(errors, "name is required") - } - - if entry.Message == "" { - errors = append(errors, "message is required") - } - - messageLength := len(entry.Message) - if messageLength > 500 { - errors = append(errors, "message cannot be longer than 500 characters") - } - - newLines := strings.Count(entry.Message, "\n") - if newLines > 10 { - errors = append(errors, "message cannot contain more than 10 new lines") - } - - return errors -} - -func SignGuestbookContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - name := req.FormValue("name") - message := req.FormValue("message") - - formErrors := FormError{ - Errors: []string{}, - } - - entry := &database.GuestbookEntry{ - ID: utils.RandomId(), - Name: name, - Message: message, - } - formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...) - - if len(formErrors.Errors) == 0 { - _, err := database.SaveGuestbookEntry(context.DBConn, entry) - if err != nil { - log.Println(err) - formErrors.Errors = append(formErrors.Errors, "failed to save entry") - } - } - - if len(formErrors.Errors) > 0 { - (*context.TemplateData)["FormError"] = formErrors - (*context.TemplateData)["EntryForm"] = entry - resp.WriteHeader(http.StatusBadRequest) - - return failure(context, req, resp) - } - - return success(context, req, resp) - } -} - -func ListGuestbookContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - entries, err := database.GetGuestbookEntries(context.DBConn) - if err != nil { - log.Println(err) - resp.WriteHeader(http.StatusInternalServerError) - return failure(context, req, resp) - } - - (*context.TemplateData)["GuestbookEntries"] = entries - return success(context, req, resp) - } -} diff --git a/api/guestbook/guestbook.go b/api/guestbook/guestbook.go new file mode 100644 index 0000000..60a7b4b --- /dev/null +++ b/api/guestbook/guestbook.go @@ -0,0 +1,85 @@ +package guestbook + +import ( + "log" + "net/http" + "strings" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +func validateGuestbookEntry(entry *database.GuestbookEntry) []string { + errors := []string{} + + if entry.Name == "" { + errors = append(errors, "name is required") + } + + if entry.Message == "" { + errors = append(errors, "message is required") + } + + messageLength := len(entry.Message) + if messageLength > 500 { + errors = append(errors, "message cannot be longer than 500 characters") + } + + newLines := strings.Count(entry.Message, "\n") + if newLines > 10 { + errors = append(errors, "message cannot contain more than 10 new lines") + } + + return errors +} + +func SignGuestbookContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + name := req.FormValue("name") + message := req.FormValue("message") + + formErrors := types.FormError{ + Errors: []string{}, + } + + entry := &database.GuestbookEntry{ + ID: utils.RandomId(), + Name: name, + Message: message, + } + formErrors.Errors = append(formErrors.Errors, validateGuestbookEntry(entry)...) + + if len(formErrors.Errors) == 0 { + _, err := database.SaveGuestbookEntry(context.DBConn, entry) + if err != nil { + log.Println(err) + formErrors.Errors = append(formErrors.Errors, "failed to save entry") + } + } + + if len(formErrors.Errors) > 0 { + (*context.TemplateData)["FormError"] = formErrors + (*context.TemplateData)["EntryForm"] = entry + resp.WriteHeader(http.StatusBadRequest) + + return failure(context, req, resp) + } + + return success(context, req, resp) + } +} + +func ListGuestbookContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + entries, err := database.GetGuestbookEntries(context.DBConn) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + (*context.TemplateData)["GuestbookEntries"] = entries + return success(context, req, resp) + } +} diff --git a/api/guestbook/guestbook_test.go b/api/guestbook/guestbook_test.go new file mode 100644 index 0000000..9fd6c62 --- /dev/null +++ b/api/guestbook/guestbook_test.go @@ -0,0 +1,136 @@ +package guestbook_test + +import ( + "database/sql" + "net/http" + "net/http/httptest" + "os" + "testing" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/guestbook" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { + return success(context, req, resp) + } +} + +func setup() (*sql.DB, *types.RequestContext, func()) { + randomDb := utils.RandomId() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + + context := &types.RequestContext{ + DBConn: testDb, + Args: &args.Arguments{}, + TemplateData: &(map[string]interface{}{}), + } + + return testDb, context, func() { + testDb.Close() + os.Remove(randomDb) + } +} + +func TestValidGuestbookPutsInDatabase(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + entries, err := database.GetGuestbookEntries(db) + if err != nil { + t.Fatal(err) + } + if len(entries) > 0 { + t.Errorf("expected no entries, got entries") + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + guestbook.SignGuestbookContinuation(context, r, w)(IdContinuation, IdContinuation) + })) + defer ts.Close() + + req := httptest.NewRequest("POST", ts.URL, nil) + req.Form = map[string][]string{ + "name": {"test"}, + "message": {"test"}, + } + + w := httptest.NewRecorder() + ts.Config.Handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status code 200, got %d", w.Code) + } + + entries, err = database.GetGuestbookEntries(db) + if err != nil { + t.Fatal(err) + } + + if len(entries) != 1 { + t.Errorf("expected 1 entry, got %d", len(entries)) + } + + if entries[0].Name != req.FormValue("name") { + t.Errorf("expected name %s, got %s", req.FormValue("name"), entries[0].Name) + } +} + +func TestInvalidGuestbookNotFoundInDatabase(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + entries, err := database.GetGuestbookEntries(db) + if err != nil { + t.Fatal(err) + } + if len(entries) > 0 { + t.Errorf("expected no entries, got entries") + } + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + guestbook.SignGuestbookContinuation(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + reallyLongStringThatWouldTakeTooMuchSpace := "a\na\na\na\na\na\na\na\na\na\na\n" + invalidRequests := []struct { + name string + message string + }{ + {"", "test"}, + {"test", ""}, + {"", ""}, + {"test", reallyLongStringThatWouldTakeTooMuchSpace}, + } + + for _, form := range invalidRequests { + req := httptest.NewRequest("POST", testServer.URL, nil) + req.Form = map[string][]string{ + "name": {form.name}, + "message": {form.message}, + } + + responseRecorder := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(responseRecorder, req) + + if responseRecorder.Code != http.StatusBadRequest { + t.Errorf("expected status code 400, got %d", responseRecorder.Code) + } + } + + entries, err = database.GetGuestbookEntries(db) + if err != nil { + t.Fatal(err) + } + + if len(entries) != 0 { + t.Errorf("expected 0 entries, got %d", len(entries)) + } +} diff --git a/api/guestbook_test.go b/api/guestbook_test.go deleted file mode 100644 index 5c1831f..0000000 --- a/api/guestbook_test.go +++ /dev/null @@ -1,129 +0,0 @@ -package api_test - -import ( - "database/sql" - "net/http" - "net/http/httptest" - "os" - "testing" - - "git.hatecomputers.club/hatecomputers/hatecomputers.club/api" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" -) - -func setup() (*sql.DB, *api.RequestContext, func()) { - randomDb := utils.RandomId() - - testDb := database.MakeConn(&randomDb) - database.Migrate(testDb) - - context := &api.RequestContext{ - DBConn: testDb, - Args: &args.Arguments{}, - TemplateData: &(map[string]interface{}{}), - } - - return testDb, context, func() { - testDb.Close() - os.Remove(randomDb) - } -} - -func TestValidGuestbookPutsInDatabase(t *testing.T) { - db, context, cleanup := setup() - defer cleanup() - - entries, err := database.GetGuestbookEntries(db) - if err != nil { - t.Fatal(err) - } - if len(entries) > 0 { - t.Errorf("expected no entries, got entries") - } - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - api.SignGuestbookContinuation(context, r, w)(api.IdContinuation, api.IdContinuation) - })) - defer ts.Close() - - req := httptest.NewRequest("POST", ts.URL, nil) - req.Form = map[string][]string{ - "name": {"test"}, - "message": {"test"}, - } - - w := httptest.NewRecorder() - ts.Config.Handler.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Errorf("expected status code 200, got %d", w.Code) - } - - entries, err = database.GetGuestbookEntries(db) - if err != nil { - t.Fatal(err) - } - - if len(entries) != 1 { - t.Errorf("expected 1 entry, got %d", len(entries)) - } - - if entries[0].Name != req.FormValue("name") { - t.Errorf("expected name %s, got %s", req.FormValue("name"), entries[0].Name) - } -} - -func TestInvalidGuestbookNotFoundInDatabase(t *testing.T) { - db, context, cleanup := setup() - defer cleanup() - - entries, err := database.GetGuestbookEntries(db) - if err != nil { - t.Fatal(err) - } - if len(entries) > 0 { - t.Errorf("expected no entries, got entries") - } - - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - api.SignGuestbookContinuation(context, r, w)(api.IdContinuation, api.IdContinuation) - })) - defer testServer.Close() - - reallyLongStringThatWouldTakeTooMuchSpace := "a\na\na\na\na\na\na\na\na\na\na\n" - invalidRequests := []struct { - name string - message string - }{ - {"", "test"}, - {"test", ""}, - {"", ""}, - {"test", reallyLongStringThatWouldTakeTooMuchSpace}, - } - - for _, form := range invalidRequests { - req := httptest.NewRequest("POST", testServer.URL, nil) - req.Form = map[string][]string{ - "name": {form.name}, - "message": {form.message}, - } - - responseRecorder := httptest.NewRecorder() - testServer.Config.Handler.ServeHTTP(responseRecorder, req) - - if responseRecorder.Code != http.StatusBadRequest { - t.Errorf("expected status code 400, got %d", responseRecorder.Code) - } - } - - entries, err = database.GetGuestbookEntries(db) - if err != nil { - t.Fatal(err) - } - - if len(entries) != 0 { - t.Errorf("expected 0 entries, got %d", len(entries)) - } -} diff --git a/api/hcaptcha.go b/api/hcaptcha.go deleted file mode 100644 index a310c01..0000000 --- a/api/hcaptcha.go +++ /dev/null @@ -1,69 +0,0 @@ -package api - -import ( - "encoding/json" - "fmt" - "net/http" - "strings" -) - -func verifyCaptcha(secret, response string) error { - verifyURL := "https://hcaptcha.com/siteverify" - body := strings.NewReader("secret=" + secret + "&response=" + response) - - req, err := http.NewRequest("POST", verifyURL, body) - if err != nil { - return err - } - - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return err - } - - jsonResponse := struct { - Success bool `json:"success"` - }{} - err = json.NewDecoder(resp.Body).Decode(&jsonResponse) - if err != nil { - return err - } - - if !jsonResponse.Success { - return fmt.Errorf("hcaptcha verification failed") - } - - defer resp.Body.Close() - return nil -} - -func CaptchaArgsContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - (*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{ - SiteKey: context.Args.HcaptchaSiteKey, - } - return success(context, req, resp) - } -} - -func CaptchaVerificationContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - hCaptchaResponse := req.FormValue("h-captcha-response") - secretKey := context.Args.HcaptchaSecret - - err := verifyCaptcha(secretKey, hCaptchaResponse) - if err != nil { - (*context.TemplateData)["FormError"] = FormError{ - Errors: []string{"hCaptcha verification failed"}, - } - resp.WriteHeader(http.StatusBadRequest) - - return failure(context, req, resp) - } - - return success(context, req, resp) - } -} diff --git a/api/hcaptcha/hcaptcha.go b/api/hcaptcha/hcaptcha.go new file mode 100644 index 0000000..007190d --- /dev/null +++ b/api/hcaptcha/hcaptcha.go @@ -0,0 +1,75 @@ +package hcaptcha + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" +) + +type HcaptchaArgs struct { + SiteKey string +} + +func verifyCaptcha(secret, response string) error { + verifyURL := "https://hcaptcha.com/siteverify" + body := strings.NewReader("secret=" + secret + "&response=" + response) + + req, err := http.NewRequest("POST", verifyURL, body) + if err != nil { + return err + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err + } + + jsonResponse := struct { + Success bool `json:"success"` + }{} + err = json.NewDecoder(resp.Body).Decode(&jsonResponse) + if err != nil { + return err + } + + if !jsonResponse.Success { + return fmt.Errorf("hcaptcha verification failed") + } + + defer resp.Body.Close() + return nil +} + +func CaptchaArgsContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + (*context.TemplateData)["HcaptchaArgs"] = HcaptchaArgs{ + SiteKey: context.Args.HcaptchaSiteKey, + } + return success(context, req, resp) + } +} + +func CaptchaVerificationContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + hCaptchaResponse := req.FormValue("h-captcha-response") + secretKey := context.Args.HcaptchaSecret + + err := verifyCaptcha(secretKey, hCaptchaResponse) + if err != nil { + (*context.TemplateData)["FormError"] = types.FormError{ + Errors: []string{"hCaptcha verification failed"}, + } + resp.WriteHeader(http.StatusBadRequest) + + return failure(context, req, resp) + } + + return success(context, req, resp) + } +} diff --git a/api/keys/keys.go b/api/keys/keys.go new file mode 100644 index 0000000..ad380fc --- /dev/null +++ b/api/keys/keys.go @@ -0,0 +1,88 @@ +package keys + +import ( + "log" + "net/http" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" +) + +const MAX_USER_API_KEYS = 5 + +func ListAPIKeysContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + typesKeys, err := database.ListUserAPIKeys(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + (*context.TemplateData)["APIKeys"] = typesKeys + return success(context, req, resp) + } +} + +func CreateAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + formErrors := types.FormError{ + Errors: []string{}, + } + + numKeys, err := database.CountUserAPIKeys(context.DBConn, context.User.ID) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + if numKeys >= MAX_USER_API_KEYS { + formErrors.Errors = append(formErrors.Errors, "max types keys reached") + } + + if len(formErrors.Errors) > 0 { + (*context.TemplateData)["FormError"] = formErrors + return failure(context, req, resp) + } + + _, err = database.SaveAPIKey(context.DBConn, &database.UserApiKey{ + UserID: context.User.ID, + Key: utils.RandomId(), + }) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + return success(context, req, resp) + } +} + +func DeleteAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + key := req.FormValue("key") + + typesKey, err := database.GetAPIKey(context.DBConn, key) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + if (typesKey == nil) || (typesKey.UserID != context.User.ID) { + resp.WriteHeader(http.StatusUnauthorized) + return failure(context, req, resp) + } + + err = database.DeleteAPIKey(context.DBConn, key) + if err != nil { + log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) + return failure(context, req, resp) + } + + http.Redirect(resp, req, "/keys", http.StatusFound) + return success(context, req, resp) + } +} diff --git a/api/serve.go b/api/serve.go index 1536f65..6d8c59c 100644 --- a/api/serve.go +++ b/api/serve.go @@ -8,31 +8,19 @@ import ( "time" "git.hatecomputers.club/hatecomputers/hatecomputers.club/adapters/cloudflare" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/guestbook" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/hcaptcha" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/keys" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/template" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" - "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) -type RequestContext struct { - DBConn *sql.DB - Args *args.Arguments - - Id string - Start time.Time - - TemplateData *map[string]interface{} - User *database.User -} - -type FormError struct { - Errors []string -} - -type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain -type ContinuationChain func(Continuation, Continuation) ContinuationChain - -func LogRequestContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, _failure Continuation) ContinuationChain { +func LogRequestContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { context.Start = time.Now() context.Id = utils.RandomId() @@ -41,8 +29,8 @@ func LogRequestContinuation(context *RequestContext, req *http.Request, resp htt } } -func LogExecutionTimeContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, _failure Continuation) ContinuationChain { +func LogExecutionTimeContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { end := time.Now() log.Println(context.Id, "took", end.Sub(context.Start)) @@ -51,22 +39,22 @@ func LogExecutionTimeContinuation(context *RequestContext, req *http.Request, re } } -func HealthCheckContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, _failure Continuation) ContinuationChain { +func HealthCheckContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { resp.WriteHeader(200) resp.Write([]byte("healthy")) return success(context, req, resp) } } -func FailurePassingContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(_success Continuation, failure Continuation) ContinuationChain { +func FailurePassingContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(_success types.Continuation, failure types.Continuation) types.ContinuationChain { return failure(context, req, resp) } } -func IdContinuation(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, _failure Continuation) ContinuationChain { +func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { return success(context, req, resp) } } @@ -90,8 +78,8 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { ZoneId: argv.CloudflareZone, } - makeRequestContext := func() *RequestContext { - return &RequestContext{ + makeRequestContext := func() *types.RequestContext { + return &types.RequestContext{ DBConn: dbConn, Args: argv, TemplateData: &map[string]interface{}{}, @@ -100,7 +88,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { mux.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(template.TemplateContinuation("home.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) { @@ -110,63 +98,63 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { mux.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.StartSessionContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /auth", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(InterceptOauthCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.InterceptOauthCodeContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /logout", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.LogoutContinuation, FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /dns", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListDNSRecordsContinuation, GoLoginContinuation)(CreateDNSRecordContinuation(cloudflareAdapter), FailurePassingContinuation)(TemplateContinuation("dns.html", true), TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(dns.CreateDNSRecordContinuation(cloudflareAdapter), FailurePassingContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteDNSRecordContinuation(cloudflareAdapter), GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.DeleteDNSRecordContinuation(cloudflareAdapter), auth.GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(ListAPIKeysContinuation, GoLoginContinuation)(TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.ListAPIKeysContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("api_keys.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("POST /keys", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CreateAPIKeyContinuation, GoLoginContinuation)(ListAPIKeysContinuation, ListAPIKeysContinuation)(TemplateContinuation("api_keys.html", true), TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.CreateAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(DeleteAPIKeyContinuation, GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CaptchaArgsContinuation, CaptchaArgsContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(hcaptcha.CaptchaArgsContinuation, hcaptcha.CaptchaArgsContinuation)(guestbook.ListGuestbookContinuation, guestbook.ListGuestbookContinuation)(template.TemplateContinuation("guestbook.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("POST /guestbook", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(CaptchaVerificationContinuation, CaptchaVerificationContinuation)(SignGuestbookContinuation, FailurePassingContinuation)(ListGuestbookContinuation, ListGuestbookContinuation)(CaptchaArgsContinuation, CaptchaArgsContinuation)(TemplateContinuation("guestbook.html", true), TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(hcaptcha.CaptchaVerificationContinuation, hcaptcha.CaptchaVerificationContinuation)(guestbook.SignGuestbookContinuation, FailurePassingContinuation)(guestbook.ListGuestbookContinuation, guestbook.ListGuestbookContinuation)(hcaptcha.CaptchaArgsContinuation, hcaptcha.CaptchaArgsContinuation)(template.TemplateContinuation("guestbook.html", true), template.TemplateContinuation("guestbook.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /{name}", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() name := r.PathValue("name") - LogRequestContinuation(requestContext, r, w)(VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(TemplateContinuation(name+".html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(IdContinuation, IdContinuation)(template.TemplateContinuation(name+".html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) return &http.Server{ diff --git a/api/template.go b/api/template.go deleted file mode 100644 index d637c64..0000000 --- a/api/template.go +++ /dev/null @@ -1,74 +0,0 @@ -package api - -import ( - "bytes" - "errors" - "html/template" - "log" - "net/http" - "os" -) - -func renderTemplate(context *RequestContext, templateName string, showBaseHtml bool) (bytes.Buffer, error) { - templatePath := context.Args.TemplatePath - basePath := templatePath + "/base_empty.html" - if showBaseHtml { - basePath = templatePath + "/base.html" - } - - templateLocation := templatePath + "/" + templateName - tmpl, err := template.New("").ParseFiles(templateLocation, basePath) - if err != nil { - return bytes.Buffer{}, err - } - - dataPtr := context.TemplateData - if dataPtr == nil { - dataPtr = &map[string]interface{}{} - } - - data := *dataPtr - if data["User"] == nil { - data["User"] = context.User - } - - var buffer bytes.Buffer - err = tmpl.ExecuteTemplate(&buffer, "base", data) - - if err != nil { - return bytes.Buffer{}, err - } - return buffer, nil -} - -func TemplateContinuation(path string, showBase bool) Continuation { - return func(context *RequestContext, req *http.Request, resp http.ResponseWriter) ContinuationChain { - return func(success Continuation, failure Continuation) ContinuationChain { - html, err := renderTemplate(context, path, true) - if errors.Is(err, os.ErrNotExist) { - resp.WriteHeader(404) - html, err = renderTemplate(context, "404.html", true) - if err != nil { - log.Println("error rendering 404 template", err) - resp.WriteHeader(500) - return failure(context, req, resp) - } - - resp.Header().Set("Content-Type", "text/html") - resp.Write(html.Bytes()) - return failure(context, req, resp) - } - - if err != nil { - log.Println("error rendering template", err) - resp.WriteHeader(500) - resp.Write([]byte("error rendering template")) - return failure(context, req, resp) - } - - resp.Header().Set("Content-Type", "text/html") - resp.Write(html.Bytes()) - return success(context, req, resp) - } - } -} diff --git a/api/template/template.go b/api/template/template.go new file mode 100644 index 0000000..2875649 --- /dev/null +++ b/api/template/template.go @@ -0,0 +1,76 @@ +package template + +import ( + "bytes" + "errors" + "html/template" + "log" + "net/http" + "os" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" +) + +func renderTemplate(context *types.RequestContext, templateName string, showBaseHtml bool) (bytes.Buffer, error) { + templatePath := context.Args.TemplatePath + basePath := templatePath + "/base_empty.html" + if showBaseHtml { + basePath = templatePath + "/base.html" + } + + templateLocation := templatePath + "/" + templateName + tmpl, err := template.New("").ParseFiles(templateLocation, basePath) + if err != nil { + return bytes.Buffer{}, err + } + + dataPtr := context.TemplateData + if dataPtr == nil { + dataPtr = &map[string]interface{}{} + } + + data := *dataPtr + if data["User"] == nil { + data["User"] = context.User + } + + var buffer bytes.Buffer + err = tmpl.ExecuteTemplate(&buffer, "base", data) + + if err != nil { + return bytes.Buffer{}, err + } + return buffer, nil +} + +func TemplateContinuation(path string, showBase bool) types.Continuation { + return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + html, err := renderTemplate(context, path, true) + if errors.Is(err, os.ErrNotExist) { + resp.WriteHeader(404) + html, err = renderTemplate(context, "404.html", true) + if err != nil { + log.Println("error rendering 404 template", err) + resp.WriteHeader(500) + return failure(context, req, resp) + } + + resp.Header().Set("Content-Type", "text/html") + resp.Write(html.Bytes()) + return failure(context, req, resp) + } + + if err != nil { + log.Println("error rendering template", err) + resp.WriteHeader(500) + resp.Write([]byte("error rendering template")) + return failure(context, req, resp) + } + + resp.Header().Set("Content-Type", "text/html") + resp.Write(html.Bytes()) + return success(context, req, resp) + } + } +} diff --git a/api/types/types.go b/api/types/types.go new file mode 100644 index 0000000..bbc25ea --- /dev/null +++ b/api/types/types.go @@ -0,0 +1,28 @@ +package types + +import ( + "database/sql" + "net/http" + "time" + + "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" +) + +type RequestContext struct { + DBConn *sql.DB + Args *args.Arguments + + Id string + Start time.Time + + TemplateData *map[string]interface{} + User *database.User +} + +type FormError struct { + Errors []string +} + +type Continuation func(*RequestContext, *http.Request, http.ResponseWriter) ContinuationChain +type ContinuationChain func(Continuation, Continuation) ContinuationChain -- cgit v1.2.3-70-g09d2 From d9d39a01f24922b6de6ad65ceebcb3da501d2790 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Thu, 4 Apr 2024 15:08:50 -0600 Subject: dns api tests --- api/dns/dns.go | 22 ++- api/dns/dns_test.go | 393 +++++++++++++++++++++++++++++++++++++++++++++++++++- api/serve.go | 6 +- database/dns.go | 23 ++- 4 files changed, 421 insertions(+), 23 deletions(-) diff --git a/api/dns/dns.go b/api/dns/dns.go index 4805146..aa2f356 100644 --- a/api/dns/dns.go +++ b/api/dns/dns.go @@ -14,10 +14,6 @@ import ( "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) -const MAX_USER_RECORDS = 65 - -var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} - func userCanFuckWithDNSRecord(dbConn *sql.DB, user *database.User, record *database.DNSRecord, ownedInternalDomainFormats []string) bool { ownedByUser := (user.ID == record.UserID) if !ownedByUser { @@ -60,14 +56,14 @@ func ListDNSRecordsContinuation(context *types.RequestContext, req *http.Request } } -func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { +func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter, maxUserRecords int, allowedUserDomainFormats []string) func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { formErrors := types.FormError{ Errors: []string{}, } - internal := req.FormValue("internal") == "on" + internal := req.FormValue("internal") == "on" || req.FormValue("internal") == "true" name := req.FormValue("name") if internal && !strings.HasSuffix(name, ".") { name += "." @@ -80,6 +76,7 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun ttl := req.FormValue("ttl") ttlNum, err := strconv.Atoi(ttl) if err != nil { + resp.WriteHeader(http.StatusBadRequest) formErrors.Errors = append(formErrors.Errors, "invalid ttl") } @@ -89,7 +86,8 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } - if dnsRecordCount >= MAX_USER_RECORDS { + if dnsRecordCount >= maxUserRecords { + resp.WriteHeader(http.StatusTooManyRequests) formErrors.Errors = append(formErrors.Errors, "max records reached") } @@ -102,7 +100,8 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun Internal: internal, } - if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, USER_OWNED_INTERNAL_FMT_DOMAINS) { + if !userCanFuckWithDNSRecord(context.DBConn, context.User, dnsRecord, allowedUserDomainFormats) { + resp.WriteHeader(http.StatusUnauthorized) formErrors.Errors = append(formErrors.Errors, "'name' must end with "+context.User.Username+" or you must be a domain owner for internal domains") } @@ -113,6 +112,7 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun dnsRecord.ID, err = dnsAdapter.CreateDNSRecord(dnsRecord) if err != nil { log.Println(err) + resp.WriteHeader(http.StatusInternalServerError) formErrors.Errors = append(formErrors.Errors, err.Error()) } } @@ -127,14 +127,11 @@ func CreateDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun } if len(formErrors.Errors) == 0 { - http.Redirect(resp, req, "/dns", http.StatusFound) return success(context, req, resp) } (*context.TemplateData)["FormError"] = &formErrors (*context.TemplateData)["RecordForm"] = dnsRecord - - resp.WriteHeader(http.StatusBadRequest) return failure(context, req, resp) } } @@ -151,7 +148,7 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun return failure(context, req, resp) } - if !userCanFuckWithDNSRecord(context.DBConn, context.User, record, USER_OWNED_INTERNAL_FMT_DOMAINS) { + if !(record.UserID == context.User.ID) { resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } @@ -171,7 +168,6 @@ func DeleteDNSRecordContinuation(dnsAdapter external_dns.ExternalDNSAdapter) fun return failure(context, req, resp) } - http.Redirect(resp, req, "/dns", http.StatusFound) return success(context, req, resp) } } diff --git a/api/dns/dns_test.go b/api/dns/dns_test.go index cc56120..43dc680 100644 --- a/api/dns/dns_test.go +++ b/api/dns/dns_test.go @@ -2,18 +2,25 @@ package dns_test import ( "database/sql" + "fmt" "net/http" "net/http/httptest" "os" + "strconv" "testing" + "time" - // "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/dns" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) +const MAX_USER_RECORDS = 10 + +var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} + func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { return success(context, req, resp) @@ -26,10 +33,19 @@ func setup() (*sql.DB, *types.RequestContext, func()) { testDb := database.MakeConn(&randomDb) database.Migrate(testDb) + user := &database.User{ + ID: "test", + Username: "test", + Mail: "test@test.com", + DisplayName: "test", + } + database.FindOrSaveUser(testDb, user) + context := &types.RequestContext{ DBConn: testDb, Args: &args.Arguments{}, TemplateData: &(map[string]interface{}{}), + User: user, } return testDb, context, func() { @@ -38,14 +54,33 @@ func setup() (*sql.DB, *types.RequestContext, func()) { } } +type SignallingExternalDnsAdapter struct { + AddChannel chan *database.DNSRecord + RmChannel chan string +} + +func (adapter *SignallingExternalDnsAdapter) CreateDNSRecord(record *database.DNSRecord) (string, error) { + id := utils.RandomId() + go func() { adapter.AddChannel <- record }() + + return id, nil +} + +func (adapter *SignallingExternalDnsAdapter) DeleteDNSRecord(id string) error { + go func() { adapter.RmChannel <- id }() + + return nil +} + func TestThatOwnerCanPutRecordInDomain(t *testing.T) { db, context, cleanup := setup() defer cleanup() - _ = &database.User{ - ID: "test", - Username: "test", + domainOwner := &database.DomainOwner{ + UserID: context.User.ID, + Domain: "test.domain.", } + domainOwner, _ = database.SaveDomainOwner(db, domainOwner) records, err := database.GetUserDNSRecords(db, context.User.ID) if err != nil { @@ -55,9 +90,353 @@ func TestThatOwnerCanPutRecordInDomain(t *testing.T) { t.Errorf("expected no records, got records") } - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // dns.CreateDNSRecordContinuation(context, r, w)(IdContinuation, IdContinuation) + addChannel := make(chan *database.DNSRecord) + signallingDnsAdapter := &SignallingExternalDnsAdapter{ + AddChannel: addChannel, + } + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + validOwner := httptest.NewRequest("POST", testServer.URL, nil) + validOwner.Form = map[string][]string{ + "internal": {"on"}, + "name": {"new.test.domain."}, + "type": {"CNAME"}, + "ttl": {"43000"}, + "content": {"test.domain."}, + } + + validOwnerRecorder := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner) + if validOwnerRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", validOwnerRecorder.Code) + } + + validOwnerNonInternalRecorder := httptest.NewRecorder() + validOwner.Form["internal"] = []string{"off"} + testServer.Config.Handler.ServeHTTP(validOwnerNonInternalRecorder, validOwner) + if validOwnerNonInternalRecorder.Code != http.StatusUnauthorized { + t.Errorf("expected invalid return, got %d", validOwnerNonInternalRecorder.Code) + } + + invalidOwnerRecorder := httptest.NewRecorder() + invalidOwner := validOwner + invalidOwner.Form["internal"] = []string{"on"} + invalidOwner.Form["name"] = []string{"new.invalid.domain."} + testServer.Config.Handler.ServeHTTP(invalidOwnerRecorder, invalidOwner) + if invalidOwnerRecorder.Code != http.StatusUnauthorized { + t.Errorf("expected invalid return, got %d", invalidOwnerRecorder.Code) + } +} + +func TestThatUserCanAddToPublicEndpoints(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + addChannel := make(chan *database.DNSRecord) + signallingDnsAdapter := &SignallingExternalDnsAdapter{ + AddChannel: addChannel, + } + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + responseRecorder := httptest.NewRecorder() + req := httptest.NewRequest("POST", testServer.URL, nil) + fmts := USER_OWNED_INTERNAL_FMT_DOMAINS + for _, format := range fmts { + name := fmt.Sprintf(format, context.User.Username) + + req.Form = map[string][]string{ + "internal": {"off"}, + "name": {name}, + "type": {"CNAME"}, + "ttl": {"43000"}, + "content": {"test.domain."}, + } + + testServer.Config.Handler.ServeHTTP(responseRecorder, req) + if responseRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", responseRecorder.Code) + } + + namedRecords, _ := database.FindDNSRecords(db, name, "CNAME") + if len(namedRecords) == 0 { + t.Errorf("saved record not found") + } + } +} + +func TestThatExternalDnsSaves(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + addChannel := make(chan *database.DNSRecord) + signallingDnsAdapter := &SignallingExternalDnsAdapter{ + AddChannel: addChannel, + } + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + responseRecorder := httptest.NewRecorder() + externalRequest := httptest.NewRequest("POST", testServer.URL, nil) + + name := "test." + context.User.Username + externalRequest.Form = map[string][]string{ + "internal": {"off"}, + "name": {name}, + "type": {"CNAME"}, + "ttl": {"43000"}, + "content": {"test.domain."}, + } + + testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest) + if responseRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", responseRecorder.Code) + } + select { + case res := <-addChannel: + if res.Name != name || res.Type != "CNAME" || res.Content != "test.domain." { + t.Errorf("received the wrong external record") + } + case <-time.After(100 * time.Millisecond): + t.Errorf("timed out in waiting for external addition") + } + + domainOwner := &database.DomainOwner{ + UserID: context.User.ID, + Domain: "test.domain.", + } + domainOwner, _ = database.SaveDomainOwner(db, domainOwner) + internalRequest := externalRequest + internalRequest.Form["internal"] = []string{"on"} + internalRequest.Form["name"] = []string{"test.domain."} + + testServer.Config.Handler.ServeHTTP(responseRecorder, externalRequest) + if responseRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", responseRecorder.Code) + } + select { + case _ = <-addChannel: + t.Errorf("expected nothing in the add channel") + case <-time.After(100 * time.Millisecond): + } +} + +func TestThatUserMustOwnRecordToRemove(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + rmChannel := make(chan string) + signallingDnsAdapter := &SignallingExternalDnsAdapter{ + RmChannel: rmChannel, + } + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation) })) - defer ts.Close() + defer testServer.Close() + nonOwnerUser := &database.User{ID: "n/a", Username: "testuser"} + _, err := database.FindOrSaveUser(db, nonOwnerUser) + if err != nil { + t.Error(err) + } + + record := &database.DNSRecord{ + ID: "1", + Internal: false, + Name: "test", + Type: "CNAME", + Content: "asdf", + TTL: 1000, + UserID: nonOwnerUser.ID, + } + _, err = database.SaveDNSRecord(db, record) + if err != nil { + t.Error(err) + } + + nonOwnerRecorder := httptest.NewRecorder() + nonOwner := httptest.NewRequest("POST", testServer.URL, nil) + nonOwner.Form = map[string][]string{ + "id": {record.ID}, + } + + testServer.Config.Handler.ServeHTTP(nonOwnerRecorder, nonOwner) + if nonOwnerRecorder.Code != http.StatusUnauthorized { + t.Errorf("expected unauthorized return, got %d", nonOwnerRecorder.Code) + } + + record.UserID = context.User.ID + record.ID = "2" + database.SaveDNSRecord(db, record) + + owner := nonOwner + owner.Form["id"] = []string{"2"} + ownerRecorder := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(ownerRecorder, owner) + if ownerRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", ownerRecorder.Code) + } +} + +func TestThatExternalDnsRemoves(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + record := &database.DNSRecord{ + ID: "1", + Internal: false, + Name: "test", + Type: "CNAME", + Content: "asdf", + TTL: 1000, + UserID: context.User.ID, + } + database.SaveDNSRecord(db, record) + + rmChannel := make(chan string) + signallingDnsAdapter := &SignallingExternalDnsAdapter{ + RmChannel: rmChannel, + } + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dns.DeleteDNSRecordContinuation(signallingDnsAdapter)(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + externalResponseRecorder := httptest.NewRecorder() + deleteRequest := httptest.NewRequest("POST", testServer.URL, nil) + + deleteRequest.Form = map[string][]string{ + "id": {record.ID}, + } + + testServer.Config.Handler.ServeHTTP(externalResponseRecorder, deleteRequest) + if externalResponseRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", externalResponseRecorder.Code) + } + select { + case res := <-rmChannel: + if res != record.ID { + t.Errorf("received the wrong external record") + } + case <-time.After(100 * time.Millisecond): + t.Errorf("timed out in waiting for external addition") + } + + record.Internal = true + record.Name = "test.domain." + database.SaveDNSRecord(db, record) + domainOwner := &database.DomainOwner{ + UserID: context.User.ID, + Domain: "test.domain.", + } + database.SaveDomainOwner(db, domainOwner) + + internalResponseRecorder := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(internalResponseRecorder, deleteRequest) + if internalResponseRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", internalResponseRecorder.Code) + } + select { + case _ = <-rmChannel: + t.Errorf("expected nothing in the rmchannel") + case <-time.After(100 * time.Millisecond): + } +} + +func TestRecordCountCannotExceed(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + record := &database.DNSRecord{ + Internal: false, + Name: context.User.Username, + Type: "CNAME", + Content: "asdf", + TTL: 1000, + UserID: context.User.ID, + } + + for i := 1; i <= MAX_USER_RECORDS; i++ { + record.ID = strconv.Itoa(i) + record.Name = record.ID + "." + record.Name + database.SaveDNSRecord(db, record) + } + + addChannel := make(chan *database.DNSRecord) + signallingDnsAdapter := &SignallingExternalDnsAdapter{ + AddChannel: addChannel, + } + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + req := httptest.NewRequest("POST", testServer.URL, nil) + req.Form = map[string][]string{ + "internal": {"off"}, + "name": {record.Name}, + "type": {record.Type}, + "ttl": {"43000"}, + "content": {record.Content}, + } + + recorder := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(recorder, req) + if recorder.Code != http.StatusTooManyRequests { + t.Errorf("expected too many requests code return, got %d", recorder.Code) + } +} + +func TestInternalRecordAppendsTopLevelDot(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + domainOwner := &database.DomainOwner{ + UserID: context.User.ID, + Domain: "test.internal.", + } + database.SaveDomainOwner(db, domainOwner) + + addChannel := make(chan *database.DNSRecord) + signallingDnsAdapter := &SignallingExternalDnsAdapter{ + AddChannel: addChannel, + } + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dns.CreateDNSRecordContinuation(signallingDnsAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS)(context, r, w)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + validOwner := httptest.NewRequest("POST", testServer.URL, nil) + validOwner.Form = map[string][]string{ + "internal": {"on"}, + "name": {"test.internal"}, + "type": {"CNAME"}, + "ttl": {"43000"}, + "content": {"asdf.internal"}, + } + + validOwnerRecorder := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(validOwnerRecorder, validOwner) + if validOwnerRecorder.Code != http.StatusOK { + t.Errorf("expected valid return, got %d", validOwnerRecorder.Code) + } + + recordsAppendedDot, _ := database.FindDNSRecords(db, "test.internal.", "CNAME") + recordsWithoutDot, _ := database.FindDNSRecords(db, "test.internal", "CNAME") + + if len(recordsAppendedDot) != 1 && len(recordsWithoutDot) != 0 { + t.Errorf("expected dot appended") + } } diff --git a/api/serve.go b/api/serve.go index 6d8c59c..2b0eba4 100644 --- a/api/serve.go +++ b/api/serve.go @@ -116,14 +116,16 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(template.TemplateContinuation("dns.html", true), FailurePassingContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) + const MAX_USER_RECORDS = 100 + var USER_OWNED_INTERNAL_FMT_DOMAINS = []string{"%s", "%s.endpoints"} mux.HandleFunc("POST /dns", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(dns.CreateDNSRecordContinuation(cloudflareAdapter), FailurePassingContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.ListDNSRecordsContinuation, auth.GoLoginContinuation)(dns.CreateDNSRecordContinuation(cloudflareAdapter, MAX_USER_RECORDS, USER_OWNED_INTERNAL_FMT_DOMAINS), FailurePassingContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("POST /dns/delete", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.DeleteDNSRecordContinuation(cloudflareAdapter), auth.GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(dns.DeleteDNSRecordContinuation(cloudflareAdapter), auth.GoLoginContinuation)(dns.ListDNSRecordsContinuation, dns.ListDNSRecordsContinuation)(template.TemplateContinuation("dns.html", true), template.TemplateContinuation("dns.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /keys", func(w http.ResponseWriter, r *http.Request) { diff --git a/database/dns.go b/database/dns.go index fc01347..7851ab4 100644 --- a/database/dns.go +++ b/database/dns.go @@ -9,6 +9,12 @@ import ( "time" ) +type DomainOwner struct { + UserID string `json:"user_id"` + Domain string `json:"domain"` + CreatedAt time.Time `json:"created_at"` +} + type DNSRecord struct { ID string `json:"id"` UserID string `json:"user_id"` @@ -57,7 +63,10 @@ func GetUserDNSRecords(db *sql.DB, userID string) ([]DNSRecord, error) { func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) { log.Println("saving dns record", record.ID) - record.CreatedAt = time.Now() + if (record.CreatedAt == time.Time{}) { + record.CreatedAt = time.Now() + } + _, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, internal, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.Internal, record.CreatedAt) if err != nil { @@ -137,3 +146,15 @@ func FindDNSRecords(dbConn *sql.DB, name string, qtype string) ([]DNSRecord, err return records, nil } + +func SaveDomainOwner(db *sql.DB, domainOwner *DomainOwner) (*DomainOwner, error) { + log.Println("saving domain owner", domainOwner.Domain) + + domainOwner.CreatedAt = time.Now() + _, err := db.Exec("INSERT OR REPLACE INTO domain_owners (user_id, domain, created_at) VALUES (?, ?, ?)", domainOwner.UserID, domainOwner.Domain, domainOwner.CreatedAt) + + if err != nil { + return nil, err + } + return domainOwner, nil +} -- cgit v1.2.3-70-g09d2 From 94984aa4b01e96773b71325b5b27e6f64d9bd102 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Thu, 4 Apr 2024 16:03:34 -0600 Subject: auth test scaffolding --- api/auth/auth.go | 115 +++++++++++++++++++++++++++----------------------- api/auth/auth_test.go | 74 +++++++++++++++++++++++++++++--- 2 files changed, 131 insertions(+), 58 deletions(-) diff --git a/api/auth/auth.go b/api/auth/auth.go index dc348b2..3c633cd 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -35,7 +35,7 @@ func StartSessionContinuation(context *types.RequestContext, req *http.Request, Path: "/", Secure: true, SameSite: http.SameSiteLaxMode, - MaxAge: 60, + MaxAge: 200, }) http.SetCookie(resp, &http.Cookie{ Name: "state", @@ -43,7 +43,7 @@ func StartSessionContinuation(context *types.RequestContext, req *http.Request, Path: "/", Secure: true, SameSite: http.SameSiteLaxMode, - MaxAge: 60, + MaxAge: 200, }) http.Redirect(resp, req, url, http.StatusFound) @@ -102,6 +102,16 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req SameSite: http.SameSiteLaxMode, Secure: true, }) + http.SetCookie(resp, &http.Cookie{ + Name: "verifier", + Value: "", + MaxAge: 0, + }) + http.SetCookie(resp, &http.Cookie{ + Name: "state", + Value: "", + MaxAge: 0, + }) redirect := "/" redirectCookie, err := req.Cookie("redirect") @@ -110,6 +120,7 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req http.SetCookie(resp, &http.Cookie{ Name: "redirect", MaxAge: 0, + Value: "", }) } @@ -118,52 +129,6 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req } } -func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { - if bearerToken == "" { - return nil, nil - } - - parts := strings.Split(bearerToken, " ") - if len(parts) != 2 || parts[0] != "Bearer" { - return nil, nil - } - - typesKey, err := database.GetAPIKey(dbConn, parts[1]) - if err != nil { - return nil, err - } - if typesKey == nil { - return nil, nil - } - - user, err := database.GetUser(dbConn, typesKey.UserID) - if err != nil { - return nil, err - } - - return user, nil -} - -func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { - session, err := database.GetSession(dbConn, sessionId) - if err != nil { - return nil, err - } - - if session.ExpireAt.Before(time.Now()) { - session = nil - database.DeleteSession(dbConn, sessionId) - return nil, fmt.Errorf("session expired") - } - - user, err := database.GetUser(dbConn, session.UserID) - if err != nil { - return nil, err - } - - return user, nil -} - func VerifySessionContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { authHeader := req.Header.Get("Authorization") @@ -179,6 +144,7 @@ func VerifySessionContinuation(context *types.RequestContext, req *http.Request, http.SetCookie(resp, &http.Cookie{ Name: "session", + Value: "", MaxAge: 0, // reset session cookie in case }) @@ -210,13 +176,11 @@ func RefreshSessionContinuation(context *types.RequestContext, req *http.Request return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { sessionCookie, err := req.Cookie("session") if err != nil { - resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } _, err = database.RefreshSession(context.DBConn, sessionCookie.Value) if err != nil { - resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } @@ -235,6 +199,7 @@ func LogoutContinuation(context *types.RequestContext, req *http.Request, resp h http.SetCookie(resp, &http.Cookie{ Name: "session", MaxAge: 0, + Value: "", }) return success(context, req, resp) } @@ -246,7 +211,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us return nil, err } - userStruct, err := createUserFromResponse(userResponse) + userStruct, err := createUserFromOauthResponse(userResponse) if err != nil { return nil, err } @@ -259,7 +224,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us return user, nil } -func createUserFromResponse(response *http.Response) (*database.User, error) { +func createUserFromOauthResponse(response *http.Response) (*database.User, error) { user := &database.User{ CreatedAt: time.Now(), } @@ -286,3 +251,49 @@ func verifyState(req *http.Request, stateCookieName string, expectedState string return true } + +func getUserFromAuthHeader(dbConn *sql.DB, bearerToken string) (*database.User, error) { + if bearerToken == "" { + return nil, nil + } + + parts := strings.Split(bearerToken, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + return nil, nil + } + + key, err := database.GetAPIKey(dbConn, parts[1]) + if err != nil { + return nil, err + } + if key == nil { + return nil, nil + } + + user, err := database.GetUser(dbConn, key.UserID) + if err != nil { + return nil, err + } + + return user, nil +} + +func getUserFromSession(dbConn *sql.DB, sessionId string) (*database.User, error) { + session, err := database.GetSession(dbConn, sessionId) + if err != nil { + return nil, err + } + + if session.ExpireAt.Before(time.Now()) { + session = nil + database.DeleteSession(dbConn, sessionId) + return nil, fmt.Errorf("session expired") + } + + user, err := database.GetUser(dbConn, session.UserID) + if err != nil { + return nil, err + } + + return user, nil +} diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index a6c2a45..caaedf1 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -2,14 +2,24 @@ package auth_test import ( "database/sql" + "net/http" + "net/http/httptest" "os" + "testing" + "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" ) +func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, _failure types.Continuation) types.ContinuationChain { + return success(context, req, resp) + } +} + func setup() (*sql.DB, *types.RequestContext, func()) { randomDb := utils.RandomId() @@ -28,9 +38,61 @@ func setup() (*sql.DB, *types.RequestContext, func()) { } } -/* -todo: test types key creation -+ api key attached to user -+ user session is unique -+ goLogin goes to page in cookie -*/ +func TestLoginSendsYouToRedirect(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + user := &database.User{ + ID: "test", + Username: "test", + } + database.FindOrSaveUser(db, user) + + session, _ := database.MakeUserSessionFor(db, user) + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) + })) + defer testServer.Close() + + protectedPath := testServer.URL + "/protected-path" + req := httptest.NewRequest("GET", protectedPath, nil) + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + + location := resp.Header().Get("Location") + if resp.Code != http.StatusFound && location != "/login" { + t.Errorf("expected redirect code, got %d, to login, got %s", resp.Code, location) + } + + req.AddCookie(&http.Cookie{ + Name: "session", + Value: session.ID, + MaxAge: 60, + }) + resp = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + if resp.Code != http.StatusOK { +} + +func TestOauthFormatsUsername(t *testing.T) { + +} + +func TestSessionIsUnique(t *testing.T) {} + +func TestLogoutClearsCookie(t *testing.T) { + +} + +func TestRefreshUpdatesExpiration(t *testing.T) { + +} + +func TestVerifySessionEnsuresNonExpired(t *testing.T) { + +} + +func TestAPITokensAreEquivalentToSessions(t *testing.T) { + +} -- cgit v1.2.3-70-g09d2 From ae640a253edb5935380975fb07430e910a83b340 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Fri, 5 Apr 2024 15:43:03 -0600 Subject: add some auth test cases --- api/auth/auth.go | 9 +- api/auth/auth_test.go | 234 ++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 212 insertions(+), 31 deletions(-) diff --git a/api/auth/auth.go b/api/auth/auth.go index 3c633cd..becce24 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -74,7 +74,6 @@ func InterceptOauthCodeContinuation(context *types.RequestContext, req *http.Req reqContext := req.Context() token, err := context.Args.OauthConfig.Exchange(reqContext, code, oauth2.SetAuthURLParam("code_verifier", verifierCookie.Value)) if err != nil { - log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } @@ -195,12 +194,13 @@ func LogoutContinuation(context *types.RequestContext, req *http.Request, resp h _ = database.DeleteSession(context.DBConn, sessionCookie.Value) } - http.Redirect(resp, req, "/", http.StatusFound) http.SetCookie(resp, &http.Cookie{ Name: "session", MaxAge: 0, Value: "", }) + http.Redirect(resp, req, "/", http.StatusFound) + return success(context, req, resp) } } @@ -225,10 +225,7 @@ func getOauthUser(dbConn *sql.DB, client *http.Client, uri string) (*database.Us } func createUserFromOauthResponse(response *http.Response) (*database.User, error) { - user := &database.User{ - CreatedAt: time.Now(), - } - + user := &database.User{} err := json.NewDecoder(response.Body).Decode(user) defer response.Body.Close() diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index caaedf1..1e54099 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -2,9 +2,11 @@ package auth_test import ( "database/sql" + "log" "net/http" "net/http/httptest" "os" + "strings" "testing" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" @@ -12,6 +14,7 @@ import ( "git.hatecomputers.club/hatecomputers/hatecomputers.club/args" "git.hatecomputers.club/hatecomputers/hatecomputers.club/database" "git.hatecomputers.club/hatecomputers/hatecomputers.club/utils" + "golang.org/x/oauth2" ) func IdContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { @@ -38,51 +41,232 @@ func setup() (*sql.DB, *types.RequestContext, func()) { } } -func TestLoginSendsYouToRedirect(t *testing.T) { +func FakedOauthServer() *httptest.Server { + oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/auth" { + code := utils.RandomId() + + state := r.URL.Query().Get("state") + redirectPath := r.URL.Query().Get("redirect_uri") + redirectPath += "?code=" + code + "&state=" + state + + http.Redirect(w, r, redirectPath, http.StatusFound) + } + if r.URL.Path == "/token" { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"test","token_type":"bearer","expires_in":3600,"refresh_token":"test","scope":"test"}`)) + } + if r.URL.Path == "/user" { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"sub":"test","name":"test","preferred_username":"test@domain.com"}`)) + } + })) + + return oauthServer +} + +func EchoUsernameContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { + return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { + resp.Write([]byte(context.User.Username)) + return success(context, req, resp) + } +} + +func MockUserEndpointServer(context *types.RequestContext) *httptest.Server { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/protected-path" { + auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/login" { + log.Println("login") + auth.StartSessionContinuation(context, r, w)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/callback" { + log.Println("callback") + auth.InterceptOauthCodeContinuation(context, r, w)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/me" { + auth.VerifySessionContinuation(context, r, w)(EchoUsernameContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) + } + + if r.URL.Path == "/logout" { + auth.LogoutContinuation(context, r, w)(IdContinuation, IdContinuation) + } + })) + return testServer +} + +func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) { + return &oauth2.Config{ + ClientID: "test", + ClientSecret: "test", + Scopes: []string{"test"}, + Endpoint: oauth2.Endpoint{ + AuthURL: oauthServerURL + "/auth", + TokenURL: oauthServerURL + "/token", + }, + RedirectURL: testServerURL + "/callback", + }, oauthServerURL + "/user" +} + +func FollowAuthentication( + oauthServer *httptest.Server, + testServer *httptest.Server, + cookies map[string]*http.Cookie, + location string, +) (map[string]*http.Cookie, string) { + resp := httptest.NewRecorder() + resp.Code = 0 + + for resp.Code == 0 || resp.Code == http.StatusFound { + req := httptest.NewRequest("GET", location, nil) + resp = httptest.NewRecorder() + + for _, cookie := range cookies { + req.AddCookie(cookie) + } + if strings.HasPrefix(location, oauthServer.URL) { + oauthServer.Config.Handler.ServeHTTP(resp, req) + } else { + testServer.Config.Handler.ServeHTTP(resp, req) + } + for _, cookie := range resp.Result().Cookies() { + cookies[cookie.Name] = cookie + } + + if resp.Code == http.StatusFound { + location = resp.Header().Get("Location") + } + } + + return cookies, location +} + +func TestOauthCreatesUserWithCorrectUsername(t *testing.T) { db, context, cleanup := setup() defer cleanup() - user := &database.User{ - ID: "test", - Username: "test", + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + defer oauthServer.Close() + defer testServer.Close() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + user, _ := database.GetUser(db, "test") + if user != nil { + t.Errorf("expected no user, got user") } - database.FindOrSaveUser(db, user) - session, _ := database.MakeUserSessionFor(db, user) + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - auth.VerifySessionContinuation(context, r, w)(IdContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) - })) + user, _ = database.GetUser(db, "test") + if user == nil { + t.Errorf("expected a user to be created, could not find user") + } + if user.Username != "test" { + t.Errorf("expected username to be test, got %s", user.Username) + } +} + +func TestOauthRedirectsToPreviousLockedPage(t *testing.T) { + _, context, cleanup := setup() + defer cleanup() + + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + defer oauthServer.Close() defer testServer.Close() - protectedPath := testServer.URL + "/protected-path" - req := httptest.NewRequest("GET", protectedPath, nil) + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + req := httptest.NewRequest("GET", "/protected-path", nil) resp := httptest.NewRecorder() testServer.Config.Handler.ServeHTTP(resp, req) - location := resp.Header().Get("Location") - if resp.Code != http.StatusFound && location != "/login" { - t.Errorf("expected redirect code, got %d, to login, got %s", resp.Code, location) + if resp.Code != http.StatusFound && !strings.HasSuffix(location, "/login") { + t.Errorf("expected redirect to /login, got %d and %s", resp.Code, resp.Header().Get("Location")) } - req.AddCookie(&http.Cookie{ - Name: "session", - Value: session.ID, - MaxAge: 60, - }) - resp = httptest.NewRecorder() - testServer.Config.Handler.ServeHTTP(resp, req) - if resp.Code != http.StatusOK { + cookies := make(map[string]*http.Cookie) + cookies, location = FollowAuthentication(oauthServer, testServer, cookies, "/protected-page") + + if !(strings.HasSuffix(location, "/protected-page")) { + t.Errorf("expected to redirect back to /protected-page after login, got %s", location) + } } -func TestOauthFormatsUsername(t *testing.T) { +func TestOauthSetsUniqueSession(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + defer oauthServer.Close() + defer testServer.Close() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") + + cookiesAgain := make(map[string]*http.Cookie) + cookiesAgain, _ = FollowAuthentication(oauthServer, testServer, cookiesAgain, "/me") + + sessionOne := cookies["session"].Value + sessionTwo := cookiesAgain["session"].Value + if sessionOne == sessionTwo { + t.Errorf("expected unique session ids, got %s and %s", sessionOne, sessionTwo) + } + session, _ := database.GetSession(db, sessionOne) + if session.UserID != "test" { + t.Errorf("expected session to be associated with user test, got %s", session.UserID) + } } -func TestSessionIsUnique(t *testing.T) {} +func TestLogoutClearsSession(t *testing.T) { + db, context, cleanup := setup() + defer cleanup() + + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + defer oauthServer.Close() + defer testServer.Close() -func TestLogoutClearsCookie(t *testing.T) { + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/me") + + req := httptest.NewRequest("GET", "/logout", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + for _, cookie := range resp.Result().Cookies() { + cookies[cookie.Name] = cookie + } + + req = httptest.NewRequest("GET", "/me", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp = httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") { + t.Errorf("expected redirect to /login after logout, got %d and %s", resp.Code, resp.Header().Get("Location")) + } + + session, _ := database.GetSession(db, cookies["session"].Value) + if session != nil { + t.Errorf("expected session to be deleted, got session") + } } func TestRefreshUpdatesExpiration(t *testing.T) { -- cgit v1.2.3-70-g09d2 From 5177735b835289c8437799536d3654e5ab142fa3 Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Sat, 6 Apr 2024 13:36:13 -0600 Subject: finish auth tests --- api/auth/auth_test.go | 116 ++++++++++++++++++++++++++++++-------------------- api/keys/keys.go | 9 ++-- api/serve.go | 2 +- database/users.go | 12 ++++++ 4 files changed, 88 insertions(+), 51 deletions(-) diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index 1e54099..a3d5b16 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -8,6 +8,7 @@ import ( "os" "strings" "testing" + "time" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/auth" "git.hatecomputers.club/hatecomputers/hatecomputers.club/api/types" @@ -23,24 +24,6 @@ func IdContinuation(context *types.RequestContext, req *http.Request, resp http. } } -func setup() (*sql.DB, *types.RequestContext, func()) { - randomDb := utils.RandomId() - - testDb := database.MakeConn(&randomDb) - database.Migrate(testDb) - - context := &types.RequestContext{ - DBConn: testDb, - Args: &args.Arguments{}, - TemplateData: &(map[string]interface{}{}), - } - - return testDb, context, func() { - testDb.Close() - os.Remove(randomDb) - } -} - func FakedOauthServer() *httptest.Server { oauthServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/auth" { @@ -89,7 +72,7 @@ func MockUserEndpointServer(context *types.RequestContext) *httptest.Server { } if r.URL.Path == "/me" { - auth.VerifySessionContinuation(context, r, w)(EchoUsernameContinuation, auth.GoLoginContinuation)(IdContinuation, IdContinuation) + auth.VerifySessionContinuation(context, r, w)(auth.RefreshSessionContinuation, auth.GoLoginContinuation)(EchoUsernameContinuation, IdContinuation)(IdContinuation, IdContinuation) } if r.URL.Path == "/logout" { @@ -99,6 +82,30 @@ func MockUserEndpointServer(context *types.RequestContext) *httptest.Server { return testServer } +func setup() (*sql.DB, *types.RequestContext, *httptest.Server, *httptest.Server, func()) { + randomDb := utils.RandomId() + + testDb := database.MakeConn(&randomDb) + database.Migrate(testDb) + + context := &types.RequestContext{ + DBConn: testDb, + Args: &args.Arguments{}, + TemplateData: &(map[string]interface{}{}), + } + + oauthServer := FakedOauthServer() + testServer := MockUserEndpointServer(context) + + return testDb, context, oauthServer, testServer, func() { + oauthServer.Close() + testServer.Close() + + testDb.Close() + os.Remove(randomDb) + } +} + func GetOauthConfig(oauthServerURL string, testServerURL string) (*oauth2.Config, string) { return &oauth2.Config{ ClientID: "test", @@ -146,14 +153,9 @@ func FollowAuthentication( } func TestOauthCreatesUserWithCorrectUsername(t *testing.T) { - db, context, cleanup := setup() + db, context, oauthServer, testServer, cleanup := setup() defer cleanup() - oauthServer := FakedOauthServer() - testServer := MockUserEndpointServer(context) - defer oauthServer.Close() - defer testServer.Close() - context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) user, _ := database.GetUser(db, "test") @@ -174,14 +176,9 @@ func TestOauthCreatesUserWithCorrectUsername(t *testing.T) { } func TestOauthRedirectsToPreviousLockedPage(t *testing.T) { - _, context, cleanup := setup() + _, context, oauthServer, testServer, cleanup := setup() defer cleanup() - oauthServer := FakedOauthServer() - testServer := MockUserEndpointServer(context) - defer oauthServer.Close() - defer testServer.Close() - context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) req := httptest.NewRequest("GET", "/protected-path", nil) @@ -201,14 +198,9 @@ func TestOauthRedirectsToPreviousLockedPage(t *testing.T) { } func TestOauthSetsUniqueSession(t *testing.T) { - db, context, cleanup := setup() + db, context, oauthServer, testServer, cleanup := setup() defer cleanup() - oauthServer := FakedOauthServer() - testServer := MockUserEndpointServer(context) - defer oauthServer.Close() - defer testServer.Close() - context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) cookies := make(map[string]*http.Cookie) @@ -230,14 +222,9 @@ func TestOauthSetsUniqueSession(t *testing.T) { } func TestLogoutClearsSession(t *testing.T) { - db, context, cleanup := setup() + db, context, oauthServer, testServer, cleanup := setup() defer cleanup() - oauthServer := FakedOauthServer() - testServer := MockUserEndpointServer(context) - defer oauthServer.Close() - defer testServer.Close() - context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) cookies := make(map[string]*http.Cookie) @@ -270,13 +257,52 @@ func TestLogoutClearsSession(t *testing.T) { } func TestRefreshUpdatesExpiration(t *testing.T) { + db, context, oauthServer, testServer, cleanup := setup() + defer cleanup() + + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) + + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path") + + session, _ := database.GetSession(db, cookies["session"].Value) + + req := httptest.NewRequest("GET", "/me", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + updatedSession, _ := database.GetSession(db, cookies["session"].Value) + + // if session expiration is greater than or equal to updated session expiration + if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) { + t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt) + } } func TestVerifySessionEnsuresNonExpired(t *testing.T) { + db, context, oauthServer, testServer, cleanup := setup() + defer cleanup() -} + context.Args.OauthConfig, context.Args.OauthUserInfoURI = GetOauthConfig(oauthServer.URL, testServer.URL) -func TestAPITokensAreEquivalentToSessions(t *testing.T) { + cookies := make(map[string]*http.Cookie) + cookies, _ = FollowAuthentication(oauthServer, testServer, cookies, "/protected-path") + session, _ := database.GetSession(db, cookies["session"].Value) + session.ExpireAt = time.Now().Add(-time.Hour) + database.SaveSession(db, session) + + req := httptest.NewRequest("GET", "/me", nil) + for _, cookie := range cookies { + req.AddCookie(cookie) + } + resp := httptest.NewRecorder() + testServer.Config.Handler.ServeHTTP(resp, req) + + if resp.Code != http.StatusFound && !strings.HasSuffix(resp.Header().Get("Location"), "/login") { + t.Errorf("expected redirect to /login after session expiration, got %d and %s", resp.Code, resp.Header().Get("Location")) + } } diff --git a/api/keys/keys.go b/api/keys/keys.go index ad380fc..cef3f3c 100644 --- a/api/keys/keys.go +++ b/api/keys/keys.go @@ -62,27 +62,26 @@ func CreateAPIKeyContinuation(context *types.RequestContext, req *http.Request, func DeleteAPIKeyContinuation(context *types.RequestContext, req *http.Request, resp http.ResponseWriter) types.ContinuationChain { return func(success types.Continuation, failure types.Continuation) types.ContinuationChain { - key := req.FormValue("key") + apiKey := req.FormValue("key") - typesKey, err := database.GetAPIKey(context.DBConn, key) + key, err := database.GetAPIKey(context.DBConn, apiKey) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } - if (typesKey == nil) || (typesKey.UserID != context.User.ID) { + if (key == nil) || (key.UserID != context.User.ID) { resp.WriteHeader(http.StatusUnauthorized) return failure(context, req, resp) } - err = database.DeleteAPIKey(context.DBConn, key) + err = database.DeleteAPIKey(context.DBConn, apiKey) if err != nil { log.Println(err) resp.WriteHeader(http.StatusInternalServerError) return failure(context, req, resp) } - http.Redirect(resp, req, "/keys", http.StatusFound) return success(context, req, resp) } } diff --git a/api/serve.go b/api/serve.go index 2b0eba4..c8775d8 100644 --- a/api/serve.go +++ b/api/serve.go @@ -140,7 +140,7 @@ func MakeServer(argv *args.Arguments, dbConn *sql.DB) *http.Server { mux.HandleFunc("POST /keys/delete", func(w http.ResponseWriter, r *http.Request) { requestContext := makeRequestContext() - LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) + LogRequestContinuation(requestContext, r, w)(auth.VerifySessionContinuation, FailurePassingContinuation)(keys.DeleteAPIKeyContinuation, auth.GoLoginContinuation)(keys.ListAPIKeysContinuation, keys.ListAPIKeysContinuation)(template.TemplateContinuation("api_keys.html", true), template.TemplateContinuation("api_keys.html", true))(LogExecutionTimeContinuation, LogExecutionTimeContinuation)(IdContinuation, IdContinuation) }) mux.HandleFunc("GET /guestbook", func(w http.ResponseWriter, r *http.Request) { diff --git a/database/users.go b/database/users.go index 5cebb8f..6f9456e 100644 --- a/database/users.go +++ b/database/users.go @@ -111,6 +111,18 @@ func DeleteSession(dbConn *sql.DB, sessionId string) error { return nil } +func SaveSession(dbConn *sql.DB, session *UserSession) (*UserSession, error) { + log.Println("saving session", session.ID) + + _, err := dbConn.Exec(`INSERT OR REPLACE INTO user_sessions (id, user_id, expire_at) VALUES (?, ?, ?);`, session.ID, session.UserID, session.ExpireAt) + if err != nil { + log.Println(err) + return nil, err + } + + return session, nil +} + func RefreshSession(dbConn *sql.DB, sessionId string) (*UserSession, error) { newExpireAt := time.Now().Add(ExpiryDuration) -- cgit v1.2.3-70-g09d2 From cad8e2c4ed5e3bab61ff243f8677f8a46eaeafb0 Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Sat, 6 Apr 2024 13:40:46 -0600 Subject: nits --- api/auth/auth.go | 2 +- api/auth/auth_test.go | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/api/auth/auth.go b/api/auth/auth.go index becce24..0ffbf9c 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -144,7 +144,7 @@ func VerifySessionContinuation(context *types.RequestContext, req *http.Request, http.SetCookie(resp, &http.Cookie{ Name: "session", Value: "", - MaxAge: 0, // reset session cookie in case + MaxAge: 0, }) context.User = nil diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index a3d5b16..5e67c6d 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -276,7 +276,6 @@ func TestRefreshUpdatesExpiration(t *testing.T) { updatedSession, _ := database.GetSession(db, cookies["session"].Value) - // if session expiration is greater than or equal to updated session expiration if session.ExpireAt.After(updatedSession.ExpireAt) || session.ExpireAt.Equal(updatedSession.ExpireAt) { t.Errorf("expected session expiration to be updated, got %s and %s", session.ExpireAt, updatedSession.ExpireAt) } -- cgit v1.2.3-70-g09d2