// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/websockets/websocket_inflater.h" #include <stdint.h> #include <string> #include <vector> #include "base/memory/ref_counted.h" #include "net/base/io_buffer.h" #include "net/websockets/websocket_deflater.h" #include "net/websockets/websocket_test_util.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { namespace { std::string ToString(IOBufferWithSize* buffer) { return std::string(buffer->data(), buffer->size()); } TEST(WebSocketInflaterTest, Construct) { WebSocketInflater inflater; ASSERT_TRUE(inflater.Initialize(15)); EXPECT_EQ(0u, inflater.CurrentOutputSize()); } TEST(WebSocketInflaterTest, InflateHelloTakeOverContext) { WebSocketInflater inflater; ASSERT_TRUE(inflater.Initialize(15)); scoped_refptr<IOBufferWithSize> actual1, actual2; ASSERT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9\xc9\x07\x00", 7)); ASSERT_TRUE(inflater.Finish()); actual1 = inflater.GetOutput(inflater.CurrentOutputSize()); ASSERT_TRUE(actual1); EXPECT_EQ("Hello", ToString(actual1.get())); EXPECT_EQ(0u, inflater.CurrentOutputSize()); ASSERT_TRUE(inflater.AddBytes("\xf2\x00\x11\x00\x00", 5)); ASSERT_TRUE(inflater.Finish()); actual2 = inflater.GetOutput(inflater.CurrentOutputSize()); ASSERT_TRUE(actual2); EXPECT_EQ("Hello", ToString(actual2.get())); EXPECT_EQ(0u, inflater.CurrentOutputSize()); } TEST(WebSocketInflaterTest, InflateHelloSmallCapacity) { WebSocketInflater inflater(1, 1); ASSERT_TRUE(inflater.Initialize(15)); std::string actual; ASSERT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9\xc9\x07\x00", 7)); ASSERT_TRUE(inflater.Finish()); for (size_t i = 0; i < 5; ++i) { ASSERT_EQ(1u, inflater.CurrentOutputSize()); scoped_refptr<IOBufferWithSize> buffer = inflater.GetOutput(1); ASSERT_TRUE(buffer); ASSERT_EQ(1, buffer->size()); actual += ToString(buffer.get()); } EXPECT_EQ("Hello", actual); EXPECT_EQ(0u, inflater.CurrentOutputSize()); } TEST(WebSocketInflaterTest, InflateHelloSmallCapacityGetTotalOutput) { WebSocketInflater inflater(1, 1); ASSERT_TRUE(inflater.Initialize(15)); scoped_refptr<IOBufferWithSize> actual; ASSERT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9\xc9\x07\x00", 7)); ASSERT_TRUE(inflater.Finish()); ASSERT_EQ(1u, inflater.CurrentOutputSize()); actual = inflater.GetOutput(1024); EXPECT_EQ("Hello", ToString(actual)); EXPECT_EQ(0u, inflater.CurrentOutputSize()); } TEST(WebSocketInflaterTest, InflateInvalidData) { WebSocketInflater inflater; ASSERT_TRUE(inflater.Initialize(15)); EXPECT_FALSE(inflater.AddBytes("\xf2\x48\xcd\xc9INVALID DATA", 16)); } TEST(WebSocketInflaterTest, ChokedInvalidData) { WebSocketInflater inflater(1, 1); ASSERT_TRUE(inflater.Initialize(15)); EXPECT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9INVALID DATA", 16)); EXPECT_TRUE(inflater.Finish()); EXPECT_EQ(1u, inflater.CurrentOutputSize()); EXPECT_FALSE(inflater.GetOutput(1024)); } TEST(WebSocketInflaterTest, MultipleAddBytesCalls) { WebSocketInflater inflater; ASSERT_TRUE(inflater.Initialize(15)); std::string input("\xf2\x48\xcd\xc9\xc9\x07\x00", 7); scoped_refptr<IOBufferWithSize> actual; for (size_t i = 0; i < input.size(); ++i) { ASSERT_TRUE(inflater.AddBytes(&input[i], 1)); } ASSERT_TRUE(inflater.Finish()); actual = inflater.GetOutput(5); ASSERT_TRUE(actual); EXPECT_EQ("Hello", ToString(actual.get())); } TEST(WebSocketInflaterTest, Reset) { WebSocketInflater inflater; ASSERT_TRUE(inflater.Initialize(15)); scoped_refptr<IOBufferWithSize> actual1, actual2; ASSERT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9\xc9\x07\x00", 7)); ASSERT_TRUE(inflater.Finish()); actual1 = inflater.GetOutput(inflater.CurrentOutputSize()); ASSERT_TRUE(actual1); EXPECT_EQ("Hello", ToString(actual1.get())); EXPECT_EQ(0u, inflater.CurrentOutputSize()); // Reset the stream with a block [BFINAL = 1, BTYPE = 00, LEN = 0] ASSERT_TRUE(inflater.AddBytes("\x01", 1)); ASSERT_TRUE(inflater.Finish()); ASSERT_EQ(0u, inflater.CurrentOutputSize()); ASSERT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9\xc9\x07\x00", 7)); ASSERT_TRUE(inflater.Finish()); actual2 = inflater.GetOutput(inflater.CurrentOutputSize()); ASSERT_TRUE(actual2); EXPECT_EQ("Hello", ToString(actual2.get())); EXPECT_EQ(0u, inflater.CurrentOutputSize()); } TEST(WebSocketInflaterTest, ResetAndLostContext) { WebSocketInflater inflater; scoped_refptr<IOBufferWithSize> actual1, actual2; ASSERT_TRUE(inflater.Initialize(15)); ASSERT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9\xc9\x07\x00", 7)); ASSERT_TRUE(inflater.Finish()); actual1 = inflater.GetOutput(inflater.CurrentOutputSize()); ASSERT_TRUE(actual1); EXPECT_EQ("Hello", ToString(actual1.get())); EXPECT_EQ(0u, inflater.CurrentOutputSize()); // Reset the stream with a block [BFINAL = 1, BTYPE = 00, LEN = 0] ASSERT_TRUE(inflater.AddBytes("\x01", 1)); ASSERT_TRUE(inflater.Finish()); ASSERT_EQ(0u, inflater.CurrentOutputSize()); // The context is already reset. ASSERT_FALSE(inflater.AddBytes("\xf2\x00\x11\x00\x00", 5)); } TEST(WebSocketInflaterTest, CallAddBytesAndFinishWithoutGetOutput) { WebSocketInflater inflater; scoped_refptr<IOBufferWithSize> actual1, actual2; ASSERT_TRUE(inflater.Initialize(15)); ASSERT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9\xc9\x07\x00", 7)); ASSERT_TRUE(inflater.Finish()); EXPECT_EQ(5u, inflater.CurrentOutputSize()); // This is a test for detecting memory leaks with valgrind. } TEST(WebSocketInflaterTest, CallAddBytesAndFinishWithoutGetOutputChoked) { WebSocketInflater inflater(1, 1); scoped_refptr<IOBufferWithSize> actual1, actual2; ASSERT_TRUE(inflater.Initialize(15)); ASSERT_TRUE(inflater.AddBytes("\xf2\x48\xcd\xc9\xc9\x07\x00", 7)); ASSERT_TRUE(inflater.Finish()); EXPECT_EQ(1u, inflater.CurrentOutputSize()); // This is a test for detecting memory leaks with valgrind. } TEST(WebSocketInflaterTest, LargeRandomDeflateInflate) { const size_t size = 64 * 1024; LinearCongruentialGenerator generator(133); std::vector<char> input; std::vector<char> output; scoped_refptr<IOBufferWithSize> compressed; WebSocketDeflater deflater(WebSocketDeflater::TAKE_OVER_CONTEXT); ASSERT_TRUE(deflater.Initialize(8)); WebSocketInflater inflater(256, 256); ASSERT_TRUE(inflater.Initialize(8)); for (size_t i = 0; i < size; ++i) input.push_back(static_cast<char>(generator.Generate())); ASSERT_TRUE(deflater.AddBytes(&input[0], input.size())); ASSERT_TRUE(deflater.Finish()); compressed = deflater.GetOutput(deflater.CurrentOutputSize()); ASSERT_TRUE(compressed); ASSERT_EQ(0u, deflater.CurrentOutputSize()); ASSERT_TRUE(inflater.AddBytes(compressed->data(), compressed->size())); ASSERT_TRUE(inflater.Finish()); while (inflater.CurrentOutputSize() > 0) { scoped_refptr<IOBufferWithSize> uncompressed = inflater.GetOutput(inflater.CurrentOutputSize()); ASSERT_TRUE(uncompressed); output.insert(output.end(), uncompressed->data(), uncompressed->data() + uncompressed->size()); } EXPECT_EQ(output, input); } } // unnamed namespace } // namespace net