Files
servicebase/pkg/utils/limit.go
2025-11-18 17:48:20 +08:00

89 lines
3.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package utils
import (
"sync"
"time"
"github.com/anxpp/beego/context"
"golang.org/x/time/rate"
)
// Global Rate Limiter Configuration
// 使用 rate.NewLimiter 创建一个全局限流器。
// rate.Every(5*time.Second/500) 设置了令牌的生成速率,每 5 秒生成 500 个令牌,即每秒 100 个。500 是桶的容量,允许在短时间内处理 500 个突发请求。
const (
globalBurst = 500 // Example: A burst capacity of 500
globalDuration = 5 * time.Second
)
// User Rate Limiter Configuration
// 使用 sync.Map 存储每个用户的限流器。sync.Map 是并发安全的,适合在多个 goroutine 中存取数据。getUserLimiter 函数负责获取或为新用户创建限流器。
const (
userBurst = 20 // Example: A burst capacity of 20
userDuration = 3 * time.Second
)
// Global Limiter
var globalLimiter *rate.Limiter
func init() {
// Initialize global rate limiter with a 5-second interval
globalLimiter = rate.NewLimiter(rate.Every(globalDuration/time.Duration(globalBurst)), globalBurst)
}
// User-specific limiters
var (
userLimiters sync.Map
mu sync.Mutex
)
// getUserLimiter retrieves or creates a user-specific rate limiter.
func getUserLimiter(userID string) *rate.Limiter {
mu.Lock()
defer mu.Unlock()
limiter, ok := userLimiters.Load(userID)
if !ok {
// Create a new limiter for the user
newLimiter := rate.NewLimiter(rate.Every(userDuration/time.Duration(userBurst)), userBurst)
userLimiters.Store(userID, newLimiter)
return newLimiter
}
return limiter.(*rate.Limiter)
}
// RateLimitFilter is a Beego filter for rate limiting.
// 过滤器逻辑 (RateLimitFilter)
// 首先调用 globalLimiter.Allow() 检查是否超过了全局限制。如果超出,直接返回 429 错误。
// 然后,获取用户 ID。在实际应用中你可能从 ctx.Input.Header、ctx.Input.Session 或其他地方获取。
// 通过 getUserLimiter 获取该用户的限流器,并调用 userLimiter.Allow() 检查是否超过了用户限制。如果超出,返回 429 错误。
func RateLimitFilter(ctx *context.Context) {
// --- Global Rate Limiting ---
if !globalLimiter.Allow() {
ctx.ResponseWriter.WriteHeader(429) // 429 Too Many Requests
// ctx.ResponseWriter.Write([]byte("429 Too Many Requests - Global Limit Exceeded"))
ctx.ResponseWriter.Write([]byte("429 当前服务器请求过多"))
return
}
// --- Per-User Rate Limiting ---
// In a real application, you would get the user ID from the session, JWT, or request header.
// For this example, we'll use a hardcoded value or a placeholder.
// You can replace this with your actual logic to get the user ID.
userID := ctx.Input.Param("x-token") // Example: Get user ID from URL parameter
if userID == "" {
// Handle cases where the user ID is not available.
userID = "anonymous"
}
userLimiter := getUserLimiter(userID)
if !userLimiter.Allow() {
ctx.ResponseWriter.WriteHeader(429) // 429 Too Many Requests
// ctx.ResponseWriter.Write([]byte(fmt.Sprintf("429 Too Many Requests - User %s Limit Exceeded", userID)))
ctx.ResponseWriter.Write([]byte("429 当前用户请求过多"))
return
}
}