first commit
This commit is contained in:
88
pkg/utils/limit.go
Normal file
88
pkg/utils/limit.go
Normal file
@ -0,0 +1,88 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/anxpp/beego/context"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// Global Rate Limiter Configuration
|
||||
// 使用 rate.NewLimiter 创建一个全局限流器。
|
||||
// rate.Every(5*time.Second/500) 设置了令牌的生成速率,每 5 秒生成 500 个令牌,即每秒 100 个。500 是桶的容量,允许在短时间内处理 500 个突发请求。
|
||||
const (
|
||||
globalBurst = 500 // Example: A burst capacity of 500
|
||||
globalDuration = 5 * time.Second
|
||||
)
|
||||
|
||||
// User Rate Limiter Configuration
|
||||
// 使用 sync.Map 存储每个用户的限流器。sync.Map 是并发安全的,适合在多个 goroutine 中存取数据。getUserLimiter 函数负责获取或为新用户创建限流器。
|
||||
const (
|
||||
userBurst = 20 // Example: A burst capacity of 20
|
||||
userDuration = 3 * time.Second
|
||||
)
|
||||
|
||||
// Global Limiter
|
||||
var globalLimiter *rate.Limiter
|
||||
|
||||
func init() {
|
||||
// Initialize global rate limiter with a 5-second interval
|
||||
globalLimiter = rate.NewLimiter(rate.Every(globalDuration/time.Duration(globalBurst)), globalBurst)
|
||||
}
|
||||
|
||||
// User-specific limiters
|
||||
var (
|
||||
userLimiters sync.Map
|
||||
mu sync.Mutex
|
||||
)
|
||||
|
||||
// getUserLimiter retrieves or creates a user-specific rate limiter.
|
||||
func getUserLimiter(userID string) *rate.Limiter {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
limiter, ok := userLimiters.Load(userID)
|
||||
if !ok {
|
||||
// Create a new limiter for the user
|
||||
newLimiter := rate.NewLimiter(rate.Every(userDuration/time.Duration(userBurst)), userBurst)
|
||||
userLimiters.Store(userID, newLimiter)
|
||||
return newLimiter
|
||||
}
|
||||
|
||||
return limiter.(*rate.Limiter)
|
||||
}
|
||||
|
||||
// RateLimitFilter is a Beego filter for rate limiting.
|
||||
// 过滤器逻辑 (RateLimitFilter):
|
||||
// 首先调用 globalLimiter.Allow() 检查是否超过了全局限制。如果超出,直接返回 429 错误。
|
||||
// 然后,获取用户 ID。在实际应用中,你可能从 ctx.Input.Header、ctx.Input.Session 或其他地方获取。
|
||||
// 通过 getUserLimiter 获取该用户的限流器,并调用 userLimiter.Allow() 检查是否超过了用户限制。如果超出,返回 429 错误。
|
||||
func RateLimitFilter(ctx *context.Context) {
|
||||
// --- Global Rate Limiting ---
|
||||
if !globalLimiter.Allow() {
|
||||
ctx.ResponseWriter.WriteHeader(429) // 429 Too Many Requests
|
||||
// ctx.ResponseWriter.Write([]byte("429 Too Many Requests - Global Limit Exceeded"))
|
||||
ctx.ResponseWriter.Write([]byte("429 当前服务器请求过多"))
|
||||
return
|
||||
}
|
||||
|
||||
// --- Per-User Rate Limiting ---
|
||||
// In a real application, you would get the user ID from the session, JWT, or request header.
|
||||
// For this example, we'll use a hardcoded value or a placeholder.
|
||||
// You can replace this with your actual logic to get the user ID.
|
||||
userID := ctx.Input.Param("x-token") // Example: Get user ID from URL parameter
|
||||
if userID == "" {
|
||||
// Handle cases where the user ID is not available.
|
||||
userID = "anonymous"
|
||||
}
|
||||
|
||||
userLimiter := getUserLimiter(userID)
|
||||
if !userLimiter.Allow() {
|
||||
ctx.ResponseWriter.WriteHeader(429) // 429 Too Many Requests
|
||||
// ctx.ResponseWriter.Write([]byte(fmt.Sprintf("429 Too Many Requests - User %s Limit Exceeded", userID)))
|
||||
ctx.ResponseWriter.Write([]byte("429 当前用户请求过多"))
|
||||
return
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user