Java程序  |  431行  |  14.83 KB

package fi.iki.elonen;

import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.Charset;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CharsetEncoder;
import java.util.Arrays;
import java.util.List;

public class WebSocketFrame {
    private OpCode opCode;
    private boolean fin;
    private byte[] maskingKey;

    private byte[] payload;

    private transient int _payloadLength;
    private transient String _payloadString;

    private WebSocketFrame(OpCode opCode, boolean fin) {
        setOpCode(opCode);
        setFin(fin);
    }

    public WebSocketFrame(OpCode opCode, boolean fin, byte[] payload, byte[] maskingKey) {
        this(opCode, fin);
        setMaskingKey(maskingKey);
        setBinaryPayload(payload);
    }

    public WebSocketFrame(OpCode opCode, boolean fin, byte[] payload) {
        this(opCode, fin, payload, null);
    }

    public WebSocketFrame(OpCode opCode, boolean fin, String payload, byte[] maskingKey) throws CharacterCodingException {
        this(opCode, fin);
        setMaskingKey(maskingKey);
        setTextPayload(payload);
    }

    public WebSocketFrame(OpCode opCode, boolean fin, String payload) throws CharacterCodingException {
        this(opCode, fin, payload, null);
    }

    public WebSocketFrame(WebSocketFrame clone) {
        setOpCode(clone.getOpCode());
        setFin(clone.isFin());
        setBinaryPayload(clone.getBinaryPayload());
        setMaskingKey(clone.getMaskingKey());
    }

    public WebSocketFrame(OpCode opCode, List<WebSocketFrame> fragments) throws WebSocketException {
        setOpCode(opCode);
        setFin(true);

        long _payloadLength = 0;
        for (WebSocketFrame inter : fragments) {
            _payloadLength += inter.getBinaryPayload().length;
        }
        if (_payloadLength < 0 || _payloadLength > Integer.MAX_VALUE) {
            throw new WebSocketException(CloseCode.MessageTooBig, "Max frame length has been exceeded.");
        }
        this._payloadLength = (int) _payloadLength;
        byte[] payload = new byte[this._payloadLength];
        int offset = 0;
        for (WebSocketFrame inter : fragments) {
            System.arraycopy(inter.getBinaryPayload(), 0, payload, offset, inter.getBinaryPayload().length);
            offset += inter.getBinaryPayload().length;
        }
        setBinaryPayload(payload);
    }

    // --------------------------------GETTERS---------------------------------

    public OpCode getOpCode() {
        return opCode;
    }

    public void setOpCode(OpCode opcode) {
        this.opCode = opcode;
    }

    public boolean isFin() {
        return fin;
    }

    public void setFin(boolean fin) {
        this.fin = fin;
    }

    public boolean isMasked() {
        return maskingKey != null && maskingKey.length == 4;
    }

    public byte[] getMaskingKey() {
        return maskingKey;
    }

    public void setMaskingKey(byte[] maskingKey) {
        if (maskingKey != null && maskingKey.length != 4) {
            throw new IllegalArgumentException("MaskingKey " + Arrays.toString(maskingKey) + " hasn't length 4");
        }
        this.maskingKey = maskingKey;
    }

    public void setUnmasked() {
        setMaskingKey(null);
    }

    public byte[] getBinaryPayload() {
        return payload;
    }

    public void setBinaryPayload(byte[] payload) {
        this.payload = payload;
        this._payloadLength = payload.length;
        this._payloadString = null;
    }

    public String getTextPayload() {
        if (_payloadString == null) {
            try {
                _payloadString = binary2Text(getBinaryPayload());
            } catch (CharacterCodingException e) {
                throw new RuntimeException("Undetected CharacterCodingException", e);
            }
        }
        return _payloadString;
    }

    public void setTextPayload(String payload) throws CharacterCodingException {
        this.payload = text2Binary(payload);
        this._payloadLength = payload.length();
        this._payloadString = payload;
    }

    // --------------------------------SERIALIZATION---------------------------

    public static WebSocketFrame read(InputStream in) throws IOException {
        byte head = (byte) checkedRead(in.read());
        boolean fin = ((head & 0x80) != 0);
        OpCode opCode = OpCode.find((byte) (head & 0x0F));
        if ((head & 0x70) != 0) {
            throw new WebSocketException(CloseCode.ProtocolError, "The reserved bits (" + Integer.toBinaryString(head & 0x70) + ") must be 0.");
        }
        if (opCode == null) {
            throw new WebSocketException(CloseCode.ProtocolError, "Received frame with reserved/unknown opcode " + (head & 0x0F) + ".");
        } else if (opCode.isControlFrame() && !fin) {
            throw new WebSocketException(CloseCode.ProtocolError, "Fragmented control frame.");
        }

        WebSocketFrame frame = new WebSocketFrame(opCode, fin);
        frame.readPayloadInfo(in);
        frame.readPayload(in);
        if (frame.getOpCode() == OpCode.Close) {
            return new CloseFrame(frame);
        } else {
            return frame;
        }
    }

    private static int checkedRead(int read) throws IOException {
        if (read < 0) {
            throw new EOFException();
        }
        //System.out.println(Integer.toBinaryString(read) + "/" + read + "/" + Integer.toHexString(read));
        return read;
    }


    private void readPayloadInfo(InputStream in) throws IOException {
        byte b = (byte) checkedRead(in.read());
        boolean masked = ((b & 0x80) != 0);

        _payloadLength = (byte) (0x7F & b);
        if (_payloadLength == 126) {
            // checkedRead must return int for this to work
            _payloadLength = (checkedRead(in.read()) << 8 | checkedRead(in.read())) & 0xFFFF;
            if (_payloadLength < 126) {
                throw new WebSocketException(CloseCode.ProtocolError, "Invalid data frame 2byte length. (not using minimal length encoding)");
            }
        } else if (_payloadLength == 127) {
            long _payloadLength = ((long) checkedRead(in.read())) << 56 |
                    ((long) checkedRead(in.read())) << 48 |
                    ((long) checkedRead(in.read())) << 40 |
                    ((long) checkedRead(in.read())) << 32 |
                    checkedRead(in.read()) << 24 | checkedRead(in.read()) << 16 | checkedRead(in.read()) << 8 | checkedRead(in.read());
            if (_payloadLength < 65536) {
                throw new WebSocketException(CloseCode.ProtocolError, "Invalid data frame 4byte length. (not using minimal length encoding)");
            }
            if (_payloadLength < 0 || _payloadLength > Integer.MAX_VALUE) {
                throw new WebSocketException(CloseCode.MessageTooBig, "Max frame length has been exceeded.");
            }
            this._payloadLength = (int) _payloadLength;
        }

        if (opCode.isControlFrame()) {
            if (_payloadLength > 125) {
                throw new WebSocketException(CloseCode.ProtocolError, "Control frame with payload length > 125 bytes.");
            }
            if (opCode == OpCode.Close && _payloadLength == 1) {
                throw new WebSocketException(CloseCode.ProtocolError, "Received close frame with payload len 1.");
            }
        }

        if (masked) {
            maskingKey = new byte[4];
            int read = 0;
            while (read < maskingKey.length) {
                read += checkedRead(in.read(maskingKey, read, maskingKey.length - read));
            }
        }
    }

    private void readPayload(InputStream in) throws IOException {
        payload = new byte[_payloadLength];
        int read = 0;
        while (read < _payloadLength) {
            read += checkedRead(in.read(payload, read, _payloadLength - read));
        }

        if (isMasked()) {
            for (int i = 0; i < payload.length; i++) {
                payload[i] ^= maskingKey[i % 4];
            }
        }

        //Test for Unicode errors
        if (getOpCode() == OpCode.Text) {
            _payloadString = binary2Text(getBinaryPayload());
        }
    }

    public void write(OutputStream out) throws IOException {
        byte header = 0;
        if (fin) {
            header |= 0x80;
        }
        header |= opCode.getValue() & 0x0F;
        out.write(header);

        _payloadLength = getBinaryPayload().length;
        if (_payloadLength <= 125) {
            out.write(isMasked() ? 0x80 | (byte) _payloadLength : (byte) _payloadLength);
        } else if (_payloadLength <= 0xFFFF) {
            out.write(isMasked() ? 0xFE : 126);
            out.write(_payloadLength >>> 8);
            out.write(_payloadLength);
        } else {
            out.write(isMasked() ? 0xFF : 127);
            out.write(_payloadLength >>> 56 & 0); //integer only contains 31 bit
            out.write(_payloadLength >>> 48 & 0);
            out.write(_payloadLength >>> 40 & 0);
            out.write(_payloadLength >>> 32 & 0);
            out.write(_payloadLength >>> 24);
            out.write(_payloadLength >>> 16);
            out.write(_payloadLength >>> 8);
            out.write(_payloadLength);
        }


        if (isMasked()) {
            out.write(maskingKey);
            for (int i = 0; i < _payloadLength; i++) {
                out.write(getBinaryPayload()[i] ^ maskingKey[i % 4]);
            }
        } else {
            out.write(getBinaryPayload());
        }
        out.flush();
    }

    // --------------------------------ENCODING--------------------------------

    public static final Charset TEXT_CHARSET = Charset.forName("UTF-8");
    public static final CharsetDecoder TEXT_DECODER = TEXT_CHARSET.newDecoder();
    public static final CharsetEncoder TEXT_ENCODER = TEXT_CHARSET.newEncoder();


    public static String binary2Text(byte[] payload) throws CharacterCodingException {
        return TEXT_DECODER.decode(ByteBuffer.wrap(payload)).toString();
    }

    public static String binary2Text(byte[] payload, int offset, int length) throws CharacterCodingException {
        return TEXT_DECODER.decode(ByteBuffer.wrap(payload, offset, length)).toString();
    }

    public static byte[] text2Binary(String payload) throws CharacterCodingException {
        return TEXT_ENCODER.encode(CharBuffer.wrap(payload)).array();
    }

    @Override
    public String toString() {
        final StringBuilder sb = new StringBuilder("WS[");
        sb.append(getOpCode());
        sb.append(", ").append(isFin() ? "fin" : "inter");
        sb.append(", ").append(isMasked() ? "masked" : "unmasked");
        sb.append(", ").append(payloadToString());
        sb.append(']');
        return sb.toString();
    }

    protected String payloadToString() {
        if (payload == null) return "null";
        else {
            final StringBuilder sb = new StringBuilder();
            sb.append('[').append(payload.length).append("b] ");
            if (getOpCode() == OpCode.Text) {
                String text = getTextPayload();
                if (text.length() > 100)
                    sb.append(text.substring(0, 100)).append("...");
                else
                    sb.append(text);
            } else {
                sb.append("0x");
                for (int i = 0; i < Math.min(payload.length, 50); ++i)
                    sb.append(Integer.toHexString((int) payload[i] & 0xFF));
                if (payload.length > 50)
                    sb.append("...");
            }
            return sb.toString();
        }
    }

    // --------------------------------CONSTANTS-------------------------------

    public static enum OpCode {
        Continuation(0), Text(1), Binary(2), Close(8), Ping(9), Pong(10);

        private final byte code;

        private OpCode(int code) {
            this.code = (byte) code;
        }

        public byte getValue() {
            return code;
        }

        public boolean isControlFrame() {
            return this == Close || this == Ping || this == Pong;
        }

        public static OpCode find(byte value) {
            for (OpCode opcode : values()) {
                if (opcode.getValue() == value) {
                    return opcode;
                }
            }
            return null;
        }
    }

    public static enum CloseCode {
        NormalClosure(1000), GoingAway(1001), ProtocolError(1002), UnsupportedData(1003), NoStatusRcvd(1005),
        AbnormalClosure(1006), InvalidFramePayloadData(1007), PolicyViolation(1008), MessageTooBig(1009),
        MandatoryExt(1010), InternalServerError(1011), TLSHandshake(1015);

        private final int code;

        private CloseCode(int code) {
            this.code = code;
        }

        public int getValue() {
            return code;
        }

        public static CloseCode find(int value) {
            for (CloseCode code : values()) {
                if (code.getValue() == value) {
                    return code;
                }
            }
            return null;
        }
    }

    // ------------------------------------------------------------------------

    public static class CloseFrame extends WebSocketFrame {
        private CloseCode _closeCode;
        private String _closeReason;

        private CloseFrame(WebSocketFrame wrap) throws CharacterCodingException {
            super(wrap);
            assert wrap.getOpCode() == OpCode.Close;
            if (wrap.getBinaryPayload().length >= 2) {
                _closeCode = CloseCode.find((wrap.getBinaryPayload()[0] & 0xFF) << 8 |
                        (wrap.getBinaryPayload()[1] & 0xFF));
                _closeReason = binary2Text(getBinaryPayload(), 2, getBinaryPayload().length - 2);
            }
        }

        public CloseFrame(CloseCode code, String closeReason) throws CharacterCodingException {
            super(OpCode.Close, true, generatePayload(code, closeReason));
        }

        private static byte[] generatePayload(CloseCode code, String closeReason) throws CharacterCodingException {
            if (code != null) {
                byte[] reasonBytes = text2Binary(closeReason);
                byte[] payload = new byte[reasonBytes.length + 2];
                payload[0] = (byte) ((code.getValue() >> 8) & 0xFF);
                payload[1] = (byte) ((code.getValue()) & 0xFF);
                System.arraycopy(reasonBytes, 0, payload, 2, reasonBytes.length);
                return payload;
            } else {
                return new byte[0];
            }
        }

        protected String payloadToString() {
            return (_closeCode != null ? _closeCode : "UnknownCloseCode[" + _closeCode + "]") + (_closeReason != null && !_closeReason.isEmpty() ? ": " + _closeReason : "");
        }

        public CloseCode getCloseCode() {
            return _closeCode;
        }

        public String getCloseReason() {
            return _closeReason;
        }
    }
}