Commit aa340e1c authored by louiz’'s avatar louiz’

Separate the DNS resolution logic from the TCP communication logic

fix #3137
parent 1aa2c2d8
......@@ -2,7 +2,6 @@
#ifdef CARES_FOUND
#include <network/dns_socket_handler.hpp>
#include <network/tcp_socket_handler.hpp>
#include <network/dns_handler.hpp>
#include <network/poller.hpp>
......@@ -14,19 +13,6 @@
DNSHandler DNSHandler::instance;
using namespace std::string_literals;
void on_hostname4_resolved(void* arg, int status, int, struct hostent* hostent)
{
TCPSocketHandler* socket_handler = static_cast<TCPSocketHandler*>(arg);
socket_handler->on_hostname4_resolved(status, hostent);
}
void on_hostname6_resolved(void* arg, int status, int, struct hostent* hostent)
{
TCPSocketHandler* socket_handler = static_cast<TCPSocketHandler*>(arg);
socket_handler->on_hostname6_resolved(status, hostent);
}
DNSHandler::DNSHandler()
{
int ares_error;
......@@ -54,16 +40,15 @@ void DNSHandler::destroy()
::ares_library_cleanup();
}
void DNSHandler::gethostbyname(const std::string& name,
TCPSocketHandler* socket_handler, int family)
void DNSHandler::gethostbyname(const std::string& name, ares_host_callback callback,
void* data, int family)
{
socket_handler->free_cares_addrinfo();
if (family == AF_INET)
::ares_gethostbyname(this->channel, name.data(), family,
&::on_hostname4_resolved, socket_handler);
callback, data);
else
::ares_gethostbyname(this->channel, name.data(), family,
&::on_hostname6_resolved, socket_handler);
callback, data);
}
void DNSHandler::watch_dns_sockets(std::shared_ptr<Poller>& poller)
......
......@@ -13,9 +13,6 @@ class DNSSocketHandler;
# include <string>
# include <vector>
void on_hostname4_resolved(void* arg, int status, int, struct hostent* hostent);
void on_hostname6_resolved(void* arg, int status, int, struct hostent* hostent);
/**
* Class managing DNS resolution. It should only be statically instanciated
* once in SocketHandler. It manages ares channel and calls various
......@@ -27,8 +24,8 @@ class DNSHandler
public:
DNSHandler();
~DNSHandler() = default;
void gethostbyname(const std::string& name, TCPSocketHandler* socket_handler,
int family);
void gethostbyname(const std::string& name, ares_host_callback callback,
void* socket_handler, int family);
/**
* Call ares_fds to know what fd needs to be watched by the poller, create
* or destroy DNSSocketHandlers depending on the result.
......
#include <network/dns_handler.hpp>
#include <network/resolver.hpp>
#include <string.h>
#include <arpa/inet.h>
// remove me
#include <iostream>
using namespace std::string_literals;
Resolver::Resolver():
#ifdef CARES_FOUND
resolved4(false),
resolved6(false),
resolving(false),
cares_addrinfo(nullptr),
port{},
#endif
resolved(false),
error_msg{}
{
}
void Resolver::resolve(const std::string& hostname, const std::string& port,
SuccessCallbackType success_cb, ErrorCallbackType error_cb)
{
this->error_cb = error_cb;
this->success_cb = success_cb;
this->port = port;
this->start_resolving(hostname, port);
}
#ifdef CARES_FOUND
void Resolver::start_resolving(const std::string& hostname, const std::string& port)
{
std::cout << "start_resolving: " << hostname << port << std::endl;
this->resolving = true;
this->resolved = false;
this->resolved4 = false;
this->resolved6 = false;
this->error_msg.clear();
this->cares_addrinfo = nullptr;
auto hostname4_resolved = [](void* arg, int status, int,
struct hostent* hostent)
{
Resolver* resolver = static_cast<Resolver*>(arg);
resolver->on_hostname4_resolved(status, hostent);
};
auto hostname6_resolved = [](void* arg, int status, int,
struct hostent* hostent)
{
Resolver* resolver = static_cast<Resolver*>(arg);
resolver->on_hostname6_resolved(status, hostent);
};
DNSHandler::instance.gethostbyname(hostname, hostname6_resolved,
this, AF_INET6);
DNSHandler::instance.gethostbyname(hostname, hostname4_resolved,
this, AF_INET);
}
void Resolver::on_hostname4_resolved(int status, struct hostent* hostent)
{
this->resolved4 = true;
if (status == ARES_SUCCESS)
this->fill_ares_addrinfo4(hostent);
else
this->error_msg = ::ares_strerror(status);
if (this->resolved4 && this->resolved6)
this->on_resolved();
}
void Resolver::on_hostname6_resolved(int status, struct hostent* hostent)
{
this->resolved6 = true;
if (status == ARES_SUCCESS)
this->fill_ares_addrinfo6(hostent);
else
this->error_msg = ::ares_strerror(status);
if (this->resolved4 && this->resolved6)
this->on_resolved();
}
void Resolver::on_resolved()
{
this->resolved = true;
this->resolving = false;
if (!this->cares_addrinfo)
{
if (this->error_cb)
this->error_cb(this->error_msg.data());
}
else
{
this->addr.reset(this->cares_addrinfo);
if (this->success_cb)
this->success_cb(this->addr.get());
}
}
void Resolver::fill_ares_addrinfo4(const struct hostent* hostent)
{
std::cout << "fill_ares_addrinfo4" << this->port << std::endl;
struct addrinfo* prev = this->cares_addrinfo;
struct in_addr** address = reinterpret_cast<struct in_addr**>(hostent->h_addr_list);
while (*address)
{
// Create a new addrinfo list element, and fill it
struct addrinfo* current = new struct addrinfo;
current->ai_flags = 0;
current->ai_family = hostent->h_addrtype;
current->ai_socktype = SOCK_STREAM;
current->ai_protocol = 0;
current->ai_addrlen = sizeof(struct sockaddr_in);
struct sockaddr_in* addr = new struct sockaddr_in;
addr->sin_family = hostent->h_addrtype;
addr->sin_port = htons(strtoul(this->port.data(), nullptr, 10));
addr->sin_addr.s_addr = (*address)->s_addr;
current->ai_addr = reinterpret_cast<struct sockaddr*>(addr);
current->ai_next = nullptr;
current->ai_canonname = nullptr;
current->ai_next = prev;
this->cares_addrinfo = current;
prev = current;
++address;
}
}
void Resolver::fill_ares_addrinfo6(const struct hostent* hostent)
{
std::cout << "fill_ares_addrinfo6" << this->port << std::endl;
struct addrinfo* prev = this->cares_addrinfo;
struct in6_addr** address = reinterpret_cast<struct in6_addr**>(hostent->h_addr_list);
while (*address)
{
// Create a new addrinfo list element, and fill it
struct addrinfo* current = new struct addrinfo;
current->ai_flags = 0;
current->ai_family = hostent->h_addrtype;
current->ai_socktype = SOCK_STREAM;
current->ai_protocol = 0;
current->ai_addrlen = sizeof(struct sockaddr_in6);
struct sockaddr_in6* addr = new struct sockaddr_in6;
addr->sin6_family = hostent->h_addrtype;
addr->sin6_port = htons(strtoul(this->port.data(), nullptr, 10));
::memcpy(addr->sin6_addr.s6_addr, (*address)->s6_addr, 16);
addr->sin6_flowinfo = 0;
addr->sin6_scope_id = 0;
current->ai_addr = reinterpret_cast<struct sockaddr*>(addr);
current->ai_next = nullptr;
current->ai_canonname = nullptr;
current->ai_next = prev;
this->cares_addrinfo = current;
prev = current;
++address;
}
}
#else // ifdef CARES_FOUND
void Resolver::start_resolving(const std::string& hostname, const std::string& port)
{
// If the resolution fails, the addr will be unset
this->addr.reset(nullptr);
struct addrinfo hints;
memset(&hints, 0, sizeof(struct addrinfo));
hints.ai_flags = 0;
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = 0;
struct addrinfo* addr_res = nullptr;
const int res = ::getaddrinfo(hostname.data(), port.data(),
&hints, &addr_res);
this->resolved = true;
if (res != 0)
{
this->error_msg = gai_strerror(res);
if (this->error_cb)
this->error_cb(this->error_msg.data());
}
else
{
this->addr.reset(addr_res);
if (this->success_cb)
this->success_cb(this->addr.get());
}
}
#endif // ifdef CARES_FOUND
std::string addr_to_string(const struct addrinfo* rp)
{
char buf[INET6_ADDRSTRLEN];
if (rp->ai_family == AF_INET)
return ::inet_ntop(rp->ai_family,
&reinterpret_cast<sockaddr_in*>(rp->ai_addr)->sin_addr,
buf, sizeof(buf));
else if (rp->ai_family == AF_INET6)
return ::inet_ntop(rp->ai_family,
&reinterpret_cast<sockaddr_in6*>(rp->ai_addr)->sin6_addr,
buf, sizeof(buf));
return {};
}
#ifndef RESOLVER_HPP_INCLUDED
#define RESOLVER_HPP_INCLUDED
#include "louloulibs.h"
#include <functional>
#include <memory>
#include <string>
#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>
struct AddrinfoDeleter
{
void operator()(struct addrinfo* addr)
{
#ifdef CARES_FOUND
while (addr)
{
delete addr->ai_addr;
auto next = addr->ai_next;
delete addr;
addr = next;
}
#else
freeaddrinfo(addr);
#endif
}
};
class Resolver
{
public:
using ErrorCallbackType = std::function<void(const char*)>;
using SuccessCallbackType = std::function<void(const struct addrinfo*)>;
Resolver();
~Resolver() = default;
bool is_resolving() const
{
#ifdef CARES_FOUND
return this->resolving;
#else
return false;
#endif
}
bool is_resolved() const
{
return this->resolved;
}
const auto& get_result() const
{
return this->addr;
}
std::string get_error_message() const
{
return this->error_msg;
}
void clear()
{
#ifdef CARES_FOUND
this->resolved6 = false;
this->resolved4 = false;
#endif
this->resolved = false;
this->resolving = false;
this->addr.reset();
this->cares_addrinfo = nullptr;
this->port.clear();
this->error_msg.clear();
}
void resolve(const std::string& hostname, const std::string& port,
SuccessCallbackType success_cb, ErrorCallbackType error_cb);
private:
void start_resolving(const std::string& hostname, const std::string& port);
#ifdef CARES_FOUND
void on_hostname4_resolved(int status, struct hostent* hostent);
void on_hostname6_resolved(int status, struct hostent* hostent);
void fill_ares_addrinfo4(const struct hostent* hostent);
void fill_ares_addrinfo6(const struct hostent* hostent);
void on_resolved();
bool resolved4;
bool resolved6;
bool resolving;
/**
* When using c-ares to resolve the host asynchronously, we need the
* c-ares callbacks to fill a structure (a struct addrinfo, for
* compatibility with getaddrinfo and the rest of the code that works when
* c-ares is not used) with all returned values (for example an IPv6 and
* an IPv4). The pointer is given to the unique_ptr to manage its lifetime.
*/
struct addrinfo* cares_addrinfo;
std::string port;
#endif
/**
* Tells if we finished the resolution process. It doesn't indicate if it
* was successful (it is true even if the result is an error).
*/
bool resolved;
std::string error_msg;
std::unique_ptr<struct addrinfo, AddrinfoDeleter> addr;
ErrorCallbackType error_cb;
SuccessCallbackType success_cb;
Resolver(const Resolver&) = delete;
Resolver(Resolver&&) = delete;
Resolver& operator=(const Resolver&) = delete;
Resolver& operator=(Resolver&&) = delete;
};
std::string addr_to_string(const struct addrinfo* rp);
#endif /* RESOLVER_HPP_INCLUDED */
......@@ -11,13 +11,9 @@
#include <stdexcept>
#include <unistd.h>
#include <errno.h>
#include <netdb.h>
#include <cstring>
#include <fcntl.h>
#include <iostream>
#include <arpa/inet.h>
#ifdef BOTAN_FOUND
# include <botan/hex.h>
# include <botan/tls_exceptn.h>
......@@ -43,24 +39,9 @@ TCPSocketHandler::TCPSocketHandler(std::shared_ptr<Poller> poller):
use_tls(false),
connected(false),
connecting(false),
#ifdef CARES_FOUND
resolving(false),
resolved(false),
resolved4(false),
resolved6(false),
cares_addrinfo(nullptr),
cares_error(),
#endif
hostname_resolution_failed(false)
{}
TCPSocketHandler::~TCPSocketHandler()
{
#ifdef CARES_FOUND
this->free_cares_addrinfo();
#endif
}
void TCPSocketHandler::init_socket(const struct addrinfo* rp)
{
if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1)
......@@ -91,17 +72,24 @@ void TCPSocketHandler::connect(const std::string& address, const std::string& po
{
// Get the addrinfo from getaddrinfo (or ares_gethostbyname), only if
// this is the first call of this function.
#ifdef CARES_FOUND
if (!this->resolved)
if (!this->resolver.is_resolved())
{
log_info("Trying to connect to " << address << ":" << port);
// Start the asynchronous process of resolving the hostname. Once
// the addresses have been found and `resolved` has been set to true
// (but connecting will still be false), TCPSocketHandler::connect()
// needs to be called, again.
this->resolving = true;
DNSHandler::instance.gethostbyname(address, this, AF_INET6);
DNSHandler::instance.gethostbyname(address, this, AF_INET);
this->resolver.resolve(address, port,
[this](const struct addrinfo*)
{
log_debug("Resolution success, calling connect() again");
this->connect();
},
[this](const char*)
{
log_debug("Resolution failed, calling connect() again");
this->connect();
});
return;
}
else
......@@ -109,39 +97,16 @@ void TCPSocketHandler::connect(const std::string& address, const std::string& po
// The c-ares resolved the hostname and the available addresses
// where saved in the cares_addrinfo linked list. Now, just use
// this list to try to connect.
addr_res = this->cares_addrinfo;
addr_res = this->resolver.get_result().get();
if (!addr_res)
{
this->hostname_resolution_failed = true;
const auto msg = this->cares_error;
const auto msg = this->resolver.get_error_message();
this->close();
this->on_connection_failed(msg);
return ;
}
}
#else
log_info("Trying to connect to " << address << ":" << port);
struct addrinfo hints;
memset(&hints, 0, sizeof(struct addrinfo));
hints.ai_flags = 0;
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = 0;
const int res = ::getaddrinfo(address.c_str(), port.c_str(), &hints, &addr_res);
if (res != 0)
{
log_warning("getaddrinfo failed: "s + gai_strerror(res));
this->hostname_resolution_failed = true;
this->close();
this->on_connection_failed(gai_strerror(res));
return ;
}
// Make sure the alloced structure is always freed at the end of the
// function
sg.add_callback([&addr_res](){ freeaddrinfo(addr_res); });
#endif
}
else
{ // This function is called again, use the saved addrinfo structure,
......@@ -335,30 +300,18 @@ void TCPSocketHandler::close()
}
this->connected = false;
this->connecting = false;
#ifdef CARES_FOUND
this->resolving = false;
this->resolved = false;
this->resolved4 = false;
this->resolved6 = false;
this->free_cares_addrinfo();
this->cares_error.clear();
#endif
this->in_buf.clear();
this->out_buf.clear();
this->port.clear();
this->resolver.clear();
}
void TCPSocketHandler::display_resolved_ip(struct addrinfo* rp) const
{
char buf[INET6_ADDRSTRLEN];
if (rp->ai_family == AF_INET)
log_debug("Connecting to IP address " << ::inet_ntop(rp->ai_family,
&reinterpret_cast<sockaddr_in*>(rp->ai_addr)->sin_addr,
buf, sizeof(buf)));
log_debug("Trying IPv4 address " << addr_to_string(rp));
else if (rp->ai_family == AF_INET6)
log_debug("Connecting to IPv6 address " << ::inet_ntop(rp->ai_family,
&reinterpret_cast<sockaddr_in6*>(rp->ai_addr)->sin6_addr,
buf, sizeof(buf)));
log_debug("Trying IPv6 address " << addr_to_string(rp));
}
void TCPSocketHandler::send_data(std::string&& data)
......@@ -399,11 +352,7 @@ bool TCPSocketHandler::is_connected() const
bool TCPSocketHandler::is_connecting() const
{
#ifdef CARES_FOUND
return this->connecting || this->resolving;
#else
return this->connecting;
#endif
return this->connecting || this->resolver.is_resolving();
}
void* TCPSocketHandler::get_receive_buffer(const size_t) const
......@@ -506,122 +455,11 @@ void TCPSocketHandler::on_tls_activated()
}
void Permissive_Credentials_Manager::verify_certificate_chain(const std::string& type,
const std::string& purported_hostname,
const std::vector<Botan::X509_Certificate>&)
const std::string& purported_hostname,
const std::vector<Botan::X509_Certificate>&)
{ // TODO: Offer the admin to disallow connection on untrusted
// certificates
log_debug("Checking remote certificate (" << type << ") for hostname " << purported_hostname);
}
#endif // BOTAN_FOUND
#ifdef CARES_FOUND
void TCPSocketHandler::on_hostname4_resolved(int status, struct hostent* hostent)
{
this->resolved4 = true;
if (status == ARES_SUCCESS)
this->fill_ares_addrinfo4(hostent);
else
this->cares_error = ::ares_strerror(status);
if (this->resolved4 && this->resolved6)
{
this->resolved = true;
this->resolving = false;
this->connect();
}
}
void TCPSocketHandler::on_hostname6_resolved(int status, struct hostent* hostent)
{
this->resolved6 = true;
if (status == ARES_SUCCESS)
this->fill_ares_addrinfo6(hostent);
else
this->cares_error = ::ares_strerror(status);
if (this->resolved4 && this->resolved6)
{
this->resolved = true;
this->resolving = false;
this->connect();
}
}
void TCPSocketHandler::fill_ares_addrinfo4(const struct hostent* hostent)
{
struct addrinfo* prev = this->cares_addrinfo;
struct in_addr** address = reinterpret_cast<struct in_addr**>(hostent->h_addr_list);
while (*address)
{
// Create a new addrinfo list element, and fill it
struct addrinfo* current = new struct addrinfo;