From f960fbfabdc66c81171cc8b09562b7e426425de5 Mon Sep 17 00:00:00 2001 From: Stephane Bortzmeyer Date: Mon, 16 Apr 2018 14:18:50 +0200 Subject: [PATCH] Implements NSID (option -nsid). Closes #11 --- check-soa.go | 53 +++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 50 insertions(+), 3 deletions(-) diff --git a/check-soa.go b/check-soa.go index 31e0f7d..273149e 100644 --- a/check-soa.go +++ b/check-soa.go @@ -6,6 +6,7 @@ package main import ( + "encoding/hex" "errors" "flag" "fmt" @@ -42,11 +43,13 @@ type SOAreply struct { retrieved bool msg string rtt time.Duration + nsid string } type nameServer struct { name string ips []string + fnsid []string globalErrMsg string success []bool errMsg []string @@ -67,6 +70,7 @@ var ( version *bool quiet *bool noedns *bool + nsid *bool bufsize *int tcp *bool nodnssec *bool @@ -149,6 +153,17 @@ func soaQuery(mychan chan SOAreply, zone string, name string, server string) { if !*noedns { m.SetEdns0(uint16(*bufsize), !*nodnssec) } + if *nsid { + o := new(dns.OPT) + o.Hdr.Name = "." // MUST be the root zone, per definition. + o.Hdr.Rrtype = dns.TypeOPT + e := new(dns.EDNS0_NSID) + e.Code = dns.EDNS0NSID + e.Nsid = "" + o.Option = append(o.Option, e) + m.Extra = make([]dns.RR, 1) + m.Extra[0] = o + } m.Id = dns.Id() if *recursion { m.RecursionDesired = true @@ -174,6 +189,18 @@ func soaQuery(mychan chan SOAreply, zone string, name string, server string) { result.msg = fmt.Sprintf("%s", err.Error()) } else { result.rtt = rtt + if *nsid { + for n := range soa.Extra { + if soa.Extra[n].Header().Rrtype == dns.TypeOPT { + for m := range soa.Extra[n].(*dns.OPT).Option { + switch e := soa.Extra[n].(*dns.OPT).Option[m].(type) { + case *dns.EDNS0_NSID: + result.nsid = e.Nsid + } + } + } + } + } if soa.Rcode != dns.RcodeSuccess { result.msg = dns.RcodeToString[soa.Rcode] break @@ -288,13 +315,22 @@ func masterTask(zone string, nameservers map[string]nameServer) (uint, uint, boo } soaResult := <-soaChannel _, present := results[soaResult.name] + fnsid := make([]byte, 0) + if *nsid { + fnsid = make([]byte, hex.DecodedLen(len(soaResult.nsid))) + n, err := hex.Decode(fnsid, []byte(soaResult.nsid)) + if err != nil || n != hex.DecodedLen(len(soaResult.nsid)) { + fnsid = []byte("ERROR IN DECODING") + } + } if !present { results[soaResult.name] = nameServer{name: soaResult.name, ips: make([]string, 0), success: make([]bool, 0), errMsg: make([]string, 0), serial: make([]uint32, 0), - rtts: make([]time.Duration, 0)} + rtts: make([]time.Duration, 0), + fnsid: make([]string, 0)} } if !soaResult.retrieved { results[soaResult.name] = nameServer{name: soaResult.name, @@ -302,7 +338,8 @@ func masterTask(zone string, nameservers map[string]nameServer) (uint, uint, boo success: append(results[soaResult.name].success, false), errMsg: append(results[soaResult.name].errMsg, fmt.Sprintf("%s", soaResult.msg)), serial: append(results[soaResult.name].serial, 0), - rtts: append(results[soaResult.name].rtts, soaResult.rtt)} + rtts: append(results[soaResult.name].rtts, soaResult.rtt), + fnsid: append(results[soaResult.name].fnsid, string(fnsid))} success = false } else { results[soaResult.name] = nameServer{name: soaResult.name, @@ -310,7 +347,8 @@ func masterTask(zone string, nameservers map[string]nameServer) (uint, uint, boo success: append(results[soaResult.name].success, true), errMsg: append(results[soaResult.name].errMsg, ""), serial: append(results[soaResult.name].serial, soaResult.serial), - rtts: append(results[soaResult.name].rtts, soaResult.rtt)} + rtts: append(results[soaResult.name].rtts, soaResult.rtt), + fnsid: append(results[soaResult.name].fnsid, string(fnsid))} } } for name := range nameservers { @@ -338,6 +376,7 @@ func main() { version = flag.Bool("v", false, "Displays version of the code") quiet = flag.Bool("q", false, "Quiet mode, display only errors") noedns = flag.Bool("r", false, "Disable EDNS format") + nsid = flag.Bool("nsid", false, "Enable NSID option") bufsize = flag.Int("b", int(EDNSBUFFERSIZE), "EDNS buffer size") tcp = flag.Bool("tcp", false, "Use TCP") // DNSSEC DO is on by default, to detect firewall or @@ -359,6 +398,11 @@ func main() { flag.Usage() os.Exit(1) } + if *noedns && *nsid { + fmt.Fprintf(os.Stderr, "NSID requires EDNS\n") + flag.Usage() + os.Exit(1) + } if *v4only && *v6only { fmt.Fprintf(os.Stderr, "v4-only or v6-only but not both\n") flag.Usage() @@ -477,6 +521,9 @@ func main() { if *times && result.rtts[i] != 0 { msg = msg + fmt.Sprintf(" (%d ms)", int(float64(result.rtts[i])/1e6)) } + if *nsid && result.fnsid[i] != "" { + msg = msg + fmt.Sprintf(" (NSID %s)", result.fnsid[i]) + } if !*quiet || !result.success[i] { fmt.Printf("\t%s: %s: %s\n", result.ips[i], code, msg) }