package fi.iki.elonen;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.Charset;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CharsetEncoder;
import java.util.Arrays;
import java.util.List;
public class WebSocketFrame {
private OpCode opCode;
private boolean fin;
private byte[] maskingKey;
private byte[] payload;
private transient int _payloadLength;
private transient String _payloadString;
private WebSocketFrame(OpCode opCode, boolean fin) {
setOpCode(opCode);
setFin(fin);
}
public WebSocketFrame(OpCode opCode, boolean fin, byte[] payload, byte[] maskingKey) {
this(opCode, fin);
setMaskingKey(maskingKey);
setBinaryPayload(payload);
}
public WebSocketFrame(OpCode opCode, boolean fin, byte[] payload) {
this(opCode, fin, payload, null);
}
public WebSocketFrame(OpCode opCode, boolean fin, String payload, byte[] maskingKey) throws CharacterCodingException {
this(opCode, fin);
setMaskingKey(maskingKey);
setTextPayload(payload);
}
public WebSocketFrame(OpCode opCode, boolean fin, String payload) throws CharacterCodingException {
this(opCode, fin, payload, null);
}
public WebSocketFrame(WebSocketFrame clone) {
setOpCode(clone.getOpCode());
setFin(clone.isFin());
setBinaryPayload(clone.getBinaryPayload());
setMaskingKey(clone.getMaskingKey());
}
public WebSocketFrame(OpCode opCode, List<WebSocketFrame> fragments) throws WebSocketException {
setOpCode(opCode);
setFin(true);
long _payloadLength = 0;
for (WebSocketFrame inter : fragments) {
_payloadLength += inter.getBinaryPayload().length;
}
if (_payloadLength < 0 || _payloadLength > Integer.MAX_VALUE) {
throw new WebSocketException(CloseCode.MessageTooBig, "Max frame length has been exceeded.");
}
this._payloadLength = (int) _payloadLength;
byte[] payload = new byte[this._payloadLength];
int offset = 0;
for (WebSocketFrame inter : fragments) {
System.arraycopy(inter.getBinaryPayload(), 0, payload, offset, inter.getBinaryPayload().length);
offset += inter.getBinaryPayload().length;
}
setBinaryPayload(payload);
}
// --------------------------------GETTERS---------------------------------
public OpCode getOpCode() {
return opCode;
}
public void setOpCode(OpCode opcode) {
this.opCode = opcode;
}
public boolean isFin() {
return fin;
}
public void setFin(boolean fin) {
this.fin = fin;
}
public boolean isMasked() {
return maskingKey != null && maskingKey.length == 4;
}
public byte[] getMaskingKey() {
return maskingKey;
}
public void setMaskingKey(byte[] maskingKey) {
if (maskingKey != null && maskingKey.length != 4) {
throw new IllegalArgumentException("MaskingKey " + Arrays.toString(maskingKey) + " hasn't length 4");
}
this.maskingKey = maskingKey;
}
public void setUnmasked() {
setMaskingKey(null);
}
public byte[] getBinaryPayload() {
return payload;
}
public void setBinaryPayload(byte[] payload) {
this.payload = payload;
this._payloadLength = payload.length;
this._payloadString = null;
}
public String getTextPayload() {
if (_payloadString == null) {
try {
_payloadString = binary2Text(getBinaryPayload());
} catch (CharacterCodingException e) {
throw new RuntimeException("Undetected CharacterCodingException", e);
}
}
return _payloadString;
}
public void setTextPayload(String payload) throws CharacterCodingException {
this.payload = text2Binary(payload);
this._payloadLength = payload.length();
this._payloadString = payload;
}
// --------------------------------SERIALIZATION---------------------------
public static WebSocketFrame read(InputStream in) throws IOException {
byte head = (byte) checkedRead(in.read());
boolean fin = ((head & 0x80) != 0);
OpCode opCode = OpCode.find((byte) (head & 0x0F));
if ((head & 0x70) != 0) {
throw new WebSocketException(CloseCode.ProtocolError, "The reserved bits (" + Integer.toBinaryString(head & 0x70) + ") must be 0.");
}
if (opCode == null) {
throw new WebSocketException(CloseCode.ProtocolError, "Received frame with reserved/unknown opcode " + (head & 0x0F) + ".");
} else if (opCode.isControlFrame() && !fin) {
throw new WebSocketException(CloseCode.ProtocolError, "Fragmented control frame.");
}
WebSocketFrame frame = new WebSocketFrame(opCode, fin);
frame.readPayloadInfo(in);
frame.readPayload(in);
if (frame.getOpCode() == OpCode.Close) {
return new CloseFrame(frame);
} else {
return frame;
}
}
private static int checkedRead(int read) throws IOException {
if (read < 0) {
throw new EOFException();
}
//System.out.println(Integer.toBinaryString(read) + "/" + read + "/" + Integer.toHexString(read));
return read;
}
private void readPayloadInfo(InputStream in) throws IOException {
byte b = (byte) checkedRead(in.read());
boolean masked = ((b & 0x80) != 0);
_payloadLength = (byte) (0x7F & b);
if (_payloadLength == 126) {
// checkedRead must return int for this to work
_payloadLength = (checkedRead(in.read()) << 8 | checkedRead(in.read())) & 0xFFFF;
if (_payloadLength < 126) {
throw new WebSocketException(CloseCode.ProtocolError, "Invalid data frame 2byte length. (not using minimal length encoding)");
}
} else if (_payloadLength == 127) {
long _payloadLength = ((long) checkedRead(in.read())) << 56 |
((long) checkedRead(in.read())) << 48 |
((long) checkedRead(in.read())) << 40 |
((long) checkedRead(in.read())) << 32 |
checkedRead(in.read()) << 24 | checkedRead(in.read()) << 16 | checkedRead(in.read()) << 8 | checkedRead(in.read());
if (_payloadLength < 65536) {
throw new WebSocketException(CloseCode.ProtocolError, "Invalid data frame 4byte length. (not using minimal length encoding)");
}
if (_payloadLength < 0 || _payloadLength > Integer.MAX_VALUE) {
throw new WebSocketException(CloseCode.MessageTooBig, "Max frame length has been exceeded.");
}
this._payloadLength = (int) _payloadLength;
}
if (opCode.isControlFrame()) {
if (_payloadLength > 125) {
throw new WebSocketException(CloseCode.ProtocolError, "Control frame with payload length > 125 bytes.");
}
if (opCode == OpCode.Close && _payloadLength == 1) {
throw new WebSocketException(CloseCode.ProtocolError, "Received close frame with payload len 1.");
}
}
if (masked) {
maskingKey = new byte[4];
int read = 0;
while (read < maskingKey.length) {
read += checkedRead(in.read(maskingKey, read, maskingKey.length - read));
}
}
}
private void readPayload(InputStream in) throws IOException {
payload = new byte[_payloadLength];
int read = 0;
while (read < _payloadLength) {
read += checkedRead(in.read(payload, read, _payloadLength - read));
}
if (isMasked()) {
for (int i = 0; i < payload.length; i++) {
payload[i] ^= maskingKey[i % 4];
}
}
//Test for Unicode errors
if (getOpCode() == OpCode.Text) {
_payloadString = binary2Text(getBinaryPayload());
}
}
public void write(OutputStream out) throws IOException {
byte header = 0;
if (fin) {
header |= 0x80;
}
header |= opCode.getValue() & 0x0F;
out.write(header);
_payloadLength = getBinaryPayload().length;
if (_payloadLength <= 125) {
out.write(isMasked() ? 0x80 | (byte) _payloadLength : (byte) _payloadLength);
} else if (_payloadLength <= 0xFFFF) {
out.write(isMasked() ? 0xFE : 126);
out.write(_payloadLength >>> 8);
out.write(_payloadLength);
} else {
out.write(isMasked() ? 0xFF : 127);
out.write(_payloadLength >>> 56 & 0); //integer only contains 31 bit
out.write(_payloadLength >>> 48 & 0);
out.write(_payloadLength >>> 40 & 0);
out.write(_payloadLength >>> 32 & 0);
out.write(_payloadLength >>> 24);
out.write(_payloadLength >>> 16);
out.write(_payloadLength >>> 8);
out.write(_payloadLength);
}
if (isMasked()) {
out.write(maskingKey);
for (int i = 0; i < _payloadLength; i++) {
out.write(getBinaryPayload()[i] ^ maskingKey[i % 4]);
}
} else {
out.write(getBinaryPayload());
}
out.flush();
}
// --------------------------------ENCODING--------------------------------
public static final Charset TEXT_CHARSET = Charset.forName("UTF-8");
public static final CharsetDecoder TEXT_DECODER = TEXT_CHARSET.newDecoder();
public static final CharsetEncoder TEXT_ENCODER = TEXT_CHARSET.newEncoder();
public static String binary2Text(byte[] payload) throws CharacterCodingException {
return TEXT_DECODER.decode(ByteBuffer.wrap(payload)).toString();
}
public static String binary2Text(byte[] payload, int offset, int length) throws CharacterCodingException {
return TEXT_DECODER.decode(ByteBuffer.wrap(payload, offset, length)).toString();
}
public static byte[] text2Binary(String payload) throws CharacterCodingException {
return TEXT_ENCODER.encode(CharBuffer.wrap(payload)).array();
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder("WS[");
sb.append(getOpCode());
sb.append(", ").append(isFin() ? "fin" : "inter");
sb.append(", ").append(isMasked() ? "masked" : "unmasked");
sb.append(", ").append(payloadToString());
sb.append(']');
return sb.toString();
}
protected String payloadToString() {
if (payload == null) return "null";
else {
final StringBuilder sb = new StringBuilder();
sb.append('[').append(payload.length).append("b] ");
if (getOpCode() == OpCode.Text) {
String text = getTextPayload();
if (text.length() > 100)
sb.append(text.substring(0, 100)).append("...");
else
sb.append(text);
} else {
sb.append("0x");
for (int i = 0; i < Math.min(payload.length, 50); ++i)
sb.append(Integer.toHexString((int) payload[i] & 0xFF));
if (payload.length > 50)
sb.append("...");
}
return sb.toString();
}
}
// --------------------------------CONSTANTS-------------------------------
public static enum OpCode {
Continuation(0), Text(1), Binary(2), Close(8), Ping(9), Pong(10);
private final byte code;
private OpCode(int code) {
this.code = (byte) code;
}
public byte getValue() {
return code;
}
public boolean isControlFrame() {
return this == Close || this == Ping || this == Pong;
}
public static OpCode find(byte value) {
for (OpCode opcode : values()) {
if (opcode.getValue() == value) {
return opcode;
}
}
return null;
}
}
public static enum CloseCode {
NormalClosure(1000), GoingAway(1001), ProtocolError(1002), UnsupportedData(1003), NoStatusRcvd(1005),
AbnormalClosure(1006), InvalidFramePayloadData(1007), PolicyViolation(1008), MessageTooBig(1009),
MandatoryExt(1010), InternalServerError(1011), TLSHandshake(1015);
private final int code;
private CloseCode(int code) {
this.code = code;
}
public int getValue() {
return code;
}
public static CloseCode find(int value) {
for (CloseCode code : values()) {
if (code.getValue() == value) {
return code;
}
}
return null;
}
}
// ------------------------------------------------------------------------
public static class CloseFrame extends WebSocketFrame {
private CloseCode _closeCode;
private String _closeReason;
private CloseFrame(WebSocketFrame wrap) throws CharacterCodingException {
super(wrap);
assert wrap.getOpCode() == OpCode.Close;
if (wrap.getBinaryPayload().length >= 2) {
_closeCode = CloseCode.find((wrap.getBinaryPayload()[0] & 0xFF) << 8 |
(wrap.getBinaryPayload()[1] & 0xFF));
_closeReason = binary2Text(getBinaryPayload(), 2, getBinaryPayload().length - 2);
}
}
public CloseFrame(CloseCode code, String closeReason) throws CharacterCodingException {
super(OpCode.Close, true, generatePayload(code, closeReason));
}
private static byte[] generatePayload(CloseCode code, String closeReason) throws CharacterCodingException {
if (code != null) {
byte[] reasonBytes = text2Binary(closeReason);
byte[] payload = new byte[reasonBytes.length + 2];
payload[0] = (byte) ((code.getValue() >> 8) & 0xFF);
payload[1] = (byte) ((code.getValue()) & 0xFF);
System.arraycopy(reasonBytes, 0, payload, 2, reasonBytes.length);
return payload;
} else {
return new byte[0];
}
}
protected String payloadToString() {
return (_closeCode != null ? _closeCode : "UnknownCloseCode[" + _closeCode + "]") + (_closeReason != null && !_closeReason.isEmpty() ? ": " + _closeReason : "");
}
public CloseCode getCloseCode() {
return _closeCode;
}
public String getCloseReason() {
return _closeReason;
}
}
}