diff --git a/context.go b/context.go index 138110d..c555f5c 100644 --- a/context.go +++ b/context.go @@ -9,10 +9,13 @@ const ( requestLimitKey ) +// WithIncrement sets the increment value in the context. func WithIncrement(ctx context.Context, value int) context.Context { return context.WithValue(ctx, incrementKey, value) } +// getIncrement gets the increment value from the context, which was set by +// [WithIncrement]. func getIncrement(ctx context.Context) int { if value, ok := ctx.Value(incrementKey).(int); ok { return value @@ -20,10 +23,13 @@ func getIncrement(ctx context.Context) int { return 1 } +// WithRequestLimit sets the request limit in the context. func WithRequestLimit(ctx context.Context, value int) context.Context { return context.WithValue(ctx, requestLimitKey, value) } +// getRequestLimit gets the request limit from the context, which was set by +// [WithRequestLimit]. func getRequestLimit(ctx context.Context) int { if value, ok := ctx.Value(requestLimitKey).(int); ok { return value diff --git a/go.mod b/go.mod index 90b49fa..3b8a098 100644 --- a/go.mod +++ b/go.mod @@ -4,11 +4,12 @@ go 1.23.0 toolchain go1.24.1 -require github.com/zeebo/xxh3 v1.0.2 - -require golang.org/x/sys v0.30.0 // indirect +require ( + github.com/zeebo/xxh3 v1.0.2 + golang.org/x/sync v0.12.0 +) require ( github.com/klauspost/cpuid/v2 v2.2.10 // indirect - golang.org/x/sync v0.12.0 + golang.org/x/sys v0.30.0 // indirect ) diff --git a/httprate.go b/httprate.go index 75a438d..ee6d810 100644 --- a/httprate.go +++ b/httprate.go @@ -7,40 +7,71 @@ import ( "time" ) +// Limit creates a new [net/http] middleware that limits requests by the given +// request limit and window length. The returned middleware will call the next +// handler if the request limit is not exceeded. func Limit(requestLimit int, windowLength time.Duration, options ...Option) func(next http.Handler) http.Handler { return NewRateLimiter(requestLimit, windowLength, options...).Handler } +// KeyFunc is a function that derives a key for the given request. type KeyFunc func(r *http.Request) (string, error) + +// Option is a function that configures the rate limiter. type Option func(rl *RateLimiter) -// Set custom response headers. If empty, the header is omitted. +// ResponseHeaders defines custom response headers. If empty, the header is omitted. type ResponseHeaders struct { - Limit string // Default: X-RateLimit-Limit - Remaining string // Default: X-RateLimit-Remaining - Increment string // Default: X-RateLimit-Increment - Reset string // Default: X-RateLimit-Reset - RetryAfter string // Default: Retry-After -} - + // Limit is the total number of requests that are permitted before the rate limit + // is exceeded. Default: "X-RateLimit-Limit". + Limit string + // Remaining is the number of requests remaining before the rate limit is + // exceeded. Default: "X-RateLimit-Remaining". + Remaining string + // Increment is the number of requests incremented by the rate limiter. Default: + // "X-RateLimit-Increment". + Increment string + // Reset is the time at which the rate limit will be reset. Default: + // "X-RateLimit-Reset". + Reset string + // RetryAfter is the time in seconds after which the rate limit will be reset. + // Default: "Retry-After". + RetryAfter string +} + +// LimitAll is a shortcut for [Limit] which uses a shared default key, resulting in +// a single rate-limiter for all requests. func LimitAll(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler { return Limit(requestLimit, windowLength) } +// LimitByIP is a shortcut for [Limit] with the key function set to [KeyByIP], +// returning a new [net/http] middleware that limits requests by IP address. func LimitByIP(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler { return Limit(requestLimit, windowLength, WithKeyFuncs(KeyByIP)) } +// LimitByRealIP is a shortcut for [Limit] with the key function set to [KeyByRealIP], +// returning a new [net/http] middleware that limits requests by real IP address. func LimitByRealIP(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler { return Limit(requestLimit, windowLength, WithKeyFuncs(KeyByRealIP)) } +// LimitByEndpoint is a shortcut for [Limit] with the key function set to [KeyByEndpoint], +// returning a new [net/http] middleware that limits requests by endpoint. +func LimitByEndpoint(requestLimit int, windowLength time.Duration) func(next http.Handler) http.Handler { + return Limit(requestLimit, windowLength, WithKeyFuncs(KeyByEndpoint)) +} + +// Key returns a key function that always returns the specified key. func Key(key string) func(r *http.Request) (string, error) { return func(r *http.Request) (string, error) { return key, nil } } +// KeyByIP uses the canonicalized remote address, [net/http.Request.RemoteAddr], +// to get the IP address. func KeyByIP(r *http.Request) (string, error) { ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { @@ -49,6 +80,9 @@ func KeyByIP(r *http.Request) (string, error) { return canonicalizeIP(ip), nil } +// KeyByRealIP uses the "True-Client-IP", "X-Real-IP", and "X-Forwarded-For" +// headers (in that order of precedence) to get the IP address, after canonicalizing. +// If none of the headers are present, the remote address is used. func KeyByRealIP(r *http.Request) (string, error) { var ip string @@ -73,10 +107,12 @@ func KeyByRealIP(r *http.Request) (string, error) { return canonicalizeIP(ip), nil } +// KeyByEndpoint uses the URL path, [net/url.URL.Path] as the key. func KeyByEndpoint(r *http.Request) (string, error) { return r.URL.Path, nil } +// WithKeyFuncs composes multiple key functions into a single key. func WithKeyFuncs(keyFuncs ...KeyFunc) Option { return func(rl *RateLimiter) { if len(keyFuncs) > 0 { @@ -85,42 +121,76 @@ func WithKeyFuncs(keyFuncs ...KeyFunc) Option { } } +// WithKeyByIP is an option which sets the key function to [KeyByIP]. func WithKeyByIP() Option { return WithKeyFuncs(KeyByIP) } +// WithKeyByRealIP is an option which sets the key function to [KeyByRealIP]. func WithKeyByRealIP() Option { return WithKeyFuncs(KeyByRealIP) } +// WithKeyByEndpoint is an option which sets the key function to [KeyByEndpoint]. +func WithKeyByEndpoint() Option { + return WithKeyFuncs(KeyByEndpoint) +} + +// WithLimitHandler is an option which sets the limit handler to the given +// [http.HandlerFunc]. If not set, the default limit handler is used. func WithLimitHandler(h http.HandlerFunc) Option { return func(rl *RateLimiter) { rl.onRateLimited = h } } +// WithErrorHandler is an option which sets the error handler to the given +// function. If not set, the default error handler is used. func WithErrorHandler(h func(http.ResponseWriter, *http.Request, error)) Option { return func(rl *RateLimiter) { rl.onError = h } } +// WithLimitCounter is an option which sets the limit counter to the given +// [LimitCounter]. If not set, the default [LocalLimitCounter] is used. func WithLimitCounter(c LimitCounter) Option { return func(rl *RateLimiter) { rl.limitCounter = c } } +// WithResponseHeaders is an option which sets the response headers to the given +// [ResponseHeaders]. If not set, the default response headers are used. func WithResponseHeaders(headers ResponseHeaders) Option { return func(rl *RateLimiter) { rl.headers = headers } } +// WithNoop is an option which does nothing. func WithNoop() Option { return func(rl *RateLimiter) {} } +// Skip is a middleware that allows the rate limiter headers to be applied onto a +// request, without actually including the request in the rate limit. Use this for +// endpoints that can be used for checking the rate limit, without affecting the +// rate limit. NOTE: This MUST be loaded in your middleware stack before the rate +// limiter. +// +// Example: +// +// rl := httprate.Limit(100, time.Minute) +// r.With(rl).Get(...) // Will be rate limited. +// r.With(httprate.Skip, rl).Get(...) // Will not be rate limited, but still sets appropriate headers. +func Skip(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r.WithContext(WithIncrement(r.Context(), 0))) + }) +} + +// composedKeyFunc composes multiple key functions into a single key. func composedKeyFunc(keyFuncs ...KeyFunc) KeyFunc { return func(r *http.Request) (string, error) { var key strings.Builder @@ -151,11 +221,10 @@ func canonicalizeIP(ip string) string { case ':': // IPv6 isIPv6 = true - break } } if !isIPv6 { - // Not an IP address at all + // Not an IP address at all. return ip } diff --git a/httprate_test.go b/httprate_test.go index 8fa7d30..202f184 100644 --- a/httprate_test.go +++ b/httprate_test.go @@ -1,6 +1,11 @@ package httprate -import "testing" +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) func Test_canonicalizeIP(t *testing.T) { tests := []struct { @@ -57,3 +62,46 @@ func Test_canonicalizeIP(t *testing.T) { }) } } + +func TestSkip(t *testing.T) { + window := time.Minute + limit := 3 + + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + t.Run("without skip exhausts limit", func(t *testing.T) { + limited := LimitAll(limit, window)(inner) + want := []int{http.StatusOK, http.StatusOK, http.StatusOK, http.StatusTooManyRequests} + for i, wantCode := range want { + rec := httptest.NewRecorder() + limited.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) + if got := rec.Code; got != wantCode { + t.Fatalf("request %d: status = %d, want %d", i, got, wantCode) + } + } + }) + + t.Run("with skip does not count toward limit", func(t *testing.T) { + limited := LimitAll(limit, window)(inner) + skipped := Skip(limited) + n := limit + 10 + for i := range n { + rec := httptest.NewRecorder() + skipped.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) + if got := rec.Code; got != http.StatusOK { + t.Fatalf("request %d: status = %d, want %d", i, got, http.StatusOK) + } + } + }) + + t.Run("with skip still sets rate limit headers", func(t *testing.T) { + limited := LimitAll(limit, window)(inner) + rec := httptest.NewRecorder() + Skip(limited).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) + if got := rec.Header().Get("X-RateLimit-Limit"); got != "3" { + t.Errorf("X-RateLimit-Limit = %q, want %q", got, "3") + } + }) +} diff --git a/limit_key.go b/limit_key.go index 7e9b32a..c20cdcf 100644 --- a/limit_key.go +++ b/limit_key.go @@ -7,9 +7,10 @@ import ( "github.com/zeebo/xxh3" ) +// LimitCounterKey computes a hash key for the given key and window. func LimitCounterKey(key string, window time.Time) uint64 { h := xxh3.New() - h.WriteString(key) - h.WriteString(strconv.FormatInt(window.Unix(), 10)) + _, _ = h.WriteString(key) + _, _ = h.WriteString(strconv.FormatInt(window.Unix(), 10)) return h.Sum64() } diff --git a/limiter.go b/limiter.go index 6a62009..80b2e50 100644 --- a/limiter.go +++ b/limiter.go @@ -8,6 +8,9 @@ import ( "time" ) +// LimitCounter is an interface that defines the methods for a rate limit counter. +// It is used to store and retrieve the rate limit counter for a given key and +// window. A default implementation is provided by [NewLocalLimitCounter]. type LimitCounter interface { Config(requestLimit int, windowLength time.Duration) Increment(key string, currentWindow time.Time) error @@ -15,6 +18,9 @@ type LimitCounter interface { Get(key string, currentWindow, previousWindow time.Time) (int, int, error) } +// NewRateLimiter creates a new [RateLimiter] with the given request limit and +// window length. The returned rate limiter will use the default [LocalLimitCounter] +// implementation, if not overridden. func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *RateLimiter { rl := &RateLimiter{ requestLimit: requestLimit, @@ -125,10 +131,12 @@ func (l *RateLimiter) RespondOnLimit(w http.ResponseWriter, r *http.Request, key return onLimit } +// Counter returns the limit counter used by the rate limiter. func (l *RateLimiter) Counter() LimitCounter { return l.limitCounter } +// Status returns the current status of the rate limiter for the given key. func (l *RateLimiter) Status(key string) (bool, float64, error) { return l.calculateRate(key, l.requestLimit) } @@ -149,6 +157,8 @@ func (l *RateLimiter) Handler(next http.Handler) http.Handler { }) } +// calculateRate calculates the rate for the given key and request limit. It does +// not increment the counter. func (l *RateLimiter) calculateRate(key string, requestLimit int) (bool, float64, error) { now := time.Now().UTC() currentWindow := now.Truncate(l.windowLength) diff --git a/limiter_test.go b/limiter_test.go index 5ac41c1..6008859 100644 --- a/limiter_test.go +++ b/limiter_test.go @@ -163,7 +163,7 @@ func TestResponseHeaders(t *testing.T) { h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) router := httprate.LimitAll(tt.requestsLimit, time.Minute)(h) - for i := 0; i < count; i++ { + for i := range count { req := httptest.NewRequest("GET", "/", nil) req = req.WithContext(httprate.WithIncrement(req.Context(), tt.increments[i])) recorder := httptest.NewRecorder() @@ -224,7 +224,7 @@ func TestCustomResponseHeaders(t *testing.T) { 1, time.Minute, httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "Wow Slow Down Kiddo", 429) + http.Error(w, "Wow Slow Down Kiddo", http.StatusTooManyRequests) }), httprate.WithResponseHeaders(tt.headers), )(h) @@ -316,7 +316,7 @@ func TestLimitHandler(t *testing.T) { tt.requestsLimit, tt.windowLength, httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "Wow Slow Down Kiddo", 429) + http.Error(w, "Wow Slow Down Kiddo", http.StatusTooManyRequests) }), )(h) @@ -329,7 +329,7 @@ func TestLimitHandler(t *testing.T) { t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respStatus, expected.StatusCode) } buf := new(bytes.Buffer) - buf.ReadFrom(result.Body) + _, _ = buf.ReadFrom(result.Body) respBody := strings.TrimSuffix(buf.String(), "\n") if respBody != expected.Body { @@ -395,7 +395,7 @@ func TestOverrideRequestLimit(t *testing.T) { 3, time.Minute, httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "Wow Slow Down Kiddo", 429) + http.Error(w, "Wow Slow Down Kiddo", http.StatusTooManyRequests) }), )(h) @@ -458,7 +458,7 @@ func TestRateLimitPayload(t *testing.T) { return } - w.Write([]byte("login at 5 req/min\n")) + _, _ = w.Write([]byte("login at 5 req/min\n")) }) responses := []struct { diff --git a/local_counter.go b/local_counter.go index 37cd083..39de5ce 100644 --- a/local_counter.go +++ b/local_counter.go @@ -7,8 +7,8 @@ import ( "github.com/zeebo/xxh3" ) -// NewLocalLimitCounter creates an instance of localCounter, -// which is an in-memory implementation of http.LimitCounter. +// NewLocalLimitCounter creates an instance of localCounter, which is an in-memory +// implementation of [LimitCounter]. // // All methods are guaranteed to always return nil error. func NewLocalLimitCounter(windowLength time.Duration) *localCounter { @@ -38,7 +38,7 @@ func (c *localCounter) IncrementBy(key string, currentWindow time.Time, amount i hkey := limitCounterKey(key) - count, _ := c.latestCounters[hkey] + count := c.latestCounters[hkey] c.latestCounters[hkey] = count + amount return nil @@ -48,14 +48,14 @@ func (c *localCounter) Get(key string, currentWindow, previousWindow time.Time) c.mu.RLock() defer c.mu.RUnlock() - if c.latestWindow == currentWindow { - curr, _ := c.latestCounters[limitCounterKey(key)] - prev, _ := c.previousCounters[limitCounterKey(key)] + if c.latestWindow.Equal(currentWindow) { + curr := c.latestCounters[limitCounterKey(key)] + prev := c.previousCounters[limitCounterKey(key)] return curr, prev, nil } - if c.latestWindow == previousWindow { - prev, _ := c.latestCounters[limitCounterKey(key)] + if c.latestWindow.Equal(previousWindow) { + prev := c.latestCounters[limitCounterKey(key)] return 0, prev, nil } @@ -72,12 +72,12 @@ func (c *localCounter) Increment(key string, currentWindow time.Time) error { } func (c *localCounter) evict(currentWindow time.Time) { - if c.latestWindow == currentWindow { + if c.latestWindow.Equal(currentWindow) { return } previousWindow := currentWindow.Add(-c.windowLength) - if c.latestWindow == previousWindow { + if c.latestWindow.Equal(previousWindow) { c.latestWindow = currentWindow // Shift the windows without map re-allocation. clear(c.previousCounters) @@ -93,6 +93,6 @@ func (c *localCounter) evict(currentWindow time.Time) { func limitCounterKey(key string) uint64 { h := xxh3.New() - h.WriteString(key) + _, _ = h.WriteString(key) return h.Sum64() } diff --git a/local_counter_test.go b/local_counter_test.go index b3ee478..3d4d0da 100644 --- a/local_counter_test.go +++ b/local_counter_test.go @@ -103,8 +103,7 @@ func TestLocalCounter(t *testing.T) { if tt.incrBy > 0 { var g errgroup.Group - for i := 0; i < concurrentRequests; i++ { - i := i + for i := range concurrentRequests { g.Go(func() error { key := fmt.Sprintf("key:%v", i) return limitCounter.IncrementBy(key, currentWindow, tt.incrBy) @@ -116,8 +115,7 @@ func TestLocalCounter(t *testing.T) { } var g errgroup.Group - for i := 0; i < concurrentRequests; i++ { - i := i + for i := range concurrentRequests { g.Go(func() error { key := fmt.Sprintf("key:%v", i) curr, prev, err := limitCounter.Get(key, currentWindow, previousWindow) @@ -155,14 +153,14 @@ func BenchmarkLocalCounter(b *testing.B) { wg := sync.WaitGroup{} wg.Add(1000) - for i := 0; i < 1000; i++ { + for i := range 1000 { // Simulate concurrent requests with different rate-limit keys. - go func(i int) { + go func() { defer wg.Done() _, _, _ = limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow) _ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, rand.Intn(100)) - }(i) + }() } wg.Wait() }