1 package info_websocket 2 3 import ( 4 "crypto/sha1" 5 "encoding/base64" 6 "errors" 7 "io" 8 "log" 9 "net" 10 "strings" 11 ) 12 13 func main() { 14 ln, err := net.Listen("tcp", ":8000")//监听端口 15 if err != nil { 16 log.Panic(err) 17 } 18 for { 19 log.Println("wss") 20 conn, err := ln.Accept()//等待客户的连接 21 if err != nil { 22 log.Println("Accept err:", err) 23 } 24 for { 25 handleConnection(conn) 26 } 27 } 28 } 29 30 func handleConnection(conn net.Conn) { 31 content := make([]byte, 1024) 32 _, err := conn.Read(content) 33 log.Println(string(content)) 34 if err != nil { 35 log.Println(err) 36 } 37 isHttp := false 38 // 先暂时这么判断 39 if string(content[0:3]) == "GET" { 40 isHttp = true 41 } 42 log.Println("isHttp:", isHttp) 43 if isHttp { 44 headers := parseHandshake(string(content)) 45 log.Println("headers", headers) 46 secWebsocketKey := headers["Sec-WebSocket-Key"] 47 // NOTE:这里省略其他的验证 48 guid := "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 49 // 计算Sec-WebSocket-Accept 50 h := sha1.New() 51 log.Println("accept raw:", secWebsocketKey+guid) 52 io.WriteString(h, secWebsocketKey+guid) 53 accept := make([]byte, 28) 54 base64.StdEncoding.Encode(accept, h.Sum(nil)) 55 log.Println(string(accept)) 56 response := "HTTP/1.1 101 Switching Protocols\r\n" 57 response = response + "Sec-WebSocket-Accept: " + string(accept) + "\r\n" 58 response = response + "Connection: Upgrade\r\n" 59 response = response + "Upgrade: websocket\r\n\r\n" 60 log.Println("response:", response) 61 if lenth, err := conn.Write([]byte(response)); err != nil { 62 log.Println(err) 63 } else { 64 log.Println("send len:", lenth) 65 } 66 wssocket := NewWsSocket(conn) 67 for { 68 data, err := wssocket.ReadIframe() 69 if err != nil { 70 log.Println("readIframe err:", err) 71 } 72 log.Println("read data:", string(data)) 73 err = wssocket.SendIframe([]byte("good")) 74 if err != nil { 75 log.Println("sendIframe err:", err) 76 } 77 log.Println("send data") 78 } 79 } else { 80 log.Println(string(content)) 81 // 直接读取 82 } 83 } 84 85 type WsSocket struct { 86 MaskingKey []byte 87 Conn net.Conn 88 } 89 90 func NewWsSocket(conn net.Conn) *WsSocket { 91 return &WsSocket{Conn: conn} 92 } 93 94 func (this *WsSocket) SendIframe(data []byte) error { 95 // 这里只处理data长度<125的 96 if len(data) >= 125 { 97 return errors.New("send iframe data error") 98 } 99 lenth := len(data)100 maskedData := make([]byte, lenth)101 for i := 0; i < lenth; i++ {102 if this.MaskingKey != nil {103 maskedData[i] = data[i] ^ this.MaskingKey[i%4]104 } else {105 maskedData[i] = data[i]106 }107 }108 this.Conn.Write([]byte{0x81})109 var payLenByte byte110 if this.MaskingKey != nil && len(this.MaskingKey) != 4 {111 payLenByte = byte(0x80) | byte(lenth)112 this.Conn.Write([]byte{payLenByte})113 this.Conn.Write(this.MaskingKey)114 } else {115 payLenByte = byte(0x00) | byte(lenth)116 this.Conn.Write([]byte{payLenByte})117 }118 this.Conn.Write(data)119 return nil120 }121 122 func (this *WsSocket) ReadIframe() (data []byte, err error) {123 err = nil124 //第一个字节:FIN + RSV1-3 + OPCODE125 opcodeByte := make([]byte, 1)126 this.Conn.Read(opcodeByte)127 FIN := opcodeByte[0] >> 7128 RSV1 := opcodeByte[0] >> 6 & 1129 RSV2 := opcodeByte[0] >> 5 & 1130 RSV3 := opcodeByte[0] >> 4 & 1131 OPCODE := opcodeByte[0] & 15132 log.Println(RSV1, RSV2, RSV3, OPCODE)133 134 payloadLenByte := make([]byte, 1)135 this.Conn.Read(payloadLenByte)136 payloadLen := int(payloadLenByte[0] & 0x7F)137 mask := payloadLenByte[0] >> 7138 if payloadLen == 127 {139 extendedByte := make([]byte, 8)140 this.Conn.Read(extendedByte)141 }142 maskingByte := make([]byte, 4)143 if mask == 1 {144 this.Conn.Read(maskingByte)145 this.MaskingKey = maskingByte146 }147 148 payloadDataByte := make([]byte, payloadLen)149 this.Conn.Read(payloadDataByte)150 log.Println("data:", payloadDataByte)151 dataByte := make([]byte, payloadLen)152 for i := 0; i < payloadLen; i++ {153 if mask == 1 {154 dataByte[i] = payloadDataByte[i] ^ maskingByte[i%4]155 } else {156 dataByte[i] = payloadDataByte[i]157 }158 }159 if FIN == 1 {160 data = dataByte161 return162 }163 nextData, err := this.ReadIframe()164 if err != nil {165 return166 }167 data = append(data, nextData...)168 return169 }170 171 func parseHandshake(content string) map[string]string {172 headers := make(map[string]string, 10)173 lines := strings.Split(content, "\r\n")174 for _, line := range lines {175 if len(line) >= 0 {176 words := strings.Split(line, ":")177 if len(words) == 2 {178 headers[strings.Trim(words[0], " ")] = strings.Trim(words[1], " ")179 }180 }181 }182 return headers183 }