#include "response.h" #include "auth.h" #include "libreichwein/base64.h" #include "libreichwein/file.h" #include "libreichwein/mime.h" #include "libreichwein/os.h" #include #include #include #include #include using namespace std::placeholders; using namespace Reichwein::Mime; using namespace Reichwein::OS; using namespace Reichwein::Base64; namespace { class RequestContext { private: request_type& m_req; std::string m_host; std::string m_target; Server& m_server; const Path& m_path; public: RequestContext(request_type& req, Server& server) : m_req(req) , m_host(req["host"]) , m_target(req.target()) , m_server(server) , m_path(server.GetConfig().GetPath(server.GetSocket(), m_host, m_target)) { } // GetTarget() == GetPluginPath() + "/" + GetRelativePath() const Path& GetPath() const {return m_path;} // GetPluginPath w/ configured params as struct std::string GetPluginName() const {return m_path.params.at("plugin");} // can throw std::out_of_range std::string GetPluginPath() const {return m_path.requested;} std::string GetDocRoot() const {return m_path.params.at("target");} // can throw std::out_of_range std::string GetRelativePath() const { // can throw std::runtime_error if (!boost::starts_with(m_target, m_path.requested)) throw std::runtime_error("Mismatch of target ("s + m_target + ") and plugin path(" + m_path.requested + ")"s); if (m_target.size() > m_path.requested.size() && m_target[m_path.requested.size()] == '/') return m_target.substr(m_path.requested.size() + 1); else return m_target.substr(m_path.requested.size()); } std::string GetPluginParam(const std::string& key) const {return m_path.params.at(key);} // can throw std::out_of_range plugin_type GetPlugin() const {return m_server.GetPlugin(m_path.params.at("plugin"));}; // can throw std::out_of_range request_type& GetReq() const {return m_req;} std::string GetTarget() const {return m_target;} std::string GetHost() const {return m_host;} Server& GetServer() const {return m_server; } const Socket& GetSocket() const {return m_server.GetSocket(); } // Returns error message, empty on auth success std::string isAuthenticated() { auto& auth{m_path.auth}; if (auth.size() != 0) { std::string authorization{m_req[http::field::authorization]}; if (authorization.substr(0, 6) != "Basic "s) { return "Bad Authorization Type"; } authorization = authorization.substr(6); authorization = decode64(authorization); size_t pos {authorization.find(':')}; if (pos == authorization.npos) return "Bad Authorization Encoding"; std::string login{authorization.substr(0, pos)}; std::string password{authorization.substr(pos + 1)}; auto it {auth.find(login)}; // it.second contains crypted/hash // password is plain text to validate against the hash if (it == auth.end() || !Auth::validate(it->second, password)) { return "Bad Authorization"; } } return ""; } }; // class RequestContext std::string extend_index_html(std::string path) { if (path.size() == 0 || path.back() == '/') path.append("index.html"); return path; } bool is_ipv6_address(const std::string& addr) { return addr.find(":") != addr.npos; } std::unordered_map> GetServerParamFunctions{ // following are the supported fields: {"address", [](Server& server) { return server.GetSocket().address; }}, {"ipv6", [](Server& server) { return is_ipv6_address(server.GetSocket().address) ? "yes" : "no"; }}, {"port", [](Server& server) { return server.GetSocket().port; }}, {"statistics", [](Server& server) { return server.GetStatistics().getValues(); }}, {"uptime_host", [](Server& server) { return uptime_host(); }}, {"uptime_webserver", [](Server& server) { return uptime_process(); }}, {"version", [](Server& server) { return Server::VersionString; }}, }; std::string GetServerParam(const std::string& key, Server& server) { auto it = GetServerParamFunctions.find(key); if (it != GetServerParamFunctions.end()) return it->second(server); throw std::runtime_error("Unsupported server param: "s + key); } std::unordered_map> GetRequestParamFunctions{ // following are the supported fields: {"authorization", [](RequestContext& req_ctx) { return std::string{req_ctx.GetReq()[http::field::authorization]}; }}, {"is_authenticated", [](RequestContext& req_ctx) { return req_ctx.isAuthenticated() == ""s ? "1"s : "0"s;}}, {"body", [](RequestContext& req_ctx) { return req_ctx.GetReq().body(); }}, {"content_length", [](RequestContext& req_ctx) { return std::to_string(req_ctx.GetReq().body().size()); }}, {"content_type", [](RequestContext& req_ctx) { return std::string{req_ctx.GetReq()[http::field::content_type]}; }}, {"doc_root", [](RequestContext& req_ctx) { return req_ctx.GetDocRoot();}}, {"host", [](RequestContext& req_ctx) { return req_ctx.GetHost();}}, {"http_accept", [](RequestContext& req_ctx) { return std::string{req_ctx.GetReq()[http::field::accept]};}}, {"http_accept_charset", [](RequestContext& req_ctx) { return std::string{req_ctx.GetReq()[http::field::accept_charset]};}}, {"http_accept_encoding", [](RequestContext& req_ctx) { return std::string{req_ctx.GetReq()[http::field::accept_encoding]};}}, {"http_accept_language", [](RequestContext& req_ctx) { return std::string{req_ctx.GetReq()[http::field::accept_language]};}}, {"http_connection", [](RequestContext& req_ctx) { return std::string{req_ctx.GetReq()[http::field::connection]};}}, {"http_host", [](RequestContext& req_ctx) { return std::string{req_ctx.GetReq()[http::field::host]};}}, {"http_user_agent", [](RequestContext& req_ctx) { return std::string{req_ctx.GetReq()[http::field::user_agent]};}}, {"http_version", [](RequestContext& req_ctx) { unsigned version {req_ctx.GetReq().version()}; unsigned major{version / 10}; unsigned minor{version % 10}; return "HTTP/"s + std::to_string(major) + "."s + std::to_string(minor); }}, {"https", [](RequestContext& req_ctx) { return req_ctx.GetSocket().protocol == SocketProtocol::HTTPS ? "on" : "off"; }}, {"location", [](RequestContext& req_ctx) { return req_ctx.GetTarget(); }}, {"method", [](RequestContext& req_ctx) { return std::string{req_ctx.GetReq().method_string()};}}, {"plugin_path", [](RequestContext& req_ctx) { return std::string{req_ctx.GetPluginPath()};}}, {"rel_target", [](RequestContext& req_ctx) {return req_ctx.GetRelativePath();}}, {"target", [](RequestContext& req_ctx) {return req_ctx.GetTarget();}}, // target == plugin_path / rel_target }; std::string GetRequestParam(const std::string& key, RequestContext& req_ctx) { // first, look up functions from GetRequestParamFunctions { auto it = GetRequestParamFunctions.find(key); if (it != GetRequestParamFunctions.end()) return it->second(req_ctx); } // second, look up plugin parameters { try { return req_ctx.GetPluginParam(key); } catch(const std::out_of_range& ex) { // not found }; } // third, look up req parameters // contains: host { try { return std::string{req_ctx.GetReq()[key]}; } catch(...) { // not found } } // otherwise: error throw std::runtime_error("Unsupported request param: "s + key); } std::unordered_map> SetResponseHeaderActions{ { "cache_control", [](const std::string& value, response_type& res){res.set(http::field::cache_control, value);} }, { "content_disposition", [](const std::string& value, response_type& res){res.set(http::field::content_disposition, value);} },// e.g. attachment; ... { "content_type", [](const std::string& value, response_type& res){res.set(http::field::content_type, value);} },// e.g. text/html { "location", [](const std::string& value, response_type& res){res.set(http::field::location, value);} },// e.g. 301 Moved Permanently: new Location { "server", [](const std::string& value, response_type& res){res.set(http::field::server, value);} }, // Server name/version string { "set_cookie", [](const std::string& value, response_type& res){res.set(http::field::set_cookie, value);} }, { "status", [](const std::string& value, response_type& res){ try { res.result(unsigned(stoul(value))); } catch (...) { std::cerr << "Error: Bad status value: " << value << std::endl; res.result(400); } } }, // HTTP Status, e.g. "200" (OK) { "www_authenticate", [](const std::string& value, response_type& res){res.set(http::field::www_authenticate, value);} }, }; void SetResponseHeader(const std::string& key, const std::string& value, response_type& res) { // following are the supported fields: auto it{SetResponseHeaderActions.find(key)}; if (it == SetResponseHeaderActions.end()) throw std::runtime_error("Unsupported response field: "s + key); it->second(value, res); } response_type HttpStatus(std::string status, std::string message, response_type& res) { if (status != "200") { // already handled at res init try { res.result(unsigned(stoul(status))); } catch (...) { std::cerr << "Error: HttpStatus: Bad status value: " << status << std::endl; res.result(400); } res.set(http::field::content_type, "text/html"); res.body() = "

"s + Server::VersionString + " Error

"s + status + " "s + message + "

"s; res.prepare_payload(); } return res; } // Used to return errors by generating response page and HTTP status code response_type HttpStatusAndStats(std::string status, std::string message, RequestContext& req_ctx, response_type& res) { HttpStatus(status, message, res); req_ctx.GetServer().GetStatistics().count( req_ctx.GetReq().body().size(), res.body().size(), res.result_int() != 200, is_ipv6_address(req_ctx.GetServer().GetSocket().address), req_ctx.GetServer().GetSocket().protocol == SocketProtocol::HTTPS); return std::move(res); } response_type handleAuth(RequestContext& req_ctx, response_type& res) { if (req_ctx.GetPlugin()->has_own_authentication() == false) { std::string message { req_ctx.isAuthenticated() }; if (message != "") { res.set(http::field::www_authenticate, "Basic realm=\"Reichwein.IT Webserver Login\""); return HttpStatusAndStats("401", message, req_ctx, res); } } return std::move(res); } } // anonymous namespace response_type response::generate_response(request_type& req, Server& server) { response_type res{http::status::ok, req.version()}; res.set(http::field::server, Server::VersionString); res.set(http::field::content_type, mime_type(extend_index_html(std::string(req.target())))); res.keep_alive(req.keep_alive()); try { RequestContext req_ctx{req, server}; // can throw std::out_of_range res = handleAuth(req_ctx, res); if (res.result_int() / 100 == 4) // status 4xx return res; plugin_type plugin{req_ctx.GetPlugin()}; auto GetServerParamFunction {std::function(std::bind(GetServerParam, _1, std::ref(server)))}; auto GetRequestParamFunction {std::function(std::bind(GetRequestParam, _1, std::ref(req_ctx)))}; auto SetResponseHeaderFunction{std::function(std::bind(SetResponseHeader, _1, _2, std::ref(res)))}; std::string res_data { plugin->generate_page(GetServerParamFunction, GetRequestParamFunction, SetResponseHeaderFunction)}; if (req.method() == http::verb::head) { res.body() = std::string{}; } else { res.body() = res_data; } res.prepare_payload(); return HttpStatusAndStats("200", "OK", req_ctx, res); } catch(const std::out_of_range& ex) { return HttpStatus("400", "Bad request: "s + std::string{req["host"]} + ":"s + std::string{req.target()} + " unknown"s, res); } catch(const std::exception& ex) { return HttpStatus("400", "Bad request: "s + std::string{ex.what()}, res); } } std::string response::get_websocket_address(request_type& req, Server& server) { try { RequestContext req_ctx{req, server}; // can throw std::out_of_range if (req_ctx.GetPluginName() != "websocket") { std::cout << "Bad plugin configured for websocket request: " << req_ctx.GetPluginName() << std::endl; return {}; } return req_ctx.GetDocRoot() + "/" + req_ctx.GetRelativePath(); // Configured "path" in config: host:port/relative_path for websocket } catch (const std::exception& ex) { std::cout << "No matching configured target websocket found: " << ex.what() << std::endl; return {}; } }