summaryrefslogtreecommitdiff
path: root/database/dns.go
blob: 7851ab41a2888ee2ab992cbdfe3115a8f15ff651 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
package database

import (
	"database/sql"
	"fmt"
	_ "github.com/mattn/go-sqlite3"
	"log"
	"strings"
	"time"
)

type DomainOwner struct {
	UserID    string    `json:"user_id"`
	Domain    string    `json:"domain"`
	CreatedAt time.Time `json:"created_at"`
}

type DNSRecord struct {
	ID        string    `json:"id"`
	UserID    string    `json:"user_id"`
	Name      string    `json:"name"`
	Type      string    `json:"type"`
	Content   string    `json:"content"`
	TTL       int       `json:"ttl"`
	Internal  bool      `json:"internal"`
	CreatedAt time.Time `json:"created_at"`
}

func CountUserDNSRecords(db *sql.DB, userID string) (int, error) {
	log.Println("counting dns records for user", userID)

	row := db.QueryRow("SELECT COUNT(*) FROM dns_records WHERE user_id = ?", userID)
	var count int
	err := row.Scan(&count)
	if err != nil {
		return 0, err
	}
	return count, nil
}

func GetUserDNSRecords(db *sql.DB, userID string) ([]DNSRecord, error) {
	log.Println("getting dns records for user", userID)

	rows, err := db.Query("SELECT * FROM dns_records WHERE user_id = ?", userID)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var records []DNSRecord
	for rows.Next() {
		var record DNSRecord
		err := rows.Scan(&record.ID, &record.UserID, &record.Name, &record.Type, &record.Content, &record.TTL, &record.Internal, &record.CreatedAt)
		if err != nil {
			return nil, err
		}
		records = append(records, record)
	}

	return records, nil
}

func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) {
	log.Println("saving dns record", record.ID)

	if (record.CreatedAt == time.Time{}) {
		record.CreatedAt = time.Now()
	}

	_, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, internal, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.Internal, record.CreatedAt)

	if err != nil {
		return nil, err
	}
	return record, nil
}

func GetDNSRecord(db *sql.DB, recordID string) (*DNSRecord, error) {
	log.Println("getting dns record", recordID)

	row := db.QueryRow("SELECT * FROM dns_records WHERE id = ?", recordID)
	var record DNSRecord
	err := row.Scan(&record.ID, &record.UserID, &record.Name, &record.Type, &record.Content, &record.TTL, &record.Internal, &record.CreatedAt)
	if err != nil {
		return nil, err
	}
	return &record, nil
}

func DeleteDNSRecord(db *sql.DB, recordID string) error {
	log.Println("deleting dns record", recordID)

	_, err := db.Exec("DELETE FROM dns_records WHERE id = ?", recordID)
	if err != nil {
		return err
	}
	return nil
}

func FindFirstDomainOwnerId(db *sql.DB, domain string) (string, error) {
	log.Println("finding domain owner for", domain)

	ownerID := ""
	parts := strings.Split(domain, ".")
	if len(parts) < 2 {
		return ownerID, fmt.Errorf("invalid domain; must have at least two parts")
	}

	for ownerID == "" {
		row := db.QueryRow("SELECT user_id FROM domain_owners WHERE domain = ?", strings.Join(parts, "."))
		err := row.Scan(&ownerID)

		if err != nil {
			if len(parts) == 1 {
				break
			}
			parts = parts[1:]
		}
	}

	if ownerID == "" {
		return ownerID, fmt.Errorf("no owner found for domain")
	}
	return ownerID, nil
}

func FindDNSRecords(dbConn *sql.DB, name string, qtype string) ([]DNSRecord, error) {
	log.Println("finding dns record(s) for", name, qtype)

	rows, err := dbConn.Query("SELECT * FROM dns_records WHERE name = ? AND type = ?", name, qtype)
	if err != nil {
		return nil, err
	}

	defer rows.Close()

	var records []DNSRecord
	for rows.Next() {
		var record DNSRecord
		err := rows.Scan(&record.ID, &record.UserID, &record.Name, &record.Type, &record.Content, &record.TTL, &record.Internal, &record.CreatedAt)
		if err != nil {
			return nil, err
		}
		records = append(records, record)
	}

	return records, nil
}

func SaveDomainOwner(db *sql.DB, domainOwner *DomainOwner) (*DomainOwner, error) {
	log.Println("saving domain owner", domainOwner.Domain)

	domainOwner.CreatedAt = time.Now()
	_, err := db.Exec("INSERT OR REPLACE INTO domain_owners (user_id, domain, created_at) VALUES (?, ?, ?)", domainOwner.UserID, domainOwner.Domain, domainOwner.CreatedAt)

	if err != nil {
		return nil, err
	}
	return domainOwner, nil
}