diff options
| author | simponic <simponic@hatecomputers.club> | 2024-03-28 16:58:07 -0400 |
|---|---|---|
| committer | simponic <simponic@hatecomputers.club> | 2024-03-28 16:58:07 -0400 |
| commit | 60fc4ebb599d82f5c7ddaca52f8aba74f0876381 (patch) | |
| tree | abe1eebb6154453cfa67812d7dfc982d758931a0 /database | |
| parent | dee173cc63d3b51d47c1a321096a4963fe458075 (diff) | |
| download | hatecomputers.club-60fc4ebb599d82f5c7ddaca52f8aba74f0876381.tar.gz hatecomputers.club-60fc4ebb599d82f5c7ddaca52f8aba74f0876381.zip | |
internal recursive dns server (#2)
Co-authored-by: Lizzy Hunt <lizzy.hunt@usu.edu>
Reviewed-on: https://git.hatecomputers.club/hatecomputers/hatecomputers.club/pulls/2
Diffstat (limited to 'database')
| -rw-r--r-- | database/dns.go | 59 | ||||
| -rw-r--r-- | database/migrate.go | 22 |
2 files changed, 78 insertions, 3 deletions
diff --git a/database/dns.go b/database/dns.go index bb5c1ef..568653d 100644 --- a/database/dns.go +++ b/database/dns.go @@ -2,8 +2,10 @@ package database import ( "database/sql" + "fmt" _ "github.com/mattn/go-sqlite3" "log" + "strings" "time" ) @@ -14,6 +16,7 @@ type DNSRecord struct { Type string `json:"type"` Content string `json:"content"` TTL int `json:"ttl"` + Internal bool `json:"internal"` CreatedAt time.Time `json:"created_at"` } @@ -29,7 +32,7 @@ func GetUserDNSRecords(db *sql.DB, userID string) ([]DNSRecord, error) { var records []DNSRecord for rows.Next() { var record DNSRecord - err := rows.Scan(&record.ID, &record.UserID, &record.Name, &record.Type, &record.Content, &record.TTL, &record.CreatedAt) + err := rows.Scan(&record.ID, &record.UserID, &record.Name, &record.Type, &record.Content, &record.TTL, &record.Internal, &record.CreatedAt) if err != nil { return nil, err } @@ -43,7 +46,7 @@ func SaveDNSRecord(db *sql.DB, record *DNSRecord) (*DNSRecord, error) { log.Println("saving dns record", record) record.CreatedAt = time.Now() - _, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.CreatedAt) + _, err := db.Exec("INSERT OR REPLACE INTO dns_records (id, user_id, name, type, content, ttl, internal, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", record.ID, record.UserID, record.Name, record.Type, record.Content, record.TTL, record.Internal, record.CreatedAt) if err != nil { return nil, err @@ -56,7 +59,7 @@ func GetDNSRecord(db *sql.DB, recordID string) (*DNSRecord, error) { row := db.QueryRow("SELECT * FROM dns_records WHERE id = ?", recordID) var record DNSRecord - err := row.Scan(&record.ID, &record.UserID, &record.Name, &record.Type, &record.Content, &record.TTL, &record.CreatedAt) + err := row.Scan(&record.ID, &record.UserID, &record.Name, &record.Type, &record.Content, &record.TTL, &record.Internal, &record.CreatedAt) if err != nil { return nil, err } @@ -72,3 +75,53 @@ func DeleteDNSRecord(db *sql.DB, recordID string) error { } return nil } + +func FindFirstDomainOwnerId(db *sql.DB, domain string) (string, error) { + log.Println("finding domain owner for", domain) + + ownerID := "" + parts := strings.Split(domain, ".") + if len(parts) < 2 { + return ownerID, fmt.Errorf("invalid domain; must have at least two parts") + } + + for ownerID == "" { + row := db.QueryRow("SELECT user_id FROM domain_owners WHERE domain = ?", strings.Join(parts, ".")) + err := row.Scan(&ownerID) + + if err != nil { + if len(parts) == 1 { + break + } + parts = parts[1:] + } + } + + if ownerID == "" { + return ownerID, fmt.Errorf("no owner found for domain") + } + return ownerID, nil +} + +func FindDNSRecords(dbConn *sql.DB, name string, qtype string) ([]DNSRecord, error) { + log.Println("finding dns record(s) for", name, qtype) + + rows, err := dbConn.Query("SELECT * FROM dns_records WHERE name = ? AND type = ?", name, qtype) + if err != nil { + return nil, err + } + + defer rows.Close() + + var records []DNSRecord + for rows.Next() { + var record DNSRecord + err := rows.Scan(&record.ID, &record.UserID, &record.Name, &record.Type, &record.Content, &record.TTL, &record.Internal, &record.CreatedAt) + if err != nil { + return nil, err + } + records = append(records, record) + } + + return records, nil +} diff --git a/database/migrate.go b/database/migrate.go index b75c123..de1db4c 100644 --- a/database/migrate.go +++ b/database/migrate.go @@ -57,6 +57,7 @@ func MigrateDNSRecords(dbConn *sql.DB) (*sql.DB, error) { type TEXT NOT NULL, content TEXT NOT NULL, ttl INTEGER NOT NULL, + internal BOOLEAN NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE);`) if err != nil { @@ -65,6 +66,26 @@ func MigrateDNSRecords(dbConn *sql.DB) (*sql.DB, error) { return dbConn, nil } +func MigrateDomainOwners(dbConn *sql.DB) (*sql.DB, error) { + log.Println("migrating domain_owners table") + + _, err := dbConn.Exec(`CREATE TABLE IF NOT EXISTS domain_owners ( + user_id INTEGER NOT NULL, + domain TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE + );`) + if err != nil { + return dbConn, err + } + + _, err = dbConn.Exec(`CREATE UNIQUE INDEX IF NOT EXISTS idx_domain_owners_domain ON domain_owners (domain);`) + if err != nil { + return dbConn, err + } + return dbConn, nil +} + func MigrateUserSessions(dbConn *sql.DB) (*sql.DB, error) { log.Println("migrating user_sessions table") @@ -88,6 +109,7 @@ func Migrate(dbConn *sql.DB) (*sql.DB, error) { MigrateUsers, MigrateUserSessions, MigrateApiKeys, + MigrateDomainOwners, MigrateDNSRecords, } |
