#include "tcpstream.hpp"

#include <stdexcept>

#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <poll.h>

using namespace netlib;

TcpStream::TcpStream(SockAddr _remote)
    : remote{_remote}, sockfd{0}
{ }
    
TcpStream::TcpStream(IpAddr remoteAddress, uint16_t port)
    : TcpStream{SockAddr{remoteAddress, port}}
{ }

TcpStream::TcpStream(const std::string &remoteAddress, uint16_t port)
    : TcpStream{SockAddr{remoteAddress, port}}
{ }

TcpStream::TcpStream(const std::string &remoteAddressPort)
    : TcpStream{SockAddr{remoteAddressPort}}
{ }

TcpStream::~TcpStream()
{
    close();
}

void TcpStream::connect()
{
    if (sockfd != 0)
        throw std::runtime_error("Can't call connect on open socket");

    int af;
    socklen_t sock_len;

    if (remote.address.type == IpAddr::Type::V4)
    {
        af = AF_INET;
        sock_len = sizeof(sockaddr_in);
    }
    else if (remote.address.type == IpAddr::Type::V6)
    {
        af = AF_INET6;
        sock_len = sizeof(sockaddr_in6);
    }
    else
    {
        throw std::runtime_error("Can't connect to IpAddr::Type::Undef");
    }

    // Create the socket and get the socket file descriptor
    sockfd = socket(af, SOCK_STREAM, 0);

    if (sockfd <= 0)
    {
        throw std::runtime_error("Creating TCP Socket failed");
    }

    if (::connect(sockfd, &remote.raw_sockaddr.generic, sock_len) != 0)
    {
        close();
        throw std::runtime_error("Connecting TCP Socket failed");
    }

}

void TcpStream::close()
{
    if (sockfd != 0)
    {
        ::close(sockfd);
    }

    sockfd = 0;
}

ssize_t TcpStream::send(const void *data, size_t len)
{
    if (sockfd == 0)
        throw std::runtime_error("Can't write to closed socket");
    
    ssize_t bytes_sent = ::write(sockfd, data, len);

    if (bytes_sent < 0)
    {
        close();
        throw std::runtime_error("Error while writing to socket");
    }

    return bytes_sent;
}

void TcpStream::sendAll(const void *data, size_t len)
{
    if (sockfd == 0)
        throw std::runtime_error("Can't write to closed socket");
    
    size_t bytesSentTotal = 0;

    while (bytesSentTotal < len)
    {
        ssize_t bytesSent = ::write(sockfd, (uint8_t*)data + bytesSentTotal, len-bytesSentTotal);

        if (bytesSent < 0)
        {
            close();
            throw std::runtime_error("Error while writing to socket");
        }

        bytesSentTotal += bytesSent;
    }
}

void TcpStream::sendAllString(const std::string &str)
{
    sendAll(str.c_str(), str.size());
}

ssize_t TcpStream::read(void *data, size_t len)
{
    if (sockfd == 0)
        throw std::runtime_error("Can't read from closed socket");

    ssize_t bytes_read = ::read(sockfd, data, len);
    if (bytes_read < 0)
    {
        close();
        throw std::runtime_error("Error while reading from socket");
    }
    return bytes_read;
}

ssize_t TcpStream::readAll(void *data, size_t len)
{
    if (sockfd == 0)
        throw std::runtime_error("Can't read from closed socket");
    
    size_t bytesReadTotal = 0;
    while (true)
    {
        ssize_t bytesRead = ::read(sockfd, (uint8_t*)data + bytesReadTotal, len-bytesReadTotal);

        if (bytesRead == 0) break;
        if (bytesRead < 0)
        {
            close();
            throw std::runtime_error("Error while reading from socket");
        }

        bytesReadTotal += bytesRead;
    }
    return bytesReadTotal;
}

ssize_t TcpStream::readTimeout(void *data, size_t len, int timeoutMs)
{
    if (timeoutMs <= 0) return read(data, len);

    if (sockfd == 0)
        throw std::runtime_error("Can't read from closed socket");

    pollfd pfd = {0};
    pfd.fd = sockfd;
    pfd.events = POLLIN;
    
    // block until data is available or the timeout is reached
    int res = poll(&pfd, 1, timeoutMs);

    // a timout occured
    if (res == 0) return 0;

    // a poll error occured
    if (res < 0)
    {
        close();
        throw std::runtime_error("Error while reading from socket");
    }

    ssize_t bytes_read = ::read(sockfd, data, len);
    if (bytes_read < 0)
    {
        close();
        throw std::runtime_error("Error while reading from socket");
    }
    return bytes_read;
}

ssize_t TcpStream::readAllTimeout(void *data, size_t len, int timeoutMs)
{
    if (timeoutMs <= 0) return readAll(data, len);

    if (sockfd == 0)
        throw std::runtime_error("Can't read from closed socket");
    
    pollfd pfd = {0};
    pfd.fd = sockfd;
    pfd.events = POLLIN;
    
    
    size_t bytesReadTotal = 0;
    while (true)
    {
        // block until data is available or the timeout is reached
        int res = poll(&pfd, 1, timeoutMs);

        // a timout occured
        if (res == 0) return -1 * bytesReadTotal;

        // a poll error occured
        if (res < 0)
        {
            close();
            throw std::runtime_error("Error while reading from socket");
        }

        ssize_t bytesRead = ::read(sockfd, (uint8_t*)data + bytesReadTotal, len-bytesReadTotal);

        if (bytesRead == 0) break;
        if (bytesRead < 0)
        {
            close();
            throw std::runtime_error("Error while reading from socket");
        }

        bytesReadTotal += bytesRead;
    }
    return bytesReadTotal;
}

const SockAddr & TcpStream::getRemoteAddr() const
{
    return remote;
}

bool TcpStream::isClosed() const
{
    return sockfd == 0;
}