Commit b86547dc authored by louiz’'s avatar louiz’

Implement async DNS resolution using c-ares

fix #2533
parent a1713572
...@@ -39,6 +39,12 @@ elseif(NOT WITHOUT_BOTAN) ...@@ -39,6 +39,12 @@ elseif(NOT WITHOUT_BOTAN)
find_package(BOTAN) find_package(BOTAN)
endif() endif()
if(WITH_CARES)
find_package(CARES REQUIRED)
elseif(NOT WITHOUT_CARES)
find_package(CARES)
endif()
# #
## Get the software version ## Get the software version
# #
...@@ -84,6 +90,10 @@ if(BOTAN_FOUND) ...@@ -84,6 +90,10 @@ if(BOTAN_FOUND)
include_directories(SYSTEM ${BOTAN_INCLUDE_DIRS}) include_directories(SYSTEM ${BOTAN_INCLUDE_DIRS})
endif() endif()
if(CARES_FOUND)
include_directories(${CARES_INCLUDE_DIRS})
endif()
set(POLLER_DOCSTRING "Choose the poller between POLL and EPOLL (Linux-only)") set(POLLER_DOCSTRING "Choose the poller between POLL and EPOLL (Linux-only)")
if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") if(${CMAKE_SYSTEM_NAME} MATCHES "Linux")
set(POLLER "EPOLL" CACHE STRING ${POLLER_DOCSTRING}) set(POLLER "EPOLL" CACHE STRING ${POLLER_DOCSTRING})
...@@ -145,6 +155,9 @@ target_link_libraries(network logger) ...@@ -145,6 +155,9 @@ target_link_libraries(network logger)
if(BOTAN_FOUND) if(BOTAN_FOUND)
target_link_libraries(network ${BOTAN_LIBRARIES}) target_link_libraries(network ${BOTAN_LIBRARIES})
endif() endif()
if(CARES_FOUND)
target_link_libraries(network ${CARES_LIBRARIES})
endif()
# #
## irclib ## irclib
......
...@@ -4,4 +4,5 @@ ...@@ -4,4 +4,5 @@
#cmakedefine SYSTEMD_FOUND #cmakedefine SYSTEMD_FOUND
#cmakedefine POLLER ${POLLER} #cmakedefine POLLER ${POLLER}
#cmakedefine BOTAN_FOUND #cmakedefine BOTAN_FOUND
#cmakedefine CARES_FOUND
#cmakedefine BIBOUMI_VERSION "${BIBOUMI_VERSION}" #cmakedefine BIBOUMI_VERSION "${BIBOUMI_VERSION}"
#include <network/tcp_socket_handler.hpp>
#include <xmpp/xmpp_component.hpp> #include <xmpp/xmpp_component.hpp>
#include <utils/timed_events.hpp> #include <utils/timed_events.hpp>
#include <network/poller.hpp> #include <network/poller.hpp>
...@@ -11,6 +10,10 @@ ...@@ -11,6 +10,10 @@
#include <signal.h> #include <signal.h>
#ifdef CARES_FOUND
# include <network/dns_handler.hpp>
#endif
// A flag set by the SIGINT signal handler. // A flag set by the SIGINT signal handler.
static volatile std::atomic<bool> stop(false); static volatile std::atomic<bool> stop(false);
// Flag set by the SIGUSR1/2 signal handler. // Flag set by the SIGUSR1/2 signal handler.
...@@ -95,6 +98,10 @@ int main(int ac, char** av) ...@@ -95,6 +98,10 @@ int main(int ac, char** av)
xmpp_component->start(); xmpp_component->start();
#ifdef CARES_FOUND
DNSHandler::instance.watch_dns_sockets(p);
#endif
auto timeout = TimedEventsManager::instance().get_timeout(); auto timeout = TimedEventsManager::instance().get_timeout();
while (p->poll(timeout) != -1) while (p->poll(timeout) != -1)
{ {
...@@ -108,6 +115,9 @@ int main(int ac, char** av) ...@@ -108,6 +115,9 @@ int main(int ac, char** av)
exiting = true; exiting = true;
stop.store(false); stop.store(false);
xmpp_component->shutdown(); xmpp_component->shutdown();
#ifdef CARES_FOUND
DNSHandler::instance.destroy();
#endif
// Cancel the timer for an potential reconnection // Cancel the timer for an potential reconnection
TimedEventsManager::instance().cancel("XMPP reconnection"); TimedEventsManager::instance().cancel("XMPP reconnection");
} }
...@@ -153,6 +163,10 @@ int main(int ac, char** av) ...@@ -153,6 +163,10 @@ int main(int ac, char** av)
xmpp_component->close(); xmpp_component->close();
if (exiting && p->size() == 1 && xmpp_component->is_document_open()) if (exiting && p->size() == 1 && xmpp_component->is_document_open())
xmpp_component->close_document(); xmpp_component->close_document();
#ifdef CARES_FOUND
if (!exiting)
DNSHandler::instance.watch_dns_sockets(p);
#endif
if (exiting) // If we are exiting, do not wait for any timed event if (exiting) // If we are exiting, do not wait for any timed event
timeout = utils::no_timeout; timeout = utils::no_timeout;
else else
......
#include <config.h>
#ifdef CARES_FOUND
#include <network/dns_socket_handler.hpp>
#include <network/tcp_socket_handler.hpp>
#include <network/dns_handler.hpp>
#include <network/poller.hpp>
#include <algorithm>
#include <stdexcept>
DNSHandler DNSHandler::instance;
using namespace std::string_literals;
void on_hostname4_resolved(void* arg, int status, int, struct hostent* hostent)
{
TCPSocketHandler* socket_handler = static_cast<TCPSocketHandler*>(arg);
socket_handler->on_hostname4_resolved(status, hostent);
}
void on_hostname6_resolved(void* arg, int status, int, struct hostent* hostent)
{
TCPSocketHandler* socket_handler = static_cast<TCPSocketHandler*>(arg);
socket_handler->on_hostname6_resolved(status, hostent);
}
DNSHandler::DNSHandler()
{
int ares_error;
if ((ares_error = ::ares_library_init(ARES_LIB_INIT_ALL)) != 0)
throw std::runtime_error("Failed to initialize c-ares lib: "s + ares_strerror(ares_error));
if ((ares_error = ::ares_init(&this->channel)) != ARES_SUCCESS)
throw std::runtime_error("Failed to initialize c-ares channel: "s + ares_strerror(ares_error));
}
ares_channel& DNSHandler::get_channel()
{
return this->channel;
}
void DNSHandler::destroy()
{
this->socket_handlers.clear();
::ares_destroy(this->channel);
::ares_library_cleanup();
}
void DNSHandler::gethostbyname(const std::string& name,
TCPSocketHandler* socket_handler, int family)
{
socket_handler->free_cares_addrinfo();
if (family == AF_INET)
::ares_gethostbyname(this->channel, name.data(), family,
&::on_hostname4_resolved, socket_handler);
else
::ares_gethostbyname(this->channel, name.data(), family,
&::on_hostname6_resolved, socket_handler);
}
void DNSHandler::watch_dns_sockets(std::shared_ptr<Poller>& poller)
{
fd_set readers;
fd_set writers;
FD_ZERO(&readers);
FD_ZERO(&writers);
int ndfs = ::ares_fds(this->channel, &readers, &writers);
// For each existing DNS socket, see if we are still supposed to watch it,
// if not then erase it
this->socket_handlers.erase(
std::remove_if(this->socket_handlers.begin(), this->socket_handlers.end(),
[&readers](const auto& dns_socket)
{
return !FD_ISSET(dns_socket->get_socket(), &readers);
}),
this->socket_handlers.end());
for (auto i = 0; i < ndfs; ++i)
{
bool read = FD_ISSET(i, &readers);
bool write = FD_ISSET(i, &writers);
// Look for the DNSSocketHandler with this fd
auto it = std::find_if(this->socket_handlers.begin(),
this->socket_handlers.end(),
[i](const auto& socket_handler)
{
return i == socket_handler->get_socket();
});
if (!read && !write) // No need to read or write to it
{ // If found, erase it and stop watching it because it is not
// needed anymore
if (it != this->socket_handlers.end())
// The socket destructor removes it from the poller
this->socket_handlers.erase(it);
}
else // We need to write and/or read to it
{ // If not found, create it because we need to watch it
if (it == this->socket_handlers.end())
{
this->socket_handlers.emplace_front(std::make_unique<DNSSocketHandler>(poller, i));
it = this->socket_handlers.begin();
}
poller->add_socket_handler(it->get());
if (write)
poller->watch_send_events(it->get());
}
}
}
#endif /* CARES_FOUND */
#ifndef DNS_HANDLER_HPP_INCLUDED
#define DNS_HANDLER_HPP_INCLUDED
#include <config.h>
#ifdef CARES_FOUND
class TCPSocketHandler;
class Poller;
class DNSSocketHandler;
# include <ares.h>
# include <memory>
# include <string>
# include <list>
void on_hostname4_resolved(void* arg, int status, int, struct hostent* hostent);
void on_hostname6_resolved(void* arg, int status, int, struct hostent* hostent);
/**
* Class managing DNS resolution. It should only be statically instanciated
* once in SocketHandler. It manages ares channel and calls various
* functions of that library.
*/
class DNSHandler
{
public:
DNSHandler();
~DNSHandler() = default;
void gethostbyname(const std::string& name, TCPSocketHandler* socket_handler,
int family);
/**
* Call ares_fds to know what fd needs to be watched by the poller, create
* or destroy DNSSocketHandlers depending on the result.
*/
void watch_dns_sockets(std::shared_ptr<Poller>& poller);
/**
* Destroy and stop watching all the DNS sockets. Then de-init the channel
* and library.
*/
void destroy();
ares_channel& get_channel();
static DNSHandler instance;
private:
/**
* The list of sockets that needs to be watched, according to the last
* call to ares_fds. DNSSocketHandlers are added to it or removed from it
* in the watch_dns_sockets() method
*/
std::list<std::unique_ptr<DNSSocketHandler>> socket_handlers;
ares_channel channel;
DNSHandler(const DNSHandler&) = delete;
DNSHandler(DNSHandler&&) = delete;
DNSHandler& operator=(const DNSHandler&) = delete;
DNSHandler& operator=(DNSHandler&&) = delete;
};
#endif /* CARES_FOUND */
#endif /* DNS_HANDLER_HPP_INCLUDED */
#include <config.h>
#ifdef CARES_FOUND
#include <network/dns_socket_handler.hpp>
#include <network/dns_handler.hpp>
#include <network/poller.hpp>
#include <ares.h>
DNSSocketHandler::DNSSocketHandler(std::shared_ptr<Poller> poller,
const socket_t socket):
SocketHandler(poller, socket)
{
}
DNSSocketHandler::~DNSSocketHandler()
{
}
void DNSSocketHandler::connect()
{
}
void DNSSocketHandler::on_recv()
{
// always stop watching send and read events. We will re-watch them if the
// next call to ares_fds tell us to
this->poller->remove_socket_handler(this->socket);
::ares_process_fd(DNSHandler::instance.get_channel(), this->socket, ARES_SOCKET_BAD);
}
void DNSSocketHandler::on_send()
{
// always stop watching send and read events. We will re-watch them if the
// next call to ares_fds tell us to
this->poller->remove_socket_handler(this->socket);
::ares_process_fd(DNSHandler::instance.get_channel(), ARES_SOCKET_BAD, this->socket);
}
bool DNSSocketHandler::is_connected() const
{
return true;
}
#endif /* CARES_FOUND */
#ifndef DNS_SOCKET_HANDLER_HPP
# define DNS_SOCKET_HANDLER_HPP
#include <config.h>
#ifdef CARES_FOUND
#include <network/socket_handler.hpp>
#include <ares.h>
/**
* Manage a socket returned by ares_fds. We do not create, open or close the
* socket ourself: this is done by c-ares. We just call ares_process_fd()
* with the correct parameters, depending on what can be done on that socket
* (Poller reported it to be writable or readeable)
*/
class DNSSocketHandler: public SocketHandler
{
public:
explicit DNSSocketHandler(std::shared_ptr<Poller> poller, const socket_t socket);
~DNSSocketHandler();
/**
* Just call dns_process_fd, c-ares will do its work of send()ing or
* recv()ing the data it wants on that socket.
*/
void on_recv() override final;
void on_send() override final;
/**
* Do nothing, because we are always considered to be connected, since the
* connection is done by c-ares and not by us.
*/
void connect() override final;
/**
* Always true, see the comment for connect()
*/
bool is_connected() const override final;
private:
DNSSocketHandler(const DNSSocketHandler&) = delete;
DNSSocketHandler(DNSSocketHandler&&) = delete;
DNSSocketHandler& operator=(const DNSSocketHandler&) = delete;
DNSSocketHandler& operator=(DNSSocketHandler&&) = delete;
};
#endif // CARES_FOUND
#endif // DNS_SOCKET_HANDLER_HPP
#ifndef SOCKET_HANDLER_HPP #ifndef SOCKET_HANDLER_HPP
# define SOCKET_HANDLER_HPP # define SOCKET_HANDLER_HPP
#include <config.h>
#include <memory> #include <memory>
class Poller; class Poller;
...@@ -19,6 +20,7 @@ public: ...@@ -19,6 +20,7 @@ public:
virtual void on_send() = 0; virtual void on_send() = 0;
virtual void connect() = 0; virtual void connect() = 0;
virtual bool is_connected() const = 0; virtual bool is_connected() const = 0;
socket_t get_socket() const socket_t get_socket() const
{ return this->socket; } { return this->socket; }
......
#include <network/tcp_socket_handler.hpp> #include <network/tcp_socket_handler.hpp>
#include <network/dns_handler.hpp>
#include <utils/timed_events.hpp> #include <utils/timed_events.hpp>
#include <utils/scopeguard.hpp> #include <utils/scopeguard.hpp>
...@@ -42,8 +43,22 @@ TCPSocketHandler::TCPSocketHandler(std::shared_ptr<Poller> poller): ...@@ -42,8 +43,22 @@ TCPSocketHandler::TCPSocketHandler(std::shared_ptr<Poller> poller):
use_tls(false), use_tls(false),
connected(false), connected(false),
connecting(false) connecting(false)
#ifdef CARES_FOUND
,resolved(false),
resolved4(false),
resolved6(false),
cares_addrinfo(nullptr),
cares_error()
#endif
{} {}
TCPSocketHandler::~TCPSocketHandler()
{
#ifdef CARES_FOUND
this->free_cares_addrinfo();
#endif
}
void TCPSocketHandler::init_socket(const struct addrinfo* rp) void TCPSocketHandler::init_socket(const struct addrinfo* rp)
{ {
if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1) if ((this->socket = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol)) == -1)
...@@ -72,9 +87,35 @@ void TCPSocketHandler::connect(const std::string& address, const std::string& po ...@@ -72,9 +87,35 @@ void TCPSocketHandler::connect(const std::string& address, const std::string& po
if (!this->connecting) if (!this->connecting)
{ {
// Get the addrinfo from getaddrinfo (or ares_gethostbyname), only if
// this is the first call of this function.
#ifdef CARES_FOUND
if (!this->resolved)
{
log_info("Trying to connect to " << address << ":" << port);
// Start the asynchronous process of resolving the hostname. Once
// the addresses have been found and `resolved` has been set to true
// (but connecting will still be false), TCPSocketHandler::connect()
// needs to be called, again.
DNSHandler::instance.gethostbyname(address, this, AF_INET6);
DNSHandler::instance.gethostbyname(address, this, AF_INET);
return;
}
else
{
// The c-ares resolved the hostname and the available addresses
// where saved in the cares_addrinfo linked list. Now, just use
// this list to try to connect.
addr_res = this->cares_addrinfo;
if (!addr_res)
{
this->close();
this->on_connection_failed(this->cares_error);
return ;
}
}
#else
log_info("Trying to connect to " << address << ":" << port); log_info("Trying to connect to " << address << ":" << port);
// Get the addrinfo from getaddrinfo, only if this is the first call
// of this function.
struct addrinfo hints; struct addrinfo hints;
memset(&hints, 0, sizeof(struct addrinfo)); memset(&hints, 0, sizeof(struct addrinfo));
hints.ai_flags = 0; hints.ai_flags = 0;
...@@ -94,6 +135,7 @@ void TCPSocketHandler::connect(const std::string& address, const std::string& po ...@@ -94,6 +135,7 @@ void TCPSocketHandler::connect(const std::string& address, const std::string& po
// Make sure the alloced structure is always freed at the end of the // Make sure the alloced structure is always freed at the end of the
// function // function
sg.add_callback([&addr_res](){ freeaddrinfo(addr_res); }); sg.add_callback([&addr_res](){ freeaddrinfo(addr_res); });
#endif
} }
else else
{ // This function is called again, use the saved addrinfo structure, { // This function is called again, use the saved addrinfo structure,
...@@ -144,9 +186,9 @@ void TCPSocketHandler::connect(const std::string& address, const std::string& po ...@@ -144,9 +186,9 @@ void TCPSocketHandler::connect(const std::string& address, const std::string& po
// If the connection has not succeeded or failed in 5s, we consider // If the connection has not succeeded or failed in 5s, we consider
// it to have failed // it to have failed
TimedEventsManager::instance().add_event( TimedEventsManager::instance().add_event(
TimedEvent(std::chrono::steady_clock::now() + 5s, TimedEvent(std::chrono::steady_clock::now() + 5s,
std::bind(&TCPSocketHandler::on_connection_timeout, this), std::bind(&TCPSocketHandler::on_connection_timeout, this),
"connection_timeout"s + std::to_string(this->socket))); "connection_timeout"s + std::to_string(this->socket)));
return ; return ;
} }
log_info("Connection failed:" << strerror(errno)); log_info("Connection failed:" << strerror(errno));
...@@ -321,7 +363,11 @@ bool TCPSocketHandler::is_connected() const ...@@ -321,7 +363,11 @@ bool TCPSocketHandler::is_connected() const
bool TCPSocketHandler::is_connecting() const bool TCPSocketHandler::is_connecting() const
{ {
#ifdef CARES_FOUND
return this->connecting || !this->resolved;
#else
return this->connecting; return this->connecting;
#endif
} }
void* TCPSocketHandler::get_receive_buffer(const size_t) const void* TCPSocketHandler::get_receive_buffer(const size_t) const
...@@ -413,4 +459,114 @@ void TCPSocketHandler::on_tls_activated() ...@@ -413,4 +459,114 @@ void TCPSocketHandler::on_tls_activated()
{ {
this->send_data(""); this->send_data("");
} }
#endif // BOTAN_FOUND #endif // BOTAN_FOUND
#ifdef CARES_FOUND
void TCPSocketHandler::on_hostname4_resolved(int status, struct hostent* hostent)
{
this->resolved4 = true;
if (status == ARES_SUCCESS)
this->fill_ares_addrinfo4(hostent);
else
this->cares_error = ::ares_strerror(status);
if (this->resolved4 && this->resolved6)
{
this->resolved = true;
this->connect();
}
}
void TCPSocketHandler::on_hostname6_resolved(int status, struct hostent* hostent)
{
this->resolved6 = true;
if (status == ARES_SUCCESS)
this->fill_ares_addrinfo6(hostent);
else
this->cares_error = ::ares_strerror(status);
if (this->resolved4 && this->resolved6)
{
this->resolved = true;
this->connect();
}
}
void TCPSocketHandler::fill_ares_addrinfo4(const struct hostent* hostent)
{
struct addrinfo* prev = this->cares_addrinfo;
struct in_addr** address = reinterpret_cast<struct in_addr**>(hostent->h_addr_list);
while (*address)
{
// Create a new addrinfo list element, and fill it
struct addrinfo* current = new struct addrinfo;
current->ai_flags = 0;
current->ai_family = hostent->h_addrtype;
current->ai_socktype = SOCK_STREAM;
current->ai_protocol = 0;
current->ai_addrlen = sizeof(struct sockaddr_in);
struct sockaddr_in* addr = new struct sockaddr_in;
addr->sin_family = hostent->h_addrtype;
addr->sin_port = htons(strtoul(this->port.data(), nullptr, 10));
addr->sin_addr.s_addr = (*address)->s_addr;
current->ai_addr = reinterpret_cast<struct sockaddr*>(addr);
current->ai_next = nullptr;
current->ai_canonname = nullptr;
current->ai_next = prev;
this->cares_addrinfo = current;
prev = current;
++address;
}
}
void TCPSocketHandler::fill_ares_addrinfo6(const struct hostent* hostent)
{
struct addrinfo* prev = this->cares_addrinfo;
struct in6_addr** address = reinterpret_cast<struct in6_addr**>(hostent->h_addr_list);
while (*address)
{
// Create a new addrinfo list element, and fill it
struct addrinfo* current = new struct addrinfo;
current->ai_flags = 0;
current->ai_family = hostent->h_addrtype;
current->ai_socktype = SOCK_STREAM;
current->ai_protocol = 0;
current->ai_addrlen = sizeof(struct sockaddr_in6);
struct sockaddr_in6* addr = new struct sockaddr_in6;
addr->sin6_family = hostent->h_addrtype;
addr->sin6_port = htons(strtoul(this->port.data(), nullptr, 10));
::memcpy(addr->sin6_addr.s6_addr, (*address)->s6_addr, 16);
addr->sin6_flowinfo = 0;
addr->sin6_scope_id = 0;
current->ai_addr = reinterpret_cast<struct sockaddr*>(addr);
current->ai_next = nullptr;
current->ai_canonname = nullptr;