89 lines
3.1 KiB
Go
89 lines
3.1 KiB
Go
package utils
|
||
|
||
import (
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/anxpp/beego/context"
|
||
|
||
"golang.org/x/time/rate"
|
||
)
|
||
|
||
// Global Rate Limiter Configuration
|
||
// 使用 rate.NewLimiter 创建一个全局限流器。
|
||
// rate.Every(5*time.Second/500) 设置了令牌的生成速率,每 5 秒生成 500 个令牌,即每秒 100 个。500 是桶的容量,允许在短时间内处理 500 个突发请求。
|
||
const (
|
||
globalBurst = 500 // Example: A burst capacity of 500
|
||
globalDuration = 5 * time.Second
|
||
)
|
||
|
||
// User Rate Limiter Configuration
|
||
// 使用 sync.Map 存储每个用户的限流器。sync.Map 是并发安全的,适合在多个 goroutine 中存取数据。getUserLimiter 函数负责获取或为新用户创建限流器。
|
||
const (
|
||
userBurst = 20 // Example: A burst capacity of 20
|
||
userDuration = 3 * time.Second
|
||
)
|
||
|
||
// Global Limiter
|
||
var globalLimiter *rate.Limiter
|
||
|
||
func init() {
|
||
// Initialize global rate limiter with a 5-second interval
|
||
globalLimiter = rate.NewLimiter(rate.Every(globalDuration/time.Duration(globalBurst)), globalBurst)
|
||
}
|
||
|
||
// User-specific limiters
|
||
var (
|
||
userLimiters sync.Map
|
||
mu sync.Mutex
|
||
)
|
||
|
||
// getUserLimiter retrieves or creates a user-specific rate limiter.
|
||
func getUserLimiter(userID string) *rate.Limiter {
|
||
mu.Lock()
|
||
defer mu.Unlock()
|
||
|
||
limiter, ok := userLimiters.Load(userID)
|
||
if !ok {
|
||
// Create a new limiter for the user
|
||
newLimiter := rate.NewLimiter(rate.Every(userDuration/time.Duration(userBurst)), userBurst)
|
||
userLimiters.Store(userID, newLimiter)
|
||
return newLimiter
|
||
}
|
||
|
||
return limiter.(*rate.Limiter)
|
||
}
|
||
|
||
// RateLimitFilter is a Beego filter for rate limiting.
|
||
// 过滤器逻辑 (RateLimitFilter):
|
||
// 首先调用 globalLimiter.Allow() 检查是否超过了全局限制。如果超出,直接返回 429 错误。
|
||
// 然后,获取用户 ID。在实际应用中,你可能从 ctx.Input.Header、ctx.Input.Session 或其他地方获取。
|
||
// 通过 getUserLimiter 获取该用户的限流器,并调用 userLimiter.Allow() 检查是否超过了用户限制。如果超出,返回 429 错误。
|
||
func RateLimitFilter(ctx *context.Context) {
|
||
// --- Global Rate Limiting ---
|
||
if !globalLimiter.Allow() {
|
||
ctx.ResponseWriter.WriteHeader(429) // 429 Too Many Requests
|
||
// ctx.ResponseWriter.Write([]byte("429 Too Many Requests - Global Limit Exceeded"))
|
||
ctx.ResponseWriter.Write([]byte("429 当前服务器请求过多"))
|
||
return
|
||
}
|
||
|
||
// --- Per-User Rate Limiting ---
|
||
// In a real application, you would get the user ID from the session, JWT, or request header.
|
||
// For this example, we'll use a hardcoded value or a placeholder.
|
||
// You can replace this with your actual logic to get the user ID.
|
||
userID := ctx.Input.Param("x-token") // Example: Get user ID from URL parameter
|
||
if userID == "" {
|
||
// Handle cases where the user ID is not available.
|
||
userID = "anonymous"
|
||
}
|
||
|
||
userLimiter := getUserLimiter(userID)
|
||
if !userLimiter.Allow() {
|
||
ctx.ResponseWriter.WriteHeader(429) // 429 Too Many Requests
|
||
// ctx.ResponseWriter.Write([]byte(fmt.Sprintf("429 Too Many Requests - User %s Limit Exceeded", userID)))
|
||
ctx.ResponseWriter.Write([]byte("429 当前用户请求过多"))
|
||
return
|
||
}
|
||
}
|