diff options
| author | Elizabeth Hunt <elizabeth@simponic.xyz> | 2024-04-02 20:26:24 -0600 |
|---|---|---|
| committer | Elizabeth Hunt <elizabeth@simponic.xyz> | 2024-04-02 20:26:24 -0600 |
| commit | 385d4a84eb813ce6f777b6ab10642ad447f93321 (patch) | |
| tree | 39ccac5de9aafe55dfa70e52d39ea78e534006a5 /test/dns_test.go | |
| parent | ce393a5ac1dedaa04a885b5400d66bcbbf794855 (diff) | |
| download | hatecomputers.club-385d4a84eb813ce6f777b6ab10642ad447f93321.tar.gz hatecomputers.club-385d4a84eb813ce6f777b6ab10642ad447f93321.zip | |
fix dns race condition
Diffstat (limited to 'test/dns_test.go')
| -rw-r--r-- | test/dns_test.go | 49 |
1 files changed, 28 insertions, 21 deletions
diff --git a/test/dns_test.go b/test/dns_test.go index 55bb060..2caabe4 100644 --- a/test/dns_test.go +++ b/test/dns_test.go @@ -21,10 +21,10 @@ func destroy(conn *sql.DB, path string) { } func randomPort() int { - return rand.Intn(3000) + 10000 + return rand.Intn(3000) + 1024 } -func setup() (*sql.DB, *dns.Server, int, *string, func()) { +func setup() (*sql.DB, *dns.Server, *string, *sync.Mutex, func()) { randomDb := utils.RandomId() dnsPort := randomPort() @@ -35,32 +35,35 @@ func setup() (*sql.DB, *dns.Server, int, *string, func()) { } database.FindOrSaveUser(testDb, testUser) + waitLock := &sync.Mutex{} server := hcdns.MakeServer(&args.Arguments{ DnsPort: dnsPort, }, testDb) + server.NotifyStartedFunc = func() { + waitLock.Unlock() + } + waitLock.Lock() - waitGroup := sync.WaitGroup{} - waitGroup.Add(1) go func() { server.ListenAndServe() - waitGroup.Done() }() + waitLock.Lock() address := fmt.Sprintf("127.0.0.1:%d", dnsPort) - return testDb, server, dnsPort, &address, func() { + return testDb, server, &address, waitLock, func() { + server.Shutdown() + testDb.Close() os.Remove(randomDb) - - server.Shutdown() - waitGroup.Wait() } } func TestWhenCNAMEIsResolved(t *testing.T) { t.Log("TestWhenCNAMEIsResolved") - testDb, _, _, addr, cleanup := setup() + testDb, _, addr, lock, cleanup := setup() defer cleanup() + defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", @@ -85,8 +88,8 @@ func TestWhenCNAMEIsResolved(t *testing.T) { qtype := dns.TypeA domain := dns.Fqdn(cname.Name) - client := new(dns.Client) - message := new(dns.Msg) + client := &dns.Client{} + message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) @@ -135,13 +138,14 @@ func TestWhenCNAMEIsResolved(t *testing.T) { func TestWhenNoRecordNxDomain(t *testing.T) { t.Log("TestWhenNoRecordNxDomain") - _, _, _, addr, cleanup := setup() + _, _, addr, lock, cleanup := setup() defer cleanup() + defer lock.Unlock() qtype := dns.TypeA domain := dns.Fqdn("nonexistant.example.com.") - client := new(dns.Client) - message := new(dns.Msg) + client := &dns.Client{} + message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) @@ -162,8 +166,9 @@ func TestWhenNoRecordNxDomain(t *testing.T) { func TestWhenUnresolvingCNAME(t *testing.T) { t.Log("TestWhenUnresolvingCNAME") - testDb, _, _, addr, cleanup := setup() + testDb, _, addr, lock, cleanup := setup() defer cleanup() + defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", @@ -178,8 +183,8 @@ func TestWhenUnresolvingCNAME(t *testing.T) { qtype := dns.TypeA domain := dns.Fqdn(cname.Name) - client := new(dns.Client) - message := new(dns.Msg) + client := &dns.Client{} + message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) @@ -216,8 +221,9 @@ func TestWhenUnresolvingCNAME(t *testing.T) { func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { t.Log("TestWhenUnresolvingCNAMEWithMaxDepth") - testDb, _, _, addr, cleanup := setup() + testDb, _, addr, lock, cleanup := setup() defer cleanup() + defer lock.Unlock() cname := &database.DNSRecord{ ID: "1", @@ -232,8 +238,8 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { qtype := dns.TypeA domain := dns.Fqdn(cname.Name) - client := new(dns.Client) - message := new(dns.Msg) + client := &dns.Client{} + message := &dns.Msg{} message.SetQuestion(domain, qtype) in, _, err := client.Exchange(message, *addr) @@ -245,6 +251,7 @@ func TestWhenUnresolvingCNAMEWithMaxDepth(t *testing.T) { if len(in.Answer) > 0 { t.Fatalf("expected 0 answers, got %d", len(in.Answer)) } + if in.Rcode != dns.RcodeServerFailure { t.Fatalf("expected SERVFAIL, got %d", in.Rcode) } |
