var events = require("events");var http = require("http");var crypto = require("crypto");var util = require("util");// 操作码var opcodes = { TEXT: 1, BINARY: 2, CLOSE: 8, PING: 9, PONG: 10};var WebSocketConnection = function(req, socket, upgradeHead) { var self = this; var key = hashWebSocketKey(req.headers[‘sec-websocket-key‘]); // 建立连接 socket.write(‘HTTP/1.1 101 Web Socket Protocol Handshake\r\n‘ + ‘Upgrade: WebSocket\r\n‘ + ‘Connection: Upgrade\r\n‘ + ‘sec-websocket-accept: ‘ + key + ‘\r\n‘ + ‘\r\n‘ ); socket.on(‘data‘, function (buf) { self.buffer = Buffer.concat([self.buffer, buf]); while (self._processBuffer()) { /// process buffer while it contains complete frames } }); socket.on(‘close‘, function (buf) { if (!self.closed) { // 自定义错误 self.emit(‘close‘, 1006); self.closed = true; } }); // initialize connection state this.socket = socket; this.buffer = new Buffer(0); this.closed = false;}// 继承events.EventEmitterutil.inherits(WebSocketConnection, events.EventEmitter);WebSocketConnection.prototype.send = function (obj) { var opcode; var payload; if (Buffer.isBuffer(obj)) { opcode = opcodes.BINARY; payload = obj; } else if (typeof obj == ‘string‘) { opcode = opcodes.TEXT; payload = new Buffer(obj, ‘utf8‘); } else { throw new Error(‘Cannot send object.Must be string or Buffer.‘); } this._doSend(opcode, payload);}WebSocketConnection.prototype.close = function (code, reason) { var opcode = opcodes.CLOSE; var buffer; if (code) { buffer = new Buffer(Buffer.byteLength(reason) + 2); buffer.writeUInt16BE(code, 0); buffer.write(reason, 2); } else { buffer = new Buffer(0); } this._doSend(opcode, buffer); this.closed = true;}WebSocketConnection.prototype._processBuffer = function () { var buf = this.buffer; if (buf.length < 2) return; var idx = 2; var b1 = buf.readUInt8(0); var fin = b1 & 0x80; // fin var opcode = b1 & 0x0f; // 操作码 var b2 = buf.readUInt8(1); var mask = b2 & 0x80; // 掩码 var length = b2 & 0x7f; // 长度 if (length > 125) { if (buf.length < 8) return; if (length == 126) { length = buf.readUInt16BE(2); idx += 2; } else if (length == 127) { var highBits = buf.readUInt32BE(2); if (highBits != 0) { // 1009代表消息过大 this.close(1009, ""); } length = buf.readUInt32BE(6); idx += 8; } } if (buf.length < idx + 4 + length) { return; } maskBytes = buf.slice(idx, idx + 4); idx += 4; var payload = buf.slice(idx, idx + length); payload = unmask(maskBytes, payload); this._handleFrame(opcode, payload); this.buffer = buf.slice(idx + length); // 数据清空 return true;}// 处理WebSocket帧WebSocketConnection.prototype._handleFrame = function (opcode, buffer) { var payload; switch (opcode) { case opcodes.TEXT: payload = buffer.toString(‘utf8‘); // 发送接收数据事件 this.emit(‘data‘, opcode, payload); break; case opcodes.BINARY: payload = buffer; // 发送接收数据事件 this.emit(‘data‘, opcode, payload); break; case opcodes.PING: this._doSend(opcodes.PONG, buffer); break; case opcodes.PONG: // break; case opcodes.CLOSE: var code, reason; if (buffer.length >= 2) { code = buffer.readUInt16BE(0); reason = buffer.toString(‘utf8‘, 2); } this.close(code, reason); // 发送close事件 this.emit(‘close‘, code, reason); break; default: // 1002代表协议错误 this.close(1002, ‘unknown opcode‘); }}WebSocketConnection.prototype._doSend = function (opcode, payload) { // 基于TCP发送数据 this.socket.write(encodeMessage(opcode, payload));}// 计算sec-websocket-accept值var hashWebSocketKey = function (key) { var sha1 = crypto.createHash(‘sha1‘); sha1.update(key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11", ‘ascii‘); return sha1.digest(‘base64‘);}// 数据掩码解析var unmask = function (maskBytes, data) { var payload = new Buffer(data.length); for (var i = 0; i < data.length; i++) { payload[i] = maskBytes[i % 4] ^ data[i]; } return payload;}// 发送数据封装var encodeMessage = function (opcode, payload) { var buf; var b1 = 0x80 | opcode; // fin置为1 var b2 = 0; // 没有掩码 var length = payload.length; if (length < 126) { buf = new Buffer(payload.length + 2 + 0); b2 |= length; buf.writeUInt8(b1, 0); buf.writeUInt8(b2, 1); payload.copy(buf, 2); } else if (length < (1 << 16)) { // buf = new Buffer(payload.length + 2 + 2); b2 |= 126; buf.writeUInt8(b1, 0); buf.writeUInt8(b2, 1); buf.writeUInt16BE(length, 2); payload.copy(buf, 4); } else { buf = new Buffer(payload + 2 + 8); b2 |= 127; buf.writeUInt8(b1, 0); buf.writeUInt8(b2, 1); buf.writeUInt32BE(0, 2); // 必需为0 buf.writeUInt32BE(length, 6); payload.copy(buf, 10); } return buf;}exports.listen = function (port, host, connectionHandler) { var srv = http.createServer(function (req, res) {}); // 监听upgrade事件并生成WebScoket连接 srv.on(‘upgrade‘, function (req, socket, upgradeHead) { var ws = new WebSocketConnection(req, socket, upgradeHead); connectionHandler(ws); }); srv.listen(port, host);}
var websocket = require(‘./websocket-example‘);websocket.listen(9999,"localhost",function(conn){ console.log("connenction opened"); conn.on(‘data‘,function(opcode,data){ console.log(‘message:‘,data); conn.send(data); }); conn.on(‘close‘,function(code,reason){ console.log("connection closed:", code , reason); });});
注:不得不佩服nodejs代码的简洁与易读性,之前项目开发过nodejs的支付SDK,就发现nodejs的魅力,希望你也可以希望上它。
public class WebSocketServer { public static void main(String[] args) throws InterruptedException { EventLoopGroup bossGroup = new NioEventLoopGroup(); EventLoopGroup workerGroup = new NioEventLoopGroup(); try{ ServerBootstrap serverBootstrap = new ServerBootstrap(); serverBootstrap.group(bossGroup, workerGroup) .channel(NioServerSocketChannel.class) .handler(new LoggingHandler(LogLevel.INFO)) .childHandler(new WebSocketInitalizer()); ChannelFuture channelFuture = serverBootstrap .bind("localhost",9999).sync(); channelFuture.channel().closeFuture().sync(); }finally{ bossGroup.shutdownGracefully(); workerGroup.shutdownGracefully(); } }}
public class WebSocketInitalizer extends ChannelInitializer<SocketChannel> { @Override protected void initChannel(SocketChannel ch) throws Exception { ChannelPipeline pipeline = ch.pipeline(); pipeline.addLast(new HttpServerCodec()); pipeline.addLast(new HttpObjectAggregator(8192)); pipeline.addLast(new ChunkedWriteHandler()); // 这个是最重要的Handler,后面稍微跟踪一下源码 pipeline.addLast(new WebSocketServerProtocolHandler("/ws")); pipeline.addLast(new MyWebSocketHandler()); }}
public class MyWebSocketHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> { @Override protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception { Channel channel = ctx.channel(); System.out.println(channel.remoteAddress() + ": " + msg.text()); ctx.channel().writeAndFlush(msg); } @Override public void handlerAdded(ChannelHandlerContext ctx) throws Exception { System.out.println("login: " + ctx.channel().id().asLongText()); } @Override public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { System.out.println("logout: " + ctx.channel().id().asLongText()); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { ctx.channel().close(); }}
protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) { FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS); // http1.1 101 Switching Protocols if (headers != null) { res.headers().add(headers); } CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY); if (key == null) { throw new WebSocketHandshakeException("not a WebSocket request: missing key"); } // WEBSOCKET_13_ACCEPT_GUID为"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" String acceptSeed = key + WEBSOCKET_13_ACCEPT_GUID; byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); String accept = WebSocketUtil.base64(sha1); res.headers().add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET); res.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE); res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept); String subprotocols = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL); if (subprotocols != null) { // 服务器挑选子协议 String selectedSubprotocol = selectSubprotocol(subprotocols); if (selectedSubprotocol == null) { } else { res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, selectedSubprotocol); } } return res;}
// 代码删减了一些细节protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out){ switch (state) { case READING_FIRST: framePayloadLength = 0; byte b = in.readByte(); frameFinalFlag = (b & 0x80) != 0; // FIN frameRsv = (b & 0x70) >> 4; // RSV frameOpcode = b & 0x0F; // OPCODE state = State.READING_SECOND; case READING_SECOND: b = in.readByte(); frameMasked = (b & 0x80) != 0; // MASK framePayloadLen1 = b & 0x7F; // LEN if (frameRsv != 0 && !allowExtensions) { // SRV当不允许拓展时必须为0 protocolViolation(ctx, "RSV != 0 and no extension negotiated, RSV:" + frameRsv); return; } if (!allowMaskMismatch && expectMaskedFrames != frameMasked) { protocolViolation(ctx, "received a frame that is not masked as expected"); return; } if (frameOpcode > 7) { // 控制帧 8 9 10 if (!frameFinalFlag) { // 控制帧FIN必需为true protocolViolation(ctx, "fragmented control frame"); return; } if (framePayloadLen1 > 125) { // 控制帧长度不可能大于125 protocolViolation(ctx, "control frame with payload length > 125 octets"); return; } // OPCODE如果不等8/9/10直接返回错误 if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING || frameOpcode == OPCODE_PONG)) { protocolViolation(ctx, "control frame using reserved opcode " + frameOpcode); return; } // 关闭连接时长度错误 if (frameOpcode == 8 && framePayloadLen1 == 1) { protocolViolation(ctx, "received close control frame with payload len 1"); return; } } else { // OPCODE为数据帧 0连接 1文本 2二进制 if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT || frameOpcode == OPCODE_BINARY)) { protocolViolation(ctx, "data frame using reserved opcode " + frameOpcode); return; } // check opcode vs message fragmentation state 1/2 if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) { protocolViolation(ctx, "received continuation data frame outside fragmented message"); return; } // check opcode vs message fragmentation state 2/2 if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT && frameOpcode != OPCODE_PING) { protocolViolation(ctx, "received non-continuation data frame while inside fragmented message"); return; } } state = State.READING_SIZE; case READING_SIZE: // Read frame payload length if (framePayloadLen1 == 126) { if (in.readableBytes() < 2) { return; } framePayloadLength = in.readUnsignedShort(); if (framePayloadLength < 126) { protocolViolation(ctx, "invalid data frame length (not using minimal length encoding)"); return; } } else if (framePayloadLen1 == 127) { if (in.readableBytes() < 8) { return; } framePayloadLength = in.readLong(); if (framePayloadLength < 65536) { protocolViolation(ctx, "invalid data frame length (not using minimal length encoding)"); return; } } else { framePayloadLength = framePayloadLen1; } // 超出范围 if (framePayloadLength > maxFramePayloadLength) { protocolViolation(ctx, "Max frame length of " + maxFramePayloadLength + " has been exceeded."); return; } state = State.MASKING_KEY; case MASKING_KEY: if (frameMasked) { if (in.readableBytes() < 4) { return; } if (maskingKey == null) { maskingKey = new byte[4]; } in.readBytes(maskingKey); } state = State.PAYLOAD; case PAYLOAD: if (in.readableBytes() < framePayloadLength) { return; } ByteBuf payloadBuffer = null; try { payloadBuffer = readBytes(ctx.alloc(), in, toFrameLength(framePayloadLength)); state = State.READING_FIRST; if (frameMasked) { unmask(payloadBuffer); // 掩码解析 } if (frameOpcode == OPCODE_PING) { out.add(new PingWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer)); payloadBuffer = null; return; } if (frameOpcode == OPCODE_PONG) { out.add(new PongWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer)); payloadBuffer = null; return; } if (frameOpcode == OPCODE_CLOSE) { receivedClosingHandshake = true; checkCloseFrameBody(ctx, payloadBuffer); out.add(new CloseWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer)); payloadBuffer = null; return; } if (frameFinalFlag) { if (frameOpcode != OPCODE_PING) { fragmentedFramesCount = 0; } } else { fragmentedFramesCount++; } if (frameOpcode == OPCODE_TEXT) { out.add(new TextWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer)); payloadBuffer = null; return; } else if (frameOpcode == OPCODE_BINARY) { out.add(new BinaryWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer)); payloadBuffer = null; return; } else if (frameOpcode == OPCODE_CONT) { out.add(new ContinuationWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer)); payloadBuffer = null; return; } else { throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: " + frameOpcode); } } finally { if (payloadBuffer != null) { payloadBuffer.release(); } } case CORRUPT: // 继续读数据 if (in.isReadable()) { in.readByte(); } return; default: throw new Error("Shouldn‘t reach here."); }}// 目前看不太懂明白这段代码,因为对ByteBuf不了解,有点类似Nio的Buffer// +-------------------+------------------+------------------+// | discardable bytes | readable bytes | writable bytes |// +-------------------+------------------+------------------+// | | | |// 0 <= readerIndex <= writerIndex <= capacity// 我觉得应该是一个个byte解析太慢了,于是netty直接用int(4个byte)同时解析,最后不满一个 // int才使用最原始的一个个byte解析。private void unmask(ByteBuf frame) { int i = frame.readerIndex(); int end = frame.writerIndex(); ByteOrder order = frame.order(); int intMask = ((maskingKey[0] & 0xFF) << 24) | ((maskingKey[1] & 0xFF) << 16) | ((maskingKey[2] & 0xFF) << 8) | (maskingKey[3] & 0xFF); if (order == ByteOrder.LITTLE_ENDIAN) { intMask = Integer.reverseBytes(intMask); } for (; i + 3 < end; i += 4) { frame.setInt(i, frame.getInt(i) ^ intMask); } for (; i < end; i++) { frame.setByte(i, frame.getByte(i) ^ maskingKey[i % 4]); }}
protected void encode(ChannelHandlerContext ctx, WebSocketFrame msg, List<Object> out) { final ByteBuf data = msg.content(); byte[] mask; byte opcode; if (msg instanceof TextWebSocketFrame) { opcode = OPCODE_TEXT; } else if (msg instanceof PingWebSocketFrame) { opcode = OPCODE_PING; } else if (msg instanceof PongWebSocketFrame) { opcode = OPCODE_PONG; } else if (msg instanceof CloseWebSocketFrame) { opcode = OPCODE_CLOSE; } else if (msg instanceof BinaryWebSocketFrame) { opcode = OPCODE_BINARY; } else if (msg instanceof ContinuationWebSocketFrame) { opcode = OPCODE_CONT; } else { throw new UnsupportedOperationException("Cannot encode frame of type: " + msg.getClass().getName()); } int length = data.readableBytes(); int b0 = 0; if (msg.isFinalFragment()) { b0 |= 1 << 7; } b0 |= msg.rsv() % 8 << 4; b0 |= opcode % 128; if (opcode == OPCODE_PING && length > 125) { throw new TooLongFrameException("invalid payload for PING (payload length must be <= 125, was " + length); } boolean release = true; ByteBuf buf = null; try { int maskLength = maskPayload ? 4 : 0; if (length <= 125) { int size = 2 + maskLength; if (maskPayload || length <= GATHERING_WRITE_THRESHOLD) { size += length; } buf = ctx.alloc().buffer(size); buf.writeByte(b0); byte b = (byte) (maskPayload ? 0x80 | (byte) length : (byte) length); buf.writeByte(b); } else if (length <= 0xFFFF) { int size = 4 + maskLength; if (maskPayload || length <= GATHERING_WRITE_THRESHOLD) { size += length; } buf = ctx.alloc().buffer(size); buf.writeByte(b0); buf.writeByte(maskPayload ? 0xFE : 126); buf.writeByte(length >>> 8 & 0xFF); buf.writeByte(length & 0xFF); } else { int size = 10 + maskLength; if (maskPayload || length <= GATHERING_WRITE_THRESHOLD) { size += length; } buf = ctx.alloc().buffer(size); buf.writeByte(b0); buf.writeByte(maskPayload ? 0xFF : 127); buf.writeLong(length); } // Write payload if (maskPayload) { int random = (int) (Math.random() * Integer.MAX_VALUE); mask = ByteBuffer.allocate(4).putInt(random).array(); buf.writeBytes(mask); ByteOrder srcOrder = data.order(); ByteOrder dstOrder = buf.order(); int counter = 0; int i = data.readerIndex(); int end = data.writerIndex(); if (srcOrder == dstOrder) { int intMask = ((mask[0] & 0xFF) << 24) | ((mask[1] & 0xFF) << 16) | ((mask[2] & 0xFF) << 8) | (mask[3] & 0xFF); if (srcOrder == ByteOrder.LITTLE_ENDIAN) { intMask = Integer.reverseBytes(intMask); } for (; i + 3 < end; i += 4) { int intData = data.getInt(i); buf.writeInt(intData ^ intMask); } } for (; i < end; i++) { byte byteData = data.getByte(i); buf.writeByte(byteData ^ mask[counter++ % 4]); } out.add(buf); } else { if (buf.writableBytes() >= data.readableBytes()) { buf.writeBytes(data); out.add(buf); } else { out.add(buf); out.add(data.retain()); } } release = false; } finally { if (release && buf != null) { buf.release(); } } }
<!DOCTYPE html><html lang="en"><head> <meta charset="UTF-8"> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta http-equiv="X-UA-Compatible" content="ie=edge"> <title>WebSocket</title></head><body> <div id="output"></div> <script> function setup(){ output = document.getElementById("output"); ws = new WebSocket("ws://localhost:9999/echo"); ws.onopen = function(e){ log(‘Connected‘); sendMessage(‘hello‘); } ws.onclose = function(e){ log("Disconnected:"+e.reason); } ws.onerror = function(e){ log("Error "); } ws.onmessage = function(e){ log("Message received:"+e.data); ws.close(); } } function sendMessage(msg){ ws.send(msg); log(‘Message sent‘); } function log(s){ var p = document.createElement(‘p‘); p.style.wordWrap = ‘break-word‘; p.textContent = s; output.appendChild(p); console.log(s); } setup(); </script></body></html>
参考: