581 lines
16 KiB
Go
581 lines
16 KiB
Go
package wxpay_utility
|
||
|
||
import (
|
||
"bytes"
|
||
"crypto"
|
||
"crypto/aes"
|
||
"crypto/cipher"
|
||
"crypto/rand"
|
||
"crypto/rsa"
|
||
"crypto/sha1"
|
||
"crypto/sha256"
|
||
"crypto/x509"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"encoding/pem"
|
||
"errors"
|
||
"fmt"
|
||
"hash"
|
||
"io"
|
||
"net/http"
|
||
"os"
|
||
"strconv"
|
||
"time"
|
||
|
||
"github.com/tjfoc/gmsm/sm3"
|
||
)
|
||
|
||
// MchConfig 商户信息配置,用于调用商户API
|
||
type MchConfig struct {
|
||
mchId string
|
||
certificateSerialNo string
|
||
privateKeyFilePath string
|
||
wechatPayPublicKeyId string
|
||
wechatPayPublicKeyFilePath string
|
||
privateKey *rsa.PrivateKey
|
||
wechatPayPublicKey *rsa.PublicKey
|
||
}
|
||
|
||
// MchId 商户号
|
||
func (c *MchConfig) MchId() string {
|
||
return c.mchId
|
||
}
|
||
|
||
// CertificateSerialNo 商户API证书序列号
|
||
func (c *MchConfig) CertificateSerialNo() string {
|
||
return c.certificateSerialNo
|
||
}
|
||
|
||
// PrivateKey 商户API证书对应的私钥
|
||
func (c *MchConfig) PrivateKey() *rsa.PrivateKey {
|
||
return c.privateKey
|
||
}
|
||
|
||
// WechatPayPublicKeyId 微信支付公钥ID
|
||
func (c *MchConfig) WechatPayPublicKeyId() string {
|
||
return c.wechatPayPublicKeyId
|
||
}
|
||
|
||
// WechatPayPublicKey 微信支付公钥
|
||
func (c *MchConfig) WechatPayPublicKey() *rsa.PublicKey {
|
||
return c.wechatPayPublicKey
|
||
}
|
||
|
||
// CreateMchConfig MchConfig 构造函数
|
||
func CreateMchConfig(
|
||
mchId string,
|
||
certificateSerialNo string,
|
||
privateKeyFilePath string,
|
||
wechatPayPublicKeyId string,
|
||
wechatPayPublicKeyFilePath string,
|
||
) (*MchConfig, error) {
|
||
mchConfig := &MchConfig{
|
||
mchId: mchId,
|
||
certificateSerialNo: certificateSerialNo,
|
||
privateKeyFilePath: privateKeyFilePath,
|
||
wechatPayPublicKeyId: wechatPayPublicKeyId,
|
||
wechatPayPublicKeyFilePath: wechatPayPublicKeyFilePath,
|
||
}
|
||
privateKey, err := LoadPrivateKeyWithPath(mchConfig.privateKeyFilePath)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
mchConfig.privateKey = privateKey
|
||
wechatPayPublicKey, err := LoadPublicKeyWithPath(mchConfig.wechatPayPublicKeyFilePath)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
mchConfig.wechatPayPublicKey = wechatPayPublicKey
|
||
return mchConfig, nil
|
||
}
|
||
|
||
// LoadPrivateKey 通过私钥的文本内容加载私钥
|
||
func LoadPrivateKey(privateKeyStr string) (privateKey *rsa.PrivateKey, err error) {
|
||
block, _ := pem.Decode([]byte(privateKeyStr))
|
||
if block == nil {
|
||
return nil, fmt.Errorf("decode private key err")
|
||
}
|
||
if block.Type != "PRIVATE KEY" {
|
||
return nil, fmt.Errorf("the kind of PEM should be PRVATE KEY")
|
||
}
|
||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("parse private key err:%s", err.Error())
|
||
}
|
||
privateKey, ok := key.(*rsa.PrivateKey)
|
||
if !ok {
|
||
return nil, fmt.Errorf("not a RSA private key")
|
||
}
|
||
return privateKey, nil
|
||
}
|
||
|
||
// LoadPublicKey 通过公钥的文本内容加载公钥
|
||
func LoadPublicKey(publicKeyStr string) (publicKey *rsa.PublicKey, err error) {
|
||
block, _ := pem.Decode([]byte(publicKeyStr))
|
||
if block == nil {
|
||
return nil, errors.New("decode public key error")
|
||
}
|
||
if block.Type != "PUBLIC KEY" {
|
||
return nil, fmt.Errorf("the kind of PEM should be PUBLIC KEY")
|
||
}
|
||
key, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("parse public key err:%s", err.Error())
|
||
}
|
||
publicKey, ok := key.(*rsa.PublicKey)
|
||
if !ok {
|
||
return nil, fmt.Errorf("%s is not rsa public key", publicKeyStr)
|
||
}
|
||
return publicKey, nil
|
||
}
|
||
|
||
// LoadPrivateKeyWithPath 通过私钥的文件路径内容加载私钥
|
||
func LoadPrivateKeyWithPath(path string) (privateKey *rsa.PrivateKey, err error) {
|
||
privateKeyBytes, err := os.ReadFile(path)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("read private pem file err:%s", err.Error())
|
||
}
|
||
return LoadPrivateKey(string(privateKeyBytes))
|
||
}
|
||
|
||
// LoadPublicKeyWithPath 通过公钥的文件路径加载公钥
|
||
func LoadPublicKeyWithPath(path string) (publicKey *rsa.PublicKey, err error) {
|
||
publicKeyBytes, err := os.ReadFile(path)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("read certificate pem file err:%s", err.Error())
|
||
}
|
||
return LoadPublicKey(string(publicKeyBytes))
|
||
}
|
||
|
||
// EncryptOAEPWithPublicKey 使用 OAEP padding方式用公钥进行加密
|
||
func EncryptOAEPWithPublicKey(message string, publicKey *rsa.PublicKey) (ciphertext string, err error) {
|
||
if publicKey == nil {
|
||
return "", fmt.Errorf("you should input *rsa.PublicKey")
|
||
}
|
||
ciphertextByte, err := rsa.EncryptOAEP(sha1.New(), rand.Reader, publicKey, []byte(message), nil)
|
||
if err != nil {
|
||
return "", fmt.Errorf("encrypt message with public key err:%s", err.Error())
|
||
}
|
||
ciphertext = base64.StdEncoding.EncodeToString(ciphertextByte)
|
||
return ciphertext, nil
|
||
}
|
||
|
||
// DecryptAES256GCM 使用 AEAD_AES_256_GCM 算法进行解密
|
||
//
|
||
// 可以使用此算法完成微信支付回调报文解密
|
||
func DecryptAES256GCM(aesKey, associatedData, nonce, ciphertext string) (plaintext string, err error) {
|
||
decodedCiphertext, err := base64.StdEncoding.DecodeString(ciphertext)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
c, err := aes.NewCipher([]byte(aesKey))
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
gcm, err := cipher.NewGCM(c)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
dataBytes, err := gcm.Open(nil, []byte(nonce), decodedCiphertext, []byte(associatedData))
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return string(dataBytes), nil
|
||
}
|
||
|
||
// SignSHA256WithRSA 通过私钥对字符串以 SHA256WithRSA 算法生成签名信息
|
||
func SignSHA256WithRSA(source string, privateKey *rsa.PrivateKey) (signature string, err error) {
|
||
if privateKey == nil {
|
||
return "", fmt.Errorf("private key should not be nil")
|
||
}
|
||
h := crypto.Hash.New(crypto.SHA256)
|
||
_, err = h.Write([]byte(source))
|
||
if err != nil {
|
||
return "", nil
|
||
}
|
||
hashed := h.Sum(nil)
|
||
signatureByte, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA256, hashed)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return base64.StdEncoding.EncodeToString(signatureByte), nil
|
||
}
|
||
|
||
// VerifySHA256WithRSA 通过公钥对字符串和签名结果以 SHA256WithRSA 验证签名有效性
|
||
func VerifySHA256WithRSA(source string, signature string, publicKey *rsa.PublicKey) error {
|
||
if publicKey == nil {
|
||
return fmt.Errorf("public key should not be nil")
|
||
}
|
||
|
||
sigBytes, err := base64.StdEncoding.DecodeString(signature)
|
||
if err != nil {
|
||
return fmt.Errorf("verify failed: signature is not base64 encoded")
|
||
}
|
||
hashed := sha256.Sum256([]byte(source))
|
||
err = rsa.VerifyPKCS1v15(publicKey, crypto.SHA256, hashed[:], sigBytes)
|
||
if err != nil {
|
||
return fmt.Errorf("verify signature with public key error:%s", err.Error())
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// GenerateNonce 生成一个长度为 NonceLength 的随机字符串(只包含大小写字母与数字)
|
||
func GenerateNonce() (string, error) {
|
||
const (
|
||
// NonceSymbols 随机字符串可用字符集
|
||
NonceSymbols = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||
// NonceLength 随机字符串的长度
|
||
NonceLength = 32
|
||
)
|
||
|
||
bytes := make([]byte, NonceLength)
|
||
_, err := rand.Read(bytes)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
symbolsByteLength := byte(len(NonceSymbols))
|
||
for i, b := range bytes {
|
||
bytes[i] = NonceSymbols[b%symbolsByteLength]
|
||
}
|
||
return string(bytes), nil
|
||
}
|
||
|
||
// BuildAuthorization 构建请求头中的 Authorization 信息
|
||
func BuildAuthorization(
|
||
mchid string,
|
||
certificateSerialNo string,
|
||
privateKey *rsa.PrivateKey,
|
||
method string,
|
||
canonicalURL string,
|
||
body []byte,
|
||
) (string, error) {
|
||
const (
|
||
SignatureMessageFormat = "%s\n%s\n%d\n%s\n%s\n" // 数字签名原文格式
|
||
// HeaderAuthorizationFormat 请求头中的 Authorization 拼接格式
|
||
HeaderAuthorizationFormat = "WECHATPAY2-SHA256-RSA2048 mchid=\"%s\",nonce_str=\"%s\",timestamp=\"%d\",serial_no=\"%s\",signature=\"%s\""
|
||
)
|
||
|
||
nonce, err := GenerateNonce()
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
timestamp := time.Now().Unix()
|
||
message := fmt.Sprintf(SignatureMessageFormat, method, canonicalURL, timestamp, nonce, body)
|
||
signature, err := SignSHA256WithRSA(message, privateKey)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
authorization := fmt.Sprintf(
|
||
HeaderAuthorizationFormat,
|
||
mchid, nonce, timestamp, certificateSerialNo, signature,
|
||
)
|
||
return authorization, nil
|
||
}
|
||
|
||
// ExtractResponseBody 提取应答报文的 Body
|
||
func ExtractResponseBody(response *http.Response) ([]byte, error) {
|
||
if response.Body == nil {
|
||
return nil, nil
|
||
}
|
||
|
||
body, err := io.ReadAll(response.Body)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("read response body err:[%s]", err.Error())
|
||
}
|
||
response.Body = io.NopCloser(bytes.NewBuffer(body))
|
||
return body, nil
|
||
}
|
||
|
||
const (
|
||
WechatPayTimestamp = "Wechatpay-Timestamp" // 微信支付回包时间戳
|
||
WechatPayNonce = "Wechatpay-Nonce" // 微信支付回包随机字符串
|
||
WechatPaySignature = "Wechatpay-Signature" // 微信支付回包签名信息
|
||
WechatPaySerial = "Wechatpay-Serial" // 微信支付回包平台序列号
|
||
RequestID = "Request-Id" // 微信支付回包请求ID
|
||
)
|
||
|
||
func validateWechatPaySignature(
|
||
wechatpayPublicKeyId string,
|
||
wechatpayPublicKey *rsa.PublicKey,
|
||
headers *http.Header,
|
||
body []byte,
|
||
) error {
|
||
timestampStr := headers.Get(WechatPayTimestamp)
|
||
serialNo := headers.Get(WechatPaySerial)
|
||
signature := headers.Get(WechatPaySignature)
|
||
nonce := headers.Get(WechatPayNonce)
|
||
|
||
// 拒绝过期请求
|
||
timestamp, err := strconv.ParseInt(timestampStr, 10, 64)
|
||
if err != nil {
|
||
return fmt.Errorf("invalid timestamp: %w", err)
|
||
}
|
||
if time.Now().Sub(time.Unix(timestamp, 0)) > 5*time.Minute {
|
||
return fmt.Errorf("timestamp expired: %d", timestamp)
|
||
}
|
||
|
||
if serialNo != wechatpayPublicKeyId {
|
||
return fmt.Errorf(
|
||
"serial-no mismatch: got %s, expected %s",
|
||
serialNo,
|
||
wechatpayPublicKeyId,
|
||
)
|
||
}
|
||
|
||
message := fmt.Sprintf("%s\n%s\n%s\n", timestampStr, nonce, body)
|
||
if err := VerifySHA256WithRSA(message, signature, wechatpayPublicKey); err != nil {
|
||
return fmt.Errorf("invalid signature: %v", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ValidateResponse 验证微信支付回包的签名信息
|
||
func ValidateResponse(
|
||
wechatpayPublicKeyId string,
|
||
wechatpayPublicKey *rsa.PublicKey,
|
||
headers *http.Header,
|
||
body []byte,
|
||
) error {
|
||
if err := validateWechatPaySignature(wechatpayPublicKeyId, wechatpayPublicKey, headers, body); err != nil {
|
||
return fmt.Errorf("validate response err: %w, RequestID: %s", err, headers.Get(RequestID))
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func validateNotification(
|
||
wechatpayPublicKeyId string,
|
||
wechatpayPublicKey *rsa.PublicKey,
|
||
headers *http.Header,
|
||
body []byte,
|
||
) error {
|
||
if err := validateWechatPaySignature(wechatpayPublicKeyId, wechatpayPublicKey, headers, body); err != nil {
|
||
return fmt.Errorf("validate notification err: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// Resource 微信支付通知请求中的资源数据
|
||
type Resource struct {
|
||
Algorithm string `json:"algorithm"`
|
||
Ciphertext string `json:"ciphertext"`
|
||
AssociatedData string `json:"associated_data"`
|
||
Nonce string `json:"nonce"`
|
||
OriginalType string `json:"original_type"`
|
||
}
|
||
|
||
// Notification 微信支付通知的数据结构
|
||
type Notification struct {
|
||
ID string `json:"id"`
|
||
CreateTime *time.Time `json:"create_time"`
|
||
EventType string `json:"event_type"`
|
||
ResourceType string `json:"resource_type"`
|
||
Resource *Resource `json:"resource"`
|
||
Summary string `json:"summary"`
|
||
|
||
Plaintext string // 解密后的业务数据(JSON字符串)
|
||
}
|
||
|
||
func (c *Notification) validate() error {
|
||
if c.Resource == nil {
|
||
return errors.New("resource is nil")
|
||
}
|
||
|
||
if c.Resource.Algorithm != "AEAD_AES_256_GCM" {
|
||
return fmt.Errorf("unsupported algorithm: %s", c.Resource.Algorithm)
|
||
}
|
||
|
||
if c.Resource.Ciphertext == "" {
|
||
return errors.New("ciphertext is empty")
|
||
}
|
||
|
||
if c.Resource.AssociatedData == "" {
|
||
return errors.New("associated_data is empty")
|
||
}
|
||
|
||
if c.Resource.Nonce == "" {
|
||
return errors.New("nonce is empty")
|
||
}
|
||
|
||
if c.Resource.OriginalType == "" {
|
||
return fmt.Errorf("original_type is empty")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (c *Notification) decrypt(apiv3Key string) error {
|
||
if err := c.validate(); err != nil {
|
||
return fmt.Errorf("notification format err: %w", err)
|
||
}
|
||
|
||
plaintext, err := DecryptAES256GCM(
|
||
apiv3Key,
|
||
c.Resource.AssociatedData,
|
||
c.Resource.Nonce,
|
||
c.Resource.Ciphertext,
|
||
)
|
||
if err != nil {
|
||
return fmt.Errorf("notification decrypt err: %w", err)
|
||
}
|
||
|
||
c.Plaintext = plaintext
|
||
return nil
|
||
}
|
||
|
||
// ParseNotification 解析微信支付通知的报文,返回通知中的业务数据
|
||
// Notification.PlainText 为解密后的业务数据JSON字符串,请自行反序列化后使用
|
||
func ParseNotification(
|
||
wechatpayPublicKeyId string,
|
||
wechatpayPublicKey *rsa.PublicKey,
|
||
apiv3Key string,
|
||
headers *http.Header,
|
||
body []byte,
|
||
) (*Notification, error) {
|
||
if err := validateNotification(wechatpayPublicKeyId, wechatpayPublicKey, headers, body); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
notification := &Notification{}
|
||
if err := json.Unmarshal(body, notification); err != nil {
|
||
return nil, fmt.Errorf("parse notification err: %w", err)
|
||
}
|
||
|
||
if err := notification.decrypt(apiv3Key); err != nil {
|
||
return nil, fmt.Errorf("notification decrypt err: %w", err)
|
||
}
|
||
|
||
return notification, nil
|
||
}
|
||
|
||
// ApiException 微信支付API错误异常,发送HTTP请求成功,但返回状态码不是 2XX 时抛出本异常
|
||
type ApiException struct {
|
||
statusCode int // 应答报文的 HTTP 状态码
|
||
header http.Header // 应答报文的 Header 信息
|
||
body []byte // 应答报文的 Body 原文
|
||
errorCode string // 微信支付回包的错误码
|
||
errorMessage string // 微信支付回包的错误信息
|
||
}
|
||
|
||
func (c *ApiException) Error() string {
|
||
buf := bytes.NewBuffer(nil)
|
||
buf.WriteString(fmt.Sprintf("api error:[StatusCode: %d, Body: %s", c.statusCode, string(c.body)))
|
||
if len(c.header) > 0 {
|
||
buf.WriteString(" Header: ")
|
||
for key, value := range c.header {
|
||
buf.WriteString(fmt.Sprintf("\n - %v=%v", key, value))
|
||
}
|
||
buf.WriteString("\n")
|
||
}
|
||
buf.WriteString("]")
|
||
return buf.String()
|
||
}
|
||
|
||
func (c *ApiException) StatusCode() int {
|
||
return c.statusCode
|
||
}
|
||
|
||
func (c *ApiException) Header() http.Header {
|
||
return c.header
|
||
}
|
||
|
||
func (c *ApiException) Body() []byte {
|
||
return c.body
|
||
}
|
||
|
||
func (c *ApiException) ErrorCode() string {
|
||
return c.errorCode
|
||
}
|
||
|
||
func (c *ApiException) ErrorMessage() string {
|
||
return c.errorMessage
|
||
}
|
||
|
||
func NewApiException(statusCode int, header http.Header, body []byte) error {
|
||
ret := &ApiException{
|
||
statusCode: statusCode,
|
||
header: header,
|
||
body: body,
|
||
}
|
||
|
||
bodyObject := map[string]any{}
|
||
if err := json.Unmarshal(body, &bodyObject); err == nil {
|
||
if val, ok := bodyObject["code"]; ok {
|
||
ret.errorCode = val.(string)
|
||
}
|
||
if val, ok := bodyObject["message"]; ok {
|
||
ret.errorMessage = val.(string)
|
||
}
|
||
}
|
||
|
||
return ret
|
||
}
|
||
|
||
// Time 复制 time.Time 对象,并返回复制体的指针
|
||
func Time(t time.Time) *time.Time {
|
||
return &t
|
||
}
|
||
|
||
// String 复制 string 对象,并返回复制体的指针
|
||
func String(s string) *string {
|
||
return &s
|
||
}
|
||
|
||
// Bytes 复制 []byte 对象,并返回复制体的指针
|
||
func Bytes(b []byte) *[]byte {
|
||
return &b
|
||
}
|
||
|
||
// Bool 复制 bool 对象,并返回复制体的指针
|
||
func Bool(b bool) *bool {
|
||
return &b
|
||
}
|
||
|
||
// Float64 复制 float64 对象,并返回复制体的指针
|
||
func Float64(f float64) *float64 {
|
||
return &f
|
||
}
|
||
|
||
// Float32 复制 float32 对象,并返回复制体的指针
|
||
func Float32(f float32) *float32 {
|
||
return &f
|
||
}
|
||
|
||
// Int64 复制 int64 对象,并返回复制体的指针
|
||
func Int64(i int64) *int64 {
|
||
return &i
|
||
}
|
||
|
||
// Int32 复制 int64 对象,并返回复制体的指针
|
||
func Int32(i int32) *int32 {
|
||
return &i
|
||
}
|
||
|
||
// generateHashFromStream 从io.Reader流中生成哈希值的通用函数
|
||
func generateHashFromStream(reader io.Reader, hashFunc func() hash.Hash, algorithmName string) (string, error) {
|
||
hash := hashFunc()
|
||
if _, err := io.Copy(hash, reader); err != nil {
|
||
return "", fmt.Errorf("failed to read stream for %s: %w", algorithmName, err)
|
||
}
|
||
return fmt.Sprintf("%x", hash.Sum(nil)), nil
|
||
}
|
||
|
||
// GenerateSHA256FromStream 从io.Reader流中生成SHA256哈希值
|
||
func GenerateSHA256FromStream(reader io.Reader) (string, error) {
|
||
return generateHashFromStream(reader, sha256.New, "SHA256")
|
||
}
|
||
|
||
// GenerateSHA1FromStream 从io.Reader流中生成SHA1哈希值
|
||
func GenerateSHA1FromStream(reader io.Reader) (string, error) {
|
||
return generateHashFromStream(reader, sha1.New, "SHA1")
|
||
}
|
||
|
||
// GenerateSM3FromStream 从io.Reader流中生成SM3哈希值
|
||
func GenerateSM3FromStream(reader io.Reader) (string, error) {
|
||
h := sm3.New()
|
||
if _, err := io.Copy(h, reader); err != nil {
|
||
return "", fmt.Errorf("failed to read stream for SM3: %w", err)
|
||
}
|
||
return fmt.Sprintf("%x", h.Sum(nil)), nil
|
||
}
|