main.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. package main
  2. import (
  3. "context"
  4. "flag"
  5. "log"
  6. "log/slog"
  7. "net"
  8. "net/http"
  9. "net/http/httputil"
  10. "net/url"
  11. "os"
  12. "path/filepath"
  13. "strings"
  14. "time"
  15. "github.com/TecharoHQ/anubis/internal"
  16. "github.com/facebookgo/flagenv"
  17. "github.com/google/uuid"
  18. )
  19. var (
  20. bind = flag.String("bind", ":3004", "port to listen on")
  21. certDir = flag.String("cert-dir", "/xe/pki", "where to read mounted certificates from")
  22. certFname = flag.String("cert-fname", "cert.pem", "certificate filename")
  23. keyFname = flag.String("key-fname", "key.pem", "key filename")
  24. proxyTo = flag.String("proxy-to", "http://localhost:5000", "where to reverse proxy to")
  25. slogLevel = flag.String("slog-level", "info", "logging level")
  26. )
  27. func main() {
  28. flagenv.Parse()
  29. flag.Parse()
  30. internal.InitSlog(*slogLevel)
  31. slog.Info("starting",
  32. "bind", *bind,
  33. "cert-dir", *certDir,
  34. "cert-fname", *certFname,
  35. "key-fname", *keyFname,
  36. "proxy-to", *proxyTo,
  37. )
  38. cert := filepath.Join(*certDir, *certFname)
  39. key := filepath.Join(*certDir, *keyFname)
  40. st, err := os.Stat(cert)
  41. if err != nil {
  42. slog.Error("can't stat cert file", "certFname", cert)
  43. os.Exit(1)
  44. }
  45. lastModified := st.ModTime()
  46. go func(lm time.Time) {
  47. t := time.NewTicker(time.Hour)
  48. defer t.Stop()
  49. for range t.C {
  50. st, err := os.Stat(cert)
  51. if err != nil {
  52. slog.Error("can't stat file", "fname", cert, "err", err)
  53. continue
  54. }
  55. if st.ModTime().After(lm) {
  56. slog.Info("new cert detected", "oldTime", lm.Format(time.RFC3339), "newTime", st.ModTime().Format(time.RFC3339))
  57. os.Exit(0)
  58. }
  59. }
  60. }(lastModified)
  61. u, err := url.Parse(*proxyTo)
  62. if err != nil {
  63. log.Fatal(err)
  64. }
  65. h := httputil.NewSingleHostReverseProxy(u)
  66. if u.Scheme == "unix" {
  67. slog.Info("using unix socket proxy")
  68. h = &httputil.ReverseProxy{
  69. Director: func(r *http.Request) {
  70. r.URL.Scheme = "http"
  71. r.URL.Host = r.Host
  72. r.Header.Set("X-Forwarded-Proto", "https")
  73. r.Header.Set("X-Forwarded-Scheme", "https")
  74. r.Header.Set("X-Request-Id", uuid.NewString())
  75. r.Header.Set("X-Scheme", "https")
  76. remoteHost, remotePort, err := net.SplitHostPort(r.Host)
  77. if err == nil {
  78. r.Header.Set("X-Forwarded-Host", remoteHost)
  79. r.Header.Set("X-Forwarded-Port", remotePort)
  80. } else {
  81. r.Header.Set("X-Forwarded-Host", r.Host)
  82. }
  83. host, _, err := net.SplitHostPort(r.RemoteAddr)
  84. if err == nil {
  85. r.Header.Set("X-Real-Ip", host)
  86. }
  87. },
  88. Transport: &http.Transport{
  89. DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
  90. return net.Dial("unix", strings.TrimPrefix(*proxyTo, "unix://"))
  91. },
  92. },
  93. }
  94. }
  95. log.Fatal(
  96. http.ListenAndServeTLS(
  97. *bind,
  98. cert,
  99. key,
  100. h,
  101. ),
  102. )
  103. }