Go:编写内网DNS服务器

时间:July 26, 2019 分类:

目录:

依赖的是github.com/miekg/dns

  • server.go定义dns服务器的server,也就是miekg/dns的handler,&Server作为参数传入
  • main.go为主启动流程,用于监听端口,创建DNSServer,接收信号并做对应处理
  • handler中重写ServeDNS的接口

写死域名解析的DNS

server.go

package main

import (
    "sync"
)

type Server struct {
    mu sync.RWMutex
}

func NewServer() *Server {
    return &Server {}
}

main.go

package main

import (
    "net"
    "runtime"
    "syscall"
    "os"
    "os/signal"
    "github.com/miekg/dns"
    "github.com/sirupsen/logrus"
)

var (
    log = logrus.New()
)

func main() {

    runtime.GOMAXPROCS(runtime.NumCPU())

    server := NewServer()

    // 检测UDP端口
    udpAddr, err := net.ResolveUDPAddr("udp", ":53")
    if err != nil {
        log.WithFields(logrus.Fields{"error": err.Error()}).Fatal(
            "resolve udp address error")
    }
    // 监听UDP端口
    p, err := net.ListenUDP("udp", udpAddr)
    if err != nil {
        log.WithFields(logrus.Fields{"error": err.Error()}).Fatal("listen udp error")
    }
    tcpAddr, err := net.ResolveTCPAddr("tcp", ":53")
    if err != nil {
        log.WithFields(logrus.Fields{"error": err.Error()}).Fatal(
           "resolve tcp address error")
    }
    l, err := net.ListenTCP("tcp", tcpAddr)

    // 协程启动DNS Server
    go dns.ActivateAndServe(l, p, server)
    // 监听接收到的信号
    sig := make(chan os.Signal)
    signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM)
    for {
        s := <-sig
            switch s {
                default:
                    log.WithFields(logrus.Fields{"signal": s}).Fatal("exiting")
                }

    }

}

handler.go

package main

import (
    "net"
    "github.com/miekg/dns"
    "github.com/sirupsen/logrus"
)

// 用于返回失败的DNS解析
func  (s *Server) HandleFailed(w dns.ResponseWriter, r *dns.Msg) {
    m := new(dns.Msg)
    m.SetRcode(r, dns.RcodeServerFailure)
    w.WriteMsg(m)
}

// 重写的ServeDNS接口,参考https://github.com/miekg/dns/blob/master/server.go#L25
func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
    // 空数据直接返回
    if r == nil || len(r.Question) == 0 {
        return
    }
    // 解析结果列表
    var addresses []string
    // 获取解析值
    name := r.Question[0].Name
    // 获取解析类型
    qtype := r.Question[0].Qtype
    ttl := 60

    // 对于A记录解析请求
    if qtype == dns.TypeA  {
        domain := name
        // 去掉结尾的.
        if domain[len(domain)-1] == '.' {
            domain = domain[0:len(domain)-1]
        }
        // 返回固定DNS解析
        if domain == "www.whysdomain.com" {
            addresses = []string{"192.168.1.1", "192.168.1.2"}
        } else {
            addresses = []string{"192.168.1.3", "192.168.1.4"}
        }

        // 构建返回信息
        m := new(dns.Msg)
        m.SetReply(r)
        m.Authoritative = true

        var dnsRR []dns.RR
        for _, address := range addresses {
            rr := new(dns.A)
            rr.Hdr = dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: uint32(ttl)}
            rr.A = net.ParseIP(address)
            dnsRR = append(dnsRR, rr)
        }
        m.Answer = dnsRR
        w.WriteMsg(m)
    } else {
        log.WithFields(logrus.Fields{"Qtype": qtype}).Error("wrong query type")
        s.HandleFailed(w, r)
    }
}

安装依赖

$ go get github.com/miekg/dns
$ go get github.com/sirupsen/logrus

如果golang.org/x不能get到可以

$ mkdir -p $GOPATH/src/github.com/golang/
$ git clone https://github.com/golang/sys.git $GOPATH/src/github.com/golang/sys
$ git clone https://github.com/golang/net.git $GOPATH/src/github.com/golang/net
$ git clone https://github.com/golang/text.git $GOPATH/src/github.com/golang/text
$ git clone https://github.com/golang/lint.git $GOPATH/src/github.com/golang/lint
$ git clone https://github.com/golang/tools.git $GOPATH/src/github.com/golang/tools
$ git clone https://github.com/golang/crypto.git $GOPATH/src/github.com/golang/crypto
$ ln -s $GOPATH/src/github.com/golang/ $GOPATH/src/golang.org/x

修改/etc/resolv.conf的dns

options timeout:2 attempts:3 rotate single-request-reopen
; generated by /usr/sbin/dhclient-script
nameserver 127.0.0.1

测试dns解析可以看到已经按照写死的dns进行解析了

$ nslookup blog.whysdomain.com
Server:     127.0.0.1
Address:    127.0.0.1#53

Name:   blog.whysdomain.com
Address: 192.168.1.3
Name:   blog.whysdomain.com
Address: 192.168.1.4

$ nslookup www.whysdomain.com
Server:     127.0.0.1
Address:    127.0.0.1#53

Name:   www.whysdomain.com
Address: 192.168.1.1
Name:   www.whysdomain.com
Address: 192.168.1.2

通过数据库提供解析数据

util.go定义了getRedisPool方法通过配置文件返回redigo.Pool对象

package main

import (
    "fmt"
    "time"
    "github.com/sirupsen/logrus"
    redigo "github.com/gomodule/redigo/redis"
)



func getRedisPool(config *tomlConfig) (*redigo.Pool){
    redisAddress := fmt.Sprintf("%s:%d", config.Redis.Host, config.Redis.Port)
    log.WithFields(logrus.Fields{"redisAddress": redisAddress}).Debug("get redis pool")
    idleTimeout := time.Duration(config.Redis.IdleTimeout) * time.Second

    return &redigo.Pool {
        MaxActive: config.Redis.MaxActive,
        MaxIdle: config.Redis.MaxIdle,
        IdleTimeout: idleTimeout,
        Dial: func () (redigo.Conn, error) { return redigo.Dial("tcp", redisAddress) },
    }
}

修改server,添加数据库相关

server.go

package main

import (
    "sync"
    redigo "github.com/gomodule/redigo/redis"
)

type Server struct {
    // redis连接池
    redisPool *redigo.Pool
    ttl int
    // 读写锁
    mu sync.RWMutex
}

func NewServer(config *tomlConfig) (*Server, error) {
    redisPool := getRedisPool(config)
    redisClient := redisPool.Get()
    defer redisClient.Close()
    // 测试redis连接情况
    if _, err := redisClient.Do("PING"); err != nil {
        return nil, err
    }else{
        return &Server{redisPool: redisPool, ttl: config.BaseDNS.Ttl}, nil
    }
}

config.go用于读取配置,加载配置文件方法

  • load_config通过toml加载配置文件,返回config对象
  • reload_config获取新的配置,然后通过更改Server对象属性的指针实现热加载
package main

import (
    "github.com/BurntSushi/toml"
    "github.com/sirupsen/logrus"
)


type tomlConfig struct {
    Redis redis
    BaseDNS baseDNS
}

type redis struct {
    Host string
    Port int
    MaxActive int
    MaxIdle int
    IdleTimeout int
}

type baseDNS struct {
    Ttl int
}

func load_config() (*tomlConfig, error) {
    var config *tomlConfig
    filePath := "./dns.toml"
    _, err := toml.DecodeFile(filePath, &config)
    return config, err
}

func reload_config(server *Server) {
    config, err := load_config();
    if err != nil {
        log.WithFields(logrus.Fields{"error": err.Error()}).Error("get config error, reloading abort")
    } else {
        redisPool := getRedisPool(config)
        redisClient := redisPool.Get()
        defer redisClient.Close()
        if _, err := redisClient.Do("PING"); err != nil {
            log.WithFields(logrus.Fields{"error": err.Error()}).Error("get redis connection error, reloading abort")
        } else {
            // 通过加锁的方式修改对象属性的指针
            server.mu.Lock()
            server.redisPool = redisPool
            server.mu.Unlock()
            log.Info("reload config succeeded")
        }
    }
}

main.go

package main

import (
    "net"
    "runtime"
    "syscall"
    "os"
    "os/signal"
    "github.com/miekg/dns"
    "github.com/sirupsen/logrus"
)

var (
    log = logrus.New()
)

func main() {

    runtime.GOMAXPROCS(runtime.NumCPU())

    config, err := load_config();
    if err != nil {
        log.WithFields(logrus.Fields{"error": err.Error()}).Fatal("load config error")
    } else { 
        server, err := NewServer(config)
        if err != nil {
            log.WithFields(logrus.Fields{"error": err.Error()}).Fatal("setup server error")
        } else {
            udpAddr, err := net.ResolveUDPAddr("udp", ":53")
            if err != nil {
                log.WithFields(logrus.Fields{"error": err.Error()}).Fatal("resolve udp address error")
            }
            p, err := net.ListenUDP("udp", udpAddr)
            if err != nil {
                log.WithFields(logrus.Fields{"error": err.Error()}).Fatal("listen udp error")
            }
            tcpAddr, err := net.ResolveTCPAddr("tcp", ":53")
            if err != nil {
                log.WithFields(logrus.Fields{"error": err.Error()}).Fatal("resolve tcp address error")
            }
            l, err := net.ListenTCP("tcp", tcpAddr)

            go dns.ActivateAndServe(l, p, server)

            sig := make(chan os.Signal)
            signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM)
            for {
                s := <-sig
                switch s {
                // 对于SIGHUP信号执行热加载配置
                case syscall.SIGHUP:
                    log.WithFields(logrus.Fields{"signal": s}).Info("start reloading config")
                    go reload_config(server)
                case syscall.SIGINT, syscall.SIGTERM:
                    log.WithFields(logrus.Fields{"signal": s}).Fatal("exiting")
                default:
                    log.WithFields(logrus.Fields{"signal": s}).Warn("unrecognized signal")
                }
            }
        }
    }
}

handler.go这边添加了查库返回对应域名解析和对没有解析的域名的错误返回

package main

import (
    "net"
    "fmt"
    "github.com/miekg/dns"
    "github.com/sirupsen/logrus"
    redigo "github.com/gomodule/redigo/redis"
)

// 没有域名解析
func  (s *Server) HandleFailed(w dns.ResponseWriter, r *dns.Msg) {
    m := new(dns.Msg)
    m.SetRcode(r, dns.RcodeServerFailure)
    w.WriteMsg(m)
}

// 解析域名
func (s *Server) queryDomain(domain string) ([]string, int, error) {
    redisClient := s.redisPool.Get()
    defer redisClient.Close()
    ARecordKey := fmt.Sprintf("%s:%s", domain, "A")
    ttlKey := fmt.Sprintf("%s:%s", domain, "ttl")
    // 获取域名解析的IP地址列表
    addresses, err := redigo.Strings(redisClient.Do("LRANGE", ARecordKey, 0, -1))
    if err != nil {
        log.WithFields(logrus.Fields{"ARecordKey": ARecordKey, "error": err.Error()}).Warn("get A record from redis failed")
        return nil, 0, err
    } else {
        // 获取域名TTL
        log.WithFields(logrus.Fields{"addresses": addresses}).Info("query domain addresses")
        ttl, err := redigo.Int(redisClient.Do("GET", ttlKey))
        log.WithFields(logrus.Fields{"ttl": ttl, "err": err}).Info("get ttl")
        // 没设置TTL的使用默认TTL
        if err != nil {
            ttl = s.ttl
        } 
        return addresses, ttl, nil
    }
}

func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
    if r == nil || len(r.Question) == 0 {
        return
    }
    // 获取解析域名
    name := r.Question[0].Name
    log.WithFields(logrus.Fields{"Domain": name}).Info("query domain")
    // 获取解析类型
    qtype := r.Question[0].Qtype
    if qtype == dns.TypeA  {
        domain := name
        // 去掉域名最后的"."
        if domain[len(domain)-1] == '.' {
            domain = domain[0:len(domain)-1]
        }
        // 解析域名
        addresses, ttl, err := s.queryDomain(domain)
        if err != nil {
            s.HandleFailed(w, r)
        }

        m := new(dns.Msg)
        m.SetReply(r)
        m.Authoritative = true

        // 构建返回结果
        var dnsRR []dns.RR
        for _, address := range addresses {
            rr := new(dns.A)
            rr.Hdr = dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: uint32(ttl)}
            rr.A = net.ParseIP(address)
            dnsRR = append(dnsRR, rr)
        }
        m.Answer = dnsRR
        w.WriteMsg(m)
    } else {
        log.WithFields(logrus.Fields{"Qtype": qtype}).Error("wrong query type")
        s.HandleFailed(w, r)
    }
}

dns.toml

[redis]
host = "172.31.51.204"
port = 6379
maxActive = 100
maxIdle = 50
idleTimeout = 30

[baseDNS]
ttl = 60

添加测试数据

127.0.0.1:6379> RPUSH www.whysdomain.com:A 192.168.1.1
(integer) 1
127.0.0.1:6379> RPUSH www.whysdomain.com:A 192.168.1.2
(integer) 2
127.0.0.1:6379> RPUSH blog.whysdomain.com:A 192.168.2.1
(integer) 1
127.0.0.1:6379> RPUSH blog.whysdomain.com:A 192.168.2.2
(integer) 2
127.0.0.1:6379> SET www.whysdomain.com:ttl 7200
OK
127.0.0.1:6379> LRANGE www.whysdomain.com:A 0 -1
1) "192.168.1.1"
2) "192.168.1.2"
127.0.0.1:6379> LRANGE blog.whysdomain.com:A 0 -1
1) "192.168.2.1"
2) "192.168.2.2"
127.0.0.1:6379> GET www.whysdomain.com:ttl
"7200"

按照代码,www.whysdomain.com解析到的IP地址为192.168.1.1和192.168.1.2,TTL为7200,而blog.whysdomain.com解析到的IP地址为192.168.2.1和192.168.2.2,TTL为默认的3600

$ dig www.whysdomain.com

; <<>> DiG 9.9.4-RedHat-9.9.4-51.el7_4.1 <<>> www.whysdomain.com
;; global options: +cmd
;; Got answer:
;; ->>HEADER<<- opcode: QUERY, status: NOERROR, id: 33525
;; flags: qr aa rd; QUERY: 1, ANSWER: 2, AUTHORITY: 0, ADDITIONAL: 0
;; WARNING: recursion requested but not available

;; QUESTION SECTION:
;www.whysdomain.com.        IN  A

;; ANSWER SECTION:
www.whysdomain.com. 7200    IN  A   192.168.1.1
www.whysdomain.com. 7200    IN  A   192.168.1.2

;; Query time: 0 msec
;; SERVER: 127.0.0.1#53(127.0.0.1)
;; WHEN: Thu Jul 25 19:50:31 CST 2019
;; MSG SIZE  rcvd: 104

$ dig blog.whysdomain.com

; <<>> DiG 9.9.4-RedHat-9.9.4-51.el7_4.1 <<>> blog.whysdomain.com
;; global options: +cmd
;; Got answer:
;; ->>HEADER<<- opcode: QUERY, status: SERVFAIL, id: 17207
;; flags: qr rd; QUERY: 1, ANSWER: 0, AUTHORITY: 0, ADDITIONAL: 0
;; WARNING: recursion requested but not available

;; QUESTION SECTION:
;blog.whysdomain.com.       IN  A

;; Query time: 0 msec
;; SERVER: 127.0.0.1#53(127.0.0.1)
;; WHEN: Thu Jul 25 19:50:34 CST 2019
;; MSG SIZE  rcvd: 37

余下说明

  • 作为内网DNS,解析量应该是很大的,进程文件打开数要设为65535
  • 定期备份解析到其他开源DNS服务器防止有BUG
  • 其他解析等也可以单独写处理逻辑
  • 可以写一套程序用于添加修改删除DNS解析,或者直接在这个程序上添加接口