From 45482c1e7f8b7dd7d55d2581469a65d9cd02b754 Mon Sep 17 00:00:00 2001 From: Roland Reichwein Date: Sat, 4 Feb 2023 19:02:53 +0100 Subject: Async session --- tests/test-whiteboard.cpp | 8 +- whiteboard.cpp | 519 +++++++++++++++++++++++++++------------------- whiteboard.h | 11 +- 3 files changed, 318 insertions(+), 220 deletions(-) diff --git a/tests/test-whiteboard.cpp b/tests/test-whiteboard.cpp index bed2e5a..c55c87c 100644 --- a/tests/test-whiteboard.cpp +++ b/tests/test-whiteboard.cpp @@ -160,10 +160,10 @@ public: } private: - boost::asio::io_context ioc_; - boost::asio::ip::tcp::resolver::results_type resolver_results_; - std::unique_ptr> ws_; - boost::asio::ip::tcp::endpoint ep_; + boost::asio::io_context ioc_; + boost::asio::ip::tcp::resolver::results_type resolver_results_; + std::unique_ptr> ws_; + boost::asio::ip::tcp::endpoint ep_; }; // diff --git a/whiteboard.cpp b/whiteboard.cpp index ce196f4..ea56cc3 100644 --- a/whiteboard.cpp +++ b/whiteboard.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -31,6 +32,8 @@ #include #include #include +#include +#include #include @@ -45,6 +48,7 @@ namespace pt = boost::property_tree; using namespace std::string_literals; +using namespace std::placeholders; namespace fs = std::filesystem; namespace { @@ -69,20 +73,6 @@ Whiteboard::Whiteboard() { } -// contents of cleanup thread; looping -void Whiteboard::storage_cleanup() -{ - while(true) { - { - std::lock_guard lock(m_storage_mutex); - if (!m_storage) - throw std::runtime_error("Storage not initialized"); - m_storage->cleanup(); - } - std::this_thread::sleep_for(std::chrono::hours(24)); - } -} - pt::ptree make_ptree(const std::initializer_list>& key_values) { pt::ptree ptree; @@ -99,205 +89,291 @@ std::string make_xml(const std::initializer_listgetRevision(id))}, - {"pos", std::to_string(m_storage->getCursorPos(id)) } - })}; - ptree.put_child("serverinfo.diff", diff.get_structure().get_child("diff")); - boost::beast::ostream(buffer) << Reichwein::XML::plain_xml(ptree); - std::lock_guard lock(m_websocket_mutex); - try { - ci->write(buffer.data()); - } catch (const std::exception& ex) { - std::cerr << "Warning: Notify getdiff write for " << ci << " not possible, id " << id << std::endl; - m_registry.dump(); - } - } - }); -} -void Whiteboard::notify_other_connections_pos(Whiteboard::connection& c, const std::string& id) +class session: public std::enable_shared_from_this { - std::for_each(m_registry.begin(id), m_registry.end(id), [&](const Whiteboard::connection& ci) - { - if (c != ci) { - boost::beast::flat_buffer buffer; - boost::beast::ostream(buffer) << make_xml({ - {"type", "getpos"}, - {"pos", std::to_string(m_storage->getCursorPos(id)) } - }); - std::lock_guard lock(m_websocket_mutex); - try { - ci->write(buffer.data()); - } catch (const std::exception& ex) { - std::cerr << "Warning: Notify getpos write for " << ci << " not possible, id " << id << std::endl; - m_registry.dump(); - } - } - }); -} - -std::string Whiteboard::handle_request(Whiteboard::connection& c, const std::string& request) -{ - try { - std::lock_guard lock(m_storage_mutex); - if (!m_storage) - throw std::runtime_error("Storage not initialized"); - - pt::ptree xml; - std::istringstream ss{request}; - pt::xml_parser::read_xml(ss, xml); - - std::string command {xml.get("request.command")}; - - if (command == "modify") { - std::string id {xml.get("request.id")}; - - int baserev {xml.get("request.baserev")}; - if (baserev != m_storage->getRevision(id)) - return make_xml({{"type", "error"}, {"message", "Bad base revision ("s + std::to_string(baserev) + "). Current: "s + std::to_string(m_storage->getRevision(id)) }}); - - pt::ptree ptree; - ptree.put_child("diff", xml.get_child("request.diff")); - Diff d{ptree}; - if (!d.empty()) { - std::string data {m_storage->getDocument(id)}; - data = d.apply(data); - - m_storage->setDocument(id, data); - m_registry.setId(c, id); - notify_other_connections_diff(c, id, d); - } - - int pos {xml.get("request.pos")}; - if (m_storage->getCursorPos(id) != pos) { - m_storage->setCursorPos(id, pos); - notify_other_connections_pos(c, id); - } - return make_xml({{"type", "modify"}, {"revision", std::to_string(m_storage->getRevision(id)) }}); - } else if (command == "cursorpos") { - std::string id {xml.get("request.id")}; - int pos {xml.get("request.pos")}; - if (m_storage->getCursorPos(id) != pos) { - m_storage->setCursorPos(id, pos); - notify_other_connections_pos(c, id); - } - return {}; - } else if (command == "getfile") { - std::string id {xml.get("request.id")}; - - std::string filedata; - try { - filedata = m_storage->getDocument(id); - } catch (const std::runtime_error&) { - m_storage->setDocument(id, filedata); - } - - if (filedata.size() > 30000000) - throw std::runtime_error("File too big"); - m_registry.setId(c, id); - - return make_xml({ - {"type", "getfile"}, - {"data", filedata}, - {"revision", std::to_string(m_storage->getRevision(id)) }, - {"pos", std::to_string(m_storage->getCursorPos(id)) } - }); - } else if (command == "getpos") { - std::string id {xml.get("request.id")}; - - return make_xml({ - {"type", "getpos"}, - {"pos", std::to_string(m_storage->getCursorPos(id)) } - }); - } else if (command == "newid") { - return make_xml({{"type", "newid"}, {"id", m_storage->generate_id()}}); - } else if (command == "qrcode") { - std::string url{xml.get("request.url")}; - - if (url.size() > 1000) - throw std::runtime_error("URL too big"); - - std::string pngdata {QRCode::getQRCode(url)}; - - return make_xml({{"type", "qrcode"}, {"png", Reichwein::Base64::encode64(pngdata)}}); - } else if (command == "getversion") { - return make_xml({ - {"type", "version"}, - {"version", WHITEBOARD_VERSION } - }); - } else if (command == "getstats") { - return make_xml({ - {"type", "stats" }, - {"dbsizegross", std::to_string(m_storage->dbsize_gross()) }, - {"dbsizenet", std::to_string(m_storage->dbsize_net()) }, - {"numberofdocuments", std::to_string(m_storage->getNumberOfDocuments()) }, - {"numberofconnections", std::to_string(m_registry.number_of_connections()) }, - }); - } else { - throw std::runtime_error("Bad command: "s + command); - } - - } catch (const std::exception& ex) { - return make_xml({{"type", "error"}, {"message", "Message handling error: "s + ex.what()}}); +public: + using connection = std::shared_ptr>; + + session(ConnectionRegistry& registry, Storage& storage, std::mutex& storage_mutex, std::mutex& websocket_mutex, boost::asio::ip::tcp::socket socket): + m_registry(registry), + m_storage(storage), + m_storage_mutex(storage_mutex), + m_websocket_mutex(websocket_mutex), + m_ws(std::make_shared(std::move(socket))), + m_connection_guard(m_registry, m_ws) + { } -} - -void Whiteboard::do_session(boost::asio::ip::tcp::socket socket) -{ - try { - // Construct the stream by moving in the socket - std::shared_ptr ws{std::make_shared>(std::move(socket))}; - ConnectionRegistry::RegistryGuard guard(m_registry, ws); + void do_read_handshake() + { // Set a decorator to change the Server of the handshake - ws->set_option(boost::beast::websocket::stream_base::decorator( + m_ws->set_option(boost::beast::websocket::stream_base::decorator( [](boost::beast::websocket::response_type& res) { res.set(boost::beast::http::field::server, std::string("Reichwein.IT Whiteboard")); })); - boost::beast::http::request_parser parser; - boost::beast::http::request req; - boost::beast::flat_buffer buffer; + boost::beast::http::async_read(m_ws->next_layer(), m_buffer, m_parser, boost::beast::bind_front_handler(&session::on_read_handshake, shared_from_this())); + } - boost::beast::http::read(ws->next_layer(), buffer, parser); - req = parser.get(); + void on_read_handshake(boost::beast::error_code ec, std::size_t bytes_transferred) + { + boost::ignore_unused(bytes_transferred); + if (ec) { + std::cerr << "Error on session handshake read: " << ec.message() << std::endl; + } else { + do_accept_handshake(); + } + } - ws->accept(req); + void do_accept_handshake() + { + m_req = m_parser.get(); - while (true) { - boost::beast::flat_buffer buffer; + m_ws->async_accept(m_req, boost::beast::bind_front_handler(&session::on_accept_handshake, shared_from_this())); + } - ws->read(buffer); + void on_accept_handshake(boost::beast::error_code ec) + { + if (ec) { + std::cerr << "Error on session handshake accept: " << ec.message() << std::endl; + } else { + do_read(); + } + } - ws->text(ws->got_text()); - std::string data(boost::asio::buffers_begin(buffer.data()), boost::asio::buffers_end(buffer.data())); - data = handle_request(ws, data); - if (buffer.data().size() > 0) { - buffer.consume(buffer.size()); - } - if (data.size() > 0) { - boost::beast::ostream(buffer) << data; - std::lock_guard lock(m_websocket_mutex); - ws->write(buffer.data()); + void do_read() + { + if (m_buffer.size() > 0) { + m_buffer.consume(m_buffer.size()); + } + + m_ws->async_read(m_buffer, boost::beast::bind_front_handler(&session::on_read, shared_from_this())); + } + + void on_read(boost::beast::error_code ec, std::size_t bytes_transferred) + { + boost::ignore_unused(bytes_transferred); + if (ec) { + std::cerr << "Error on session read: " << ec.message() << std::endl; + } else { + do_write(); + } + } + + void do_write() + { + m_ws->text(m_ws->got_text()); + std::string data(boost::asio::buffers_begin(m_buffer.data()), boost::asio::buffers_end(m_buffer.data())); + data = handle_request(data); + if (m_buffer.size() > 0) { + m_buffer.consume(m_buffer.size()); + } + if (data.size() > 0) { + boost::beast::ostream(m_buffer) << data; + std::lock_guard lock(m_websocket_mutex); + m_ws->async_write(m_buffer.data(), boost::beast::bind_front_handler(&session::on_write, shared_from_this())); + } else { + do_read(); + } + } + + void on_write(boost::beast::error_code ec, std::size_t bytes_transferred) + { + boost::ignore_unused(bytes_transferred); + + if (ec) { + std::cerr << "Error on session write: " << ec.message() << std::endl; + } else { + do_read(); + } + } + + void run() + { + do_read_handshake(); + } + + void notify_other_connections_diff(const std::string& id, const Diff& diff) + { + std::for_each(m_registry.begin(id), m_registry.end(id), [&](const connection& ci) + { + if (m_ws != ci) { + boost::beast::flat_buffer buffer; + pt::ptree ptree {make_ptree({ + {"type", "getdiff"}, + {"revision", std::to_string(m_storage.getRevision(id))}, + {"pos", std::to_string(m_storage.getCursorPos(id)) } + })}; + ptree.put_child("serverinfo.diff", diff.get_structure().get_child("diff")); + boost::beast::ostream(buffer) << Reichwein::XML::plain_xml(ptree); + std::lock_guard lock(m_websocket_mutex); + try { + ci->write(buffer.data()); + } catch (const std::exception& ex) { + std::cerr << "Warning: Notify getdiff write for " << ci << " not possible, id " << id << std::endl; + m_registry.dump(); + } + } + }); + } + + void notify_other_connections_pos(const std::string& id) + { + std::for_each(m_registry.begin(id), m_registry.end(id), [&](const connection& ci) + { + if (m_ws != ci) { + boost::beast::flat_buffer buffer; + boost::beast::ostream(buffer) << make_xml({ + {"type", "getpos"}, + {"pos", std::to_string(m_storage.getCursorPos(id)) } + }); + std::lock_guard lock(m_websocket_mutex); + try { + ci->write(buffer.data()); + } catch (const std::exception& ex) { + std::cerr << "Warning: Notify getpos write for " << ci << " not possible, id " << id << std::endl; + m_registry.dump(); + } + } + }); + } + + std::string handle_request(const std::string& request) + { + try { + std::lock_guard lock(m_storage_mutex); + + pt::ptree xml; + std::istringstream ss{request}; + pt::xml_parser::read_xml(ss, xml); + + std::string command {xml.get("request.command")}; + + if (command == "modify") { + std::string id {xml.get("request.id")}; + + int baserev {xml.get("request.baserev")}; + if (baserev != m_storage.getRevision(id)) + return make_xml({{"type", "error"}, {"message", "Bad base revision ("s + std::to_string(baserev) + "). Current: "s + std::to_string(m_storage.getRevision(id)) }}); + + pt::ptree ptree; + ptree.put_child("diff", xml.get_child("request.diff")); + Diff d{ptree}; + if (!d.empty()) { + std::string data {m_storage.getDocument(id)}; + data = d.apply(data); + + m_storage.setDocument(id, data); + m_registry.setId(m_ws, id); + notify_other_connections_diff(id, d); + } + + int pos {xml.get("request.pos")}; + if (m_storage.getCursorPos(id) != pos) { + m_storage.setCursorPos(id, pos); + notify_other_connections_pos(id); + } + return make_xml({{"type", "modify"}, {"revision", std::to_string(m_storage.getRevision(id)) }}); + } else if (command == "cursorpos") { + std::string id {xml.get("request.id")}; + int pos {xml.get("request.pos")}; + if (m_storage.getCursorPos(id) != pos) { + m_storage.setCursorPos(id, pos); + notify_other_connections_pos(id); + } + return {}; + } else if (command == "getfile") { + std::string id {xml.get("request.id")}; + + std::string filedata; + try { + filedata = m_storage.getDocument(id); + } catch (const std::runtime_error&) { + m_storage.setDocument(id, filedata); + } + + if (filedata.size() > 30000000) + throw std::runtime_error("File too big"); + m_registry.setId(m_ws, id); + + return make_xml({ + {"type", "getfile"}, + {"data", filedata}, + {"revision", std::to_string(m_storage.getRevision(id)) }, + {"pos", std::to_string(m_storage.getCursorPos(id)) } + }); + } else if (command == "getpos") { + std::string id {xml.get("request.id")}; + + return make_xml({ + {"type", "getpos"}, + {"pos", std::to_string(m_storage.getCursorPos(id)) } + }); + } else if (command == "newid") { + return make_xml({{"type", "newid"}, {"id", m_storage.generate_id()}}); + } else if (command == "qrcode") { + std::string url{xml.get("request.url")}; + + if (url.size() > 1000) + throw std::runtime_error("URL too big"); + + std::string pngdata {QRCode::getQRCode(url)}; + + return make_xml({{"type", "qrcode"}, {"png", Reichwein::Base64::encode64(pngdata)}}); + } else if (command == "getversion") { + return make_xml({ + {"type", "version"}, + {"version", WHITEBOARD_VERSION } + }); + } else if (command == "getstats") { + return make_xml({ + {"type", "stats" }, + {"dbsizegross", std::to_string(m_storage.dbsize_gross()) }, + {"dbsizenet", std::to_string(m_storage.dbsize_net()) }, + {"numberofdocuments", std::to_string(m_storage.getNumberOfDocuments()) }, + {"numberofconnections", std::to_string(m_registry.number_of_connections()) }, + }); + } else { + throw std::runtime_error("Bad command: "s + command); } + + } catch (const std::exception& ex) { + return make_xml({{"type", "error"}, {"message", "Message handling error: "s + ex.what()}}); } - } catch (boost::beast::system_error const& se) { - // This indicates that the session was closed - if (se.code() != boost::beast::websocket::error::closed && se.code() != boost::asio::error::eof) - std::cerr << "Boost system_error in session: " << se.code().message() << std::endl; - } catch (std::exception const& ex) { - std::cerr << "Error in session: " << ex.what() << std::endl; } +private: + ConnectionRegistry& m_registry; + Storage& m_storage; + std::mutex& m_storage_mutex; + std::mutex& m_websocket_mutex; + connection m_ws; + ConnectionRegistry::RegistryGuard m_connection_guard; + + boost::beast::http::request_parser m_parser; + boost::beast::http::request m_req; + boost::beast::flat_buffer m_buffer; + +}; + +void Whiteboard::do_accept() +{ + // The new connection gets its own strand + m_acceptor->async_accept(boost::asio::make_strand(*m_ioc), + std::bind(&Whiteboard::on_accept, this, _1, _2)); +} + +void Whiteboard::on_accept(boost::system::error_code ec, boost::asio::ip::tcp::socket socket) +{ + if (ec) { + std::cerr << "Error on accept: " << ec.message() << std::endl; + } else { + std::make_shared(m_registry, *m_storage, m_storage_mutex, m_websocket_mutex, std::move(socket))->run(); + } + + do_accept(); } // the actual main() for testability @@ -332,32 +408,55 @@ int Whiteboard::run(int argc, char* argv[]) exit(0); } - std::thread storage_cleanup_thread(std::bind(&Whiteboard::storage_cleanup, this)); - QRCode::init(); auto const address = boost::asio::ip::make_address(m_config->getListenAddress()); auto const port = static_cast(m_config->getListenPort()); // The io_context is required for all I/O - boost::asio::io_context ioc{m_config->getThreads()}; + m_ioc = std::make_unique(m_config->getThreads()); + + // for now, just terminate on SIGINT, SIGHUP and SIGTERM + boost::asio::signal_set signals(*m_ioc, SIGINT, SIGTERM, SIGHUP); + signals.async_wait([&](const boost::system::error_code& error, int signal_number){ + std::cout << "Terminating via signal " << signal_number << std::endl; + m_ioc->stop(); + }); + + // Storage cleanup once a day + boost::asio::steady_timer storage_cleanup_timer(*m_ioc, boost::asio::chrono::hours(24)); + std::function storage_cleanup_callback = + [&](const boost::system::error_code& error){ + std::lock_guard lock(m_storage_mutex); + if (!m_storage) + throw std::runtime_error("Storage not initialized"); + m_storage->cleanup(); + storage_cleanup_timer.expires_at(storage_cleanup_timer.expires_at() + boost::asio::chrono::hours(24)); + storage_cleanup_timer.async_wait(storage_cleanup_callback); + }; + storage_cleanup_timer.async_wait(storage_cleanup_callback); // The acceptor receives incoming connections - boost::asio::ip::tcp::acceptor acceptor{ioc, {address, port}}; - while (true) { - // This will receive the new connection - boost::asio::ip::tcp::socket socket{ioc}; - - // Block until we get a connection - acceptor.accept(socket); - - // Launch the session, transferring ownership of the socket - std::thread( - &Whiteboard::do_session, this, - std::move(socket)).detach(); + m_acceptor = std::make_unique(*m_ioc, boost::asio::ip::tcp::endpoint{address, port}); + + do_accept(); + + // Run the I/O service on the requested number of threads + std::vector v; + v.reserve(m_config->getThreads() - 1); + for (auto i = m_config->getThreads() - 1; i > 0; --i) { + v.emplace_back( + [&] + { + m_ioc->run(); + }); + } + m_ioc->run(); + + for (auto& t: v) { + t.join(); } - storage_cleanup_thread.join(); } catch (const std::exception& ex) { std::cerr << "Error: " << ex.what() << std::endl; } diff --git a/whiteboard.h b/whiteboard.h index c000e36..15d764a 100644 --- a/whiteboard.h +++ b/whiteboard.h @@ -26,11 +26,10 @@ private: ConnectionRegistry m_registry; - using connection = std::shared_ptr>; - std::string handle_request(connection& c, const std::string& request); - void notify_other_connections_diff(connection& c, const std::string& id, const Diff& diff); // notify all other id-related connections about changes - void notify_other_connections_pos(connection& c, const std::string& id); // notify all other id-related connections about changes - void do_session(boost::asio::ip::tcp::socket socket); - void storage_cleanup(); + std::unique_ptr m_ioc; + std::unique_ptr m_acceptor; + + void do_accept(); + void on_accept(boost::system::error_code ec, boost::asio::ip::tcp::socket socket); }; -- cgit v1.2.3