c5_labsci/zciyon/ws.go
2026-01-27 00:52:00 +08:00

339 lines
7.8 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
=================================================================================
* License: GPL-2.0 license
* Author: 众产® https://ciy.cn/code
* Version: 0.1.3
=================================================================================
*/
package zciyon
import (
"bufio"
"bytes"
"compress/flate"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"errors"
"io"
"net"
"net/http"
"strings"
"time"
)
type CiyWebsocket struct {
Isclose bool
compress bool
conn net.Conn
brw *bufio.ReadWriter
msgjsonfn func(code int, id int, json map[string]any)
messagefn func(byt []byte)
errorfn func(error)
closefn func()
}
// 单个数据库连接池,平时单独使用
func NewCiyWebsocket(w http.ResponseWriter, r *http.Request) (*CiyWebsocket, error) {
ws := &CiyWebsocket{}
if r.Method != "GET" {
return nil, errors.New("not a Get request")
}
if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
return nil, errors.New("notfind Upgrade")
}
if strings.ToLower(r.Header.Get("Connection")) != "upgrade" {
return nil, errors.New("notfind Connection")
}
if strings.ToLower(r.Header.Get("Sec-Websocket-Version")) != "13" {
return nil, errors.New("notfind Connection")
}
extensions := Getstrparam(r.Header.Get("Sec-Websocket-Extensions"), ";")
challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" {
return nil, errors.New("nofind Sec-WebSocket-Key")
}
decoded, err := base64.StdEncoding.DecodeString(challengeKey)
if err != nil {
return nil, err
}
if len(decoded) != 16 {
return nil, errors.New("error Sec-WebSocket-Key")
}
if _, ok := extensions["permessage-deflate"]; ok {
ws.compress = true
}
ws.conn, ws.brw, err = http.NewResponseController(w).Hijack()
if err != nil {
return nil, err
}
if ws.brw.Reader.Buffered() > 0 {
ws.conn.Close()
return nil, errors.New("websocket: client sent data before handshake is complete")
}
p := make([]byte, 0, 1024)
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write([]byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
p = append(p, base64.StdEncoding.EncodeToString(h.Sum(nil))...)
p = append(p, "\r\n"...)
if ws.compress {
p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
}
p = append(p, "\r\n"...)
ws.conn.SetDeadline(time.Time{})
if _, err = ws.conn.Write(p); err != nil {
ws.conn.Close()
return nil, err
}
go ws.thread()
return ws, nil
}
func (thos *CiyWebsocket) OnMessage(fn func(byt []byte)) {
thos.messagefn = fn
}
func (thos *CiyWebsocket) OnMessageJSON(fn func(code int, id int, json map[string]any)) {
thos.msgjsonfn = fn
}
func (thos *CiyWebsocket) OnError(fn func(error)) {
thos.errorfn = fn
}
func (thos *CiyWebsocket) OnClose(fn func()) {
thos.closefn = fn
}
func (thos *CiyWebsocket) SendSucc(idx int, argjson ...map[string]any) error {
json := map[string]any{}
if len(argjson) > 0 {
json = argjson[0]
}
json["_wsidx"] = idx
json["code"] = 1
return thos.Send(JSON_Byte(json))
}
func (thos *CiyWebsocket) SendFail(idx int, errmsg string) error {
return thos.Send(JSON_Byte(map[string]any{
"_wsidx": idx,
"errmsg": errmsg,
}))
}
func (thos *CiyWebsocket) Send(wdata []byte) error {
cancompress := thos.compress && len(wdata) > 10
b0 := byte(1) //1表示文本2表示二进制
b0 |= (1 << 7) //flush
b1 := byte(0)
if cancompress {
b0 |= (1 << 6) //rsv1
endata, err := endeflate(wdata)
if err != nil {
return err
}
wdata = endata
}
length := len(wdata)
hdr := make([]byte, 0, 14)
hdr = append(hdr, b0)
switch {
case length >= 65536:
hdr = append(hdr, b1|127)
binary.BigEndian.PutUint64(hdr[2:], uint64(length))
case length > 125:
hdr = append(hdr, b1|126)
binary.BigEndian.PutUint16(hdr[2:], uint16(length))
default:
hdr = append(hdr, b1|byte(length))
}
if _, err := thos.conn.Write(hdr); err != nil {
thos.Close()
return err
}
if _, err := thos.conn.Write(wdata); err != nil {
thos.Close()
return err
}
return nil
}
func (thos *CiyWebsocket) thread() {
for {
err := thos.watchread()
if thos.Isclose {
if thos.closefn != nil {
go thos.closefn()
}
return
}
if err != nil && thos.errorfn != nil {
go thos.errorfn(err)
}
}
}
func (thos *CiyWebsocket) Close() {
thos.Isclose = true
thos.conn.Close()
}
func (thos *CiyWebsocket) watchread() error {
p, err := thos.read(2)
if err != nil {
if err == io.EOF {
Sleep(0.1)
return nil
}
return err
}
if len(p) != 2 {
return nil
}
frameType := int(p[0] & 0xf)
final := p[0]&(1<<7) != 0
rsv1 := p[0]&(1<<6) != 0
rsv2 := p[0]&(1<<5) != 0
rsv3 := p[0]&(1<<4) != 0
mask := p[1]&(1<<7) != 0
datalen := int(p[1] & 0x7f)
//final为false表示后续还有数据包待完成。
//frameType=0 后续还有数据包brw
//frameType=1 文本类型
//frameType=2 二进制类型
//frameType=8 连接断开
//frameType=9/10 心跳Ping/Pong。(客户端维护心跳包)
if !final || rsv2 || rsv3 || (frameType > 2 && frameType != 8) {
if buff := thos.brw.Reader.Buffered(); buff > 0 {
thos.brw.Reader.Discard(buff)
}
return errors.New("data err, restart work")
}
switch datalen {
case 126:
p, err := thos.read(2)
if err != nil {
return err
}
datalen = int(binary.BigEndian.Uint16(p))
case 127:
p, err := thos.read(8)
if err != nil {
return err
}
datalen = int(binary.BigEndian.Uint64(p))
}
maskkey := make([]byte, 4)
if mask {
p, err = thos.read(4)
if err != nil {
return err
}
copy(maskkey, p)
}
byts := make([]byte, datalen)
bytsidx := 0
for datalen > 0 {
if datalen > 512 {
pp := make([]byte, 512)
n1, err := thos.brw.Reader.Read(pp)
if err != nil {
return err
}
copy(byts[bytsidx:bytsidx+n1], pp[:n1])
bytsidx += n1
datalen -= n1
} else {
pp := make([]byte, datalen)
n1, err := thos.brw.Reader.Read(pp)
if err != nil {
return err
}
copy(byts[bytsidx:bytsidx+n1], pp[:n1])
bytsidx += n1
datalen -= n1
}
}
if mask {
for i := range byts {
byts[i] ^= maskkey[i%4]
}
}
if rsv1 {
debyt, err := dedeflate(byts)
if err != nil {
return err
}
byts = debyt
}
if frameType == 1 || frameType == 2 {
if len(byts) == 1 && byts[0] == 'h' {
return nil
}
if thos.messagefn != nil {
go thos.messagefn(byts)
}
if thos.msgjsonfn != nil {
json := Byte_JSON(byts)
if json != nil {
code := Toint(json["code"])
id := Toint(json["_wsidx"])
go thos.msgjsonfn(code, id, json)
}
}
}
if frameType == 8 {
thos.Close()
}
return nil
}
func endeflate(data []byte) ([]byte, error) {
var b bytes.Buffer
w, err := flate.NewWriter(&b, flate.BestCompression)
if err != nil {
return nil, err
}
if _, err := w.Write(data); err != nil {
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
lastByte := b.Bytes()[len(b.Bytes())-1]
if lastByte&0x03 != 0x00 {
emptyDeflateBlock := []byte{0x00, 0x00, 0xff, 0xff, 0x00}
b.Write(emptyDeflateBlock)
}
compressedData := b.Bytes()[:len(b.Bytes())-4]
return compressedData, nil
}
func dedeflate(data []byte) ([]byte, error) {
if len(data) >= 4 && bytes.Equal(data[len(data)-4:], []byte{0x00, 0x00, 0xff, 0xff}) {
data = data[:len(data)-4]
}
data = append(data, []byte{0x00, 0x00, 0xff, 0xff}...)
fr := flate.NewReader(bytes.NewReader(data))
decompressedData := new(bytes.Buffer)
if _, err := io.Copy(decompressedData, fr); err != nil && !strings.Contains(err.Error(), "EOF") {
return nil, err
}
if err := fr.Close(); err != nil && !strings.Contains(err.Error(), "EOF") {
return nil, err
}
return decompressedData.Bytes(), nil
}
func (thos *CiyWebsocket) read(n int) ([]byte, error) {
p, err := thos.brw.Reader.Peek(n)
if err == io.EOF {
return nil, io.EOF
}
thos.brw.Reader.Discard(len(p))
return p, nil
}