339 lines
7.8 KiB
Go
339 lines
7.8 KiB
Go
/*
|
||
=================================================================================
|
||
* License: GPL-2.0 license
|
||
* Author: 众产® https://ciy.cn/code
|
||
* Version: 0.1.3
|
||
=================================================================================
|
||
*/
|
||
package zciyon
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"compress/flate"
|
||
"crypto/sha1"
|
||
"encoding/base64"
|
||
"encoding/binary"
|
||
"errors"
|
||
"io"
|
||
"net"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
type CiyWebsocket struct {
|
||
Isclose bool
|
||
compress bool
|
||
conn net.Conn
|
||
brw *bufio.ReadWriter
|
||
msgjsonfn func(code int, id int, json map[string]any)
|
||
messagefn func(byt []byte)
|
||
errorfn func(error)
|
||
closefn func()
|
||
}
|
||
|
||
// 单个数据库连接池,平时单独使用
|
||
func NewCiyWebsocket(w http.ResponseWriter, r *http.Request) (*CiyWebsocket, error) {
|
||
ws := &CiyWebsocket{}
|
||
if r.Method != "GET" {
|
||
return nil, errors.New("not a Get request")
|
||
}
|
||
if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
|
||
return nil, errors.New("notfind Upgrade")
|
||
}
|
||
if strings.ToLower(r.Header.Get("Connection")) != "upgrade" {
|
||
return nil, errors.New("notfind Connection")
|
||
}
|
||
if strings.ToLower(r.Header.Get("Sec-Websocket-Version")) != "13" {
|
||
return nil, errors.New("notfind Connection")
|
||
}
|
||
|
||
extensions := Getstrparam(r.Header.Get("Sec-Websocket-Extensions"), ";")
|
||
|
||
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
||
if challengeKey == "" {
|
||
return nil, errors.New("nofind Sec-WebSocket-Key")
|
||
}
|
||
decoded, err := base64.StdEncoding.DecodeString(challengeKey)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if len(decoded) != 16 {
|
||
return nil, errors.New("error Sec-WebSocket-Key")
|
||
}
|
||
|
||
if _, ok := extensions["permessage-deflate"]; ok {
|
||
ws.compress = true
|
||
}
|
||
|
||
ws.conn, ws.brw, err = http.NewResponseController(w).Hijack()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if ws.brw.Reader.Buffered() > 0 {
|
||
ws.conn.Close()
|
||
return nil, errors.New("websocket: client sent data before handshake is complete")
|
||
}
|
||
|
||
p := make([]byte, 0, 1024)
|
||
|
||
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
|
||
|
||
h := sha1.New()
|
||
h.Write([]byte(challengeKey))
|
||
h.Write([]byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
|
||
p = append(p, base64.StdEncoding.EncodeToString(h.Sum(nil))...)
|
||
p = append(p, "\r\n"...)
|
||
if ws.compress {
|
||
p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
|
||
}
|
||
p = append(p, "\r\n"...)
|
||
|
||
ws.conn.SetDeadline(time.Time{})
|
||
|
||
if _, err = ws.conn.Write(p); err != nil {
|
||
ws.conn.Close()
|
||
return nil, err
|
||
}
|
||
go ws.thread()
|
||
return ws, nil
|
||
}
|
||
|
||
func (thos *CiyWebsocket) OnMessage(fn func(byt []byte)) {
|
||
thos.messagefn = fn
|
||
}
|
||
func (thos *CiyWebsocket) OnMessageJSON(fn func(code int, id int, json map[string]any)) {
|
||
thos.msgjsonfn = fn
|
||
}
|
||
func (thos *CiyWebsocket) OnError(fn func(error)) {
|
||
thos.errorfn = fn
|
||
}
|
||
func (thos *CiyWebsocket) OnClose(fn func()) {
|
||
thos.closefn = fn
|
||
}
|
||
func (thos *CiyWebsocket) SendSucc(idx int, argjson ...map[string]any) error {
|
||
json := map[string]any{}
|
||
if len(argjson) > 0 {
|
||
json = argjson[0]
|
||
}
|
||
json["_wsidx"] = idx
|
||
json["code"] = 1
|
||
return thos.Send(JSON_Byte(json))
|
||
}
|
||
func (thos *CiyWebsocket) SendFail(idx int, errmsg string) error {
|
||
return thos.Send(JSON_Byte(map[string]any{
|
||
"_wsidx": idx,
|
||
"errmsg": errmsg,
|
||
}))
|
||
}
|
||
func (thos *CiyWebsocket) Send(wdata []byte) error {
|
||
cancompress := thos.compress && len(wdata) > 10
|
||
b0 := byte(1) //1表示文本,2表示二进制
|
||
b0 |= (1 << 7) //flush
|
||
b1 := byte(0)
|
||
if cancompress {
|
||
b0 |= (1 << 6) //rsv1
|
||
endata, err := endeflate(wdata)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
wdata = endata
|
||
}
|
||
length := len(wdata)
|
||
hdr := make([]byte, 0, 14)
|
||
hdr = append(hdr, b0)
|
||
switch {
|
||
case length >= 65536:
|
||
hdr = append(hdr, b1|127)
|
||
binary.BigEndian.PutUint64(hdr[2:], uint64(length))
|
||
case length > 125:
|
||
hdr = append(hdr, b1|126)
|
||
binary.BigEndian.PutUint16(hdr[2:], uint16(length))
|
||
default:
|
||
hdr = append(hdr, b1|byte(length))
|
||
}
|
||
if _, err := thos.conn.Write(hdr); err != nil {
|
||
thos.Close()
|
||
return err
|
||
}
|
||
if _, err := thos.conn.Write(wdata); err != nil {
|
||
thos.Close()
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
func (thos *CiyWebsocket) thread() {
|
||
for {
|
||
err := thos.watchread()
|
||
if thos.Isclose {
|
||
if thos.closefn != nil {
|
||
go thos.closefn()
|
||
}
|
||
return
|
||
}
|
||
if err != nil && thos.errorfn != nil {
|
||
go thos.errorfn(err)
|
||
}
|
||
}
|
||
}
|
||
func (thos *CiyWebsocket) Close() {
|
||
thos.Isclose = true
|
||
thos.conn.Close()
|
||
}
|
||
|
||
func (thos *CiyWebsocket) watchread() error {
|
||
p, err := thos.read(2)
|
||
if err != nil {
|
||
if err == io.EOF {
|
||
Sleep(0.1)
|
||
return nil
|
||
}
|
||
return err
|
||
}
|
||
if len(p) != 2 {
|
||
return nil
|
||
}
|
||
frameType := int(p[0] & 0xf)
|
||
final := p[0]&(1<<7) != 0
|
||
rsv1 := p[0]&(1<<6) != 0
|
||
rsv2 := p[0]&(1<<5) != 0
|
||
rsv3 := p[0]&(1<<4) != 0
|
||
mask := p[1]&(1<<7) != 0
|
||
datalen := int(p[1] & 0x7f)
|
||
//final为false,表示后续还有数据包,待完成。
|
||
//frameType=0 后续还有数据包brw
|
||
//frameType=1 文本类型
|
||
//frameType=2 二进制类型
|
||
//frameType=8 连接断开
|
||
//frameType=9/10 心跳Ping/Pong。(客户端维护心跳包)
|
||
if !final || rsv2 || rsv3 || (frameType > 2 && frameType != 8) {
|
||
if buff := thos.brw.Reader.Buffered(); buff > 0 {
|
||
thos.brw.Reader.Discard(buff)
|
||
}
|
||
return errors.New("data err, restart work")
|
||
}
|
||
switch datalen {
|
||
case 126:
|
||
p, err := thos.read(2)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
datalen = int(binary.BigEndian.Uint16(p))
|
||
case 127:
|
||
p, err := thos.read(8)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
datalen = int(binary.BigEndian.Uint64(p))
|
||
}
|
||
maskkey := make([]byte, 4)
|
||
if mask {
|
||
p, err = thos.read(4)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
copy(maskkey, p)
|
||
}
|
||
byts := make([]byte, datalen)
|
||
bytsidx := 0
|
||
for datalen > 0 {
|
||
if datalen > 512 {
|
||
pp := make([]byte, 512)
|
||
n1, err := thos.brw.Reader.Read(pp)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
copy(byts[bytsidx:bytsidx+n1], pp[:n1])
|
||
bytsidx += n1
|
||
datalen -= n1
|
||
} else {
|
||
pp := make([]byte, datalen)
|
||
n1, err := thos.brw.Reader.Read(pp)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
copy(byts[bytsidx:bytsidx+n1], pp[:n1])
|
||
bytsidx += n1
|
||
datalen -= n1
|
||
}
|
||
}
|
||
if mask {
|
||
for i := range byts {
|
||
byts[i] ^= maskkey[i%4]
|
||
}
|
||
}
|
||
if rsv1 {
|
||
debyt, err := dedeflate(byts)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
byts = debyt
|
||
}
|
||
if frameType == 1 || frameType == 2 {
|
||
if len(byts) == 1 && byts[0] == 'h' {
|
||
return nil
|
||
}
|
||
if thos.messagefn != nil {
|
||
go thos.messagefn(byts)
|
||
}
|
||
if thos.msgjsonfn != nil {
|
||
json := Byte_JSON(byts)
|
||
if json != nil {
|
||
code := Toint(json["code"])
|
||
id := Toint(json["_wsidx"])
|
||
go thos.msgjsonfn(code, id, json)
|
||
}
|
||
}
|
||
}
|
||
if frameType == 8 {
|
||
thos.Close()
|
||
}
|
||
return nil
|
||
}
|
||
func endeflate(data []byte) ([]byte, error) {
|
||
var b bytes.Buffer
|
||
w, err := flate.NewWriter(&b, flate.BestCompression)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if _, err := w.Write(data); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := w.Close(); err != nil {
|
||
return nil, err
|
||
}
|
||
lastByte := b.Bytes()[len(b.Bytes())-1]
|
||
if lastByte&0x03 != 0x00 {
|
||
emptyDeflateBlock := []byte{0x00, 0x00, 0xff, 0xff, 0x00}
|
||
b.Write(emptyDeflateBlock)
|
||
}
|
||
compressedData := b.Bytes()[:len(b.Bytes())-4]
|
||
return compressedData, nil
|
||
}
|
||
func dedeflate(data []byte) ([]byte, error) {
|
||
if len(data) >= 4 && bytes.Equal(data[len(data)-4:], []byte{0x00, 0x00, 0xff, 0xff}) {
|
||
data = data[:len(data)-4]
|
||
}
|
||
data = append(data, []byte{0x00, 0x00, 0xff, 0xff}...)
|
||
fr := flate.NewReader(bytes.NewReader(data))
|
||
decompressedData := new(bytes.Buffer)
|
||
if _, err := io.Copy(decompressedData, fr); err != nil && !strings.Contains(err.Error(), "EOF") {
|
||
return nil, err
|
||
}
|
||
if err := fr.Close(); err != nil && !strings.Contains(err.Error(), "EOF") {
|
||
return nil, err
|
||
}
|
||
return decompressedData.Bytes(), nil
|
||
}
|
||
|
||
func (thos *CiyWebsocket) read(n int) ([]byte, error) {
|
||
p, err := thos.brw.Reader.Peek(n)
|
||
if err == io.EOF {
|
||
return nil, io.EOF
|
||
}
|
||
thos.brw.Reader.Discard(len(p))
|
||
return p, nil
|
||
}
|