/* ================================================================================= * License: GPL-2.0 license * Author: 众产® https://ciy.cn/code * Version: 0.1.3 ================================================================================= */ package zciyon import ( "bufio" "bytes" "compress/flate" "crypto/sha1" "encoding/base64" "encoding/binary" "errors" "io" "net" "net/http" "strings" "time" ) type CiyWebsocket struct { Isclose bool compress bool conn net.Conn brw *bufio.ReadWriter msgjsonfn func(code int, id int, json map[string]any) messagefn func(byt []byte) errorfn func(error) closefn func() } // 单个数据库连接池,平时单独使用 func NewCiyWebsocket(w http.ResponseWriter, r *http.Request) (*CiyWebsocket, error) { ws := &CiyWebsocket{} if r.Method != "GET" { return nil, errors.New("not a Get request") } if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" { return nil, errors.New("notfind Upgrade") } if strings.ToLower(r.Header.Get("Connection")) != "upgrade" { return nil, errors.New("notfind Connection") } if strings.ToLower(r.Header.Get("Sec-Websocket-Version")) != "13" { return nil, errors.New("notfind Connection") } extensions := Getstrparam(r.Header.Get("Sec-Websocket-Extensions"), ";") challengeKey := r.Header.Get("Sec-Websocket-Key") if challengeKey == "" { return nil, errors.New("nofind Sec-WebSocket-Key") } decoded, err := base64.StdEncoding.DecodeString(challengeKey) if err != nil { return nil, err } if len(decoded) != 16 { return nil, errors.New("error Sec-WebSocket-Key") } if _, ok := extensions["permessage-deflate"]; ok { ws.compress = true } ws.conn, ws.brw, err = http.NewResponseController(w).Hijack() if err != nil { return nil, err } if ws.brw.Reader.Buffered() > 0 { ws.conn.Close() return nil, errors.New("websocket: client sent data before handshake is complete") } p := make([]byte, 0, 1024) p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) h := sha1.New() h.Write([]byte(challengeKey)) h.Write([]byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) p = append(p, base64.StdEncoding.EncodeToString(h.Sum(nil))...) p = append(p, "\r\n"...) if ws.compress { p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) } p = append(p, "\r\n"...) ws.conn.SetDeadline(time.Time{}) if _, err = ws.conn.Write(p); err != nil { ws.conn.Close() return nil, err } go ws.thread() return ws, nil } func (thos *CiyWebsocket) OnMessage(fn func(byt []byte)) { thos.messagefn = fn } func (thos *CiyWebsocket) OnMessageJSON(fn func(code int, id int, json map[string]any)) { thos.msgjsonfn = fn } func (thos *CiyWebsocket) OnError(fn func(error)) { thos.errorfn = fn } func (thos *CiyWebsocket) OnClose(fn func()) { thos.closefn = fn } func (thos *CiyWebsocket) SendSucc(idx int, argjson ...map[string]any) error { json := map[string]any{} if len(argjson) > 0 { json = argjson[0] } json["_wsidx"] = idx json["code"] = 1 return thos.Send(JSON_Byte(json)) } func (thos *CiyWebsocket) SendFail(idx int, errmsg string) error { return thos.Send(JSON_Byte(map[string]any{ "_wsidx": idx, "errmsg": errmsg, })) } func (thos *CiyWebsocket) Send(wdata []byte) error { cancompress := thos.compress && len(wdata) > 10 b0 := byte(1) //1表示文本,2表示二进制 b0 |= (1 << 7) //flush b1 := byte(0) if cancompress { b0 |= (1 << 6) //rsv1 endata, err := endeflate(wdata) if err != nil { return err } wdata = endata } length := len(wdata) hdr := make([]byte, 0, 14) hdr = append(hdr, b0) switch { case length >= 65536: hdr = append(hdr, b1|127) binary.BigEndian.PutUint64(hdr[2:], uint64(length)) case length > 125: hdr = append(hdr, b1|126) binary.BigEndian.PutUint16(hdr[2:], uint16(length)) default: hdr = append(hdr, b1|byte(length)) } if _, err := thos.conn.Write(hdr); err != nil { thos.Close() return err } if _, err := thos.conn.Write(wdata); err != nil { thos.Close() return err } return nil } func (thos *CiyWebsocket) thread() { for { err := thos.watchread() if thos.Isclose { if thos.closefn != nil { go thos.closefn() } return } if err != nil && thos.errorfn != nil { go thos.errorfn(err) } } } func (thos *CiyWebsocket) Close() { thos.Isclose = true thos.conn.Close() } func (thos *CiyWebsocket) watchread() error { p, err := thos.read(2) if err != nil { if err == io.EOF { Sleep(0.1) return nil } return err } if len(p) != 2 { return nil } frameType := int(p[0] & 0xf) final := p[0]&(1<<7) != 0 rsv1 := p[0]&(1<<6) != 0 rsv2 := p[0]&(1<<5) != 0 rsv3 := p[0]&(1<<4) != 0 mask := p[1]&(1<<7) != 0 datalen := int(p[1] & 0x7f) //final为false,表示后续还有数据包,待完成。 //frameType=0 后续还有数据包brw //frameType=1 文本类型 //frameType=2 二进制类型 //frameType=8 连接断开 //frameType=9/10 心跳Ping/Pong。(客户端维护心跳包) if !final || rsv2 || rsv3 || (frameType > 2 && frameType != 8) { if buff := thos.brw.Reader.Buffered(); buff > 0 { thos.brw.Reader.Discard(buff) } return errors.New("data err, restart work") } switch datalen { case 126: p, err := thos.read(2) if err != nil { return err } datalen = int(binary.BigEndian.Uint16(p)) case 127: p, err := thos.read(8) if err != nil { return err } datalen = int(binary.BigEndian.Uint64(p)) } maskkey := make([]byte, 4) if mask { p, err = thos.read(4) if err != nil { return err } copy(maskkey, p) } byts := make([]byte, datalen) bytsidx := 0 for datalen > 0 { if datalen > 512 { pp := make([]byte, 512) n1, err := thos.brw.Reader.Read(pp) if err != nil { return err } copy(byts[bytsidx:bytsidx+n1], pp[:n1]) bytsidx += n1 datalen -= n1 } else { pp := make([]byte, datalen) n1, err := thos.brw.Reader.Read(pp) if err != nil { return err } copy(byts[bytsidx:bytsidx+n1], pp[:n1]) bytsidx += n1 datalen -= n1 } } if mask { for i := range byts { byts[i] ^= maskkey[i%4] } } if rsv1 { debyt, err := dedeflate(byts) if err != nil { return err } byts = debyt } if frameType == 1 || frameType == 2 { if len(byts) == 1 && byts[0] == 'h' { return nil } if thos.messagefn != nil { go thos.messagefn(byts) } if thos.msgjsonfn != nil { json := Byte_JSON(byts) if json != nil { code := Toint(json["code"]) id := Toint(json["_wsidx"]) go thos.msgjsonfn(code, id, json) } } } if frameType == 8 { thos.Close() } return nil } func endeflate(data []byte) ([]byte, error) { var b bytes.Buffer w, err := flate.NewWriter(&b, flate.BestCompression) if err != nil { return nil, err } if _, err := w.Write(data); err != nil { return nil, err } if err := w.Close(); err != nil { return nil, err } lastByte := b.Bytes()[len(b.Bytes())-1] if lastByte&0x03 != 0x00 { emptyDeflateBlock := []byte{0x00, 0x00, 0xff, 0xff, 0x00} b.Write(emptyDeflateBlock) } compressedData := b.Bytes()[:len(b.Bytes())-4] return compressedData, nil } func dedeflate(data []byte) ([]byte, error) { if len(data) >= 4 && bytes.Equal(data[len(data)-4:], []byte{0x00, 0x00, 0xff, 0xff}) { data = data[:len(data)-4] } data = append(data, []byte{0x00, 0x00, 0xff, 0xff}...) fr := flate.NewReader(bytes.NewReader(data)) decompressedData := new(bytes.Buffer) if _, err := io.Copy(decompressedData, fr); err != nil && !strings.Contains(err.Error(), "EOF") { return nil, err } if err := fr.Close(); err != nil && !strings.Contains(err.Error(), "EOF") { return nil, err } return decompressedData.Bytes(), nil } func (thos *CiyWebsocket) read(n int) ([]byte, error) { p, err := thos.brw.Reader.Peek(n) if err == io.EOF { return nil, io.EOF } thos.brw.Reader.Discard(len(p)) return p, nil }