/* vim: set sw=2 ts=2 expandtab:
 *
 * Copyright (C) 2010 by Multi-Tech Systems
 *
 * Author: James Maki <jmaki@multitech.com>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 *
 */

#ifndef CONTAINER_BUFFER_H_
#define CONTAINER_BUFFER_H_

#include <pthread.h>
#include <list>
#include <sys/time.h>

#define SSL_SUPPORT 1

#if SSL_SUPPORT
#include <openssl/rand.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#endif

#include <poll.h>

#include <annex_common.h>
#include <annex.pb.h>

namespace mts {

class ContainerBuffer {
 public:
//  static const int kSerializedMax = (1 << 20);
  static const int kSerializedMax = (1 << 15);	// 32768

  enum BufferEvent {
    BUFFER_EVENT_START = 1,
    BUFFER_EVENT_COMPLETE = 2,
    BUFFER_EVENT_TIMEOUT = 3,
    BUFFER_EVENT_STEP = 4,
    BUFFER_EVENT_SYS_ERROR = 5,
    BUFFER_EVENT_READY = 6,
  };

  ContainerBuffer() {
  }

  ~ContainerBuffer() {
  }

#if SSL_SUPPORT
  BufferEvent SendStart(int sd, SSL *ssl, int timeout_msec, annex::Container *container, bool blocking) {
#else
  BufferEvent SendStart(int sd, char *ssl, int timeout_msec, annex::Container *container, bool blocking) {
#endif
    bool rc;

    rc = container->SerializeToArray(
        buffer_ + sizeof(nlen_),
        sizeof(buffer_) - sizeof(nlen_)
    );
    if (!rc) {
      log_err("SerializeToString failed");
      return BUFFER_EVENT_SYS_ERROR;
    }

    container_len_ = container->GetCachedSize();

    sd_ = sd;
    ssl_ = ssl;
    timeout_msec_ = timeout_msec;
    blocking_ = blocking;

    nlen_ = htonl(container_len_);
    memcpy(buffer_, &nlen_, sizeof(nlen_));

    buffer_len_ = container_len_ + sizeof(nlen_);
    buffer_offset_ = 0;
    abstimeout_tv_msec(&abstimeout_, timeout_msec_);

    return BUFFER_EVENT_STEP;
  }

#if SSL_SUPPORT
  BufferEvent RecvStart(int sd, SSL *ssl, int timeout_msec, bool blocking) {
#else
  BufferEvent RecvStart(int sd, char *ssl, int timeout_msec, bool blocking) {
#endif
    sd_ = sd;
    ssl_ = ssl;
    timeout_msec_ = timeout_msec;
    buffer_len_ = sizeof(nlen_);
    buffer_offset_ = 0;
    container_len_ = 0;
    blocking_ = blocking;
    abstimeout_tv_msec(&abstimeout_, timeout_msec_);

    if (blocking_) {
      return BUFFER_EVENT_STEP;
    } else {
      BufferEvent event = RecvBytes(1);
      switch (event) {
      case BUFFER_EVENT_STEP:
      case BUFFER_EVENT_TIMEOUT:
        return BUFFER_EVENT_START;
      case BUFFER_EVENT_COMPLETE:
        return BUFFER_EVENT_STEP;
      default:
        return event;
      }
    }
  }

  BufferEvent SendStep() {
    ssize_t cc;
    size_t len = buffer_len_ - buffer_offset_;
    BufferEvent event;

    while (len > 0) {
      if (ssl_) {
#if SSL_SUPPORT
        cc = SSL_write(ssl_, buffer_ + buffer_offset_, len);
        if (cc < 0) {
          int ssl_err;

          switch ((ssl_err = SSL_get_error(ssl_, cc))) {
          case SSL_ERROR_WANT_READ:
            event = Poll(POLLIN);
            if (event != BUFFER_EVENT_READY) {
              return event;
            }
            continue;
          case SSL_ERROR_WANT_WRITE:
            event = Poll(POLLOUT);
            if (event != BUFFER_EVENT_READY) {
              return event;
            }
            continue;
          default:
            ssl_log_ssl_error("SSL_write failed", ssl_, cc);
            return BUFFER_EVENT_SYS_ERROR;
          }
        } else if (cc == 0) {
          log_err("connection closed unexpectedly");
          return BUFFER_EVENT_SYS_ERROR;
        }

        abstimeout_tv_msec(&abstimeout_, timeout_msec_);
#endif
      } else {
        event = Poll(POLLOUT);
        if (event != BUFFER_EVENT_READY) {
          return event;
        }

        cc = write(sd_, buffer_ + buffer_offset_, len);
        if (cc <= 0) {
          log_debug("write failed: %m");
          return BUFFER_EVENT_SYS_ERROR;
        }

        abstimeout_tv_msec(&abstimeout_, timeout_msec_);
      }

      buffer_offset_ += cc;
      len -= cc;
    }

    return BUFFER_EVENT_COMPLETE;
  }

  BufferEvent RecvStep() {
    size_t len;
    BufferEvent event;

again:
    len = buffer_len_ - buffer_offset_;
    event = RecvBytes(len);
    if (event == BUFFER_EVENT_COMPLETE && buffer_len_ == sizeof(nlen_)) {
      memcpy(&nlen_, buffer_, sizeof(nlen_));
      container_len_ = ntohl(nlen_);
      if (container_len_ > kSerializedMax) {
        log_err("container is too large: %d", container_len_);
        return BUFFER_EVENT_SYS_ERROR;
      }
      if (container_len_ <= 0) {
        log_err("container is too small: %d", container_len_);
        return BUFFER_EVENT_SYS_ERROR;
      }

      buffer_len_ += container_len_;
      goto again;
    }

    return event;
  }

  bool ParseContainer(annex::Container *container) {
    bool rc;
    container->Clear();

    rc = container->ParseFromArray(buffer_ + sizeof(nlen_), container_len_);
    if (!rc) {
      log_err("ParseFromArray failed");
      return false;
    }

    return true;
  }

 private:
  bool AbsTimeout() {
    struct timeval now;
    int tmp;

    tmp = gettimeofday(&now, NULL);
    if (tmp < 0) {
      log_err("gettimeofday failed: %m");
      return false;
    }

    if (timercmp(&abstimeout_, &now, <=)) {
      return true;
    }

    return false;
  }

  BufferEvent Poll(int event) {
    int tmp;
    struct pollfd pfd;
    pfd.fd = sd_;
    pfd.events = event;
    pfd.revents = 0;

    if (blocking_) {
      tmp = poll(&pfd, 1, timeout_msec_);
    } else {
      tmp = poll(&pfd, 1, 0);
    }
    if (tmp < 0) {
      log_err("poll failed: %m");
      return BUFFER_EVENT_SYS_ERROR;
    } else if (tmp == 0) {
      if (!blocking_) {
        if (AbsTimeout()) {
          return BUFFER_EVENT_TIMEOUT;
        }
        return BUFFER_EVENT_STEP;
      } else {
        log_debug("poll timeout");
        return BUFFER_EVENT_TIMEOUT;
      }
    }
    if (!(pfd.revents & event)) {
      log_debug("poll revents error %08X", pfd.revents);
      return BUFFER_EVENT_SYS_ERROR;
    }

    return BUFFER_EVENT_READY;
  }

  BufferEvent RecvBytes(int len) {
    ssize_t cc;
    BufferEvent event;

    while (len > 0) {
      if (ssl_) {
#if SSL_SUPPORT
        cc = SSL_read(ssl_, buffer_ + buffer_offset_, len);
        if (cc < 0) {
          int ssl_err;

          switch ((ssl_err = SSL_get_error(ssl_, cc))) {
          case SSL_ERROR_WANT_READ:
            event = Poll(POLLIN);
            if (event != BUFFER_EVENT_READY) {
              return event;
            }
            continue;
          case SSL_ERROR_WANT_WRITE:
            event = Poll(POLLOUT);
            if (event != BUFFER_EVENT_READY) {
              return event;
            }
            continue;
          default:
            ssl_log_ssl_error("SSL_write failed", ssl_, cc);
            return BUFFER_EVENT_SYS_ERROR;
          }
        } else if (cc == 0) {
          log_err("connection closed unexpectedly");
          return BUFFER_EVENT_SYS_ERROR;
        }

        abstimeout_tv_msec(&abstimeout_, timeout_msec_);
#endif
      } else {
        event = Poll(POLLIN);
        if (event != BUFFER_EVENT_READY) {
          return event;
        }

        cc = read(sd_, buffer_ + buffer_offset_, len);
        if (cc <= 0) {
          log_debug("read failed: %m");
          return BUFFER_EVENT_SYS_ERROR;
        }

        abstimeout_tv_msec(&abstimeout_, timeout_msec_);
      }

      buffer_offset_ += cc;
      len -= cc;
    }

    return BUFFER_EVENT_COMPLETE;
  }

  int sd_;
#if SSL_SUPPORT
  SSL *ssl_;
#else
  char *ssl_;
#endif
  bool blocking_;

  char buffer_[sizeof(uint32_t) + kSerializedMax];
  size_t buffer_len_;
  size_t buffer_offset_;
  int container_len_;
  uint32_t nlen_;

  int timeout_msec_;
  struct timeval abstimeout_;
};

}
#endif  // CONTAINER_BUFFER_H_
