#region Copyright notice and license

// Copyright 2015-2016 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#endregion

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

using CommandLine;
using CommandLine.Text;
using Grpc.Core;
using Grpc.Core.Logging;
using Grpc.Core.Utils;
using Grpc.Testing;

namespace Grpc.IntegrationTesting
{
    public class StressTestClient
    {
        static readonly ILogger Logger = GrpcEnvironment.Logger.ForType<StressTestClient>();
        const double SecondsToNanos = 1e9;

        private class ClientOptions
        {
            [Option("server_addresses", Default = "localhost:8080")]
            public string ServerAddresses { get; set; }

            [Option("test_cases", Default = "large_unary:100")]
            public string TestCases { get; set; }

            [Option("test_duration_secs", Default = -1)]
            public int TestDurationSecs { get; set; }

            [Option("num_channels_per_server", Default = 1)]
            public int NumChannelsPerServer { get; set; }

            [Option("num_stubs_per_channel", Default = 1)]
            public int NumStubsPerChannel { get; set; }

            [Option("metrics_port", Default = 8081)]
            public int MetricsPort { get; set; }
        }

        ClientOptions options;
        List<string> serverAddresses;
        Dictionary<string, int> weightedTestCases;
        WeightedRandomGenerator testCaseGenerator;

        // cancellation will be emitted once test_duration_secs has elapsed.
        CancellationTokenSource finishedTokenSource = new CancellationTokenSource();
        Histogram histogram = new Histogram(0.01, 60 * SecondsToNanos);

        private StressTestClient(ClientOptions options, List<string> serverAddresses, Dictionary<string, int> weightedTestCases)
        {
            this.options = options;
            this.serverAddresses = serverAddresses;
            this.weightedTestCases = weightedTestCases;
            this.testCaseGenerator = new WeightedRandomGenerator(this.weightedTestCases);
        }

        public static void Run(string[] args)
        {
            GrpcEnvironment.SetLogger(new ConsoleLogger());
            var parserResult = Parser.Default.ParseArguments<ClientOptions>(args)
                .WithNotParsed((x) => Environment.Exit(1))
                .WithParsed(options => {
                    GrpcPreconditions.CheckArgument(options.NumChannelsPerServer > 0);
                    GrpcPreconditions.CheckArgument(options.NumStubsPerChannel > 0);

                    var serverAddresses = options.ServerAddresses.Split(',');
                    GrpcPreconditions.CheckArgument(serverAddresses.Length > 0, "You need to provide at least one server address");

                    var testCases = ParseWeightedTestCases(options.TestCases);
                    GrpcPreconditions.CheckArgument(testCases.Count > 0, "You need to provide at least one test case");

                    var interopClient = new StressTestClient(options, serverAddresses.ToList(), testCases);
                    interopClient.Run().Wait();
                });
        }

        async Task Run()
        {
            var metricsServer = new Server()
            {
                Services = { MetricsService.BindService(new MetricsServiceImpl(histogram)) },
                Ports = { { "[::]", options.MetricsPort, ServerCredentials.Insecure } }
            };
            metricsServer.Start();

            if (options.TestDurationSecs >= 0)
            {
                finishedTokenSource.CancelAfter(TimeSpan.FromSeconds(options.TestDurationSecs));
            }

            var tasks = new List<Task>();
            var channels = new List<Channel>();
            foreach (var serverAddress in serverAddresses)
            {
                for (int i = 0; i < options.NumChannelsPerServer; i++)
                {
                    var channel = new Channel(serverAddress, ChannelCredentials.Insecure);
                    channels.Add(channel);
                    for (int j = 0; j < options.NumStubsPerChannel; j++)
                    {
                        var client = new TestService.TestServiceClient(channel);
                        var task = Task.Factory.StartNew(() => RunBodyAsync(client).GetAwaiter().GetResult(),
                            TaskCreationOptions.LongRunning);
                        tasks.Add(task);  
                    }
                }
            }
            await Task.WhenAll(tasks);

            foreach (var channel in channels)
            {
                await channel.ShutdownAsync();
            }

            await metricsServer.ShutdownAsync();
        }

        async Task RunBodyAsync(TestService.TestServiceClient client)
        {
            Logger.Info("Starting stress test client thread.");
            while (!finishedTokenSource.Token.IsCancellationRequested)
            {
                var testCase = testCaseGenerator.GetNext();

                var stopwatch = Stopwatch.StartNew();

                await RunTestCaseAsync(client, testCase);

                stopwatch.Stop();
                histogram.AddObservation(stopwatch.Elapsed.TotalSeconds * SecondsToNanos);
            }
            Logger.Info("Stress test client thread finished.");
        }

        async Task RunTestCaseAsync(TestService.TestServiceClient client, string testCase)
        {
            switch (testCase)
            {
                case "empty_unary":
                    InteropClient.RunEmptyUnary(client);
                    break;
                case "large_unary":
                    InteropClient.RunLargeUnary(client);
                    break;
                case "client_streaming":
                    await InteropClient.RunClientStreamingAsync(client);
                    break;
                case "server_streaming":
                    await InteropClient.RunServerStreamingAsync(client);
                    break;
                case "ping_pong":
                    await InteropClient.RunPingPongAsync(client);
                    break;
                case "empty_stream":
                    await InteropClient.RunEmptyStreamAsync(client);
                    break;
                case "cancel_after_begin":
                    await InteropClient.RunCancelAfterBeginAsync(client);
                    break;
                case "cancel_after_first_response":
                    await InteropClient.RunCancelAfterFirstResponseAsync(client);
                    break;
                case "timeout_on_sleeping_server":
                    await InteropClient.RunTimeoutOnSleepingServerAsync(client);
                    break;
                case "custom_metadata":
                    await InteropClient.RunCustomMetadataAsync(client);
                    break;
                case "status_code_and_message":
                    await InteropClient.RunStatusCodeAndMessageAsync(client);
                    break;
                default:
                    throw new ArgumentException("Unsupported test case  " + testCase);
            }
        }

        static Dictionary<string, int> ParseWeightedTestCases(string weightedTestCases)
        {
            var result = new Dictionary<string, int>();
            foreach (var weightedTestCase in weightedTestCases.Split(','))
            {
                var parts = weightedTestCase.Split(new char[] {':'}, 2);
                GrpcPreconditions.CheckArgument(parts.Length == 2, "Malformed test_cases option.");
                result.Add(parts[0], int.Parse(parts[1]));
            }
            return result;
        }

        class WeightedRandomGenerator
        {
            readonly Random random = new Random();
            readonly List<Tuple<int, string>> cumulativeSums;
            readonly int weightSum;

            public WeightedRandomGenerator(Dictionary<string, int> weightedItems)
            {
                cumulativeSums = new List<Tuple<int, string>>();
                weightSum = 0;
                foreach (var entry in weightedItems)
                {
                    weightSum += entry.Value;
                    cumulativeSums.Add(Tuple.Create(weightSum, entry.Key));
                }
            }

            public string GetNext()
            {
                int rand = random.Next(weightSum);
                foreach (var entry in cumulativeSums)
                {
                    if (rand < entry.Item1)
                    {
                        return entry.Item2;
                    }
                }
                throw new InvalidOperationException("GetNext() failed.");
            }
        }

        class MetricsServiceImpl : MetricsService.MetricsServiceBase 
        {
            const string GaugeName = "csharp_overall_qps";

            readonly Histogram histogram;
            readonly TimeStats timeStats = new TimeStats();

            public MetricsServiceImpl(Histogram histogram)
            {
                this.histogram = histogram;
            }

            public override Task<GaugeResponse> GetGauge(GaugeRequest request, ServerCallContext context)
            {
                if (request.Name == GaugeName)
                {
                    long qps = GetQpsAndReset();

                    return Task.FromResult(new GaugeResponse
                    {
                        Name = GaugeName,
                        LongValue = qps
                    });
                }
                throw new RpcException(new Status(StatusCode.InvalidArgument, "Gauge does not exist"));
            }

            public override async Task GetAllGauges(EmptyMessage request, IServerStreamWriter<GaugeResponse> responseStream, ServerCallContext context)
            {
                long qps = GetQpsAndReset();

                var response = new GaugeResponse
                {
                    Name = GaugeName,
                    LongValue = qps
                };
                await responseStream.WriteAsync(response);
            }

            long GetQpsAndReset()
            {
                var snapshot = histogram.GetSnapshot(true);
                var timeSnapshot = timeStats.GetSnapshot(true);

                return (long) (snapshot.Count / timeSnapshot.WallClockTime.TotalSeconds);
            }
        }
    }
}