fix: use net.SplitHostPort for rate limit key and add IPv6 tests
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package httpx
|
package httpx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -107,19 +108,12 @@ func rateLimitKey(r *http.Request) string {
|
|||||||
}
|
}
|
||||||
return fwd
|
return fwd
|
||||||
}
|
}
|
||||||
// Strip port from RemoteAddr
|
// Strip port from RemoteAddr using net.SplitHostPort for correct IPv6 handling.
|
||||||
addr := r.RemoteAddr
|
addr := r.RemoteAddr
|
||||||
if idx := lastIndexByte(addr, ':'); idx > 0 {
|
if host, _, err := net.SplitHostPort(addr); err == nil {
|
||||||
return addr[:idx]
|
return host
|
||||||
}
|
}
|
||||||
return addr
|
return addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func lastIndexByte(s string, c byte) int {
|
|
||||||
for i := len(s) - 1; i >= 0; i-- {
|
|
||||||
if s[i] == c {
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -144,3 +144,22 @@ func TestRateLimiter_WithRateLimit_XForwardedFor(t *testing.T) {
|
|||||||
t.Errorf("different IP: expected 200, got %d", rec.Code)
|
t.Errorf("different IP: expected 200, got %d", rec.Code)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRateLimitKey_IPv6(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
remoteAddr string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"[::1]:8080", "::1"},
|
||||||
|
{"[2001:db8::1]:8080", "2001:db8::1"},
|
||||||
|
{"192.168.1.1:1234", "192.168.1.1"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.RemoteAddr = tt.remoteAddr
|
||||||
|
got := rateLimitKey(req)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("rateLimitKey(%q) = %q, want %q", tt.remoteAddr, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user