summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorElizabeth Hunt <elizabeth@simponic.xyz>2024-04-02 20:26:24 -0600
committerElizabeth Hunt <elizabeth@simponic.xyz>2024-04-02 20:26:24 -0600
commit385d4a84eb813ce6f777b6ab10642ad447f93321 (patch)
tree39ccac5de9aafe55dfa70e52d39ea78e534006a5
parentce393a5ac1dedaa04a885b5400d66bcbbf794855 (diff)
downloadhatecomputers.club-385d4a84eb813ce6f777b6ab10642ad447f93321.tar.gz
hatecomputers.club-385d4a84eb813ce6f777b6ab10642ad447f93321.zip
fix dns race condition
-rw-r--r--.drone.yml7
-rw-r--r--test/dns_test.go49
2 files changed, 34 insertions, 22 deletions
diff --git a/.drone.yml b/.drone.yml
index b96d25e..d056e69 100644
--- a/.drone.yml
+++ b/.drone.yml
@@ -12,7 +12,7 @@ steps:
trigger:
event:
- - push
+ - pull_request
---
kind: pipeline
@@ -20,6 +20,11 @@ type: docker
name: deploy
steps:
+ - name: run tests
+ image: golang
+ commands:
+ - go build
+ - go test -p 1 -v ./...
- name: docker
image: plugins/docker
settings:
diff --git a/test/dns_test.go b/test/dns_test.go
index 55bb060..2caabe4 100644
--- a/test/dns_test.go
+++ b/test/dns_test.go
@@ -21,10 +21,10 @@ func destroy(conn *sql.DB, path string) {
}
func randomPort() int {
- return rand.Intn(3000) + 10000
+ return rand.Intn(3000) + 1024
}
-func setup() (*sql.DB, *dns.Server, int, *string, func()) {
+func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
randomDb := utils.RandomId()
dnsPort := randomPort()
@@ -35,32 +35,35 @@ func setup() (*sql.DB, *dns.Server, int, *string, func()) {
}
database.FindOrSaveUser(testDb, testUser)
+ waitLock := &sync.Mutex{}
server := hcdns.MakeServer(&args.Arguments{
DnsPort: dnsPort,
}, testDb)
+ server.NotifyStartedFunc = func() {
+ waitLock.Unlock()
+ }
+ waitLock.Lock()
- waitGroup := sync.WaitGroup{}
- waitGroup.Add(1)
go func() {
server.ListenAndServe()
- waitGroup.Done()
}()
+ waitLock.Lock()
address := fmt.Sprintf("127.0.0.1:%d", dnsPort)
- return testDb, server, dnsPort, &address, func() {
+ return testDb, server, &address, waitLock, func() {
+ server.Shutdown()
+
testDb.Close()
os.Remove(randomDb)
-
- server.Shutdown()
- waitGroup.Wait()
}
}
func TestWhenCNAMEIsResolved(t *testing.T) {
t.Log("TestWhenCNAMEIsResolved")
- testDb, _, _, addr, cleanup := setup()
+ testDb, _, addr, lock, cleanup := setup()
defer cleanup()
+ defer lock.Unlock()
cname := &database.DNSRecord{
ID: "1",
@@ -85,8 +88,8 @@ func TestWhenCNAMEIsResolved(t *testing.T) {
qtype := dns.TypeA
domain := dns.Fqdn(cname.Name)
- client := new(dns.Client)
- message := new(dns.Msg)
+ client := &dns.Client{}
+ message := &dns.Msg{}
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, *addr)
@@ -135,13 +138,14 @@ func TestWhenCNAMEIsResolved(t *testing.T) {
func TestWhenNoRecordNxDomain(t *testing.T) {
t.Log("TestWhenNoRecordNxDomain")
- _, _, _, addr, cleanup := setup()
+ _, _, addr, lock, cleanup := setup()
defer cleanup()
+ defer lock.Unlock()
qtype := dns.TypeA
domain := dns.Fqdn("nonexistant.example.com.")
- client := new(dns.Client)
- message := new(dns.Msg)
+ client := &dns.Client{}
+ message := &dns.Msg{}
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, *addr)
@@ -162,8 +166,9 @@ func TestWhenNoRecordNxDomain(t *testing.T) {
func TestWhenUnresolvingCNAME(t *testing.T) {
t.Log("TestWhenUnresolvingCNAME")
- testDb, _, _, addr, cleanup := setup()
+ testDb, _, addr, lock, cleanup := setup()
defer cleanup()
+ defer lock.Unlock()
cname := &database.DNSRecord{
ID: "1",
@@ -178,8 +183,8 @@ func TestWhenUnresolvingCNAME(t *testing.T) {
qtype := dns.TypeA
domain := dns.Fqdn(cname.Name)
- client := new(dns.Client)
- message := new(dns.Msg)
+ client := &dns.Client{}
+ message := &dns.Msg{}
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, *addr)
@@ -216,8 +221,9 @@ func TestWhenUnresolvingCNAME(t *testing.T) {
func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
t.Log("TestWhenUnresolvingCNAMEWithMaxDepth")
- testDb, _, _, addr, cleanup := setup()
+ testDb, _, addr, lock, cleanup := setup()
defer cleanup()
+ defer lock.Unlock()
cname := &database.DNSRecord{
ID: "1",
@@ -232,8 +238,8 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
qtype := dns.TypeA
domain := dns.Fqdn(cname.Name)
- client := new(dns.Client)
- message := new(dns.Msg)
+ client := &dns.Client{}
+ message := &dns.Msg{}
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, *addr)
@@ -245,6 +251,7 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
if len(in.Answer) > 0 {
t.Fatalf("expected 0 answers, got %d", len(in.Answer))
}
+
if in.Rcode != dns.RcodeServerFailure {
t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
}