checker.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. package policy
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/http"
  6. "net/netip"
  7. "regexp"
  8. "strings"
  9. "github.com/TecharoHQ/anubis/internal"
  10. "github.com/TecharoHQ/anubis/lib/policy/checker"
  11. "github.com/gaissmai/bart"
  12. )
  13. var (
  14. ErrMisconfiguration = errors.New("[unexpected] policy: administrator misconfiguration")
  15. )
  16. type RemoteAddrChecker struct {
  17. prefixTable *bart.Lite
  18. hash string
  19. }
  20. func NewRemoteAddrChecker(cidrs []string) (checker.Impl, error) {
  21. table := new(bart.Lite)
  22. for _, cidr := range cidrs {
  23. prefix, err := netip.ParsePrefix(cidr)
  24. if err != nil {
  25. return nil, fmt.Errorf("%w: range %s not parsing: %w", ErrMisconfiguration, cidr, err)
  26. }
  27. table.Insert(prefix)
  28. }
  29. return &RemoteAddrChecker{
  30. prefixTable: table,
  31. hash: internal.FastHash(strings.Join(cidrs, ",")),
  32. }, nil
  33. }
  34. func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) {
  35. host := r.Header.Get("X-Real-Ip")
  36. if host == "" {
  37. return false, fmt.Errorf("%w: header X-Real-Ip is not set", ErrMisconfiguration)
  38. }
  39. addr, err := netip.ParseAddr(host)
  40. if err != nil {
  41. return false, fmt.Errorf("%w: %s is not an IP address: %w", ErrMisconfiguration, host, err)
  42. }
  43. // Convert IPv4-mapped IPv6 addresses to IPv4
  44. if addr.Is6() && addr.Is4In6() {
  45. addr = addr.Unmap()
  46. }
  47. return rac.prefixTable.Contains(addr), nil
  48. }
  49. func (rac *RemoteAddrChecker) Hash() string {
  50. return rac.hash
  51. }
  52. type HeaderMatchesChecker struct {
  53. header string
  54. regexp *regexp.Regexp
  55. hash string
  56. }
  57. func NewUserAgentChecker(rexStr string) (checker.Impl, error) {
  58. return NewHeaderMatchesChecker("User-Agent", rexStr)
  59. }
  60. func NewHeaderMatchesChecker(header, rexStr string) (checker.Impl, error) {
  61. rex, err := regexp.Compile(strings.TrimSpace(rexStr))
  62. if err != nil {
  63. return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err)
  64. }
  65. return &HeaderMatchesChecker{strings.TrimSpace(header), rex, internal.FastHash(header + ": " + rexStr)}, nil
  66. }
  67. func (hmc *HeaderMatchesChecker) Check(r *http.Request) (bool, error) {
  68. if hmc.regexp.MatchString(r.Header.Get(hmc.header)) {
  69. return true, nil
  70. }
  71. return false, nil
  72. }
  73. func (hmc *HeaderMatchesChecker) Hash() string {
  74. return hmc.hash
  75. }
  76. type PathChecker struct {
  77. regexp *regexp.Regexp
  78. hash string
  79. }
  80. func NewPathChecker(rexStr string) (checker.Impl, error) {
  81. rex, err := regexp.Compile(strings.TrimSpace(rexStr))
  82. if err != nil {
  83. return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err)
  84. }
  85. return &PathChecker{rex, internal.FastHash(rexStr)}, nil
  86. }
  87. func (pc *PathChecker) Check(r *http.Request) (bool, error) {
  88. originalUrl := r.Header.Get("X-Original-URI")
  89. if originalUrl != "" {
  90. if pc.regexp.MatchString(originalUrl) {
  91. return true, nil
  92. }
  93. }
  94. if pc.regexp.MatchString(r.URL.Path) {
  95. return true, nil
  96. }
  97. return false, nil
  98. }
  99. func (pc *PathChecker) Hash() string {
  100. return pc.hash
  101. }
  102. func NewHeaderExistsChecker(key string) checker.Impl {
  103. return headerExistsChecker{strings.TrimSpace(key)}
  104. }
  105. type headerExistsChecker struct {
  106. header string
  107. }
  108. func (hec headerExistsChecker) Check(r *http.Request) (bool, error) {
  109. if r.Header.Get(hec.header) != "" {
  110. return true, nil
  111. }
  112. return false, nil
  113. }
  114. func (hec headerExistsChecker) Hash() string {
  115. return internal.FastHash(hec.header)
  116. }
  117. func NewHeadersChecker(headermap map[string]string) (checker.Impl, error) {
  118. var result checker.List
  119. var errs []error
  120. for key, rexStr := range headermap {
  121. if rexStr == ".*" {
  122. result = append(result, headerExistsChecker{strings.TrimSpace(key)})
  123. continue
  124. }
  125. rex, err := regexp.Compile(strings.TrimSpace(rexStr))
  126. if err != nil {
  127. errs = append(errs, fmt.Errorf("while compiling header %s regex %s: %w", key, rexStr, err))
  128. continue
  129. }
  130. result = append(result, &HeaderMatchesChecker{key, rex, internal.FastHash(key + ": " + rexStr)})
  131. }
  132. if len(errs) != 0 {
  133. return nil, errors.Join(errs...)
  134. }
  135. return result, nil
  136. }