diff --git a/pom.xml b/pom.xml index 9c4f82d5..c5af9229 100644 --- a/pom.xml +++ b/pom.xml @@ -52,6 +52,11 @@ Java-WebSocket 1.3.0 + + commons-codec + commons-codec + 1.10 + junit junit diff --git a/src/main/java/com/github/nkzawa/engineio/client/transports/Polling.java b/src/main/java/com/github/nkzawa/engineio/client/transports/Polling.java index 5fe520bc..3c75a923 100644 --- a/src/main/java/com/github/nkzawa/engineio/client/transports/Polling.java +++ b/src/main/java/com/github/nkzawa/engineio/client/transports/Polling.java @@ -177,7 +177,7 @@ public void run() { } }; - Parser.encodePayload(packets, new Parser.EncodeCallback() { + Parser.encodePayload(packets, true, new Parser.EncodeCallback() { @Override public void call(byte[] data) { self.doWrite(data, callbackfn); diff --git a/src/main/java/com/github/nkzawa/engineio/client/transports/WebSocket.java b/src/main/java/com/github/nkzawa/engineio/client/transports/WebSocket.java index f7a8c880..af6123dd 100644 --- a/src/main/java/com/github/nkzawa/engineio/client/transports/WebSocket.java +++ b/src/main/java/com/github/nkzawa/engineio/client/transports/WebSocket.java @@ -107,7 +107,7 @@ protected void write(Packet[] packets) { final WebSocket self = this; this.writable = false; for (Packet packet : packets) { - Parser.encodePacket(packet, new Parser.EncodeCallback() { + Parser.encodePacket(packet, true, new Parser.EncodeCallback() { @Override public void call(Object packet) { if (packet instanceof String) { diff --git a/src/main/java/com/github/nkzawa/engineio/parser/Parser.java b/src/main/java/com/github/nkzawa/engineio/parser/Parser.java index 04e0117f..4e9dfc6c 100644 --- a/src/main/java/com/github/nkzawa/engineio/parser/Parser.java +++ b/src/main/java/com/github/nkzawa/engineio/parser/Parser.java @@ -3,6 +3,7 @@ import com.github.nkzawa.utf8.UTF8; import com.github.nkzawa.utf8.UTF8Exception; +import org.apache.commons.codec.binary.Base64; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -42,13 +43,15 @@ public static void encodePacket(Packet packet, EncodeCallback callback) { encodePacket(packet, false, callback); } - public static void encodePacket(Packet packet, boolean utf8encode, EncodeCallback callback) { + public static void encodePacket(Packet packet, boolean supportsBinary, EncodeCallback callback) { + encodePacket(packet, supportsBinary, false, callback); + } + + public static void encodePacket(Packet packet, boolean supportsBinary, boolean utf8encode, EncodeCallback callback) { if (packet.data instanceof byte[]) { @SuppressWarnings("unchecked") Packet _packet = packet; - @SuppressWarnings("unchecked") - EncodeCallback _callback = callback; - encodeByteArray(_packet, _callback); + encodeByteArray(_packet, supportsBinary, callback); return; } @@ -63,26 +66,59 @@ public static void encodePacket(Packet packet, boolean utf8encode, EncodeCallbac _callback.call(encoded); } - private static void encodeByteArray(Packet packet, EncodeCallback callback) { + private static void encodeByteArray(Packet packet, boolean supportsBinary, EncodeCallback callback) { + if (!supportsBinary) { + @SuppressWarnings("unchecked") + EncodeCallback _callback = callback; + encodeBase64Packet(packet, _callback); + return; + } + byte[] data = packet.data; byte[] resultArray = new byte[1 + data.length]; resultArray[0] = packets.get(packet.type).byteValue(); System.arraycopy(data, 0, resultArray, 1, data.length); - callback.call(resultArray); + @SuppressWarnings("unchecked") + EncodeCallback _callback = callback; + _callback.call(resultArray); + } + + public static void encodeBase64Packet(Packet packet, EncodeCallback callback) { + byte[] data = packet.data; + String message = "b" + packets.get(packet.type); + message += Base64.encodeBase64String(data); + callback.call(message); + } + + public static Packet decodePacket(Object data) { + return decodePacket(data, false); } - public static Packet decodePacket(String data) { + public static Packet decodePacket(Object data, boolean utf8decode) { + if (data instanceof String) { + return decodePacket((String) data, utf8decode); + } else { + return decodePacket((byte[]) data); + } + } + + public static Packet decodePacket(String data) { return decodePacket(data, false); } - public static Packet decodePacket(String data, boolean utf8decode) { - int type; + public static Packet decodePacket(String data, boolean utf8decode) { + char charAt0; try { - type = Character.getNumericValue(data.charAt(0)); + charAt0 = data.charAt(0); } catch (IndexOutOfBoundsException e) { - type = -1; + charAt0 = '\u0000'; } + if (charAt0 == 'b') { + return decodeBase64Packet(data.substring(1)); + } + + int type = Character.getNumericValue(charAt0); if (utf8decode) { try { data = UTF8.decode(data); @@ -102,6 +138,12 @@ public static Packet decodePacket(String data, boolean utf8decode) { } } + public static Packet decodeBase64Packet(String msg) { + String type = packetslist.get(Character.getNumericValue(msg.charAt(0))); + byte[] data = Base64.decodeBase64(msg.substring(1)); + return new Packet(type, data); + } + public static Packet decodePacket(byte[] data) { int type = data[0]; byte[] intArray = new byte[data.length - 1]; @@ -109,7 +151,41 @@ public static Packet decodePacket(byte[] data) { return new Packet(packetslist.get(type), intArray); } - public static void encodePayload(Packet[] packets, EncodeCallback callback) { + public static void encodePayload(Packet[] packets, EncodeCallback callback) { + encodePayload(packets, false, callback); + } + + public static void encodePayload(Packet[] packets, boolean supportsBinary, EncodeCallback callback) { + if (supportsBinary) { + @SuppressWarnings("unchecked") + EncodeCallback _callback = callback; + encodePayloadAsBinary(packets, _callback); + return; + } + + @SuppressWarnings("unchecked") + EncodeCallback _callback = callback; + + if (packets.length == 0) { + _callback.call("0:"); + return; + } + + final StringBuilder results = new StringBuilder(); + + for (Packet packet : packets) { + encodePacket(packet, false, true, new EncodeCallback() { + @Override + public void call(String message) { + results.append(message.length()).append(":").append(message); + } + }); + } + + _callback.call(results.toString()); + } + + public static void encodePayloadAsBinary(Packet[] packets, EncodeCallback callback) { if (packets.length == 0) { callback.call(new byte[0]); return; @@ -118,7 +194,7 @@ public static void encodePayload(Packet[] packets, EncodeCallback callba final ArrayList results = new ArrayList(packets.length); for (Packet packet : packets) { - encodePacket(packet, true, new EncodeCallback() { + encodePacket(packet, true, true, new EncodeCallback() { @Override public void call(Object packet) { if (packet instanceof String) { @@ -149,7 +225,15 @@ public void call(Object packet) { callback.call(Buffer.concat(results.toArray(new byte[results.size()][]))); } - public static void decodePayload(String data, DecodePayloadCallback callback) { + public static void decodePayload(Object data, DecodePayloadCallback callback) { + if (data instanceof String) { + decodePayload((String)data, callback); + } else { + decodePayload((byte[])data, callback); + } + } + + public static void decodePayload(String data, DecodePayloadCallback callback) { if (data == null || data.length() == 0) { callback.call(err, 0, 1); return; @@ -179,7 +263,7 @@ public static void decodePayload(String data, DecodePayloadCallback call } if (msg.length() != 0) { - Packet packet = decodePacket(msg, true); + Packet packet = decodePacket(msg, true); if (err.type.equals(packet.type) && err.data.equals(packet.data)) { callback.call(err, 0, 1); return; @@ -246,9 +330,7 @@ public static void decodePayload(byte[] data, DecodePayloadCallback callback) { for (int i = 0; i < total; i++) { Object buffer = buffers.get(i); if (buffer instanceof String) { - @SuppressWarnings("unchecked") - DecodePayloadCallback _callback = callback; - _callback.call(decodePacket((String)buffer, true), i, total); + callback.call(decodePacket((String)buffer, true), i, total); } else if (buffer instanceof byte[]) { @SuppressWarnings("unchecked") DecodePayloadCallback _callback = callback; diff --git a/src/test/java/com/github/nkzawa/engineio/parser/ParserTest.java b/src/test/java/com/github/nkzawa/engineio/parser/ParserTest.java index 748b5b1c..4b3c6c0b 100644 --- a/src/test/java/com/github/nkzawa/engineio/parser/ParserTest.java +++ b/src/test/java/com/github/nkzawa/engineio/parser/ParserTest.java @@ -15,19 +15,19 @@ public class ParserTest { @Test public void encodeAsString() { - encodePacket(new Packet(Packet.MESSAGE, "test"), new EncodeCallback() { + encodePacket(new Packet(Packet.MESSAGE, "test"), new EncodeCallback() { @Override - public void call(String data) { - assertThat(data, isA(String.class)); + public void call(Object data) { + assertThat(data, instanceOf(String.class)); } }); } @Test public void decodeAsPacket() { - encodePacket(new Packet(Packet.MESSAGE, "test"), new EncodeCallback() { + encodePacket(new Packet(Packet.MESSAGE, "test"), new EncodeCallback() { @Override - public void call(String data) { + public void call(Object data) { assertThat(decodePacket(data), isA(Packet.class)); } }); @@ -35,9 +35,9 @@ public void call(String data) { @Test public void noData() { - encodePacket(new Packet(Packet.MESSAGE), new EncodeCallback() { + encodePacket(new Packet(Packet.MESSAGE), new EncodeCallback() { @Override - public void call(String data) { + public void call(Object data) { Packet p = decodePacket(data); assertThat(p.type, is(Packet.MESSAGE)); assertThat(p.data, is(nullValue())); @@ -47,9 +47,9 @@ public void call(String data) { @Test public void encodeOpenPacket() { - encodePacket(new Packet(Packet.OPEN, "{\"some\":\"json\"}"), new EncodeCallback() { + encodePacket(new Packet(Packet.OPEN, "{\"some\":\"json\"}"), new EncodeCallback() { @Override - public void call(String data) { + public void call(Object data) { Packet p = decodePacket(data); assertThat(p.type, is(Packet.OPEN)); assertThat(p.data, is("{\"some\":\"json\"}")); @@ -59,9 +59,9 @@ public void call(String data) { @Test public void encodeClosePacket() { - encodePacket(new Packet(Packet.CLOSE), new EncodeCallback() { + encodePacket(new Packet(Packet.CLOSE), new EncodeCallback() { @Override - public void call(String data) { + public void call(Object data) { Packet p = decodePacket(data); assertThat(p.type, is(Packet.CLOSE)); } @@ -70,9 +70,9 @@ public void call(String data) { @Test public void encodePingPacket() { - encodePacket(new Packet(Packet.PING, "1"), new EncodeCallback() { + encodePacket(new Packet(Packet.PING, "1"), new EncodeCallback() { @Override - public void call(String data) { + public void call(Object data) { Packet p = decodePacket(data); assertThat(p.type, is(Packet.PING)); assertThat(p.data, is("1")); @@ -82,9 +82,9 @@ public void call(String data) { @Test public void encodePongPacket() { - encodePacket(new Packet(Packet.PONG, "1"), new EncodeCallback() { + encodePacket(new Packet(Packet.PONG, "1"), new EncodeCallback() { @Override - public void call(String data) { + public void call(Object data) { Packet p = decodePacket(data); assertThat(p.type, is(Packet.PONG)); assertThat(p.data, is("1")); @@ -94,9 +94,9 @@ public void call(String data) { @Test public void encodeMessagePacket() { - encodePacket(new Packet(Packet.MESSAGE, "aaa"), new EncodeCallback() { + encodePacket(new Packet(Packet.MESSAGE, "aaa"), new EncodeCallback() { @Override - public void call(String data) { + public void call(Object data) { Packet p = decodePacket(data); assertThat(p.type, is(Packet.MESSAGE)); assertThat(p.data, is("aaa")); @@ -106,9 +106,9 @@ public void call(String data) { @Test public void encodeUTF8SpecialCharsMessagePacket() { - encodePacket(new Packet(Packet.MESSAGE, "utf8 — string"), new EncodeCallback() { + encodePacket(new Packet(Packet.MESSAGE, "utf8 — string"), new EncodeCallback() { @Override - public void call(String data) { + public void call(Object data) { Packet p = decodePacket(data); assertThat(p.type, is(Packet.MESSAGE)); assertThat(p.data, is("utf8 — string")); @@ -118,9 +118,9 @@ public void call(String data) { @Test public void encodeMessagePacketCoercingToString() { - encodePacket(new Packet(Packet.MESSAGE, 1), new EncodeCallback() { + encodePacket(new Packet(Packet.MESSAGE, 1), new EncodeCallback() { @Override - public void call(String data) { + public void call(Object data) { Packet p = decodePacket(data); assertThat(p.type, is(Packet.MESSAGE)); assertThat(p.data, is("1")); @@ -130,9 +130,9 @@ public void call(String data) { @Test public void encodeUpgradePacket() { - encodePacket(new Packet(Packet.UPGRADE), new EncodeCallback() { + encodePacket(new Packet(Packet.UPGRADE), new EncodeCallback() { @Override - public void call(String data) { + public void call(Object data) { Packet p = decodePacket(data); assertThat(p.type, is(Packet.UPGRADE)); } @@ -177,20 +177,20 @@ public void decodeInvalidUTF8() { } @Test - public void encodePayloads() { - encodePayload(new Packet[]{new Packet(Packet.PING), new Packet(Packet.PONG)}, new EncodeCallback() { + public void encodePayloadsAsString() { + encodePayload(new Packet[]{new Packet(Packet.PING), new Packet(Packet.PONG)}, new EncodeCallback() { @Override - public void call(byte[] data) { - assertThat(data, isA(byte[].class)); + public void call(Object data) { + assertThat(data, instanceOf(String.class)); } }); } @Test public void encodeAndDecodePayloads() { - encodePayload(new Packet[] {new Packet(Packet.MESSAGE, "a")}, new EncodeCallback() { + encodePayload(new Packet[] {new Packet(Packet.MESSAGE, "a")}, new EncodeCallback() { @Override - public void call(byte[] data) { + public void call(Object data) { decodePayload(data, new DecodePayloadCallback() { @Override public boolean call(Packet packet, int index, int total) { @@ -201,9 +201,9 @@ public boolean call(Packet packet, int index, int total) { }); } }); - encodePayload(new Packet[]{new Packet(Packet.MESSAGE, "a"), new Packet(Packet.PING)}, new EncodeCallback() { + encodePayload(new Packet[]{new Packet(Packet.MESSAGE, "a"), new Packet(Packet.PING)}, new EncodeCallback() { @Override - public void call(byte[] data) { + public void call(Object data) { decodePayload(data, new DecodePayloadCallback() { @Override public boolean call(Packet packet, int index, int total) { @@ -222,9 +222,9 @@ public boolean call(Packet packet, int index, int total) { @Test public void encodeAndDecodeEmptyPayloads() { - encodePayload(new Packet[] {}, new EncodeCallback() { + encodePayload(new Packet[] {}, new EncodeCallback() { @Override - public void call(byte[] data) { + public void call(Object data) { decodePayload(data, new DecodePayloadCallback() { @Override public boolean call(Packet packet, int index, int total) { @@ -340,9 +340,9 @@ public void encodeBinaryMessage() { for (int i = 0; i < data.length; i++) { data[0] = (byte)i; } - encodePacket(new Packet(Packet.MESSAGE, data), new EncodeCallback() { + encodePacket(new Packet(Packet.MESSAGE, data), new EncodeCallback() { @Override - public void call(byte[] encoded) { + public void call(Object encoded) { Packet p = decodePacket(encoded); assertThat(p.type, is(Packet.MESSAGE)); assertThat(p.data, is(data)); @@ -364,9 +364,9 @@ public void encodeBinaryContents() { encodePayload(new Packet[]{ new Packet(Packet.MESSAGE, firstBuffer), new Packet(Packet.MESSAGE, secondBuffer), - }, new EncodeCallback() { + }, new EncodeCallback() { @Override - public void call(byte[] data) { + public void call(Object data) { decodePayload(data, new DecodePayloadCallback() { @Override public boolean call(Packet packet, int index, int total) { @@ -385,12 +385,41 @@ public boolean call(Packet packet, int index, int total) { } @Test - public void encodeMixedBinaryAndStringContents() { + public void encodeMixedBinaryAndStringContentsAsBase64() { + final byte[] data = new byte[5]; + for (int i = 0 ; i < data.length; i++) { + data[i] = (byte)i; + } + encodePayload(new Packet[]{ + new Packet(Packet.MESSAGE, data), + new Packet(Packet.MESSAGE, "hello") + }, new EncodeCallback() { + @Override + public void call(Object encoded) { + decodePayload(encoded, new DecodePayloadCallback() { + @Override + public boolean call(Packet packet, int index, int total) { + boolean isLast = index + 1 == total; + assertThat(packet.type, is(Packet.MESSAGE)); + if (!isLast) { + assertThat((byte[])packet.data, is(data)); + } else { + assertThat((String)packet.data, is("hello")); + } + return true; + } + }); + } + }); + } + + @Test + public void encodeMixedBinaryAndStringContentsAsBinary() { final byte[] firstBuffer = new byte[123]; for (int i = 0 ; i < firstBuffer.length; i++) { - firstBuffer[0] = (byte)i; + firstBuffer[i] = (byte)i; } - encodePayload(new Packet[]{ + encodePayloadAsBinary(new Packet[]{ new Packet(Packet.MESSAGE, firstBuffer), new Packet(Packet.MESSAGE, "hello"), new Packet(Packet.CLOSE),