checker_test.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. package policy
  2. import (
  3. "errors"
  4. "net/http"
  5. "testing"
  6. )
  7. func TestRemoteAddrChecker(t *testing.T) {
  8. for _, tt := range []struct {
  9. err error
  10. name string
  11. ip string
  12. cidrs []string
  13. ok bool
  14. }{
  15. {
  16. name: "match_ipv4",
  17. cidrs: []string{"0.0.0.0/0"},
  18. ip: "1.1.1.1",
  19. ok: true,
  20. err: nil,
  21. },
  22. {
  23. name: "match_ipv4_in_ipv6",
  24. cidrs: []string{"0.0.0.0/0"},
  25. ip: "::ffff:1.1.1.1",
  26. ok: true,
  27. err: nil,
  28. },
  29. {
  30. name: "match_ipv4_in_ipv6_hex",
  31. cidrs: []string{"0.0.0.0/0"},
  32. ip: "::ffff:101:101",
  33. ok: true,
  34. err: nil,
  35. },
  36. {
  37. name: "match_ipv6",
  38. cidrs: []string{"::/0"},
  39. ip: "cafe:babe::",
  40. ok: true,
  41. err: nil,
  42. },
  43. {
  44. name: "not_match_ipv4",
  45. cidrs: []string{"1.1.1.1/32"},
  46. ip: "1.1.1.2",
  47. ok: false,
  48. err: nil,
  49. },
  50. {
  51. name: "not_match_ipv6",
  52. cidrs: []string{"cafe:babe::/128"},
  53. ip: "cafe:babe:4::/128",
  54. ok: false,
  55. err: nil,
  56. },
  57. {
  58. name: "no_ip_set",
  59. cidrs: []string{"::/0"},
  60. ok: false,
  61. err: ErrMisconfiguration,
  62. },
  63. {
  64. name: "invalid_ip",
  65. cidrs: []string{"::/0"},
  66. ip: "According to all natural laws of aviation",
  67. ok: false,
  68. err: ErrMisconfiguration,
  69. },
  70. } {
  71. t.Run(tt.name, func(t *testing.T) {
  72. rac, err := NewRemoteAddrChecker(tt.cidrs)
  73. if err != nil && !errors.Is(err, tt.err) {
  74. t.Fatalf("creating RemoteAddrChecker failed: %v", err)
  75. }
  76. r, err := http.NewRequest(http.MethodGet, "/", nil)
  77. if err != nil {
  78. t.Fatalf("can't make request: %v", err)
  79. }
  80. if tt.ip != "" {
  81. r.Header.Add("X-Real-Ip", tt.ip)
  82. }
  83. ok, err := rac.Check(r)
  84. if tt.ok != ok {
  85. t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
  86. }
  87. if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
  88. t.Errorf("err: %v, wanted: %v", err, tt.err)
  89. }
  90. })
  91. }
  92. }
  93. func TestHeaderMatchesChecker(t *testing.T) {
  94. for _, tt := range []struct {
  95. err error
  96. name string
  97. header string
  98. rexStr string
  99. reqHeaderKey string
  100. reqHeaderValue string
  101. ok bool
  102. }{
  103. {
  104. name: "match",
  105. header: "Cf-Worker",
  106. rexStr: ".*",
  107. reqHeaderKey: "Cf-Worker",
  108. reqHeaderValue: "true",
  109. ok: true,
  110. err: nil,
  111. },
  112. {
  113. name: "not_match",
  114. header: "Cf-Worker",
  115. rexStr: "false",
  116. reqHeaderKey: "Cf-Worker",
  117. reqHeaderValue: "true",
  118. ok: false,
  119. err: nil,
  120. },
  121. {
  122. name: "not_present",
  123. header: "Cf-Worker",
  124. rexStr: "foobar",
  125. reqHeaderKey: "Something-Else",
  126. reqHeaderValue: "true",
  127. ok: false,
  128. err: nil,
  129. },
  130. {
  131. name: "invalid_regex",
  132. rexStr: "a(b",
  133. err: ErrMisconfiguration,
  134. },
  135. } {
  136. t.Run(tt.name, func(t *testing.T) {
  137. hmc, err := NewHeaderMatchesChecker(tt.header, tt.rexStr)
  138. if err != nil && !errors.Is(err, tt.err) {
  139. t.Fatalf("creating HeaderMatchesChecker failed")
  140. }
  141. if tt.err != nil && hmc == nil {
  142. return
  143. }
  144. r, err := http.NewRequest(http.MethodGet, "/", nil)
  145. if err != nil {
  146. t.Fatalf("can't make request: %v", err)
  147. }
  148. r.Header.Set(tt.reqHeaderKey, tt.reqHeaderValue)
  149. ok, err := hmc.Check(r)
  150. if tt.ok != ok {
  151. t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
  152. }
  153. if err != nil && tt.err != nil && !errors.Is(err, tt.err) {
  154. t.Errorf("err: %v, wanted: %v", err, tt.err)
  155. }
  156. })
  157. }
  158. }
  159. func TestHeaderExistsChecker(t *testing.T) {
  160. for _, tt := range []struct {
  161. name string
  162. header string
  163. reqHeader string
  164. ok bool
  165. }{
  166. {
  167. name: "match",
  168. header: "Authorization",
  169. reqHeader: "Authorization",
  170. ok: true,
  171. },
  172. {
  173. name: "not_match",
  174. header: "Authorization",
  175. reqHeader: "Authentication",
  176. },
  177. } {
  178. t.Run(tt.name, func(t *testing.T) {
  179. hec := headerExistsChecker{tt.header}
  180. r, err := http.NewRequest(http.MethodGet, "/", nil)
  181. if err != nil {
  182. t.Fatalf("can't make request: %v", err)
  183. }
  184. r.Header.Set(tt.reqHeader, "hunter2")
  185. ok, err := hec.Check(r)
  186. if tt.ok != ok {
  187. t.Errorf("ok: %v, wanted: %v", ok, tt.ok)
  188. }
  189. if err != nil {
  190. t.Errorf("err: %v", err)
  191. }
  192. })
  193. }
  194. }
  195. func TestPathChecker_XOriginalURI(t *testing.T) {
  196. tests := []struct {
  197. name string
  198. regex string
  199. xOriginalURI string
  200. urlPath string
  201. headerKey string
  202. expectedMatch bool
  203. expectError bool
  204. }{
  205. {
  206. name: "X-Original-URI matches regex (with trailing space - current typo)",
  207. regex: "^/api/.*",
  208. xOriginalURI: "/api/users",
  209. urlPath: "/different/path",
  210. headerKey: "X-Original-URI",
  211. expectedMatch: true,
  212. expectError: false,
  213. },
  214. {
  215. name: "X-Original-URI doesn't match, falls back to URL.Path",
  216. regex: "^/admin/.*",
  217. xOriginalURI: "/api/users",
  218. urlPath: "/admin/dashboard",
  219. headerKey: "X-Original-URI",
  220. expectedMatch: true,
  221. expectError: false,
  222. },
  223. {
  224. name: "Neither X-Original-URI nor URL.Path match",
  225. regex: "^/admin/.*",
  226. xOriginalURI: "/api/users",
  227. urlPath: "/public/info",
  228. headerKey: "X-Original-URI ",
  229. expectedMatch: false,
  230. expectError: false,
  231. },
  232. {
  233. name: "Empty X-Original-URI, URL.Path matches",
  234. regex: "^/static/.*",
  235. xOriginalURI: "",
  236. urlPath: "/static/css/style.css",
  237. headerKey: "X-Original-URI",
  238. expectedMatch: true,
  239. expectError: false,
  240. },
  241. {
  242. name: "Complex regex matching X-Original-URI",
  243. regex: `^/api/v[0-9]+/(users|posts)/[0-9]+$`,
  244. xOriginalURI: "/api/v1/users/123",
  245. urlPath: "/different",
  246. headerKey: "X-Original-URI",
  247. expectedMatch: true,
  248. expectError: false,
  249. },
  250. }
  251. for _, tt := range tests {
  252. t.Run(tt.name, func(t *testing.T) {
  253. // Create the PathChecker
  254. pc, err := NewPathChecker(tt.regex)
  255. if err != nil {
  256. if !tt.expectError {
  257. t.Fatalf("NewPathChecker() unexpected error: %v", err)
  258. }
  259. return
  260. }
  261. if tt.expectError {
  262. t.Fatal("NewPathChecker() expected error but got none")
  263. }
  264. req, err := http.NewRequest("GET", "http://example.com"+tt.urlPath, nil)
  265. if err != nil {
  266. t.Fatalf("Failed to create request: %v", err)
  267. }
  268. if tt.xOriginalURI != "" {
  269. req.Header.Set(tt.headerKey, tt.xOriginalURI)
  270. }
  271. match, err := pc.Check(req)
  272. if err != nil {
  273. t.Fatalf("Check() unexpected error: %v", err)
  274. }
  275. if match != tt.expectedMatch {
  276. t.Errorf("Check() = %v, want %v", match, tt.expectedMatch)
  277. }
  278. })
  279. }
  280. }