package websocket.test;

import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.io.Reader;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;

import com.caucho.websocket.WebSocketContext;
import com.caucho.websocket.WebSocketServletRequest;
import com.lmax.commons.threading.DaemonThreadFactory;

import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TestWebsocketServlet extends HttpServlet
{
    private static final Logger LOGGER = LoggerFactory.getLogger(TestWebsocketServlet.class);

    private static final String WEB_SOCKET_PROTOCOL = "Sec-WebSocket-Protocol";
    private static final String SUB_PROTOCOL = "websocket-test";

    private final ScheduledExecutorService scheduledExecutorService = Executors.newScheduledThreadPool(16, DaemonThreadFactory.INSTANCE);

    @Override
    protected void service(final HttpServletRequest req, final HttpServletResponse resp) throws ServletException, IOException
    {
        final String protocol = req.getHeader(WEB_SOCKET_PROTOCOL);
        if (!SUB_PROTOCOL.equals(protocol))
        {
            resp.setStatus(400);
            return;
        }

        resp.setHeader(WEB_SOCKET_PROTOCOL, SUB_PROTOCOL);

        final WebSocketServletRequest wsReq;
        if (req instanceof WebSocketServletRequest)
        {
            wsReq = (WebSocketServletRequest)req;
        }
        else if (req instanceof HttpServletRequestWrapper)
        {
            wsReq = (WebSocketServletRequest)((HttpServletRequestWrapper)req).getRequest();
        }
        else
        {
            throw new ServletException("Invalid servlet request, needs to support web sockets");
        }

        wsReq.startWebSocket(new TestWebSocketPushListener());
    }


    public class TestWebSocketPushListener implements com.caucho.websocket.WebSocketListener
    {
        private final AtomicInteger counter = new AtomicInteger();
        private final BlockingQueue<String> queue = new ArrayBlockingQueue<>(4);

        private volatile boolean running = true;
        private volatile Thread publisherThread;

        private volatile WebSocketContext context;

        @Override
        public void onStart(final WebSocketContext context)
        {
            this.context = context;

            queueNextMessage();

            scheduledExecutorService.execute(this::runPublisher);
        }

        @Override
        public void onReadText(final WebSocketContext context, final Reader is)
        {
            try
            {
                final String s = IOUtils.toString(is);

                final int clientCounter = Integer.parseInt(s);
                final int serverCounter = counter.get();
                if (clientCounter == serverCounter)
                {
                    scheduledExecutorService.schedule(this::queueNextMessage, 1, TimeUnit.SECONDS);
                }
                else
                {
                    LOGGER.error(String.format("Client did not echo expected sequence: expected=%s got=%s", serverCounter, clientCounter));
                }
            }
            catch (final IOException | NumberFormatException e)
            {
                LOGGER.warn("Problem reading client message.", e);
            }
        }

        @Override
        public void onReadBinary(final WebSocketContext context, final InputStream is)
        {
        }

        @Override
        public void onClose(final WebSocketContext context)
        {
            stopPublisher();
        }

        @Override
        public void onDisconnect(final WebSocketContext context)
        {
            stopPublisher();
        }

        @Override
        public void onTimeout(final WebSocketContext context)
        {
            stopPublisher();
        }

        private void queueNextMessage()
        {
            queue.offer(String.valueOf(counter.incrementAndGet()));
        }

        private void stopPublisher()
        {
            running = false;
            publisherThread.interrupt();
        }

        public void runPublisher()
        {
            this.publisherThread = Thread.currentThread();

            try
            {
                while (running)
                {
                    try
                    {
                        sendOneMessage();
                    }
                    catch (InterruptedException e)
                    {
                    }
                }
            }
            catch (IOException e)
            {
                throw new RuntimeException(e);
            }
        }

        void sendOneMessage() throws InterruptedException, IOException
        {
            final String msg = queue.take();

            try (final PrintWriter out = context.startTextMessage())
            {
                out.print(msg);
            }
        }
    }
}
