socket_handler.cpp 12 KB
Newer Older
1 2
#include <network/socket_handler.hpp>

3
#include <utils/timed_events.hpp>
4 5 6
#include <utils/scopeguard.hpp>
#include <network/poller.hpp>

louiz’'s avatar
louiz’ committed
7
#include <logger/logger.hpp>
8
#include <sys/socket.h>
louiz’'s avatar
louiz’ committed
9
#include <sys/types.h>
louiz’'s avatar
louiz’ committed
10
#include <stdexcept>
louiz’'s avatar
louiz’ committed
11
#include <unistd.h>
12
#include <stdlib.h>
louiz’'s avatar
louiz’ committed
13
#include <errno.h>
louiz’'s avatar
louiz’ committed
14
#include <netdb.h>
15
#include <cstring>
16
#include <fcntl.h>
louiz’'s avatar
louiz’ committed
17
#include <stdio.h>
18 19 20

#include <iostream>

louiz’'s avatar
louiz’ committed
21 22
#ifdef BOTAN_FOUND
# include <botan/hex.h>
23 24 25 26 27 28

Botan::AutoSeeded_RNG SocketHandler::rng;
Permissive_Credentials_Manager SocketHandler::credential_manager;
Botan::TLS::Policy SocketHandler::policy;
Botan::TLS::Session_Manager_In_Memory SocketHandler::session_manager(SocketHandler::rng);

louiz’'s avatar
louiz’ committed
29
#endif
louiz’'s avatar
louiz’ committed
30

31 32 33 34
#ifndef UIO_FASTIOV
# define UIO_FASTIOV 8
#endif

louiz’'s avatar
louiz’ committed
35
using namespace std::string_literals;
36
using namespace std::chrono_literals;
louiz’'s avatar
louiz’ committed
37 38 39

namespace ph = std::placeholders;

40
SocketHandler::SocketHandler(std::shared_ptr<Poller> poller):
41
  socket(-1),
42
  poller(poller),
louiz’'s avatar
louiz’ committed
43
  use_tls(false),
44 45
  connected(false),
  connecting(false)
louiz’'s avatar
louiz’ committed
46
{}
47

louiz’'s avatar
louiz’ committed
48
void SocketHandler::init_socket(const struct addrinfo* rp)
49
{
louiz’'s avatar
louiz’ committed
50
  if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1)
51
    throw std::runtime_error("Could not create socket: "s + strerror(errno));
louiz’'s avatar
louiz’ committed
52 53 54
  int optval = 1;
  if (::setsockopt(this->socket, SOL_SOCKET, SO_KEEPALIVE, &optval, sizeof(optval)) == -1)
    log_warning("Failed to enable TCP keepalive on socket: " << strerror(errno));
55 56 57 58 59 60
  // Set the socket on non-blocking mode.  This is useful to receive a EAGAIN
  // error when connect() would block, to not block the whole process if a
  // remote is not responsive.
  const int existing_flags = ::fcntl(this->socket, F_GETFL, 0);
  if ((existing_flags == -1) ||
      (::fcntl(this->socket, F_SETFL, existing_flags | O_NONBLOCK) == -1))
louiz’'s avatar
louiz’ committed
61
    throw std::runtime_error("Could not initialize socket: "s + strerror(errno));
62 63
}

louiz’'s avatar
louiz’ committed
64
void SocketHandler::connect(const std::string& address, const std::string& port, const bool tls)
65
{
66 67
  this->address = address;
  this->port = port;
louiz’'s avatar
louiz’ committed
68
  this->use_tls = tls;
69

70
  utils::ScopeGuard sg;
71 72 73

  struct addrinfo* addr_res;

74
  if (!this->connecting)
75
    {
76 77 78 79 80 81
      log_info("Trying to connect to " << address << ":" << port);
      // Get the addrinfo from getaddrinfo, only if this is the first call
      // of this function.
      struct addrinfo hints;
      memset(&hints, 0, sizeof(struct addrinfo));
      hints.ai_flags = 0;
louiz’'s avatar
louiz’ committed
82
      hints.ai_family = AF_UNSPEC;
83 84 85 86 87 88 89
      hints.ai_socktype = SOCK_STREAM;
      hints.ai_protocol = 0;

      const int res = ::getaddrinfo(address.c_str(), port.c_str(), &hints, &addr_res);

      if (res != 0)
        {
louiz’'s avatar
louiz’ committed
90
          log_warning("getaddrinfo failed: "s + gai_strerror(res));
91 92 93 94 95 96 97
          this->close();
          this->on_connection_failed(gai_strerror(res));
          return ;
        }
      // Make sure the alloced structure is always freed at the end of the
      // function
      sg.add_callback([&addr_res](){ freeaddrinfo(addr_res); });
98
    }
99
  else
100 101 102
    { // This function is called again, use the saved addrinfo structure,
      // instead of re-doing the whole getaddrinfo process.
      addr_res = &this->addrinfo;
103
    }
104

105 106
  for (struct addrinfo* rp = addr_res; rp; rp = rp->ai_next)
    {
louiz’'s avatar
louiz’ committed
107 108 109 110 111 112 113 114 115 116
      if (!this->connecting)
        {
          try {
            this->init_socket(rp);
          }
          catch (const std::runtime_error& error) {
            log_error("Failed to init socket: " << error.what());
            break;
          }
        }
117 118
      if (::connect(this->socket, rp->ai_addr, rp->ai_addrlen) == 0
          || errno == EISCONN)
119
        {
louiz’'s avatar
louiz’ committed
120
          log_info("Connection success.");
121 122
          TimedEventsManager::instance().cancel("connection_timeout"s +
                                                std::to_string(this->socket));
123
          this->poller->add_socket_handler(this);
124
          this->connected = true;
125
          this->connecting = false;
louiz’'s avatar
louiz’ committed
126 127 128 129
#ifdef BOTAN_FOUND
          if (this->use_tls)
            this->start_tls();
#endif
130
          this->on_connected();
131 132 133 134 135
          return ;
        }
      else if (errno == EINPROGRESS || errno == EALREADY)
        {   // retry this process later, when the socket
            // is ready to be written on.
louiz’'s avatar
louiz’ committed
136 137
          this->connecting = true;
          this->poller->add_socket_handler(this);
138
          this->poller->watch_send_events(this);
139 140 141
          // Save the addrinfo structure, to use it on the next call
          this->ai_addrlen = rp->ai_addrlen;
          memcpy(&this->ai_addr, rp->ai_addr, this->ai_addrlen);
142 143 144
          memcpy(&this->addrinfo, rp, sizeof(struct addrinfo));
          this->addrinfo.ai_addr = &this->ai_addr;
          this->addrinfo.ai_next = nullptr;
145 146 147 148 149 150
          // If the connection has not succeeded or failed in 5s, we consider
          // it to have failed
          TimedEventsManager::instance().add_event(
                TimedEvent(std::chrono::steady_clock::now() + 5s,
                           std::bind(&SocketHandler::on_connection_timeout, this),
                           "connection_timeout"s + std::to_string(this->socket)));
151
          return ;
152
        }
153
      log_info("Connection failed:" << strerror(errno));
154
    }
louiz’'s avatar
louiz’ committed
155
  log_error("All connection attempts failed.");
156
  this->close();
157 158 159 160
  this->on_connection_failed(strerror(errno));
  return ;
}

161 162 163 164 165 166
void SocketHandler::on_connection_timeout()
{
  this->close();
  this->on_connection_failed("connection timed out");
}

167 168
void SocketHandler::connect()
{
louiz’'s avatar
louiz’ committed
169
  this->connect(this->address, this->port, this->use_tls);
170 171
}

172
void SocketHandler::on_recv()
louiz’'s avatar
louiz’ committed
173 174 175 176 177 178 179 180 181 182
{
#ifdef BOTAN_FOUND
  if (this->use_tls)
    this->tls_recv();
  else
#endif
    this->plain_recv();
}

void SocketHandler::plain_recv()
183
{
184 185 186
  static constexpr size_t buf_size = 4096;
  char buf[buf_size];
  void* recv_buf = this->get_receive_buffer(buf_size);
187

188 189 190
  if (recv_buf == nullptr)
    recv_buf = buf;

louiz’'s avatar
louiz’ committed
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
  const ssize_t size = this->do_recv(recv_buf, buf_size);

  if (size > 0)
    {
      if (buf == recv_buf)
        {
          // data needs to be placed in the in_buf string, because no buffer
          // was provided to receive that data directly. The in_buf buffer
          // will be handled in parse_in_buffer()
          this->in_buf += std::string(buf, size);
        }
      this->parse_in_buffer(size);
    }
}

ssize_t SocketHandler::do_recv(void* recv_buf, const size_t buf_size)
{
208
  ssize_t size = ::recv(this->socket, recv_buf, buf_size, 0);
209
  if (0 == size)
210 211 212 213
    {
      this->on_connection_close();
      this->close();
    }
louiz’'s avatar
louiz’ committed
214
  else if (-1 == size)
215 216
    {
      log_warning("Error while reading from socket: " << strerror(errno));
217
      if (this->connecting)
louiz’'s avatar
louiz’ committed
218 219 220 221
        {
          this->close();
          this->on_connection_failed(strerror(errno));
        }
222
      else
223
        {
louiz’'s avatar
louiz’ committed
224 225
          this->close();
          this->on_connection_close();
226
        }
227
    }
louiz’'s avatar
louiz’ committed
228
  return size;
229 230 231 232
}

void SocketHandler::on_send()
{
233 234 235 236 237 238 239 240 241
  struct iovec msg_iov[UIO_FASTIOV] = {};
  struct msghdr msg{nullptr, 0,
      msg_iov,
      0, nullptr, 0, 0};
  for (std::string& s: this->out_buf)
    {
      // unconsting the content of s is ok, sendmsg will never modify it
      msg_iov[msg.msg_iovlen].iov_base = const_cast<char*>(s.data());
      msg_iov[msg.msg_iovlen].iov_len = s.size();
242 243
      if (++msg.msg_iovlen == UIO_FASTIOV)
        break;
244 245 246
    }
  ssize_t res = ::sendmsg(this->socket, &msg, MSG_NOSIGNAL);
  if (res < 0)
247
    {
248
      log_error("sendmsg failed: " << strerror(errno));
249
      this->on_connection_close();
250 251 252 253
      this->close();
    }
  else
    {
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
      // remove all the strings that were successfully sent.
      for (auto it = this->out_buf.begin();
           it != this->out_buf.end();)
        {
          if (static_cast<size_t>(res) >= (*it).size())
            {
              res -= (*it).size();
              it = this->out_buf.erase(it);
            }
          else
            {
              // If one string has partially been sent, we use substr to
              // crop it
              if (res > 0)
                (*it) = (*it).substr(res, std::string::npos);
              break;
            }
        }
272 273 274 275 276 277 278
      if (this->out_buf.empty())
        this->poller->stop_watching_send_events(this);
    }
}

void SocketHandler::close()
{
279 280
  TimedEventsManager::instance().cancel("connection_timeout"s +
                                        std::to_string(this->socket));
281 282 283 284 285 286 287
  if (this->connected || this->connecting)
    this->poller->remove_socket_handler(this->get_socket());
  if (this->socket != -1)
    {
      ::close(this->socket);
      this->socket = -1;
    }
288
  this->connected = false;
289
  this->connecting = false;
290 291 292
  this->in_buf.clear();
  this->out_buf.clear();
  this->port.clear();
293 294 295 296 297 298 299 300
}

socket_t SocketHandler::get_socket() const
{
  return this->socket;
}

void SocketHandler::send_data(std::string&& data)
louiz’'s avatar
louiz’ committed
301 302 303 304 305 306 307 308 309 310
{
#ifdef BOTAN_FOUND
  if (this->use_tls)
    this->tls_send(std::move(data));
  else
#endif
    this->raw_send(std::move(data));
}

void SocketHandler::raw_send(std::string&& data)
311
{
312 313 314
  if (data.empty())
    return ;
  this->out_buf.emplace_back(std::move(data));
315 316 317 318 319 320 321 322
  if (this->connected)
    this->poller->watch_send_events(this);
}

void SocketHandler::send_pending_data()
{
  if (this->connected && !this->out_buf.empty())
    this->poller->watch_send_events(this);
323
}
324 325 326 327 328

bool SocketHandler::is_connected() const
{
  return this->connected;
}
329 330 331 332 333

bool SocketHandler::is_connecting() const
{
  return this->connecting;
}
334 335 336 337 338

void* SocketHandler::get_receive_buffer(const size_t) const
{
  return nullptr;
}
louiz’'s avatar
louiz’ committed
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424

#ifdef BOTAN_FOUND
void SocketHandler::start_tls()
{
  Botan::TLS::Server_Information server_info(this->address, "irc", std::stoul(this->port));
  this->tls = std::make_unique<Botan::TLS::Client>(
      std::bind(&SocketHandler::tls_output_fn, this, ph::_1, ph::_2),
      std::bind(&SocketHandler::tls_data_cb, this, ph::_1, ph::_2),
      std::bind(&SocketHandler::tls_alert_cb, this, ph::_1, ph::_2, ph::_3),
      std::bind(&SocketHandler::tls_handshake_cb, this, ph::_1),
      session_manager, credential_manager, policy,
      rng, server_info, Botan::TLS::Protocol_Version::latest_tls_version());
}

void SocketHandler::tls_recv()
{
  static constexpr size_t buf_size = 4096;
  char recv_buf[buf_size];

  const ssize_t size = this->do_recv(recv_buf, buf_size);
  if (size > 0)
    {
      const bool was_active = this->tls->is_active();
      this->tls->received_data(reinterpret_cast<const Botan::byte*>(recv_buf),
                              static_cast<size_t>(size));
      if (!was_active && this->tls->is_active())
        this->on_tls_activated();
    }
}

void SocketHandler::tls_send(std::string&& data)
{
  if (this->tls->is_active())
    {
      const bool was_active = this->tls->is_active();
      if (!this->pre_buf.empty())
        {
          this->tls->send(reinterpret_cast<const Botan::byte*>(this->pre_buf.data()),
                         this->pre_buf.size());
          this->pre_buf = "";
        }
      if (!data.empty())
        this->tls->send(reinterpret_cast<const Botan::byte*>(data.data()),
                       data.size());
      if (!was_active && this->tls->is_active())
        this->on_tls_activated();
    }
  else
    this->pre_buf += data;
}

void SocketHandler::tls_data_cb(const Botan::byte* data, size_t size)
{
  this->in_buf += std::string(reinterpret_cast<const char*>(data),
                              size);
  if (!this->in_buf.empty())
    this->parse_in_buffer(size);
}

void SocketHandler::tls_output_fn(const Botan::byte* data, size_t size)
{
  this->raw_send(std::string(reinterpret_cast<const char*>(data), size));
}

void SocketHandler::tls_alert_cb(Botan::TLS::Alert alert, const Botan::byte*, size_t)
{
  log_debug("tls_alert: " << alert.type_string());
}

bool SocketHandler::tls_handshake_cb(const Botan::TLS::Session& session)
{
  log_debug("Handshake with " << session.server_info().hostname() << " complete."
            << " Version: " << session.version().to_string()
            << " using " << session.ciphersuite().to_string());
  if (!session.session_id().empty())
    log_debug("Session ID " << Botan::hex_encode(session.session_id()));
  if (!session.session_ticket().empty())
    log_debug("Session ticket " << Botan::hex_encode(session.session_ticket()));
  return true;
}

void SocketHandler::on_tls_activated()
{
  this->send_data("");
}
#endif // BOTAN_FOUND