summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorElizabeth Hunt <elizabeth@simponic.xyz>2024-04-02 20:26:24 -0600
committerElizabeth Hunt <elizabeth@simponic.xyz>2024-04-02 20:26:24 -0600
commit385d4a84eb813ce6f777b6ab10642ad447f93321 (patch)
tree39ccac5de9aafe55dfa70e52d39ea78e534006a5 /test
parentce393a5ac1dedaa04a885b5400d66bcbbf794855 (diff)
downloadhatecomputers.club-385d4a84eb813ce6f777b6ab10642ad447f93321.tar.gz
hatecomputers.club-385d4a84eb813ce6f777b6ab10642ad447f93321.zip
fix dns race condition
Diffstat (limited to 'test')
-rw-r--r--test/dns_test.go49
1 files changed, 28 insertions, 21 deletions
diff --git a/test/dns_test.go b/test/dns_test.go
index 55bb060..2caabe4 100644
--- a/test/dns_test.go
+++ b/test/dns_test.go
@@ -21,10 +21,10 @@ func destroy(conn *sql.DB, path string) {
}
func randomPort() int {
- return rand.Intn(3000) + 10000
+ return rand.Intn(3000) + 1024
}
-func setup() (*sql.DB, *dns.Server, int, *string, func()) {
+func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) {
randomDb := utils.RandomId()
dnsPort := randomPort()
@@ -35,32 +35,35 @@ func setup() (*sql.DB, *dns.Server, int, *string, func()) {
}
database.FindOrSaveUser(testDb, testUser)
+ waitLock := &sync.Mutex{}
server := hcdns.MakeServer(&args.Arguments{
DnsPort: dnsPort,
}, testDb)
+ server.NotifyStartedFunc = func() {
+ waitLock.Unlock()
+ }
+ waitLock.Lock()
- waitGroup := sync.WaitGroup{}
- waitGroup.Add(1)
go func() {
server.ListenAndServe()
- waitGroup.Done()
}()
+ waitLock.Lock()
address := fmt.Sprintf("127.0.0.1:%d", dnsPort)
- return testDb, server, dnsPort, &address, func() {
+ return testDb, server, &address, waitLock, func() {
+ server.Shutdown()
+
testDb.Close()
os.Remove(randomDb)
-
- server.Shutdown()
- waitGroup.Wait()
}
}
func TestWhenCNAMEIsResolved(t *testing.T) {
t.Log("TestWhenCNAMEIsResolved")
- testDb, _, _, addr, cleanup := setup()
+ testDb, _, addr, lock, cleanup := setup()
defer cleanup()
+ defer lock.Unlock()
cname := &database.DNSRecord{
ID: "1",
@@ -85,8 +88,8 @@ func TestWhenCNAMEIsResolved(t *testing.T) {
qtype := dns.TypeA
domain := dns.Fqdn(cname.Name)
- client := new(dns.Client)
- message := new(dns.Msg)
+ client := &dns.Client{}
+ message := &dns.Msg{}
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, *addr)
@@ -135,13 +138,14 @@ func TestWhenCNAMEIsResolved(t *testing.T) {
func TestWhenNoRecordNxDomain(t *testing.T) {
t.Log("TestWhenNoRecordNxDomain")
- _, _, _, addr, cleanup := setup()
+ _, _, addr, lock, cleanup := setup()
defer cleanup()
+ defer lock.Unlock()
qtype := dns.TypeA
domain := dns.Fqdn("nonexistant.example.com.")
- client := new(dns.Client)
- message := new(dns.Msg)
+ client := &dns.Client{}
+ message := &dns.Msg{}
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, *addr)
@@ -162,8 +166,9 @@ func TestWhenNoRecordNxDomain(t *testing.T) {
func TestWhenUnresolvingCNAME(t *testing.T) {
t.Log("TestWhenUnresolvingCNAME")
- testDb, _, _, addr, cleanup := setup()
+ testDb, _, addr, lock, cleanup := setup()
defer cleanup()
+ defer lock.Unlock()
cname := &database.DNSRecord{
ID: "1",
@@ -178,8 +183,8 @@ func TestWhenUnresolvingCNAME(t *testing.T) {
qtype := dns.TypeA
domain := dns.Fqdn(cname.Name)
- client := new(dns.Client)
- message := new(dns.Msg)
+ client := &dns.Client{}
+ message := &dns.Msg{}
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, *addr)
@@ -216,8 +221,9 @@ func TestWhenUnresolvingCNAME(t *testing.T) {
func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
t.Log("TestWhenUnresolvingCNAMEWithMaxDepth")
- testDb, _, _, addr, cleanup := setup()
+ testDb, _, addr, lock, cleanup := setup()
defer cleanup()
+ defer lock.Unlock()
cname := &database.DNSRecord{
ID: "1",
@@ -232,8 +238,8 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
qtype := dns.TypeA
domain := dns.Fqdn(cname.Name)
- client := new(dns.Client)
- message := new(dns.Msg)
+ client := &dns.Client{}
+ message := &dns.Msg{}
message.SetQuestion(domain, qtype)
in, _, err := client.Exchange(message, *addr)
@@ -245,6 +251,7 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) {
if len(in.Answer) > 0 {
t.Fatalf("expected 0 answers, got %d", len(in.Answer))
}
+
if in.Rcode != dns.RcodeServerFailure {
t.Fatalf("expected SERVFAIL, got %d", in.Rcode)
}