dns_test.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. package dns
  2. import (
  3. "context"
  4. "errors"
  5. "net"
  6. "reflect"
  7. "testing"
  8. "time"
  9. "github.com/TecharoHQ/anubis/lib/store/memory"
  10. )
  11. // newTestDNS is a helper function to create a new Dns object with an in-memory cache for testing.
  12. func newTestDNS(forwardTTL int, reverseTTL int) *Dns {
  13. ctx := context.Background()
  14. memStore := memory.New(ctx)
  15. cache := NewDNSCache(forwardTTL, reverseTTL, memStore)
  16. return New(ctx, cache)
  17. }
  18. // mockLookupAddr is a mock implementation of the net.LookupAddr function.
  19. func mockLookupAddr(addr string) ([]string, error) {
  20. switch addr {
  21. case "8.8.8.8":
  22. return []string{"dns.google."}, nil
  23. case "1.1.1.1":
  24. return []string{"one.one.one.one."}, nil
  25. case "208.67.222.222":
  26. return []string{"resolver1.opendns.com."}, nil
  27. case "9.9.9.9":
  28. return nil, &net.DNSError{Err: "no such host", Name: "9.9.9.9", IsNotFound: true}
  29. case "1.2.3.4":
  30. return nil, errors.New("unknown error")
  31. default:
  32. return nil, &net.DNSError{Err: "no such host", Name: addr, IsNotFound: true}
  33. }
  34. }
  35. // mockLookupHost is a mock implementation of the net.LookupHost function.
  36. func mockLookupHost(host string) ([]string, error) {
  37. switch host {
  38. case "dns.google":
  39. return []string{"8.8.8.8", "8.8.4.4"}, nil
  40. case "one.one.one.one":
  41. return []string{"1.1.1.1", "1.0.0.1"}, nil
  42. case "resolver1.opendns.com":
  43. return []string{"208.67.222.222"}, nil
  44. case "example.com":
  45. return nil, &net.DNSError{Err: "no such host", Name: "example.com", IsNotFound: true}
  46. default:
  47. return nil, &net.DNSError{Err: "no such host", Name: host, IsNotFound: true}
  48. }
  49. }
  50. func TestMain(m *testing.M) {
  51. // Before all tests
  52. originalLookupAddr := DNSLookupAddr
  53. originalLookupHost := DNSLookupHost
  54. DNSLookupAddr = mockLookupAddr
  55. DNSLookupHost = mockLookupHost
  56. // Run tests
  57. exitCode := m.Run()
  58. // After all tests
  59. DNSLookupAddr = originalLookupAddr
  60. DNSLookupHost = originalLookupHost
  61. // Exit
  62. if exitCode != 0 {
  63. panic(exitCode)
  64. }
  65. }
  66. func TestDns_ArpaReverseIP(t *testing.T) {
  67. d := newTestDNS(0, 0)
  68. tests := []struct {
  69. name string
  70. ip string
  71. want string
  72. wantErr bool
  73. }{
  74. {"ipv4", "192.0.2.1", "1.2.0.192", false},
  75. {"ipv6", "2001:db8::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2", false},
  76. {"invalid ip", "invalid", "invalid", true},
  77. {"ipv4-mapped ipv6", "::ffff:192.0.2.1", "1.2.0.192", false},
  78. }
  79. for _, tt := range tests {
  80. t.Run(tt.name, func(t *testing.T) {
  81. got, err := d.ArpaReverseIP(tt.ip)
  82. if (err != nil) != tt.wantErr {
  83. t.Errorf("ArpaReverseIP() error = %v, wantErr %v", err, tt.wantErr)
  84. return
  85. }
  86. if got != tt.want {
  87. t.Errorf("ArpaReverseIP() = %v, want %v", got, tt.want)
  88. }
  89. })
  90. }
  91. }
  92. func TestDns_ReverseDNS(t *testing.T) {
  93. d := newTestDNS(1, 1) // short TTL for testing cache
  94. // First call - cache miss
  95. t.Run("cache miss", func(t *testing.T) {
  96. got, err := d.ReverseDNS("8.8.8.8")
  97. if err != nil {
  98. t.Fatalf("ReverseDNS() error = %v", err)
  99. }
  100. want := []string{"dns.google"}
  101. if !reflect.DeepEqual(got, want) {
  102. t.Errorf("ReverseDNS() = %v, want %v", got, want)
  103. }
  104. })
  105. // Second call - cache hit
  106. t.Run("cache hit", func(t *testing.T) {
  107. // Temporarily replace lookup function to ensure cache is used
  108. originalLookupAddr := DNSLookupAddr
  109. DNSLookupAddr = func(addr string) ([]string, error) {
  110. return nil, errors.New("should not be called")
  111. }
  112. defer func() { DNSLookupAddr = originalLookupAddr }()
  113. got, err := d.ReverseDNS("8.8.8.8")
  114. if err != nil {
  115. t.Fatalf("ReverseDNS() error = %v", err)
  116. }
  117. want := []string{"dns.google"}
  118. if !reflect.DeepEqual(got, want) {
  119. t.Errorf("ReverseDNS() = %v, want %v", got, want)
  120. }
  121. })
  122. // Test cache expiration
  123. t.Run("cache expiration", func(t *testing.T) {
  124. time.Sleep(2 * time.Second)
  125. // Now the cache should be expired
  126. // We expect the mock to be called again
  127. // To test this we will change the mock to return something different
  128. originalLookupAddr := DNSLookupAddr
  129. DNSLookupAddr = func(addr string) ([]string, error) {
  130. if addr == "8.8.8.8" {
  131. return []string{"expired.google."}, nil
  132. }
  133. return mockLookupAddr(addr)
  134. }
  135. defer func() { DNSLookupAddr = originalLookupAddr }()
  136. got, err := d.ReverseDNS("8.8.8.8")
  137. if err != nil {
  138. t.Fatalf("ReverseDNS() error = %v", err)
  139. }
  140. want := []string{"expired.google"}
  141. if !reflect.DeepEqual(got, want) {
  142. t.Errorf("ReverseDNS() = %v, want %v", got, want)
  143. }
  144. })
  145. // Test not found
  146. t.Run("not found", func(t *testing.T) {
  147. got, err := d.ReverseDNS("9.9.9.9")
  148. if err != nil {
  149. t.Fatalf("ReverseDNS() error = %v", err)
  150. }
  151. if len(got) != 0 {
  152. t.Errorf("ReverseDNS() = %v, want empty slice", got)
  153. }
  154. })
  155. }
  156. func TestDns_LookupHost(t *testing.T) {
  157. d := newTestDNS(1, 1)
  158. t.Run("cache miss", func(t *testing.T) {
  159. got, err := d.LookupHost("dns.google")
  160. if err != nil {
  161. t.Fatalf("LookupHost() error = %v", err)
  162. }
  163. want := []string{"8.8.8.8", "8.8.4.4"}
  164. if !reflect.DeepEqual(got, want) {
  165. t.Errorf("LookupHost() = %v, want %v", got, want)
  166. }
  167. })
  168. t.Run("cache hit", func(t *testing.T) {
  169. originalLookupHost := DNSLookupHost
  170. DNSLookupHost = func(host string) ([]string, error) {
  171. return nil, errors.New("should not be called")
  172. }
  173. defer func() { DNSLookupHost = originalLookupHost }()
  174. got, err := d.LookupHost("dns.google")
  175. if err != nil {
  176. t.Fatalf("LookupHost() error = %v", err)
  177. }
  178. want := []string{"8.8.8.8", "8.8.4.4"}
  179. if !reflect.DeepEqual(got, want) {
  180. t.Errorf("LookupHost() = %v, want %v", got, want)
  181. }
  182. })
  183. t.Run("cache expiration", func(t *testing.T) {
  184. time.Sleep(2 * time.Second)
  185. originalLookupHost := DNSLookupHost
  186. DNSLookupHost = func(host string) ([]string, error) {
  187. if host == "dns.google" {
  188. return []string{"9.9.9.9"}, nil
  189. }
  190. return mockLookupHost(host)
  191. }
  192. defer func() { DNSLookupHost = originalLookupHost }()
  193. got, err := d.LookupHost("dns.google")
  194. if err != nil {
  195. t.Fatalf("LookupHost() error = %v", err)
  196. }
  197. want := []string{"9.9.9.9"}
  198. if !reflect.DeepEqual(got, want) {
  199. t.Errorf("LookupHost() = %v, want %v", got, want)
  200. }
  201. })
  202. t.Run("not found", func(t *testing.T) {
  203. got, err := d.LookupHost("example.com")
  204. if err != nil {
  205. t.Fatalf("LookupHost() error = %v", err)
  206. }
  207. if len(got) != 0 {
  208. t.Errorf("LookupHost() = %v, want empty slice", got)
  209. }
  210. })
  211. }
  212. func TestDns_VerifyFCrDNS(t *testing.T) {
  213. d := newTestDNS(1, 1)
  214. // Helper to convert string to *string
  215. p := func(s string) *string {
  216. return &s
  217. }
  218. tests := []struct {
  219. name string
  220. ip string
  221. pattern *string
  222. want bool
  223. }{
  224. // Cases without pattern
  225. {"valid no pattern", "8.8.8.8", nil, true},
  226. {"valid partial no pattern", "1.1.1.1", nil, true},
  227. {"not found no pattern", "9.9.9.9", nil, true},
  228. {"unknown error no pattern", "1.2.3.4", nil, false},
  229. // Cases with pattern
  230. {"valid match", "8.8.8.8", p(`.*\.google$`), true},
  231. {"valid no match", "8.8.8.8", p(`\.com$`), false},
  232. {"not found with pattern", "9.9.9.9", p(".*"), false},
  233. {"unknown error with pattern", "1.2.3.4", p(".*"), false},
  234. {"invalid pattern", "8.8.8.8", p(`[`), false},
  235. }
  236. for _, tt := range tests {
  237. t.Run(tt.name, func(t *testing.T) {
  238. if got := d.VerifyFCrDNS(tt.ip, tt.pattern); got != tt.want {
  239. t.Errorf("VerifyFCrDNS() = %v, want %v", got, tt.want)
  240. }
  241. })
  242. }
  243. t.Run("reverse cache hit", func(t *testing.T) {
  244. // Prime the cache
  245. if got := d.VerifyFCrDNS("8.8.8.8", nil); got != true {
  246. t.Fatalf("VerifyFCrDNS() priming failed, got %v, want true", got)
  247. }
  248. // Now test with a failing lookup to ensure cache is used
  249. originalLookupAddr := DNSLookupAddr
  250. DNSLookupAddr = func(addr string) ([]string, error) {
  251. return nil, errors.New("should not be called")
  252. }
  253. defer func() { DNSLookupAddr = originalLookupAddr }()
  254. if got := d.VerifyFCrDNS("8.8.8.8", nil); got != true {
  255. t.Errorf("VerifyFCrDNS() = %v, want true", got)
  256. }
  257. })
  258. t.Run("forward cache hit", func(t *testing.T) {
  259. // Prime the cache
  260. if got := d.VerifyFCrDNS("8.8.8.8", nil); got != true {
  261. t.Fatalf("VerifyFCrDNS() priming failed, got %v, want true", got)
  262. }
  263. // Now test with a failing lookup to ensure cache is used
  264. originalLookupHost := DNSLookupHost
  265. DNSLookupHost = func(host string) ([]string, error) {
  266. return nil, errors.New("should not be called")
  267. }
  268. defer func() { DNSLookupHost = originalLookupHost }()
  269. if got := d.VerifyFCrDNS("8.8.8.8", nil); got != true {
  270. t.Errorf("VerifyFCrDNS() = %v, want true", got)
  271. }
  272. })
  273. }