fix: use net.SplitHostPort for rate limit key and add IPv6 tests
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package httpx
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -107,19 +108,12 @@ func rateLimitKey(r *http.Request) string {
|
||||
}
|
||||
return fwd
|
||||
}
|
||||
// Strip port from RemoteAddr
|
||||
// Strip port from RemoteAddr using net.SplitHostPort for correct IPv6 handling.
|
||||
addr := r.RemoteAddr
|
||||
if idx := lastIndexByte(addr, ':'); idx > 0 {
|
||||
return addr[:idx]
|
||||
if host, _, err := net.SplitHostPort(addr); err == nil {
|
||||
return host
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
func lastIndexByte(s string, c byte) int {
|
||||
for i := len(s) - 1; i >= 0; i-- {
|
||||
if s[i] == c {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
|
||||
@@ -144,3 +144,22 @@ func TestRateLimiter_WithRateLimit_XForwardedFor(t *testing.T) {
|
||||
t.Errorf("different IP: expected 200, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitKey_IPv6(t *testing.T) {
|
||||
tests := []struct {
|
||||
remoteAddr string
|
||||
want string
|
||||
}{
|
||||
{"[::1]:8080", "::1"},
|
||||
{"[2001:db8::1]:8080", "2001:db8::1"},
|
||||
{"192.168.1.1:1234", "192.168.1.1"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
got := rateLimitKey(req)
|
||||
if got != tt.want {
|
||||
t.Errorf("rateLimitKey(%q) = %q, want %q", tt.remoteAddr, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user