Commit 0168b96b authored by louiz’'s avatar louiz’

Add postgresql support

parent 5b27cee9
...@@ -130,6 +130,12 @@ elseif(NOT WITHOUT_SQLITE3) ...@@ -130,6 +130,12 @@ elseif(NOT WITHOUT_SQLITE3)
find_package(SQLITE3) find_package(SQLITE3)
endif() endif()
if(WITH_POSTGRESQL)
find_package(PQ REQUIRED)
elseif(NOT WITHOUT_POSTGRESQL)
find_package(PQ)
endif()
# #
## Set all the include directories, depending on what libraries are used ## Set all the include directories, depending on what libraries are used
# #
...@@ -193,6 +199,7 @@ if(SQLITE3_FOUND) ...@@ -193,6 +199,7 @@ if(SQLITE3_FOUND)
add_library(database OBJECT ${source_database}) add_library(database OBJECT ${source_database})
include_directories(database ${SQLITE3_INCLUDE_DIRS}) include_directories(database ${SQLITE3_INCLUDE_DIRS})
include_directories(database ${PQ_INCLUDE_DIRS})
set(USE_DATABASE TRUE) set(USE_DATABASE TRUE)
else() else()
add_library(database OBJECT "") add_library(database OBJECT "")
...@@ -261,7 +268,9 @@ if(LIBIDN_FOUND) ...@@ -261,7 +268,9 @@ if(LIBIDN_FOUND)
endif() endif()
if(USE_DATABASE) if(USE_DATABASE)
target_link_libraries(${PROJECT_NAME} ${SQLITE3_LIBRARIES}) target_link_libraries(${PROJECT_NAME} ${SQLITE3_LIBRARIES})
target_link_libraries(${PROJECT_NAME} ${PQ_LIBRARIES})
target_link_libraries(test_suite ${SQLITE3_LIBRARIES}) target_link_libraries(test_suite ${SQLITE3_LIBRARIES})
target_link_libraries(test_suite ${PQ_LIBRARIES})
endif() endif()
# Define a __FILENAME__ macro with the relative path (from the base project directory) # Define a __FILENAME__ macro with the relative path (from the base project directory)
......
...@@ -32,10 +32,12 @@ libiconv_ ...@@ -32,10 +32,12 @@ libiconv_
libuuid_ libuuid_
Generate unique IDs Generate unique IDs
sqlite3_ (option, but highly recommended) sqlite3_
Provides a way to store various options in a (sqlite3) database. Each user or
of the gateway can store their own values (for example their prefered port, libpq_
or their IRC password). Without this dependency, many interesting features Provides a way to store various options in a database. Each user of the
gateway can store their own values (for example their prefered port, or
their IRC password). Without this dependency, many interesting features
are missing. are missing.
libidn_ (optional, but recommended) libidn_ (optional, but recommended)
...@@ -165,3 +167,4 @@ to use biboumi. ...@@ -165,3 +167,4 @@ to use biboumi.
.. _systemd: https://www.freedesktop.org/wiki/Software/systemd/ .. _systemd: https://www.freedesktop.org/wiki/Software/systemd/
.. _biboumi.1.rst: doc/biboumi.1.rst .. _biboumi.1.rst: doc/biboumi.1.rst
.. _gcrypt: https://www.gnu.org/software/libgcrypt/ .. _gcrypt: https://www.gnu.org/software/libgcrypt/
.. _libpq: https://www.postgresql.org/docs/current/static/libpq.html
# - Find libpq
# Find the postgresql front end library
#
# This module defines the following variables:
# PQ_FOUND - True if library and include directory are found
# If set to TRUE, the following are also defined:
# PQ_INCLUDE_DIRS - The directory where to find the header file
# PQ_LIBRARIES - Where to find the library file
#
# For conveniance, these variables are also set. They have the same values
# than the variables above. The user can thus choose his/her prefered way
# to write them.
# PQ_LIBRARY
# PQ_INCLUDE_DIR
#
# This file is in the public domain
include(FindPkgConfig)
if(NOT PQ_FOUND)
pkg_check_modules(PQ libpq)
endif()
if(NOT PQ_FOUND)
find_path(PQ_INCLUDE_DIRS NAMES libpq-fe.h
DOC "The libpq include directory")
find_library(PQ_LIBRARIES NAMES pq
DOC "The pq library")
# Use some standard module to handle the QUIETLY and REQUIRED arguments, and
# set PQ_FOUND to TRUE if these two variables are set.
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(PQ REQUIRED_VARS PQ_LIBRARIES PQ_INCLUDE_DIRS)
if(PQ_FOUND)
set(PQ_LIBRARY ${PQ_LIBRARIES} CACHE INTERNAL "")
set(PQ_INCLUDE_DIR ${PQ_INCLUDE_DIRS} CACHE INTERNAL "")
set(PQ_FOUND ${PQ_FOUND} CACHE INTERNAL "")
endif()
endif()
mark_as_advanced(PQ_INCLUDE_DIRS PQ_LIBRARIES)
...@@ -13,5 +13,10 @@ struct Column ...@@ -13,5 +13,10 @@ struct Column
T value{}; T value{};
}; };
struct Id: Column<std::size_t> { static constexpr auto name = "id_"; struct Id: Column<std::size_t> {
static constexpr auto options = "PRIMARY KEY AUTOINCREMENT"; }; static constexpr std::size_t unset_value = static_cast<std::size_t>(-1);
static constexpr auto name = "id_";
static constexpr auto options = "PRIMARY KEY";
Id(): Column<std::size_t>(-1) {}
};
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <database/query.hpp> #include <database/query.hpp>
#include <database/table.hpp> #include <database/table.hpp>
#include <database/statement.hpp>
#include <string> #include <string>
...@@ -15,20 +16,17 @@ struct CountQuery: public Query ...@@ -15,20 +16,17 @@ struct CountQuery: public Query
this->body += std::move(name); this->body += std::move(name);
} }
int64_t execute(sqlite3* db) int64_t execute(DatabaseEngine& db)
{ {
auto statement = this->prepare(db); auto statement = db.prepare(this->body);
int64_t res = 0; int64_t res = 0;
if (sqlite3_step(statement.get()) == SQLITE_ROW) if (statement->step() != StepResult::Error)
res = sqlite3_column_int64(statement.get(), 0); res = statement->get_column_int64(0);
else else
{ {
log_error("Count request didn’t return a result"); log_error("Count request didn’t return a result");
return 0; return 0;
} }
if (sqlite3_step(statement.get()) != SQLITE_DONE)
log_warning("Count request returned more than one result.");
return res; return res;
} }
}; };
...@@ -7,16 +7,21 @@ ...@@ -7,16 +7,21 @@
#include <utils/time.hpp> #include <utils/time.hpp>
#include <config/config.hpp> #include <config/config.hpp>
#include <database/sqlite3_engine.hpp>
#include <database/postgresql_engine.hpp>
#include <database/engine.hpp>
#include <database/index.hpp> #include <database/index.hpp>
#include <memory>
#include <sqlite3.h> #include <sqlite3.h>
sqlite3* Database::db; std::unique_ptr<DatabaseEngine> Database::db;
Database::MucLogLineTable Database::muc_log_lines("MucLogLine_"); Database::MucLogLineTable Database::muc_log_lines("muclogline_");
Database::GlobalOptionsTable Database::global_options("GlobalOptions_"); Database::GlobalOptionsTable Database::global_options("globaloptions_");
Database::IrcServerOptionsTable Database::irc_server_options("IrcServerOptions_"); Database::IrcServerOptionsTable Database::irc_server_options("ircserveroptions_");
Database::IrcChannelOptionsTable Database::irc_channel_options("IrcChannelOptions_"); Database::IrcChannelOptionsTable Database::irc_channel_options("ircchanneloptions_");
Database::RosterTable Database::roster("roster"); Database::RosterTable Database::roster("roster");
std::map<Database::CacheKey, Database::EncodingIn::real_type> Database::encoding_in_cache{}; std::map<Database::CacheKey, Database::EncodingIn::real_type> Database::encoding_in_cache{};
...@@ -29,27 +34,26 @@ void Database::open(const std::string& filename) ...@@ -29,27 +34,26 @@ void Database::open(const std::string& filename)
// Try to open the specified database. // Try to open the specified database.
// Close and replace the previous database pointer if it succeeded. If it did // Close and replace the previous database pointer if it succeeded. If it did
// not, just leave things untouched // not, just leave things untouched
sqlite3* new_db; std::unique_ptr<DatabaseEngine> new_db;
auto res = sqlite3_open_v2(filename.data(), &new_db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr); static const auto psql_prefix = "postgresql://"s;
Database::close(); if (filename.substr(0, psql_prefix.size()) == psql_prefix)
if (res != SQLITE_OK) new_db = PostgresqlEngine::open("dbname="s + filename.substr(psql_prefix.size()));
{ else
log_error("Failed to open database file ", filename, ": ", sqlite3_errmsg(new_db)); new_db = Sqlite3Engine::open(filename);
sqlite3_close(new_db); if (!new_db)
throw std::runtime_error(""); return;
} Database::db = std::move(new_db);
Database::db = new_db; Database::muc_log_lines.create(*Database::db);
Database::muc_log_lines.create(Database::db); Database::muc_log_lines.upgrade(*Database::db);
Database::muc_log_lines.upgrade(Database::db); Database::global_options.create(*Database::db);
Database::global_options.create(Database::db); Database::global_options.upgrade(*Database::db);
Database::global_options.upgrade(Database::db); Database::irc_server_options.create(*Database::db);
Database::irc_server_options.create(Database::db); Database::irc_server_options.upgrade(*Database::db);
Database::irc_server_options.upgrade(Database::db); Database::irc_channel_options.create(*Database::db);
Database::irc_channel_options.create(Database::db); Database::irc_channel_options.upgrade(*Database::db);
Database::irc_channel_options.upgrade(Database::db); Database::roster.create(*Database::db);
Database::roster.create(Database::db); Database::roster.upgrade(*Database::db);
Database::roster.upgrade(Database::db); create_index<Database::Owner, Database::IrcChanName, Database::IrcServerName>(*Database::db, "archive_index", Database::muc_log_lines.get_name());
create_index<Database::Owner, Database::IrcChanName, Database::IrcServerName>(Database::db, "archive_index", Database::muc_log_lines.get_name());
} }
...@@ -59,7 +63,7 @@ Database::GlobalOptions Database::get_global_options(const std::string& owner) ...@@ -59,7 +63,7 @@ Database::GlobalOptions Database::get_global_options(const std::string& owner)
request.where() << Owner{} << "=" << owner; request.where() << Owner{} << "=" << owner;
Database::GlobalOptions options{Database::global_options.get_name()}; Database::GlobalOptions options{Database::global_options.get_name()};
auto result = request.execute(Database::db); auto result = request.execute(*Database::db);
if (result.size() == 1) if (result.size() == 1)
options = result.front(); options = result.front();
else else
...@@ -73,7 +77,7 @@ Database::IrcServerOptions Database::get_irc_server_options(const std::string& o ...@@ -73,7 +77,7 @@ Database::IrcServerOptions Database::get_irc_server_options(const std::string& o
request.where() << Owner{} << "=" << owner << " and " << Server{} << "=" << server; request.where() << Owner{} << "=" << owner << " and " << Server{} << "=" << server;
Database::IrcServerOptions options{Database::irc_server_options.get_name()}; Database::IrcServerOptions options{Database::irc_server_options.get_name()};
auto result = request.execute(Database::db); auto result = request.execute(*Database::db);
if (result.size() == 1) if (result.size() == 1)
options = result.front(); options = result.front();
else else
...@@ -91,7 +95,7 @@ Database::IrcChannelOptions Database::get_irc_channel_options(const std::string& ...@@ -91,7 +95,7 @@ Database::IrcChannelOptions Database::get_irc_channel_options(const std::string&
" and " << Server{} << "=" << server <<\ " and " << Server{} << "=" << server <<\
" and " << Channel{} << "=" << channel; " and " << Channel{} << "=" << channel;
Database::IrcChannelOptions options{Database::irc_channel_options.get_name()}; Database::IrcChannelOptions options{Database::irc_channel_options.get_name()};
auto result = request.execute(Database::db); auto result = request.execute(*Database::db);
if (result.size() == 1) if (result.size() == 1)
options = result.front(); options = result.front();
else else
...@@ -186,7 +190,7 @@ std::vector<Database::MucLogLine> Database::get_muc_logs(const std::string& owne ...@@ -186,7 +190,7 @@ std::vector<Database::MucLogLine> Database::get_muc_logs(const std::string& owne
if (limit >= 0) if (limit >= 0)
request.limit() << limit; request.limit() << limit;
auto result = request.execute(Database::db); auto result = request.execute(*Database::db);
return {result.crbegin(), result.crend()}; return {result.crbegin(), result.crend()};
} }
...@@ -207,7 +211,7 @@ void Database::delete_roster_item(const std::string& local, const std::string& r ...@@ -207,7 +211,7 @@ void Database::delete_roster_item(const std::string& local, const std::string& r
query << " WHERE " << Database::RemoteJid{} << "=" << remote << \ query << " WHERE " << Database::RemoteJid{} << "=" << remote << \
" AND " << Database::LocalJid{} << "=" << local; " AND " << Database::LocalJid{} << "=" << local;
query.execute(Database::db); // query.execute(*Database::db);
} }
bool Database::has_roster_item(const std::string& local, const std::string& remote) bool Database::has_roster_item(const std::string& local, const std::string& remote)
...@@ -216,7 +220,7 @@ bool Database::has_roster_item(const std::string& local, const std::string& remo ...@@ -216,7 +220,7 @@ bool Database::has_roster_item(const std::string& local, const std::string& remo
query.where() << Database::LocalJid{} << "=" << local << \ query.where() << Database::LocalJid{} << "=" << local << \
" and " << Database::RemoteJid{} << "=" << remote; " and " << Database::RemoteJid{} << "=" << remote;
auto res = query.execute(Database::db); auto res = query.execute(*Database::db);
return !res.empty(); return !res.empty();
} }
...@@ -226,20 +230,19 @@ std::vector<Database::RosterItem> Database::get_contact_list(const std::string& ...@@ -226,20 +230,19 @@ std::vector<Database::RosterItem> Database::get_contact_list(const std::string&
auto query = Database::roster.select(); auto query = Database::roster.select();
query.where() << Database::LocalJid{} << "=" << local; query.where() << Database::LocalJid{} << "=" << local;
return query.execute(Database::db); return query.execute(*Database::db);
} }
std::vector<Database::RosterItem> Database::get_full_roster() std::vector<Database::RosterItem> Database::get_full_roster()
{ {
auto query = Database::roster.select(); auto query = Database::roster.select();
return query.execute(Database::db); return query.execute(*Database::db);
} }
void Database::close() void Database::close()
{ {
sqlite3_close(Database::db); Database::db.release();
Database::db = nullptr;
} }
std::string Database::gen_uuid() std::string Database::gen_uuid()
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include <database/column.hpp> #include <database/column.hpp>
#include <database/count_query.hpp> #include <database/count_query.hpp>
#include <database/engine.hpp>
#include <utils/optional_bool.hpp> #include <utils/optional_bool.hpp>
#include <chrono> #include <chrono>
...@@ -25,11 +27,11 @@ class Database ...@@ -25,11 +27,11 @@ class Database
struct Owner: Column<std::string> { static constexpr auto name = "owner_"; }; struct Owner: Column<std::string> { static constexpr auto name = "owner_"; };
struct IrcChanName: Column<std::string> { static constexpr auto name = "ircChanName_"; }; struct IrcChanName: Column<std::string> { static constexpr auto name = "ircchanname_"; };
struct Channel: Column<std::string> { static constexpr auto name = "channel_"; }; struct Channel: Column<std::string> { static constexpr auto name = "channel_"; };
struct IrcServerName: Column<std::string> { static constexpr auto name = "ircServerName_"; }; struct IrcServerName: Column<std::string> { static constexpr auto name = "ircservername_"; };
struct Server: Column<std::string> { static constexpr auto name = "server_"; }; struct Server: Column<std::string> { static constexpr auto name = "server_"; };
...@@ -44,30 +46,30 @@ class Database ...@@ -44,30 +46,30 @@ class Database
struct Ports: Column<std::string> { static constexpr auto name = "ports_"; struct Ports: Column<std::string> { static constexpr auto name = "ports_";
Ports(): Column<std::string>("6667") {} }; Ports(): Column<std::string>("6667") {} };
struct TlsPorts: Column<std::string> { static constexpr auto name = "tlsPorts_"; struct TlsPorts: Column<std::string> { static constexpr auto name = "tlsports_";
TlsPorts(): Column<std::string>("6697;6670") {} }; TlsPorts(): Column<std::string>("6697;6670") {} };
struct Username: Column<std::string> { static constexpr auto name = "username_"; }; struct Username: Column<std::string> { static constexpr auto name = "username_"; };
struct Realname: Column<std::string> { static constexpr auto name = "realname_"; }; struct Realname: Column<std::string> { static constexpr auto name = "realname_"; };
struct AfterConnectionCommand: Column<std::string> { static constexpr auto name = "afterConnectionCommand_"; }; struct AfterConnectionCommand: Column<std::string> { static constexpr auto name = "afterconnectioncommand_"; };
struct TrustedFingerprint: Column<std::string> { static constexpr auto name = "trustedFingerprint_"; }; struct TrustedFingerprint: Column<std::string> { static constexpr auto name = "trustedfingerprint_"; };
struct EncodingOut: Column<std::string> { static constexpr auto name = "encodingOut_"; }; struct EncodingOut: Column<std::string> { static constexpr auto name = "encodingout_"; };
struct EncodingIn: Column<std::string> { static constexpr auto name = "encodingIn_"; }; struct EncodingIn: Column<std::string> { static constexpr auto name = "encodingin_"; };
struct MaxHistoryLength: Column<int> { static constexpr auto name = "maxHistoryLength_"; struct MaxHistoryLength: Column<int> { static constexpr auto name = "maxhistorylength_";
MaxHistoryLength(): Column<int>(20) {} }; MaxHistoryLength(): Column<int>(20) {} };
struct RecordHistory: Column<bool> { static constexpr auto name = "recordHistory_"; struct RecordHistory: Column<bool> { static constexpr auto name = "recordhistory_";
RecordHistory(): Column<bool>(true) {}}; RecordHistory(): Column<bool>(true) {}};
struct RecordHistoryOptional: Column<OptionalBool> { static constexpr auto name = "recordHistory_"; }; struct RecordHistoryOptional: Column<OptionalBool> { static constexpr auto name = "recordhistory_"; };
struct VerifyCert: Column<bool> { static constexpr auto name = "verifyCert_"; struct VerifyCert: Column<bool> { static constexpr auto name = "verifycert_";
VerifyCert(): Column<bool>(true) {} }; VerifyCert(): Column<bool>(true) {} };
struct Persistent: Column<bool> { static constexpr auto name = "persistent_"; struct Persistent: Column<bool> { static constexpr auto name = "persistent_";
...@@ -134,7 +136,7 @@ class Database ...@@ -134,7 +136,7 @@ class Database
static int64_t count(const TableType& table) static int64_t count(const TableType& table)
{ {
CountQuery query{table.get_name()}; CountQuery query{table.get_name()};
return query.execute(Database::db); return query.execute(*Database::db);
} }
static MucLogLineTable muc_log_lines; static MucLogLineTable muc_log_lines;
...@@ -142,7 +144,7 @@ class Database ...@@ -142,7 +144,7 @@ class Database
static IrcServerOptionsTable irc_server_options; static IrcServerOptionsTable irc_server_options;
static IrcChannelOptionsTable irc_channel_options; static IrcChannelOptionsTable irc_channel_options;
static RosterTable roster; static RosterTable roster;
static sqlite3* db; static std::unique_ptr<DatabaseEngine> db;
/** /**
* Some caches, to avoid doing very frequent query requests for a few options. * Some caches, to avoid doing very frequent query requests for a few options.
...@@ -177,6 +179,11 @@ class Database ...@@ -177,6 +179,11 @@ class Database
Database::encoding_in_cache.clear(); Database::encoding_in_cache.clear();
} }
static auto raw_exec(const std::string& query)
{
Database::db->raw_exec(query);
}
private: private:
static std::string gen_uuid(); static std::string gen_uuid();
static std::map<CacheKey, EncodingIn::real_type> encoding_in_cache; static std::map<CacheKey, EncodingIn::real_type> encoding_in_cache;
......
#pragma once
/**
* Interface to provide non-portable behaviour, specific to each
* database engine we want to support.
*
* Everything else (all portable stuf) should go outside of this class.
*/
#include <database/statement.hpp>
#include <memory>
#include <string>
#include <vector>
#include <tuple>
#include <set>
class DatabaseEngine
{
public:
DatabaseEngine() = default;
DatabaseEngine(const DatabaseEngine&) = delete;
DatabaseEngine& operator=(const DatabaseEngine&) = delete;
DatabaseEngine(DatabaseEngine&&) = delete;
DatabaseEngine& operator=(DatabaseEngine&&) = delete;
virtual std::set<std::string> get_all_columns_from_table(const std::string& table_name) = 0;
virtual std::tuple<bool, std::string> raw_exec(const std::string& query) = 0;
virtual std::unique_ptr<Statement> prepare(const std::string& query) = 0;
virtual void extract_last_insert_rowid(Statement& statement) = 0;
virtual std::string get_returning_id_sql_string(const std::string&)
{
return {};
}
virtual std::string id_column_type() = 0;
int64_t last_inserted_rowid{-1};
};
#pragma once #pragma once
#include <sqlite3.h> #include <database/engine.hpp>
#include <string> #include <string>
#include <tuple> #include <tuple>
...@@ -25,18 +25,14 @@ add_column_name(std::string& out) ...@@ -25,18 +25,14 @@ add_column_name(std::string& out)
} }
template <typename... Columns> template <typename... Columns>
void create_index(sqlite3* db, const std::string& name, const std::string& table) void create_index(DatabaseEngine& db, const std::string& name, const std::string& table)
{ {
std::string res{"CREATE INDEX IF NOT EXISTS "}; std::string query{"CREATE INDEX IF NOT EXISTS "};
res += name + " ON " + table + "("; query += name + " ON " + table + "(";
add_column_name<0, Columns...>(res); add_column_name<0, Columns...>(query);
res += ")"; query += ")";
char* error; auto result = db.raw_exec(query);
const auto result = sqlite3_exec(db, res.data(), nullptr, nullptr, &error); if (std::get<0>(result) == false)
if (result != SQLITE_OK) log_error("Error executing query: ", std::get<1>(result));
{
log_error("Error executing query: ", error);
sqlite3_free(error);
}
} }
...@@ -12,62 +12,63 @@ ...@@ -12,62 +12,63 @@
#include <sqlite3.h> #include <sqlite3.h>
template <int N, typename ColumnType, typename... T> template <std::size_t N=0, typename... T>
typename std::enable_if<!std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type typename std::enable_if<N < sizeof...(T), void>::type
actual_bind(Statement& statement, std::vector<std::string>& params, const std::tuple<T...>&) update_autoincrement_id(std::tuple<T...>& columns, Statement& statement)
{ {
const auto value = params.front(); using ColumnType = typename std::decay<decltype(std::get<N>(columns))>::type;
params.erase(params.begin()); if (std::is_same<ColumnType, Id>::value)
if (sqlite3_bind_text(statement.get(), N + 1, value.data(), static_cast<int>(value.size()), SQLITE_TRANSIENT) != SQLITE_OK)
log_error("Failed to bind ", value, " to param ", N);
}
template <int N, typename ColumnType, typename... T>
typename std::enable_if<std::is_same<std::decay_t<ColumnType>, Id>::value, void>::type