From 8ad2512a267f6f2d22df2209e964f3422c20437e Mon Sep 17 00:00:00 2001 From: cruise Date: Tue, 1 Jul 2025 10:09:25 +0800 Subject: [PATCH 1/6] clean up the code --- Socket/SecureSocket.cpp | 415 +++++++++++++--------- Socket/SecureSocket.h | 227 +++++------- Socket/Socket.cpp | 505 ++++++++++++++++++--------- Socket/Socket.h | 202 +++++++---- Socket/TCPClient.cpp | 614 +++++++++++++-------------------- Socket/TCPClient.h | 97 +++--- Socket/TCPSSLClient.cpp | 430 ++++++++++++----------- Socket/TCPSSLClient.h | 63 ++-- Socket/TCPSSLServer.cpp | 422 ++++++++++++----------- Socket/TCPSSLServer.h | 50 ++- Socket/TCPServer.cpp | 743 +++++++++++++++++----------------------- Socket/TCPServer.h | 88 ++--- 12 files changed, 1971 insertions(+), 1885 deletions(-) diff --git a/Socket/SecureSocket.cpp b/Socket/SecureSocket.cpp index 3821400..a380cbf 100644 --- a/Socket/SecureSocket.cpp +++ b/Socket/SecureSocket.cpp @@ -1,221 +1,318 @@ /** -* @file SecureSocket.cpp -* @brief implementation of the Secure Socket class -* @author Mohamed Amine Mzoughi -*/ + * @file SecureSocket.cpp + * @brief implementation of the Secure Socket class + * @author Mohamed Amine Mzoughi + */ #ifdef OPENSSL - #include "SecureSocket.h" #include -#ifndef LINUX +#ifdef _WIN32 // to avoid link problems in prod/test program // Update : with the newer versions of OpenSSL, there's no need to include it //#include #endif -ASecureSocket::SecureSocketGlobalInitializer& ASecureSocket::SecureSocketGlobalInitializer::instance() +ASecureSocket::SSLSocket::SSLSocket() + : m_SockFd(INVALID_SOCKET) + , m_pSSL(nullptr) + , m_pCTXSSL(nullptr) + , m_pMTHDSSL(nullptr) { - static SecureSocketGlobalInitializer inst{}; - return inst; } -ASecureSocket::SecureSocketGlobalInitializer::SecureSocketGlobalInitializer() +ASecureSocket::SSLSocket::SSLSocket(SSLSocket&& Sockother) + : m_SockFd(Sockother.m_SockFd) + , m_pSSL(Sockother.m_pSSL) + , m_pCTXSSL(Sockother.m_pCTXSSL) + , m_pMTHDSSL(Sockother.m_pMTHDSSL) { - InitializeSSL(); + Sockother.m_SockFd = INVALID_SOCKET; + Sockother.m_pSSL = nullptr; + Sockother.m_pCTXSSL = nullptr; + Sockother.m_pMTHDSSL = nullptr; } -ASecureSocket::SecureSocketGlobalInitializer::~SecureSocketGlobalInitializer() +ASecureSocket::SSLSocket& ASecureSocket::SSLSocket::operator=(SSLSocket&& Sockother) { - DestroySSL(); + if (this != &Sockother) + { + m_SockFd = Sockother.m_SockFd; + m_pSSL = Sockother.m_pSSL; + m_pCTXSSL = Sockother.m_pCTXSSL; + m_pMTHDSSL = Sockother.m_pMTHDSSL; + + // reset Sockother + Sockother.m_SockFd = INVALID_SOCKET; + Sockother.m_pSSL = nullptr; + Sockother.m_pCTXSSL = nullptr; + Sockother.m_pMTHDSSL = nullptr; + } + return *this; } -/** -* @brief constructor of the Secure Socket -* -* @param oLogger - a callabck to a logger function void(const std::string&) -* @param eSSLVersion - SSL/TLS protocol version -* -*/ -ASecureSocket::ASecureSocket(const LogFnCallback& oLogger, - const OpenSSLProtocol eSSLVersion, - const SettingsFlag eSettings /*= ALL_FLAGS*/) : - ASocket(oLogger, eSettings), - m_eOpenSSLProtocol(eSSLVersion), - m_globalInitializer(SecureSocketGlobalInitializer::instance()) +ASecureSocket::SSLSocket::~SSLSocket() { + Disconnect(); } -/** -* @brief destructor of the secure socket object -* It's a pure virtual destructor but an implementation is provided below. -* this to avoid creating a dummy pure virtual method to transform the class -* to an abstract one. -*/ -ASecureSocket::~ASecureSocket() +void ASecureSocket::SSLSocket::Disconnect() { + SocketClose(m_SockFd); + + if (m_pSSL != nullptr) + { + /* send the close_notify alert to the peer. */ + SSL_shutdown(m_pSSL); // must be called before SSL_free + SSL_free(m_pSSL); + m_pSSL = nullptr; + } + + if (m_pCTXSSL != nullptr) + { + SSL_CTX_free(m_pCTXSSL); + m_pCTXSSL = nullptr; + } } -void ASecureSocket::SetUpCtxClient(SSLSocket& Socket) +bool ASecureSocket::SSLSocket::HasPending() const { - switch (m_eOpenSSLProtocol) - { - default: - case OpenSSLProtocol::TLS: - // Standard Protocol as of 11/2018, OpenSSL will choose highest possible TLS standard between peers - Socket.m_pMTHDSSL = const_cast(TLS_client_method()); - break; + return SSL_has_pending(m_pSSL) == 1; +} - case OpenSSLProtocol::SSL_V23: - Socket.m_pMTHDSSL = const_cast(SSLv23_client_method()); - break; +int ASecureSocket::SSLSocket::PendingBytes() const +{ + return SSL_pending(m_pSSL); +} - #ifndef LINUX - // deprecated in newer versions of OpenSSL - //case OpenSSLProtocol::SSL_V2: - //Socket.m_pMTHDSSL = const_cast(SSLv2_client_method()); - //break; - #endif +std::atomic ASecureSocket::s_iSecureSocketCount = ATOMIC_VAR_INIT(0); - // deprecated - /*case OpenSSLProtocol::SSL_V3: - Socket.m_pMTHDSSL = const_cast(SSLv3_client_method()); - break;*/ +/** + * @brief constructor of the Secure Socket + * + * @param oLogger - a callabck to a logger function void(const std::string&) + * @param eSSLVersion - SSL/TLS protocol version + */ +ASecureSocket::ASecureSocket(const LogFnCallback& oLogger, + const OpenSSLProtocol& eSSLVersion, + const SettingsFlag& eSettings /*= ALL_FLAGS*/) + : ASocket(oLogger, eSettings) + , m_eOpenSSLProtocol(eSSLVersion) +{ + int expected = 0; + if (s_iSecureSocketCount.compare_exchange_strong(expected, 1)) + { + // Initialize OpenSSL + InitializeSSL(); + } + else + { + s_iSecureSocketCount.fetch_add(1, std::memory_order_relaxed); + } +} - case OpenSSLProtocol::TLS_V1: - Socket.m_pMTHDSSL = const_cast(TLSv1_client_method()); - break; - } - Socket.m_pCTXSSL = SSL_CTX_new(Socket.m_pMTHDSSL); +/** + * @brief destructor of the secure socket object + * It's a pure virtual destructor but an implementation is provided below. + * this to avoid creating a dummy pure virtual method to transform the class + * to an abstract one. + */ +ASecureSocket::~ASecureSocket() +{ + int value = s_iSecureSocketCount.load(std::memory_order_relaxed); + + do + { + if (value == 0) + { + return; + } + + if (s_iSecureSocketCount.compare_exchange_weak(value, value - 1)) + { + if (value == 1) + { + DestroySSL(); + } + return; + } + } while (true); } -void ASecureSocket::SetUpCtxServer(SSLSocket& Socket) +void ASecureSocket::InitializeSSL() { - switch (m_eOpenSSLProtocol) - { - default: - case OpenSSLProtocol::TLS: - // Standard Protocol as of 11/2018, OpenSSL will choose highest possible TLS standard between peers - Socket.m_pMTHDSSL = const_cast(TLS_server_method()); - break; + /* Initialize malloc, free, etc for OpenSSL's use. */ + //CRYPTO_malloc_init(); - #ifndef LINUX - //case OpenSSLProtocol::SSL_V2: - //Socket.m_pMTHDSSL = const_cast(SSLv2_server_method()); - //break; - #endif + /* Initialize OpenSSL's SSL libraries: load encryption & hash algorithms for SSL */ + (void)SSL_library_init(); //always returns 1 - // deprecated - /*case OpenSSLProtocol::SSL_V3: - Socket.m_pMTHDSSL = const_cast(SSLv3_server_method()); - break;*/ + /* Load the error strings for good error reporting */ + SSL_load_error_strings(); - case OpenSSLProtocol::TLS_V1: - Socket.m_pMTHDSSL = const_cast(TLSv1_server_method()); - break; + /* Load BIO error strings. */ + //ERR_load_BIO_strings(); - case OpenSSLProtocol::SSL_V23: - Socket.m_pMTHDSSL = const_cast(SSLv23_server_method()); - break; - } - Socket.m_pCTXSSL = SSL_CTX_new(Socket.m_pMTHDSSL); + /* Load all available encryption algorithms. */ + OpenSSL_add_all_algorithms(); } -void ASecureSocket::InitializeSSL() +void ASecureSocket::DestroySSL() { - /* Initialize malloc, free, etc for OpenSSL's use. */ - //CRYPTO_malloc_init(); - - /* Initialize OpenSSL's SSL libraries: load encryption & hash algorithms for SSL */ - SSL_library_init(); + ERR_free_strings(); + EVP_cleanup(); +} - /* Load the error strings for good error reporting */ - SSL_load_error_strings(); +bool ASecureSocket::SetUpCtxClient(SSLSocket& Socket) +{ + switch (m_eOpenSSLProtocol) + { + default: + case OpenSSLProtocol::TLS: + // Standard Protocol as of 11/2018, OpenSSL will choose highest possible TLS standard between peers + Socket.m_pMTHDSSL = const_cast(TLS_client_method()); + break; + + case OpenSSLProtocol::SSL_V23: + Socket.m_pMTHDSSL = const_cast(SSLv23_client_method()); + break; + +#if 0 +#ifdef _WIN32 + // deprecated in newer versions of OpenSSL + case OpenSSLProtocol::SSL_V2: + Socket.m_pMTHDSSL = const_cast(SSLv2_client_method()); + break; +#endif - /* Load BIO error strings. */ - //ERR_load_BIO_strings(); + // deprecated + case OpenSSLProtocol::SSL_V3: + Socket.m_pMTHDSSL = const_cast(SSLv3_client_method()); + break; +#endif - /* Load all available encryption algorithms. */ - OpenSSL_add_all_algorithms(); + case OpenSSLProtocol::TLS_V1: + Socket.m_pMTHDSSL = const_cast(TLSv1_client_method()); + break; + } + + if (Socket.m_pMTHDSSL == nullptr) + { + //SocketLog("[WARN ]ASecureSocket, XXX_client_method failed[%lu:%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr)); + } + + Socket.m_pCTXSSL = SSL_CTX_new(Socket.m_pMTHDSSL); + if (Socket.m_pCTXSSL == nullptr) + { + //SocketLog("[ERROR]ASecureSocket, client SSL_CTX_new failed[%lu:%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr)); + //ERR_print_errors_fp(stdout); + return false; + } + + return true; } -void ASecureSocket::DestroySSL() +bool ASecureSocket::SetUpCtxServer(SSLSocket& Socket) { - ERR_free_strings(); - EVP_cleanup(); + switch (m_eOpenSSLProtocol) + { + default: + case OpenSSLProtocol::TLS: + // Standard Protocol as of 11/2018, OpenSSL will choose highest possible TLS standard between peers + Socket.m_pMTHDSSL = const_cast(TLS_server_method()); + break; + +#if 0 +#ifdef _WIN32 + case OpenSSLProtocol::SSL_V2: + Socket.m_pMTHDSSL = const_cast(SSLv2_server_method()); + break; +#endif + + // deprecated + case OpenSSLProtocol::SSL_V3: + Socket.m_pMTHDSSL = const_cast(SSLv3_server_method()); + break; +#endif + + case OpenSSLProtocol::TLS_V1: + Socket.m_pMTHDSSL = const_cast(TLSv1_server_method()); + break; + + case OpenSSLProtocol::SSL_V23: + Socket.m_pMTHDSSL = const_cast(SSLv23_server_method()); + break; + } + + if (Socket.m_pMTHDSSL == nullptr) + { + //SocketLog("[WARN ]ASecureSocket, XXX_server_method failed[%lu:%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr)); + } + + Socket.m_pCTXSSL = SSL_CTX_new(Socket.m_pMTHDSSL); + if (Socket.m_pCTXSSL == nullptr) + { + //SocketLog("[ERROR]ASecureSocket, server SSL_CTX_new failed[%lu:%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr)); + return false; + } + + //SSL_CTX_set_verify(Socket.m_pCTXSSL, SSL_VERIFY_NONE, nullptr); + return true; } void ASecureSocket::ShutdownSSL(SSLSocket& SSLSock) { - if (SSLSock.m_pSSL != nullptr) - { - /* send the close_notify alert to the peer. */ - SSL_shutdown(SSLSock.m_pSSL); // must be called before SSL_free - SSL_free(SSLSock.m_pSSL); - SSL_CTX_free(SSLSock.m_pCTXSSL); - - SSLSock.m_pSSL = nullptr; - } + SSLSock.Disconnect(); } const char* ASecureSocket::GetSSLErrorString(int iErrorCode) { - switch (iErrorCode) - { - case SSL_ERROR_NONE: - return "The TLS/SSL I/O operation completed."; - break; - - case SSL_ERROR_ZERO_RETURN: - return "The TLS/SSL connection has been closed."; - break; - - case SSL_ERROR_WANT_READ: - return "The read operation did not complete; " - "the same TLS/SSL I/O function should be called again later."; - break; - - case SSL_ERROR_WANT_WRITE: - return "The write operation did not complete; " - "the same TLS/SSL I/O function should be called again later."; - break; - - case SSL_ERROR_WANT_CONNECT: - return "The connect operation did not complete; " - "the same TLS/SSL I/O function should be called again later."; - break; - - case SSL_ERROR_WANT_ACCEPT: - return "The accept operation did not complete; " - "the same TLS/SSL I/O function should be called again later."; - break; - - case SSL_ERROR_WANT_X509_LOOKUP: - return "The operation did not complete because an application callback set" - " by SSL_CTX_set_client_cert_cb() has asked to be called again. " - "The TLS/SSL I/O function should be called again later."; - break; - - case SSL_ERROR_SYSCALL: - return "Some I/O error occurred. The OpenSSL error queue may contain" - " more information on the error."; - break; - - case SSL_ERROR_SSL: - return "A failure in the SSL library occurred, usually a protocol error. " - "The OpenSSL error queue contains more information on the error."; - break; - - default: - return "Unknown error !"; - break; - } + switch (iErrorCode) + { + case SSL_ERROR_NONE: + return "The TLS/SSL I/O operation completed."; + + case SSL_ERROR_ZERO_RETURN: + return "The TLS/SSL connection has been closed."; + + case SSL_ERROR_WANT_READ: + return "The read operation did not complete; " + "the same TLS/SSL I/O function should be called again later."; + + case SSL_ERROR_WANT_WRITE: + return "The write operation did not complete; " + "the same TLS/SSL I/O function should be called again later."; + + case SSL_ERROR_WANT_CONNECT: + return "The connect operation did not complete; " + "the same TLS/SSL I/O function should be called again later."; + + case SSL_ERROR_WANT_ACCEPT: + return "The accept operation did not complete; " + "the same TLS/SSL I/O function should be called again later."; + + case SSL_ERROR_WANT_X509_LOOKUP: + return "The operation did not complete because an application callback set" + " by SSL_CTX_set_client_cert_cb() has asked to be called again. " + "The TLS/SSL I/O function should be called again later."; + + case SSL_ERROR_SYSCALL: + return "Some I/O error occurred. The OpenSSL error queue may contain" + " more information on the error."; + + case SSL_ERROR_SSL: + return "A failure in the SSL library occurred, usually a protocol error. " + "The OpenSSL error queue contains more information on the error."; + + default: + return "Unknown error !"; + } } int ASecureSocket::AlwaysTrueCallback(X509_STORE_CTX* pCTX, void* pArg) { - return 1; + return 1; } #endif diff --git a/Socket/SecureSocket.h b/Socket/SecureSocket.h index dfac9ad..2c2fcb1 100644 --- a/Socket/SecureSocket.h +++ b/Socket/SecureSocket.h @@ -1,15 +1,16 @@ -/* -* @file SecureSocket.h -* @brief Abstract class to perform OpenSSL API global operations -* -* @author Mohamed Amine Mzoughi -* @date 2017-02-16 -*/ +/** + * @file SecureSocket.h + * @brief Abstract class to perform OpenSSL API global operations + * + * @author Mohamed Amine Mzoughi + * @date 2017-02-16 + */ #ifdef OPENSSL #ifndef INCLUDE_ASECURESOCKET_H_ #define INCLUDE_ASECURESOCKET_H_ +#include #ifdef OPENSSL #include #include @@ -21,137 +22,97 @@ class ASecureSocket : public ASocket { public: - enum class OpenSSLProtocol - { - #ifndef LINUX - //SSL_V2, // deprecated - #endif - //SSL_V3, // deprecated - TLS_V1, - SSL_V23, /* There is no SSL protocol version named SSLv23. The SSLv23_method() API - and its variants choose SSLv2, SSLv3, or TLSv1 for compatibility with the peer. */ - TLS // Standard Protocol as of 11/2018, OpenSSL will choose highest possible TLS standard between peers - }; - - struct SSLSocket - { - SSLSocket() : - m_SockFd(INVALID_SOCKET), - m_pSSL(nullptr), - m_pCTXSSL(nullptr), - m_pMTHDSSL(nullptr) - { - } - - // copy constructor and assignment operator are disabled - SSLSocket(const SSLSocket&) = delete; - SSLSocket& operator=(const SSLSocket&) = delete; - - // move constructor - SSLSocket(SSLSocket&& Sockother) : - m_SockFd(Sockother.m_SockFd), - m_pSSL(Sockother.m_pSSL), - m_pCTXSSL(Sockother.m_pCTXSSL), - m_pMTHDSSL(Sockother.m_pMTHDSSL) - { - Sockother.m_SockFd = INVALID_SOCKET; - Sockother.m_pSSL = nullptr; - Sockother.m_pCTXSSL = nullptr; - Sockother.m_pMTHDSSL = nullptr; - } - - // move assignment operator - SSLSocket& operator=(SSLSocket&& Sockother) - { - if (this != &Sockother) - { - m_SockFd = Sockother.m_SockFd; - m_pSSL = Sockother.m_pSSL; - m_pCTXSSL = Sockother.m_pCTXSSL; - m_pMTHDSSL = Sockother.m_pMTHDSSL; - - // reset Sockother - Sockother.m_SockFd = INVALID_SOCKET; - Sockother.m_pSSL = nullptr; - Sockother.m_pCTXSSL = nullptr; - Sockother.m_pMTHDSSL = nullptr; - } - return *this; - } - - Socket m_SockFd; - SSL* m_pSSL; - SSL_CTX* m_pCTXSSL; // SSL Context Structure - SSL_METHOD* m_pMTHDSSL; // used to create an SSL_CTX - }; - - /* Please provide your logger thread-safe routine, otherwise, you can turn off - * error log messages printing by not using the flag ALL_FLAGS or ENABLE_LOG */ - explicit ASecureSocket(const LogFnCallback& oLogger, - const OpenSSLProtocol eSSLVersion = OpenSSLProtocol::TLS, - const SettingsFlag eSettings = ALL_FLAGS); - virtual ~ASecureSocket() = 0; - - /* - * For the SSL server: - * Server's own certificate (mandatory) - * CA certificate (optional) - * - * For the SSL client: - * CA certificate (mandatory) - * Client's own certificate (optional) - */ - inline const std::string& GetSSLCertAuth() { return m_strCAFile; } - inline void SetSSLCerthAuth(const std::string& strPath) { m_strCAFile = strPath; } - - inline void SetSSLCertFile(const std::string& strPath) { m_strSSLCertFile = strPath; } - inline const std::string& GetSSLCertFile() const { return m_strSSLCertFile; } - - inline void SetSSLKeyFile(const std::string& strPath) { m_strSSLKeyFile = strPath; } - inline const std::string& GetSSLKeyFile() const { return m_strSSLKeyFile; } - - //void SetSSLKeyPassword(const std::string& strPwd) { m_strSSLKeyPwd = strPwd; } - //const std::string& GetSSLKeyPwd() const { return m_strSSLKeyPwd; } + enum class OpenSSLProtocol + { +#ifdef _WIN32 + //SSL_V2, // deprecated +#endif + //SSL_V3, // deprecated + TLS_V1, + SSL_V23, /* There is no SSL protocol version named SSLv23. The SSLv23_method() API + and its variants choose SSLv2, SSLv3, or TLSv1 for compatibility with the peer. */ + TLS // Standard Protocol as of 11/2018, OpenSSL will choose highest possible TLS standard between peers + }; + + struct SSLSocket + { + SSLSocket(); + ~SSLSocket(); + + // copy constructor and assignment operator are disabled + SSLSocket(const SSLSocket&) = delete; + SSLSocket& operator=(const SSLSocket&) = delete; + + // move constructor + SSLSocket(SSLSocket&& Sockother); + // move assignment operator + SSLSocket& operator=(SSLSocket&& Sockother); + + void Disconnect(); + + bool HasPending() const; + int PendingBytes() const; + + Socket m_SockFd; + SSL* m_pSSL; + SSL_CTX* m_pCTXSSL; // SSL Context Structure + SSL_METHOD* m_pMTHDSSL; // used to create an SSL_CTX + }; + + /** + * Please provide your logger thread-safe routine, otherwise, you can turn off + * error log messages printing by not using the flag ALL_FLAGS or ENABLE_LOG + */ + explicit ASecureSocket(const LogFnCallback& oLogger, + const OpenSSLProtocol& eSSLVersion = OpenSSLProtocol::TLS, + const SettingsFlag& eSettings = ALL_FLAGS); + virtual ~ASecureSocket(); + + /** + * For the SSL server: + * Server's own certificate (mandatory) + * CA certificate (optional) + * + * For the SSL client: + * CA certificate (mandatory) + * Client's own certificate (optional) + */ + inline const std::string& GetSSLCertAuth() { return m_strCAFile; } + inline void SetSSLCerthAuth(const std::string& strPath) { m_strCAFile = strPath; } + + inline void SetSSLCertFile(const std::string& strPath) { m_strSSLCertFile = strPath; } + inline const std::string& GetSSLCertFile() const { return m_strSSLCertFile; } + + inline void SetSSLKeyFile(const std::string& strPath) { m_strSSLKeyFile = strPath; } + inline const std::string& GetSSLKeyFile() const { return m_strSSLKeyFile; } + + //void SetSSLKeyPassword(const std::string& strPwd) { m_strSSLKeyPwd = strPwd; } + //const std::string& GetSSLKeyPwd() const { return m_strSSLKeyPwd; } protected: - // object methods - void SetUpCtxClient(SSLSocket& Socket); - void SetUpCtxServer(SSLSocket& Socket); - //void SetUpCtxCombined(SSLSocket& Socket); - - // class methods - static void ShutdownSSL(SSLSocket& SSLSocket); - static const char* GetSSLErrorString(int iErrorCode); - static int AlwaysTrueCallback(X509_STORE_CTX* pCTX, void* pArg); - - // non-static/object members - OpenSSLProtocol m_eOpenSSLProtocol; - std::string m_strCAFile; - std::string m_strSSLCertFile; - std::string m_strSSLKeyFile; - //std::string m_strSSLKeyPwd; - -private: - friend class SecureSocketGlobalInitializer; - class SecureSocketGlobalInitializer { - public: - static SecureSocketGlobalInitializer& instance(); - - SecureSocketGlobalInitializer(SecureSocketGlobalInitializer const&) = delete; - SecureSocketGlobalInitializer(SecureSocketGlobalInitializer&&) = delete; + // object methods + bool SetUpCtxClient(SSLSocket& Socket); + bool SetUpCtxServer(SSLSocket& Socket); + //void SetUpCtxCombined(SSLSocket& Socket); - SecureSocketGlobalInitializer& operator=(SecureSocketGlobalInitializer const&) = delete; - SecureSocketGlobalInitializer& operator=(SecureSocketGlobalInitializer&&) = delete; + // class methods + static void ShutdownSSL(SSLSocket& SSLSocket); + static const char* GetSSLErrorString(int iErrorCode); + static int AlwaysTrueCallback(X509_STORE_CTX* pCTX, void* pArg); - ~SecureSocketGlobalInitializer(); - - private: - SecureSocketGlobalInitializer(); - }; - SecureSocketGlobalInitializer& m_globalInitializer; +private: + static void InitializeSSL(); + static void DestroySSL(); - static void InitializeSSL(); - static void DestroySSL(); +protected: + // non-static/object members + OpenSSLProtocol m_eOpenSSLProtocol; + std::string m_strCAFile; + std::string m_strSSLCertFile; + std::string m_strSSLKeyFile; + //std::string m_strSSLKeyPwd; + + static std::atomic s_iSecureSocketCount; // Count of the actual secure socket sessions }; #endif diff --git a/Socket/Socket.cpp b/Socket/Socket.cpp index 15ec3bc..70e0454 100644 --- a/Socket/Socket.cpp +++ b/Socket/Socket.cpp @@ -1,209 +1,394 @@ /** -* @file Socket.cpp -* @brief implementation of the Socket class -* @author Mohamed Amine Mzoughi -*/ + * @file Socket.cpp + * @brief implementation of the Socket class + * @author Mohamed Amine Mzoughi + */ #include "Socket.h" - +#include +#include // va_start, etc. +#include // snprintf #include #include -#ifdef WINDOWS -WSADATA ASocket::s_wsaData; +#ifdef _WIN32 +// Static members initialization +std::atomic ASocket::s_iSocketCount = ATOMIC_VAR_INIT(0); +WSADATA ASocket::s_wsaData{}; #endif -ASocket::SocketGlobalInitializer& ASocket::SocketGlobalInitializer::instance() +/** + * @brief constructor of the Socket + * + * @param Logger - a callabck to a logger function void(const std::string&) + * + */ +ASocket::ASocket(const LogFnCallback& oLogger, SettingsFlag eSettings /*= ALL_FLAGS*/) + : m_oLog(oLogger) + , m_eSettingsFlags(eSettings) { - static SocketGlobalInitializer inst{}; - return inst; +#ifdef _WIN32 + int expected = 0; + if (s_iSocketCount.compare_exchange_strong(expected, 1)) + { + InitializeEnvironment(); + } + else + { + s_iSocketCount.fetch_add(1, std::memory_order_relaxed); + } +#endif } -ASocket::SocketGlobalInitializer::SocketGlobalInitializer() +/** + * @brief destructor of the socket object + * It's a pure virtual destructor but an implementation is provided below. + * this to avoid creating a dummy pure virtual method to transform the class + * to an abstract one. + */ +ASocket::~ASocket() { - // In windows, this will init the winsock DLL stuff -#ifdef WINDOWS - // MAKEWORD(2,2) version 2.2 of Winsock - int iWinSockInitResult = WSAStartup(MAKEWORD(2, 2), &s_wsaData); - - if (iWinSockInitResult != 0) - { - std::cerr << ASocket::StringFormat("[TCPClient][Error] WSAStartup failed : %d", iWinSockInitResult); - } +#ifdef _WIN32 + int value = s_iSocketCount.load(std::memory_order_relaxed); + + do + { + if (value == 0) + { + return; + } + + if (s_iSocketCount.compare_exchange_weak(value, value - 1)) + { + if (value == 1) + { + UnInitializeEnvironment(); + } + return; + } + } while(true); #endif } -ASocket::SocketGlobalInitializer::~SocketGlobalInitializer() +#ifdef _WIN32 +bool ASocket::InitializeEnvironment() +{ + // In windows, this will init the winsock DLL stuff + // MAKEWORD(2,2) version 2.2 of Winsock + int iWinSockInitResult = WSAStartup(MAKEWORD(2, 2), &s_wsaData); + if (iWinSockInitResult != NO_ERROR) + { + //SocketLog("[ERROR]ASocket, WSAStartup failed[%d:%s]", iWinSockInitResult, strerror(iWinSockInitResult)); + return false; + } + + if (LOBYTE(s_wsaData.wVersion) != 2 || HIBYTE(s_wsaData.wVersion) != 2) + { + //SocketLog("[ERROR]ASocket, could not find a usable version of winsock.dll[%x]", s_wsaData.wVersion); + return false; + } + + return true; +} + +void ASocket::UnInitializeEnvironment() +{ + /* call WSACleanup when done using the Winsock dll */ + WSACleanup(); +} +#endif + +int ASocket::GetSocketError() { -#ifdef WINDOWS - /* call WSACleanup when done using the Winsock dll */ - WSACleanup(); +#ifdef _WIN32 + return WSAGetLastError(); +#else + return errno; #endif } -/** -* @brief constructor of the Socket -* -* @param Logger - a callabck to a logger function void(const std::string&) -* -*/ -ASocket::ASocket(const LogFnCallback& oLogger, - const SettingsFlag eSettings /*= ALL_FLAGS*/) : - m_oLog(oLogger), - m_eSettingsFlags(eSettings), - m_globalInitializer(SocketGlobalInitializer::instance()) +char* ASocket::GaiStrerror(int ecode) { +#ifdef _WIN32 + return gai_strerrorA(ecode); +#else + return gai_strerror(ecode); +#endif +} +void ASocket::SocketClose(Socket& sd) +{ + if (sd == INVALID_SOCKET) + { +#ifdef _WIN32 + closesocket(sd); +#else + close(sd); +#endif + sd = INVALID_SOCKET; + } } /** -* @brief destructor of the socket object -* It's a pure virtual destructor but an implementation is provided below. -* this to avoid creating a dummy pure virtual method to transform the class -* to an abstract one. -*/ -ASocket::~ASocket() + * @brief returns a formatted string + * + * @param [in] strFormat string with one or many format specifiers + * @param [in] parameters to be placed in the format specifiers of strFormat + * + * @retval string formatted string + */ +std::string ASocket::StringFormat(const char* fmt, ...) { - + if (fmt == NULL) + { + return std::string(); + } + + va_list args; + va_start(args, fmt); + size_t len = std::vsnprintf(NULL, 0, fmt, args); + va_end(args); + std::vector vec(len + 1); + va_start(args, fmt); + std::vsnprintf(&vec[0], len + 1, fmt, args); + vec[len] = '\0'; + va_end(args); + return std::string(vec.data()); } /** -* @brief returns a formatted string -* -* @param [in] strFormat string with one or many format specifiers -* @param [in] parameters to be placed in the format specifiers of strFormat -* -* @retval string formatted string + * @brief waits for a socket's read status change + * + * @param [in] sd socket descriptor to be selected + * @param [in] msec waiting period in milliseconds, a value of 0 implies no timeout + * + * @retval int 0 on timeout, -1 on error and 1 on success. */ -std::string ASocket::StringFormat(const std::string strFormat, ...) +int ASocket::SelectSocket(Socket sd, size_t msec/* = ACCEPT_WAIT_INF_DELAY*/) { - va_list args; - va_start (args, strFormat); - size_t len = std::vsnprintf(NULL, 0, strFormat.c_str(), args); - va_end (args); - std::vector vec(len + 1); - va_start (args, strFormat); - std::vsnprintf(&vec[0], len + 1, strFormat.c_str(), args); - va_end (args); - return &vec[0]; + if (sd == INVALID_SOCKET) + { + return -1; + } + + struct timeval tval{}; + struct timeval* tvalptr = nullptr; + + if (msec != ACCEPT_WAIT_INF_DELAY) + { + tval.tv_sec = (long)msec / 1000; + tval.tv_usec = (msec % 1000) * 1000; + tvalptr = &tval; + } + + fd_set fd_reads{}; + FD_ZERO(&fd_reads); + FD_SET(sd, &fd_reads); + +#ifdef _WIN32 + Socket max_fd = 0; +#else + Socket max_fd = sd + 1; +#endif + + // block until socket is readable. + int res = select((int)max_fd, &fd_reads, nullptr, nullptr, tvalptr); + if (res == SOCKET_ERROR) + { + if (SOCKET_ERR_SELECT_RETRIABLE(GetSocketError())) + { + return 0; + } + + //SocketLog("[ERROR]ASocket, select failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), max_fd); + return -1; + } + + if (res == 0) + { + return 0; + } + +#if defined(__unix__) && defined(BSD) + if (!FD_ISSET(sd, &fd_reads)) + { + return 0; + } +#endif + + assert(FD_ISSET(sd, &fd_reads)); + assert(res == 1); + return 1; } /** -* @brief waits for a socket's read status change -* -* @param [in] sd socket descriptor to be selected -* @param [in] msec waiting period in milliseconds, a value of 0 implies no timeout -* -* @retval int 0 on timeout, -1 on error and 1 on success. -*/ -int ASocket::SelectSocket(const ASocket::Socket sd, const size_t msec) + * @brief waits for a set of sockets read status change + * + * @param [in] pSocketsToSelect pointer to an array of socket descriptors to be selected + * @param [in] count elements count of pSocketsToSelect + * @param [in] msec waiting period in milliseconds, a value of 0 implies no timeout + * @param [out] selectedIndex index of the socket that is ready to be read + * + * @retval int 0 on timeout, -1 on error and 1 on success. + */ +int ASocket::SelectSockets(const Socket* pSocketsToSelect, size_t count, size_t& selectedIndex, size_t msec/* = ACCEPT_WAIT_INF_DELAY*/) { - if (sd < 0) - { - return -1; - } + if (pSocketsToSelect == nullptr || count == 0) + { + return -1; + } + + struct timeval tval{}; + struct timeval* tvalptr = nullptr; + if (msec != ACCEPT_WAIT_INF_DELAY) + { + tval.tv_sec = (long)msec / 1000; + tval.tv_usec = (msec % 1000) * 1000; + tvalptr = &tval; + } + + fd_set fd_reads{}; + FD_ZERO(&fd_reads); + +#ifdef _WIN32 + Socket max_fd = 0; +#else + Socket max_fd = -1; +#endif - struct timeval tval; - struct timeval* tvalptr = nullptr; - fd_set rset; - int res; +#ifndef _WIN32 + for (size_t i = 0; i < count; i++) + { + if (pSocketsToSelect[i] != INVALID_SOCKET) + { + FD_SET(pSocketsToSelect[i], &fd_reads); + + if (pSocketsToSelect[i] > max_fd) + { + max_fd = pSocketsToSelect[i]; + } + } + } +#endif - if (msec > 0) - { - tval.tv_sec = msec / 1000; - tval.tv_usec = (msec % 1000) * 1000; - tvalptr = &tval; - } +#ifndef _WIN32 + max_fd += 1; +#endif - FD_ZERO(&rset); - FD_SET(sd, &rset); + // block until one socket is ready to read. + int res = select((int)max_fd, &fd_reads, nullptr, nullptr, tvalptr); + if (res == SOCKET_ERROR) + { + if (SOCKET_ERR_SELECT_RETRIABLE(GetSocketError())) + { + return 0; + } + + //SocketLog("[ERROR]ASocket, select failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), max_fd); + return -1; + } + + if (res == 0) + { + return 0; + } + + // find the first socket which has some activity. +#if defined(__unix__) && defined(BSD) + Socket firstSocket = INVALID_SOCKET; +#endif + for (size_t i = 0; i < count; ++i) + { + if (FD_ISSET(pSocketsToSelect[i], &fd_reads)) + { + selectedIndex = i; +#if defined(__unix__) && defined(BSD) + firstSocket = pSocketsToSelect[i]; +#endif + break; + } + } + +#if defined(__unix__) && defined(BSD) + if (firstSocket == INVALID_SOCKET) + { + return 0; + } +#endif - // block until socket is readable. - res = select(sd + 1, &rset, nullptr, nullptr, tvalptr); + return 1; +} - if (res <= 0) - return res; +bool ASocket::SetRcvTimeout(Socket sd, unsigned int msec_timeout) +{ +#ifndef _WIN32 + struct timeval t = TimevalFromMsec(msec_timeout); + return SetRcvTimeout(t); +#else + if (setsockopt(sd, SOL_SOCKET, SO_RCVTIMEO, (char*)&msec_timeout, sizeof(msec_timeout)) == SOCKET_ERROR) + { + //SocketLog("[ERROR]ASocket, setsockopt SOL_SOCKET SO_RCVTIMEO failed[%d:%s][%d][%u]", GetSocketError(), strerror(GetSocketError()), sd, msec_timeout); + return false; + } + + return true; +#endif +} - if (!FD_ISSET(sd, &rset)) - return -1; +#ifndef _WIN32 +bool ASocket::SetRcvTimeout(Socket sd, const struct timeval& timeout) +{ + if (setsockopt(sd, SOL_SOCKET, SO_RCVTIMEO, (char*)&timeout, sizeof(timeout)) == SOCKET_ERROR) + { + //SocketLog("[ERROR]ASocket, setsockopt SOL_SOCKET SO_RCVTIMEO failed[%d:%s][%d][%u:%u]", GetSocketError(), strerror(GetSocketError()), sd, timeout.tv_sec, timeout.tv_usec); + return false; + } - return 1; + return true; } +#endif -/** -* @brief waits for a set of sockets read status change -* -* @param [in] pSocketsToSelect pointer to an array of socket descriptors to be selected -* @param [in] count elements count of pSocketsToSelect -* @param [in] msec waiting period in milliseconds, a value of 0 implies no timeout -* @param [out] selectedIndex index of the socket that is ready to be read -* -* @retval int 0 on timeout, -1 on error and 1 on success. -*/ -int ASocket::SelectSockets(const ASocket::Socket* pSocketsToSelect, const size_t count, - const size_t msec, size_t& selectedIndex) +bool ASocket::SetSndTimeout(Socket sd, unsigned int msec_timeout) { - if (!pSocketsToSelect || count == 0) - { - return -1; - } - - fd_set rset; - int res = -1; - - struct timeval tval; - struct timeval* tvalptr = nullptr; - if (msec > 0) - { - tval.tv_sec = msec / 1000; - tval.tv_usec = (msec % 1000) * 1000; - tvalptr = &tval; - } - - FD_ZERO(&rset); - - int max_fd = -1; - for (size_t i = 0; i < count; i++) - { - FD_SET(pSocketsToSelect[i], &rset); - - if (pSocketsToSelect[i] > max_fd) - { - max_fd = pSocketsToSelect[i]; - } - } - - // block until one socket is ready to read. - res = select(max_fd + 1, &rset, nullptr, nullptr, tvalptr); - - if (res <= 0) - return res; - - // find the first socket which has some activity. - for (size_t i = 0; i < count; i++) - { - if (FD_ISSET(pSocketsToSelect[i], &rset)) - { - selectedIndex = i; - return 1; - } - } - - return -1; +#ifndef _WIN32 + struct timeval t = TimevalFromMsec(msec_timeout); + return SetSndTimeout(t); +#else + if (setsockopt(sd, SOL_SOCKET, SO_SNDTIMEO, (char*)&msec_timeout, sizeof(msec_timeout)) == SOCKET_ERROR) + { + //SocketLog("[ERROR]ASocket, setsockopt SOL_SOCKET SO_SNDTIMEO failed[%d:%s][%d][%u]", GetSocketError(), strerror(GetSocketError()), sd, msec_timeout); + return false; + } + + return true; +#endif } -/** -* @brief converts a value representing milliseconds into a struct timeval -* -* @param [time_msec] a time value in milliseconds -* -* @retval time_msec converted to struct timeval -*/ -struct timeval ASocket::TimevalFromMsec(unsigned int time_msec){ - struct timeval t; +#ifndef _WIN32 +bool ASocket::SetSndTimeout(Socket sd, const struct timeval& timeout) +{ + if (setsockopt(sd, SOL_SOCKET, SO_SNDTIMEO, (char*)&timeout, sizeof(timeout)) == SOCKET_ERROR) + { + //SocketLog("[ERROR]ASocket, setsockopt SOL_SOCKET SO_SNDTIMEO failed[%d:%s][%d][%u:%u]", GetSocketError(), strerror(GetSocketError()), sd, timeout.tv_sec, timeout.tv_usec); + return false; + } + + return true; +} +#endif - t.tv_sec = time_msec / 1000; - t.tv_usec = (time_msec % 1000) * 1000; +/** + * @brief converts a value representing milliseconds into a struct timeval + * + * @param [time_msec] a time value in milliseconds + * + * @retval time_msec converted to struct timeval + */ +struct timeval ASocket::TimevalFromMsec(unsigned int time_msec) +{ + struct timeval t{}; + t.tv_sec = (long)time_msec / 1000; + t.tv_usec = ((long)time_msec % 1000) * 1000; - return t; + return t; } diff --git a/Socket/Socket.h b/Socket/Socket.h index 0271d16..1869224 100644 --- a/Socket/Socket.h +++ b/Socket/Socket.h @@ -1,117 +1,173 @@ -/* -* @file Socket.h -* @brief Abstract class to perform API global operations -* -* @author Mohamed Amine Mzoughi -* @date 2017-02-10 -*/ +/** + * @file Socket.h + * @brief Abstract class to perform API global operations + * + * @author Mohamed Amine Mzoughi + * @date 2017-02-10 + */ #ifndef INCLUDE_ASOCKET_H_ #define INCLUDE_ASOCKET_H_ -#include // snprintf +#include +#include +#ifdef _WIN32 +#include +#endif +#include #include #include -#include -#include // va_start, etc. +#include #include - -#ifdef WINDOWS -#include -#include +#ifdef _WIN32 +#include +#include // Need to link with Ws2_32.lib #pragma comment(lib,"WS2_32.lib") #else #include -#include #include #include -#include -#include -#include -#include #include #include #include #endif -#include -#define ACCEPT_WAIT_INF_DELAY std::numeric_limits::max() +#define ACCEPT_WAIT_INF_DELAY (std::numeric_limits::max)() -class ASocket -{ -public: - // Public definitions - //typedef std::function ProgressFnCallback; - typedef std::function LogFnCallback; +#ifndef _WIN32 +#if EAGAIN == EWOULDBLOCK +#define SOCKET_ERR_IS_EAGAIN(e) ((e) == EAGAIN) +#else +#define SOCKET_ERR_IS_EAGAIN(e) ((e) == EAGAIN || (e) == EWOULDBLOCK) +#endif - // socket file descriptor id - #ifdef WINDOWS - typedef SOCKET Socket; - #else - typedef int Socket; - #define INVALID_SOCKET -1 - #endif +#define SOCKET_ERR_SELECT_RETRIABLE(e) \ + (SOCKET_ERR_IS_EAGAIN(e) || (e) == EINTR || (e) == ENOMEM) - enum SettingsFlag - { - NO_FLAGS = 0x00, - ENABLE_LOG = 0x01, - ALL_FLAGS = 0xFF - }; +#define SOCKET_ERR_RW_RETRIABLE(e) \ + (SOCKET_ERR_IS_EAGAIN(e) || (e) == EINTR || (e) == ENOMEM) - /* Please provide your logger thread-safe routine, otherwise, you can turn off - * error log messages printing by not using the flag ALL_FLAGS or ENABLE_LOG */ - explicit ASocket(const LogFnCallback& oLogger, - const SettingsFlag eSettings = ALL_FLAGS); - virtual ~ASocket() = 0; +#define SOCKET_ERR_CONNECT_RETRIABLE(e) \ + (SOCKET_ERR_IS_EAGAIN(e) || (e) == EINTR || (e) == EINPROGRESS || (e) == EALREADY) - static int SelectSockets(const Socket* pSocketsToSelect, const size_t count, - const size_t msec, size_t& selectedIndex); +#define SOCKET_ERR_ACCEPT_RETRIABLE(e) \ + (SOCKET_ERR_IS_EAGAIN(e) || (e) == EINTR || (e) == ECONNABORTED || (e) == EPROTO) - static int SelectSocket(const Socket sd, const size_t msec); +#define SOCKET_ERR_CONNECT_REFUSED(e) \ + ((e) == ECONNREFUSED) - static struct timeval TimevalFromMsec(unsigned int time_msec); +#define SOCKET_ERR_ADDR_INUSE(e) \ + ((e) == EADDRINUSE) - // String Helpers - static std::string StringFormat(const std::string strFormat, ...); +#else +#define SOCKET_ERR_IS_EAGAIN(e) ((e) == WSAEWOULDBLOCK) -protected: - // Log printer callback - /*mutable*/const LogFnCallback m_oLog; +#define SOCKET_ERR_SELECT_RETRIABLE(e) \ + (SOCKET_ERR_IS_EAGAIN(e) || (e) == WSAEINTR || (e) == WSAEFAULT || (e) == WSAEINPROGRESS) - SettingsFlag m_eSettingsFlags; +#define SOCKET_ERR_RW_RETRIABLE(e) \ + (SOCKET_ERR_IS_EAGAIN(e) || (e) == WSAEINTR || (e) == WSAENOBUFS) - #ifdef WINDOWS - static WSADATA s_wsaData; - #endif +#define SOCKET_ERR_CONNECT_RETRIABLE(e) \ + (SOCKET_ERR_IS_EAGAIN(e) || (e) == WSAEINTR || (e) == WSAEINPROGRESS || (e) == WSAEALREADY) -private: - friend class SocketGlobalInitializer; - class SocketGlobalInitializer { - public: - static SocketGlobalInitializer& instance(); +#define SOCKET_ERR_ACCEPT_RETRIABLE(e) \ + (SOCKET_ERR_IS_EAGAIN(e) || (e) == WSAEINTR || (e) == WSAECONNABORTED || (e) == WSAEPROTONOSUPPORT) + +#define SOCKET_ERR_CONNECT_REFUSED(e) \ + ((e) == WSAECONNREFUSED) + +#define SOCKET_ERR_ADDR_INUSE(e) \ + ((e) == WSAEADDRINUSE) + +#endif + +class ASocket +{ +public: + // Public definitions + //typedef std::function ProgressFnCallback; + typedef std::function LogFnCallback; - SocketGlobalInitializer(SocketGlobalInitializer const&) = delete; - SocketGlobalInitializer(SocketGlobalInitializer&&) = delete; + // socket file descriptor id +#ifdef _WIN32 + typedef SOCKET Socket; +#else + typedef int Socket; +#define INVALID_SOCKET -1 +#define SOCKET_ERROR -1 +#endif - SocketGlobalInitializer& operator=(SocketGlobalInitializer const&) = delete; - SocketGlobalInitializer& operator=(SocketGlobalInitializer&&) = delete; + static int GetSocketError(); + static char* GaiStrerror(int ecode); + static void SocketClose(Socket& sd); + + enum SettingsFlag + { + NO_FLAGS = 0x00, + ENABLE_LOG = 0x01, + ALL_FLAGS = 0xFF + }; + + /** + * Please provide your logger thread-safe routine, otherwise, you can turn off + * error log messages printing by not using the flag ALL_FLAGS or ENABLE_LOG + */ + explicit ASocket(const LogFnCallback& oLogger, SettingsFlag eSettings = ALL_FLAGS); + virtual ~ASocket(); + + static int SelectSockets(const Socket* pSocketsToSelect, size_t count, size_t& selectedIndex, size_t msec = ACCEPT_WAIT_INF_DELAY); + static int SelectSocket(Socket sd, size_t msec = ACCEPT_WAIT_INF_DELAY); + + // To disable timeout, set msec_timeout to 0. + static bool SetRcvTimeout(Socket sd, unsigned int msec_timeout); + static bool SetSndTimeout(Socket sd, unsigned int msec_timeout); + +#ifndef _WIN32 + static bool SetRcvTimeout(Socket sd, const struct timeval& timeout); + static bool SetSndTimeout(Socket sd, const struct timeval& timeout); +#endif + + static struct timeval TimevalFromMsec(unsigned int time_msec); - ~SocketGlobalInitializer(); +protected: + // String Helpers + static std::string StringFormat(const char* fmt, ...); - private: - SocketGlobalInitializer(); - }; - SocketGlobalInitializer& m_globalInitializer; +#ifdef _WIN32 +private: + static bool InitializeEnvironment(); + static void UnInitializeEnvironment(); +#endif + +protected: + // Log printer callback + /*mutable*/const LogFnCallback m_oLog; + + SettingsFlag m_eSettingsFlags; + +#ifdef _WIN32 +private: + static WSADATA s_wsaData; + static std::atomic s_iSocketCount; +#endif }; +#define SocketLog(fmt, ...) \ +do { \ + if (m_oLog && ((m_eSettingsFlags & ENABLE_LOG) == ENABLE_LOG)) \ + { \ + m_oLog(StringFormat(fmt, ##__VA_ARGS__)); \ + } \ +} while(0) + class EResolveError : public std::logic_error { public: - explicit EResolveError(const std::string &strMsg) : std::logic_error(strMsg) {} + explicit EResolveError(const std::string& strMsg) : std::logic_error(strMsg) {} }; #endif diff --git a/Socket/TCPClient.cpp b/Socket/TCPClient.cpp index e2438a6..5d266e8 100644 --- a/Socket/TCPClient.cpp +++ b/Socket/TCPClient.cpp @@ -5,419 +5,289 @@ */ #include "TCPClient.h" - -CTCPClient::CTCPClient(const LogFnCallback oLogger, - const SettingsFlag eSettings /*= ALL_FLAGS*/) : - ASocket(oLogger, eSettings), - m_eStatus(DISCONNECTED), - m_pResultAddrInfo(nullptr), - m_ConnectSocket(INVALID_SOCKET) - //m_uRetryCount(0), - //m_uRetryPeriod(0) +#include + +CTCPClient::CTCPClient(const LogFnCallback& oLogger, const SettingsFlag eSettings /*= ALL_FLAGS*/) + : ASocket(oLogger, eSettings) + , m_eStatus(DISCONNECTED) + , m_ConnectSocket(INVALID_SOCKET) + //, m_uRetryCount(0) + //, m_uRetryPeriod(0) + , m_pResultAddrInfo(nullptr) + , m_HintsAddrInfo() { +} +CTCPClient::~CTCPClient() +{ + Disconnect(); } // Method for setting receive timeout. Can be called after Connect -bool CTCPClient::SetRcvTimeout(unsigned int msec_timeout) { -#ifndef WINDOWS - struct timeval t = ASocket::TimevalFromMsec(msec_timeout); - - return this->SetRcvTimeout(t); -#else - int iErr; +bool CTCPClient::SetRcvTimeout(unsigned int msec_timeout) +{ + bool ret_val = ASocket::SetRcvTimeout(m_ConnectSocket, msec_timeout); + if (!ret_val) + { + SocketLog("[ERROR]TCPClient, setsockopt SOL_SOCKET SO_RCVTIMEO failed[%d:%s][%d][%u]", GetSocketError(), strerror(GetSocketError()), m_ConnectSocket, msec_timeout); + } - // it's expecting an int but it doesn't matter... - iErr = setsockopt(m_ConnectSocket, SOL_SOCKET, SO_RCVTIMEO, (char*)&msec_timeout, sizeof(struct timeval)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPClient::SetRcvTimeout : Socket error in SO_RCVTIMEO call to setsockopt."); + return ret_val; +} - return false; +#ifndef _WIN32 +bool CTCPClient::SetRcvTimeout(const struct timeval& timeout) +{ + bool ret_val = ASocket::SetRcvTimeout(m_ConnectSocket, timeout); + if (!ret_val) + { + SocketLog("[ERROR]TCPClient, setsockopt SOL_SOCKET SO_RCVTIMEO failed[%d:%s][%d][%u:%u]", GetSocketError(), strerror(GetSocketError()), m_ConnectSocket, timeout.tv_sec, timeout.tv_usec); } - return true; -#endif + return ret_val; } +#endif -#ifndef WINDOWS -bool CTCPClient::SetRcvTimeout(struct timeval timeout) { - int iErr; +// Method for setting send timeout. Can be called after Connect +bool CTCPClient::SetSndTimeout(unsigned int msec_timeout) +{ + bool ret_val = ASocket::SetSndTimeout(m_ConnectSocket, msec_timeout); + if (!ret_val) + { + SocketLog("[ERROR]TCPClient, setsockopt SOL_SOCKET SO_SNDTIMEO failed[%d:%s][%d][%u]", GetSocketError(), strerror(GetSocketError()), m_ConnectSocket, msec_timeout); + } - iErr = setsockopt(m_ConnectSocket, SOL_SOCKET, SO_RCVTIMEO, (char*) &timeout, sizeof(struct timeval)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPClient::SetRcvTimeout : Socket error in SO_RCVTIMEO call to setsockopt."); + return ret_val; +} - return false; - } +#ifndef _WIN32 +bool CTCPClient::SetSndTimeout(const struct timeval& timeout) +{ + bool ret_val = ASocket::SetSndTimeout(m_ConnectSocket, timeout); + if (!ret_val) + { + SocketLog("[ERROR]TCPClient, setsockopt SOL_SOCKET SO_SNDTIMEO failed[%d:%s][%d][%u:%u]", GetSocketError(), strerror(GetSocketError()), m_ConnectSocket, timeout.tv_sec, timeout.tv_usec); + } - return true; + return ret_val; } #endif -// Method for setting send timeout. Can be called after Connect -bool CTCPClient::SetSndTimeout(unsigned int msec_timeout) { -#ifndef WINDOWS - struct timeval t = ASocket::TimevalFromMsec(msec_timeout); +// Connexion au serveur +bool CTCPClient::Connect(const std::string& strServer, const std::string& strPort) +{ + if (m_eStatus == CONNECTED) + { + Disconnect(); + SocketLog("[WARN ]TCPClient, opening a new connexion. the last one was automatically closed[%s:%s]", strServer.c_str(), strPort.c_str()); + } + + memset(&m_HintsAddrInfo, 0, sizeof m_HintsAddrInfo); + /* AF_INET is used to specify the IPv4 address family. */ + m_HintsAddrInfo.ai_family = AF_INET; + /* SOCK_STREAM is used to specify a stream socket. */ + m_HintsAddrInfo.ai_socktype = SOCK_STREAM; + /* IPPROTO_TCP is used to specify the TCP protocol. */ + m_HintsAddrInfo.ai_protocol = IPPROTO_TCP; + + /* Resolve the server address and port */ + int iResult = getaddrinfo(strServer.c_str(), strPort.c_str(), &m_HintsAddrInfo, &m_pResultAddrInfo); + if (iResult != 0) + { + SocketLog("[ERROR]TCPClient, getaddrinfo failed[%d:%s][%s:%s]", iResult, GaiStrerror(iResult), strServer.c_str(), strPort.c_str()); + return false; + } - return this->SetSndTimeout(t); -#else - int iErr; + bool isOK = false; + + /* getaddrinfo() returns a list of address structures. + * Try each address until we successfully connect(2). + * If socket(2) (or connect(2)) fails, we (close the socket + * and) try the next address. */ + for (struct addrinfo* pResPtr = m_pResultAddrInfo; pResPtr != nullptr; pResPtr = pResPtr->ai_next) + { + // create socket + m_ConnectSocket = socket(pResPtr->ai_family, pResPtr->ai_socktype, pResPtr->ai_protocol); + if (m_ConnectSocket == INVALID_SOCKET) + { + SocketLog("[WARN ]TCPClient, create socket failed[%d:%s]", GetSocketError(), strerror(GetSocketError())); + continue; + } + + // Fixes windows 0.2 second delay sending (buffering) data. + int on = 1; + if (setsockopt(m_ConnectSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&on, sizeof(on)) == SOCKET_ERROR) + { + SocketLog("[WARN ]TCPClient, setsockopt IPPROTO_TCP TCP_NODELAY failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ConnectSocket); + } + + // connexion to the server + if (connect(m_ConnectSocket, pResPtr->ai_addr, static_cast(pResPtr->ai_addrlen)) == SOCKET_ERROR) + { + int iErrCode = GetSocketError(); + if (!SOCKET_ERR_CONNECT_RETRIABLE(iErrCode)) + { + SocketLog("[WARN ]TCPClient, connect failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ConnectSocket); + assert(m_ConnectSocket != INVALID_SOCKET); + SocketClose(m_ConnectSocket); + continue; + } + } + + isOK = true; + m_eStatus = CONNECTED; + SocketLog("[INFO ]TCPClient, connected[%d]", m_ConnectSocket); + break; + } - // it's expecting an int but it doesn't matter... - iErr = setsockopt(m_ConnectSocket, SOL_SOCKET, SO_SNDTIMEO, (char*)&msec_timeout, sizeof(struct timeval)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPClient::SetSndTimeout : Socket error in SO_SNDTIMEO call to setsockopt."); + if (m_pResultAddrInfo != nullptr) + { + freeaddrinfo(m_pResultAddrInfo); /* No longer needed */ + m_pResultAddrInfo = nullptr; + } - return false; + /* No address succeeded */ + if (!isOK) + { + SocketLog("[ERROR]TCPClient, Connect failed[%s:%s]", strServer.c_str(), strPort.c_str()); } - return true; -#endif + return isOK; } -#ifndef WINDOWS -bool CTCPClient::SetSndTimeout(struct timeval timeout) { - int iErr; +/* ret > 0 : bytes received + * ret == 0 : connection closed + * ret < 0 : recv failed + */ +int CTCPClient::Receive(char* pData, size_t uSize, bool bReadFully /*= true*/) const +{ + if (m_eStatus != CONNECTED || m_ConnectSocket == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPClient, recv failed[not connected to a server.]"); + return -1; + } - iErr = setsockopt(m_ConnectSocket, SOL_SOCKET, SO_SNDTIMEO, (char*) &timeout, sizeof(struct timeval)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPClient::SetSndTimeout : Socket error in SO_SNDTIMEO call to setsockopt."); + if (pData == nullptr && uSize != 0) + { + SocketLog("[ERROR]TCPClient, recv failed[%d][%p:%zu]", m_ConnectSocket, pData, uSize); + return -2; + } - return false; - } +#if 0 +#ifdef _WIN32 + int tries = 0; +#endif +#endif - return true; -} + int total = 0; + bool isOK = true; + do + { + isOK = true; + int nRecvd = recv(m_ConnectSocket, pData + total, (int)uSize - total, 0); + if (nRecvd == SOCKET_ERROR) + { + isOK = false; + int iErrCode = GetSocketError(); + if (SOCKET_ERR_RW_RETRIABLE(iErrCode)) + { + continue; + } +#if 0 +#ifdef _WIN32 + // On long messages, Windows recv sometimes fails with WSAENOBUFS, but + // will work if you try again. + if (WSAGetLastError() == WSAENOBUFS && (tries++ < 1000)) + { + Sleep(1); + continue; + } #endif +#endif + SocketLog("[ERROR]TCPClient, recv failed[%d:%s][%d]", iErrCode, strerror(iErrCode), m_ConnectSocket); + break; + } -// Connexion au serveur -bool CTCPClient::Connect(const std::string& strServer, const std::string& strPort) -{ - if (m_eStatus == CONNECTED) - { - Disconnect(); - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPClient][Warning] Opening a new connexion. The last one was automatically closed."); - } - - #ifdef WINDOWS - ZeroMemory(&m_HintsAddrInfo, sizeof(m_HintsAddrInfo)); - /* AF_INET is used to specify the IPv4 address family. */ - m_HintsAddrInfo.ai_family = AF_INET; - /* SOCK_STREAM is used to specify a stream socket. */ - m_HintsAddrInfo.ai_socktype = SOCK_STREAM; - /* IPPROTO_TCP is used to specify the TCP protocol. */ - m_HintsAddrInfo.ai_protocol = IPPROTO_TCP; - - /* Resolve the server address and port */ - int iResult = getaddrinfo(strServer.c_str(), strPort.c_str(), &m_HintsAddrInfo, &m_pResultAddrInfo); - if (iResult != 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPClient][Error] getaddrinfo failed : %d", iResult)); - - if (m_pResultAddrInfo != nullptr) - { - freeaddrinfo(m_pResultAddrInfo); - m_pResultAddrInfo = nullptr; - } - - return false; - } - - // socket creation - m_ConnectSocket = socket(m_pResultAddrInfo->ai_family, // AF_INET - m_pResultAddrInfo->ai_socktype, // SOCK_STREAM - m_pResultAddrInfo->ai_protocol);// IPPROTO_TCP - - if (m_ConnectSocket == INVALID_SOCKET) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPClient][Error] socket failed : %d", WSAGetLastError())); - - freeaddrinfo(m_pResultAddrInfo); - m_pResultAddrInfo = nullptr; - return false; - } - - // Fixes windows 0.2 second delay sending (buffering) data. - int on = 1; - int iErr; - - iErr = setsockopt(m_ConnectSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&on, sizeof(on)); - if (iErr == INVALID_SOCKET) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPClient][Error] Socket error in call to setsockopt"); - - closesocket(m_ConnectSocket); - freeaddrinfo(m_pResultAddrInfo); m_pResultAddrInfo = nullptr; - - return false; - } - - /* - SOCKET ConnectSocket = INVALID_SOCKET; - struct sockaddr_in clientService; - - ConnectSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - if (ConnectSocket == INVALID_SOCKET) { - printf("Error at socket(): %ld\n", WSAGetLastError()); - WSACleanup(); - return 1; - } - - // The sockaddr_in structure specifies the address family, - // IP address, and port of the server to be connected to. - clientService.sin_family = AF_INET; - clientService.sin_addr.s_addr = inet_addr("127.0.0.1"); - clientService.sin_port = htons(27015); - */ - - // connexion to the server - //unsigned uRetry = 0; - //do - //{ - iResult = connect(m_ConnectSocket, - m_pResultAddrInfo->ai_addr, - static_cast(m_pResultAddrInfo->ai_addrlen)); -//iResult = connect(m_ConnectSocket, (SOCKADDR*)&clientService, sizeof(clientService)); - - //if (iResult != SOCKET_ERROR) - //break; - - // retry mechanism - //if (uRetry < m_uRetryCount) - //if (m_eSettingsFlags & ENABLE_LOG) - /*m_oLog(StringFormat("[TCPClient][Error] connect retry %u after %u second(s)", - m_uRetryCount + 1, m_uRetryPeriod));*/ - - //if (m_uRetryPeriod > 0) - //{ - //for (unsigned uSec = 0; uSec < m_uRetryPeriod; uSec++) - //Sleep(1000); - //} - //} while (iResult == SOCKET_ERROR && ++uRetry < m_uRetryCount); - - freeaddrinfo(m_pResultAddrInfo); - m_pResultAddrInfo = nullptr; - - if (iResult != SOCKET_ERROR) - { - m_eStatus = CONNECTED; - return true; - } - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPClient][Error] Unable to connect to server : %d", WSAGetLastError())); - - #else - memset(&m_HintsAddrInfo, 0, sizeof m_HintsAddrInfo); - m_HintsAddrInfo.ai_family = AF_INET; // AF_INET or AF_INET6 to force version or use AF_UNSPEC - m_HintsAddrInfo.ai_socktype = SOCK_STREAM; - //m_HintsAddrInfo.ai_flags = 0; - //m_HintsAddrInfo.ai_protocol = 0; /* Any protocol */ - - int iAddrInfoRet = getaddrinfo(strServer.c_str(), strPort.c_str(), &m_HintsAddrInfo, &m_pResultAddrInfo); - if (iAddrInfoRet != 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPClient][Error] getaddrinfo failed : %s", gai_strerror(iAddrInfoRet))); - - if (m_pResultAddrInfo != nullptr) - { - freeaddrinfo(m_pResultAddrInfo); - m_pResultAddrInfo = nullptr; - } - - return false; - } - - /* getaddrinfo() returns a list of address structures. - * Try each address until we successfully connect(2). - * If socket(2) (or connect(2)) fails, we (close the socket - * and) try the next address. */ - struct addrinfo* pResPtr = m_pResultAddrInfo; - for (pResPtr = m_pResultAddrInfo; pResPtr != nullptr; pResPtr = pResPtr->ai_next) - { - // create socket - m_ConnectSocket = socket(pResPtr->ai_family, pResPtr->ai_socktype, pResPtr->ai_protocol); - if (m_ConnectSocket < 0) // or == -1 - continue; - - // connexion to the server - int iConRet = connect(m_ConnectSocket, pResPtr->ai_addr, pResPtr->ai_addrlen); - if (iConRet >= 0) // or != -1 - { - /* Success */ - m_eStatus = CONNECTED; - - if (m_pResultAddrInfo != nullptr) - { - freeaddrinfo(m_pResultAddrInfo); - m_pResultAddrInfo = nullptr; - } - - return true; - } - - close(m_ConnectSocket); - } - - if (m_pResultAddrInfo != nullptr) - { - freeaddrinfo(m_pResultAddrInfo); /* No longer needed */ - m_pResultAddrInfo = nullptr; - } - - /* No address succeeded */ - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPClient][Error] no such host."); - - #endif - - return false; -} + if (nRecvd == 0) + { + SocketLog("[INFO ]TCPClient, peer shut down[%d]", m_ConnectSocket); + break; + } -bool CTCPClient::Send(const char* pData, const size_t uSize) const -{ - if (!pData || !uSize) - return false; - - if (m_eStatus != CONNECTED) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPClient][Error] send failed : not connected to a server."); - - return false; - } - - int total = 0; - do - { - const int flags = 0; - int nSent; - - nSent = send(m_ConnectSocket, pData + total, uSize - total, flags); - - if (nSent < 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPClient][Error] Socket error in call to send."); - - return false; - } - total += nSent; - } while(total < uSize); - - return true; -} + total += nRecvd; -bool CTCPClient::Send(const std::string& strData) const -{ - return Send(strData.c_str(), strData.length()); + } while (bReadFully && (total < (int)uSize)); + + if (!isOK && total == 0) + { + return -1; + } + + return (int)total; } -bool CTCPClient::Send(const std::vector& Data) const +int CTCPClient::Send(const char* pData, size_t uSize) const { - return Send(Data.data(), Data.size()); + if (m_eStatus != CONNECTED || m_ConnectSocket == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPClient, send failed[not connected to a server.]"); + return -1; + } + + if (pData == nullptr && uSize != 0) + { + SocketLog("[WARN ]TCPClient, send failed[%d][%p:%zu]", m_ConnectSocket, pData, uSize); + return -1; + } + + int total = 0; + do + { + int nSent = send(m_ConnectSocket, pData + total, (int)uSize - total, 0); + if (nSent == SOCKET_ERROR) + { + int iErrCode = GetSocketError(); + if (SOCKET_ERR_RW_RETRIABLE(iErrCode)) + { + continue; + } + + SocketLog("[ERROR]TCPClient, send failed[%d:%s][%d]", iErrCode, strerror(iErrCode), m_ConnectSocket); + return -1; + } + + total += nSent; + } while (total < (int)uSize); + + return (int)total; } -/* ret > 0 : bytes received - * ret == 0 : connection closed - * ret < 0 : recv failed - */ -int CTCPClient::Receive(char* pData, const size_t uSize, bool bReadFully /*= true*/) const +int CTCPClient::Send(const std::string& strData) const { - if (!pData || !uSize) - return -2; - - if (m_eStatus != CONNECTED) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPClient][Error] recv failed : not connected to a server."); - - return -1; - } - - #ifdef WINDOWS - int tries = 0; - #endif - - int total = 0; - do - { - int nRecvd = recv(m_ConnectSocket, pData + total, uSize - total, 0); - - if (nRecvd == 0) - { - // peer shut down - break; - } - - #ifdef WINDOWS - if ((nRecvd < 0) && (WSAGetLastError() == WSAENOBUFS)) - { - // On long messages, Windows recv sometimes fails with WSAENOBUFS, but - // will work if you try again. - if ((tries++ < 1000)) - { - Sleep(1); - continue; - } - - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPClient][Error] Socket error in call to recv."); - - break; - } - #endif - - total += nRecvd; - - } while (bReadFully && (total < uSize)); - - return total; + return Send(strData.c_str(), strData.length()); } -bool CTCPClient::Disconnect() +int CTCPClient::Send(const std::vector& Data) const { - if (m_eStatus != CONNECTED) - return true; - - m_eStatus = DISCONNECTED; - - #ifdef WINDOWS - // shutdown the connection since no more data will be sent - int iResult = shutdown(m_ConnectSocket, SD_SEND); - if (iResult == SOCKET_ERROR) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPClient][Error] shutdown failed : %d", WSAGetLastError())); - - return false; - } - closesocket(m_ConnectSocket); - - if (m_pResultAddrInfo != nullptr) - { - freeaddrinfo(m_pResultAddrInfo); - m_pResultAddrInfo = nullptr; - } - #else - close(m_ConnectSocket); - #endif - - m_ConnectSocket = INVALID_SOCKET; - - return true; + return Send(Data.data(), Data.size()); } -CTCPClient::~CTCPClient() +void CTCPClient::Disconnect() { - if (m_eStatus == CONNECTED) - Disconnect(); + if (m_eStatus != CONNECTED) + { + m_eStatus = DISCONNECTED; + } + + if (m_ConnectSocket != INVALID_SOCKET) + { +#if 0//defined(_WIN32) + // shutdown the connection since no more data will be sent + if (shutdown(m_ConnectSocket, SD_SEND) == SOCKET_ERROR) + { + SocketLog("[ERROR]TCPClient, shutdown SD_SEND failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ConnectSocket); + } +#endif + SocketClose(m_ConnectSocket); + } } diff --git a/Socket/TCPClient.h b/Socket/TCPClient.h index 1c5db98..d78961d 100644 --- a/Socket/TCPClient.h +++ b/Socket/TCPClient.h @@ -9,14 +9,6 @@ #ifndef INCLUDE_TCPCLIENT_H_ #define INCLUDE_TCPCLIENT_H_ -#include -#include // size_t -#include -#include // strerror, strlen, memcpy, strcpy -#include -#include -#include -#include #include #include @@ -26,62 +18,51 @@ class CTCPSSLClient; class CTCPClient : public ASocket { - friend class CTCPSSLClient; - + friend class CTCPSSLClient; public: - explicit CTCPClient(const LogFnCallback oLogger, const SettingsFlag eSettings = ALL_FLAGS); - ~CTCPClient() override; - - // copy constructor and assignment operator are disabled - CTCPClient(const CTCPClient&) = delete; - CTCPClient& operator=(const CTCPClient&) = delete; - - // Setters - Getters (for unit tests) - /*inline*/// void SetProgressFnCallback(void* pOwner, const ProgressFnCallback& fnCallback); - /*inline*/// void SetProxy(const std::string& strProxy); - /*inline auto GetProgressFnCallback() const - { - return m_fnProgressCallback.target(); - } - inline void* GetProgressFnCallbackOwner() const { return m_ProgressStruct.pOwner; }*/ - //inline const std::string& GetProxy() const { return m_strProxy; } - //inline const unsigned char GetSettingsFlags() const { return m_eSettingsFlags; } - - // Session - bool Connect(const std::string& strServer, const std::string& strPort); // connect to a TCP server - bool Disconnect(); // disconnect from the TCP server - bool Send(const char* pData, const size_t uSize) const; // send data to a TCP server - bool Send(const std::string& strData) const; - bool Send(const std::vector& Data) const; - int Receive(char* pData, const size_t uSize, bool bReadFully = true) const; - - // To disable timeout, set msec_timeout to 0. - bool SetRcvTimeout(unsigned int msec_timeout); - bool SetSndTimeout(unsigned int msec_timeout); - -#ifndef WINDOWS - bool SetRcvTimeout(struct timeval Timeout); - bool SetSndTimeout(struct timeval Timeout); + explicit CTCPClient(const LogFnCallback& oLogger, const SettingsFlag eSettings = ALL_FLAGS); + ~CTCPClient() override; + + // copy constructor and assignment operator are disabled + CTCPClient(const CTCPClient&) = delete; + CTCPClient& operator=(const CTCPClient&) = delete; + + // Session + bool Connect(const std::string& strServer, const std::string& strPort); // connect to a TCP server + void Disconnect(); // disconnect from the TCP server + + int Receive(char* pData, size_t uSize, bool bReadFully = true) const; + int Send(const char* pData, size_t uSize) const; // send data to a TCP server + int Send(const std::string& strData) const; + int Send(const std::vector& Data) const; + + // To disable timeout, set msec_timeout to 0. + bool SetRcvTimeout(unsigned int msec_timeout); + bool SetSndTimeout(unsigned int msec_timeout); + +#ifndef _WIN32 + bool SetRcvTimeout(const struct timeval& timeout); + bool SetSndTimeout(const struct timeval& timeout); #endif - bool IsConnected() const { return m_eStatus == CONNECTED; } + bool IsConnected() const { return m_eStatus == CONNECTED; } - Socket GetSocketDescriptor() const { return m_ConnectSocket; } + Socket GetSocketDescriptor() const { return m_ConnectSocket; } protected: - enum SocketStatus - { - CONNECTED, - DISCONNECTED - }; - - SocketStatus m_eStatus; - Socket m_ConnectSocket; // ConnectSocket - //unsigned m_uRetryCount; - //unsigned m_uRetryPeriod; - - struct addrinfo* m_pResultAddrInfo; - struct addrinfo m_HintsAddrInfo; + enum SocketStatus + { + CONNECTED, + DISCONNECTED + }; + + SocketStatus m_eStatus; + Socket m_ConnectSocket; // ConnectSocket + //unsigned m_uRetryCount; + //unsigned m_uRetryPeriod; + + struct addrinfo* m_pResultAddrInfo; + struct addrinfo m_HintsAddrInfo; }; #endif diff --git a/Socket/TCPSSLClient.cpp b/Socket/TCPSSLClient.cpp index f8e8514..be78dc7 100644 --- a/Socket/TCPSSLClient.cpp +++ b/Socket/TCPSSLClient.cpp @@ -7,257 +7,281 @@ #ifdef OPENSSL #include "TCPSSLClient.h" -CTCPSSLClient::CTCPSSLClient(const LogFnCallback oLogger, - const OpenSSLProtocol eSSLVersion, - const SettingsFlag eSettings /*= ALL_FLAGS*/) : - ASecureSocket(oLogger, eSSLVersion, eSettings), - m_TCPClient(oLogger, eSettings) +CTCPSSLClient::CTCPSSLClient(const LogFnCallback& oLogger, + const OpenSSLProtocol& eSSLVersion, + const SettingsFlag eSettings /*= ALL_FLAGS*/) + : ASecureSocket(oLogger, eSSLVersion, eSettings) + , m_TCPClient(oLogger, eSettings) { - } -bool CTCPSSLClient::SetRcvTimeout(unsigned int msec_timeout){ - return m_TCPClient.SetRcvTimeout(msec_timeout); +CTCPSSLClient::~CTCPSSLClient() +{ + Disconnect(); } -bool CTCPSSLClient::SetSndTimeout(unsigned int msec_timeout){ - return m_TCPClient.SetSndTimeout(msec_timeout); +bool CTCPSSLClient::SetRcvTimeout(unsigned int msec_timeout) +{ + return m_TCPClient.SetRcvTimeout(msec_timeout); } -#ifndef WINDOWS -bool CTCPSSLClient::SetRcvTimeout(struct timeval timeout) { +#ifndef _WIN32 +bool CTCPSSLClient::SetRcvTimeout(struct timeval timeout) +{ return m_TCPClient.SetRcvTimeout(timeout); } +#endif + +bool CTCPSSLClient::SetSndTimeout(unsigned int msec_timeout) +{ + return m_TCPClient.SetSndTimeout(msec_timeout); +} -bool CTCPSSLClient::SetSndTimeout(struct timeval timeout){ - return m_TCPClient.SetSndTimeout(timeout); +#ifndef _WIN32 +bool CTCPSSLClient::SetSndTimeout(struct timeval timeout) +{ + return m_TCPClient.SetSndTimeout(timeout); } #endif // Connexion au serveur bool CTCPSSLClient::Connect(const std::string& strServer, const std::string& strPort) { - if (m_TCPClient.Connect(strServer, strPort)) - { - m_SSLConnectSocket.m_SockFd = m_TCPClient.m_ConnectSocket; - SetUpCtxClient(m_SSLConnectSocket); - - if (m_SSLConnectSocket.m_pCTXSSL == nullptr) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLClient][Error] SSL_CTX_new failed."); - //ERR_print_errors_fp(stdout); - return false; - } - - /* process SSL certificates */ - /* Load a client certificate into the SSL_CTX structure. */ - if (!m_strSSLCertFile.empty()) - { - if (SSL_CTX_use_certificate_file(m_SSLConnectSocket.m_pCTXSSL, - m_strSSLCertFile.c_str(), SSL_FILETYPE_PEM) <= 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLClient][Error] Loading cert file failed."); - - return false; - } - } - /* Load trusted CA. Mandatory to verify server's certificate */ - if (!m_strCAFile.empty()) - { - if (!SSL_CTX_load_verify_locations(m_SSLConnectSocket.m_pCTXSSL, m_strCAFile.c_str(), nullptr)) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLClient][Error] Loading CA file failed."); - - return false; - } - SSL_CTX_set_verify_depth(m_SSLConnectSocket.m_pCTXSSL, 1); - } - /* Load a private-key into the SSL_CTX structure. - * set key file that corresponds to the server or client certificate. - * In the SSL handshake, a certificate (which contains the public key) is transmitted to allow - * the peer to use it for encryption. The encrypted message sent from the peer can be decrypted - * only using the private key. */ - if (!m_strSSLKeyFile.empty()) - { - if (SSL_CTX_use_PrivateKey_file(m_SSLConnectSocket.m_pCTXSSL, - m_strSSLKeyFile.c_str(), SSL_FILETYPE_PEM) <= 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLClient][Error] Loading key file failed."); + if (!m_TCPClient.Connect(strServer, strPort)) + { + SocketLog("[ERROR]TCPSSLClient, m_TCPClient Connect failed[:Unable to establish a TCP connection with the server.][%s:%s]", strServer.c_str(), strPort.c_str()); + return false; + } + + do + { + m_SSLConnectSocket.m_SockFd = m_TCPClient.m_ConnectSocket; + if (!SetUpCtxClient(m_SSLConnectSocket)) + { + SocketLog("[ERROR]TCPSSLClient, SSL_CTX_new failed[%s:%s][%d]", strServer.c_str(), strPort.c_str(), m_SSLConnectSocket.m_SockFd); //ERR_print_errors_fp(stdout); - return false; - } - - /* verify private key */ - /*if (!SSL_CTX_check_private_key(m_SSLConnectSocket.m_pCTXSSL)) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLClient][Error] Private key does not match the public certificate."); - return false; - }*/ - } - //SSL_CTX_set_cert_verify_callback(m_SSLConnectSocket.m_pCTXSSL, AlwaysTrueCallback, nullptr); - - /* create new SSL connection state */ - m_SSLConnectSocket.m_pSSL = SSL_new(m_SSLConnectSocket.m_pCTXSSL); - SSL_set_fd(m_SSLConnectSocket.m_pSSL, m_SSLConnectSocket.m_SockFd); - - /* initiate the TLS/SSL handshake with an TLS/SSL server */ - int iResult = SSL_connect(m_SSLConnectSocket.m_pSSL); - if (iResult > 0) - { - /* The data can now be transmitted securely over this connection. */ - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPSSLClient][Info] Connected with '%s' encryption.", - SSL_get_cipher(m_SSLConnectSocket.m_pSSL))); - - /*if (SSL_get_peer_certificate(m_SSLConnectSocket.m_pSSL) != nullptr) - { - if (SSL_get_verify_result(m_SSLConnectSocket.m_pSSL) == X509_V_OK) + break; + } + + /* process SSL certificates */ + /* Load a client certificate into the SSL_CTX structure. */ + if (!m_strSSLCertFile.empty()) + { + if (SSL_CTX_use_certificate_file(m_SSLConnectSocket.m_pCTXSSL, m_strSSLCertFile.c_str(), SSL_FILETYPE_PEM) != 1) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("client verification with SSL_get_verify_result() succeeded."); - } - else + SocketLog("[ERROR]TCPSSLClient, SSL_CTX_use_certificate_file failed[Loading cert file failed.][%lu:%s][%d][%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), m_SSLConnectSocket.m_SockFd, m_strSSLCertFile.c_str()); + break; + } + } + + /* Load trusted CA. Mandatory to verify server's certificate */ + if (!m_strCAFile.empty()) + { + if (SSL_CTX_load_verify_locations(m_SSLConnectSocket.m_pCTXSSL, m_strCAFile.c_str(), nullptr) != 1) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("client verification with SSL_get_verify_result() failed.\n"); - - return false; + SocketLog("[ERROR]TCPSSLClient, SSL_CTX_load_verify_locations failed[Loading CA file failed.][%lu:%s][%d][%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), m_SSLConnectSocket.m_SockFd, m_strCAFile.c_str()); + break; } - } - else if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("the peer certificate was not presented.");*/ - return true; - } - // under Windows it creates problems - #ifdef LINUX - ERR_print_errors_fp(stdout); - #endif + SSL_CTX_set_verify_depth(m_SSLConnectSocket.m_pCTXSSL, 1); + } + + /** + * Load a private-key into the SSL_CTX structure. + * set key file that corresponds to the server or client certificate. + * In the SSL handshake, a certificate (which contains the public key) is transmitted to allow + * the peer to use it for encryption. The encrypted message sent from the peer can be decrypted + * only using the private key. + */ + if (!m_strSSLKeyFile.empty()) + { + if (SSL_CTX_use_PrivateKey_file(m_SSLConnectSocket.m_pCTXSSL, m_strSSLKeyFile.c_str(), SSL_FILETYPE_PEM) != 1) + { + SocketLog("[ERROR]TCPSSLClient, SSL_CTX_use_PrivateKey_file failed[Loading key file failed.][%lu:%s][%d][%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), m_SSLConnectSocket.m_SockFd, m_strSSLKeyFile.c_str()); + //ERR_print_errors_fp(stdout); + break; + } - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPSSLClient][Error] SSL_connect failed (Error=%d | %s)", - iResult, GetSSLErrorString(SSL_get_error(m_SSLConnectSocket.m_pSSL, iResult)))); +#if 0 + /* verify private key */ + if (SSL_CTX_check_private_key(m_SSLConnectSocket.m_pCTXSSL) != 1) + { + SocketLog("[ERROR]TCPSSLClient, SSL_CTX_check_private_key failed[Private key does not match the public certificate.][%lu:%s][%d]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), m_SSLConnectSocket.m_SockFd); + break; + } +#endif + } + //SSL_CTX_set_cert_verify_callback(m_SSLConnectSocket.m_pCTXSSL, AlwaysTrueCallback, nullptr); + + /* create new SSL connection state */ + m_SSLConnectSocket.m_pSSL = SSL_new(m_SSLConnectSocket.m_pCTXSSL); + if (m_SSLConnectSocket.m_pSSL == nullptr) + { + SocketLog("[ERROR]TCPSSLClient, SSL_new failed[%lu:%s][%d]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), m_SSLConnectSocket.m_SockFd); + break; + } + + if (SSL_set_fd(m_SSLConnectSocket.m_pSSL, (int)m_SSLConnectSocket.m_SockFd) != 1) + { + SocketLog("[ERROR]TCPSSLClient, SSL_set_fd failed[%lu:%s][%d]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), m_SSLConnectSocket.m_SockFd); + break; + } + + bool connectOK = false; + do + { + /* initiate the TLS/SSL handshake with an TLS/SSL server */ + int iResult = SSL_connect(m_SSLConnectSocket.m_pSSL); + if (iResult == 1) + { + connectOK = true; + break; + } - return false; - } + int iErrCode = SSL_get_error(m_SSLConnectSocket.m_pSSL, iResult); + if (iErrCode != SSL_ERROR_WANT_CONNECT) + { + // under Windows it creates problems + SocketLog("[ERROR]TCPSSLClient, SSL_connect failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), m_SSLConnectSocket.m_SockFd); +#ifndef _WIN32 + ERR_print_errors_fp(stdout); +#endif + break; + } + } while (1); + + if (!connectOK) + { + break; + } + + /* The data can now be transmitted securely over this connection. */ + SocketLog("[INFO ]TCPSSLClient, SSL_connect with '%s' encryption.", SSL_get_cipher(m_SSLConnectSocket.m_pSSL)); + +#if 0 + if (SSL_get_peer_certificate(m_SSLConnectSocket.m_pSSL) == nullptr) + { + SocketLog("[WARN ]TCPSSLClient, SSL_get_peer_certificate failed[the peer certificate was not presented.][%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), m_SSLConnectSocket.m_SockFd); + } + else + { + if (SSL_get_verify_result(m_SSLConnectSocket.m_pSSL) != X509_V_OK) + { + SocketLog("[ERROR]TCPSSLClient, client verification with SSL_get_verify_result failed.[%d]", m_SSLConnectSocket.m_SockFd); + break; + } - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLClient][Error] Unable to establish a TCP connection with the server."); + SocketLog("[WARN ]TCPSSLClient, client verification with SSL_get_verify_result succeeded.[%d]", m_SSLConnectSocket.m_SockFd); + } +#endif - return false; -} + return true; + } while (0); -bool CTCPSSLClient::Send(const char* pData, const size_t uSize) const -{ - if (m_TCPClient.m_eStatus != CTCPClient::CONNECTED) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLClient][Error] SSL send failed : not connected to an SSL server."); - - return false; - } - - int total = 0; - do - { - /* encrypt & send message */ - int nSent = SSL_write(m_SSLConnectSocket.m_pSSL, pData + total, uSize - total); - if (nSent <= 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPSSLClient][Error] SSL_write failed (Error=%d | %s)", - nSent, GetSSLErrorString(SSL_get_error(m_SSLConnectSocket.m_pSSL, nSent)))); - - return false; - } - - total += nSent; - } while (total < uSize); - - return true; + Disconnect(); + return false; } -bool CTCPSSLClient::Send(const std::string& strData) const +int CTCPSSLClient::Receive(char* pData, size_t uSize, bool bReadFully /*= true*/) const { - return Send(strData.c_str(), strData.length()); -} + if (m_TCPClient.m_eStatus != CTCPClient::CONNECTED || m_TCPClient.m_ConnectSocket == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPSSLClient, SSL_read failed[not connected to a SSL server.]"); + return -1; + } + + if (pData == nullptr && uSize != 0) + { + SocketLog("[ERROR]TCPSSLClient, SSL_read failed[%d][%p:%zu]", m_SSLConnectSocket.m_SockFd, pData, uSize); + return -2; + } + + int total = 0; + do + { + int nRecvd = SSL_read(m_SSLConnectSocket.m_pSSL, pData + total, (int)uSize - total); + if (nRecvd <= 0) + { + int iErrCode = SSL_get_error(m_SSLConnectSocket.m_pSSL, nRecvd); + if (iErrCode == SSL_ERROR_WANT_READ) + { + continue; + } -bool CTCPSSLClient::Send(const std::vector& Data) const -{ - return Send(Data.data(), Data.size()); -} + SocketLog("[ERROR]TCPSSLClient, SSL_read failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), m_SSLConnectSocket.m_SockFd); + break; + } -bool CTCPSSLClient::HasPending() -{ - int pend; + total += nRecvd; - pend = SSL_has_pending(m_SSLConnectSocket.m_pSSL); + } while (bReadFully && (total < (int)uSize)); - return pend == 1; + return total; } -int CTCPSSLClient::PendingBytes() +int CTCPSSLClient::Send(const char* pData, size_t uSize) const { - int nPend; + if (m_TCPClient.m_eStatus != CTCPClient::CONNECTED || m_TCPClient.m_ConnectSocket == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPSSLClient, SSL_write failed[not connected to a SSL server.]"); + return -1; + } + + //OpenSSL 1.1.1 + if (pData == nullptr && uSize != 0) + { + SocketLog("[ERROR]TCPSSLClient, send failed[%d][%p:%zu]", m_SSLConnectSocket.m_SockFd, pData, uSize); + return -2; + } + + int total = 0; + do + { + /* encrypt & send message */ + int nSent = SSL_write(m_SSLConnectSocket.m_pSSL, pData + total, (int)uSize - total); + if (nSent <= 0) + { + int iErrCode = SSL_get_error(m_SSLConnectSocket.m_pSSL, nSent); + if (iErrCode == SSL_ERROR_WANT_WRITE) + { + continue; + } - nPend = SSL_pending(m_SSLConnectSocket.m_pSSL); + SocketLog("[ERROR]TCPSSLClient, SSL_write failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), m_SSLConnectSocket.m_SockFd); + return -1; + } - return nPend; + total += nSent; + } while (total < (int)uSize); + + return (int)total; } -int CTCPSSLClient::Receive(char* pData, const size_t uSize, bool bReadFully /*= true*/) const +int CTCPSSLClient::Send(const std::string& strData) const { - if (m_TCPClient.m_eStatus != CTCPClient::CONNECTED) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLClient][Error] SSL recv failed : not connected to a server."); - - return -1; - } - - int total = 0; - do - { - int nRecvd = SSL_read(m_SSLConnectSocket.m_pSSL, pData + total, uSize - total); - - if (nRecvd <= 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPSSLClient][Error] SSL_read failed (Error=%d | %s)", - nRecvd, GetSSLErrorString(SSL_get_error(m_SSLConnectSocket.m_pSSL, nRecvd)))); - - break; - } - - total += nRecvd; - - } while (bReadFully && (total < uSize)); - - return total; + return Send(strData.c_str(), strData.length()); } -bool CTCPSSLClient::Disconnect() +int CTCPSSLClient::Send(const std::vector& Data) const { - if (m_TCPClient.m_eStatus != CTCPClient::CONNECTED) - return true; + return Send(Data.data(), Data.size()); +} - // send close_notify message to notify peer of the SSL closure. - ShutdownSSL(m_SSLConnectSocket); +bool CTCPSSLClient::HasPending() +{ + return m_SSLConnectSocket.HasPending(); +} - return m_TCPClient.Disconnect(); +int CTCPSSLClient::PendingBytes() +{ + return m_SSLConnectSocket.PendingBytes(); } -CTCPSSLClient::~CTCPSSLClient() +void CTCPSSLClient::Disconnect() { - if (m_TCPClient.m_eStatus == CTCPClient::CONNECTED) - { - Disconnect(); - m_TCPClient.Disconnect(); - } + // send close_notify message to notify peer of the SSL closure. + m_SSLConnectSocket.Disconnect(); + //ShutdownSSL(m_SSLConnectSocket); + m_TCPClient.Disconnect(); } #endif diff --git a/Socket/TCPSSLClient.h b/Socket/TCPSSLClient.h index 332ecce..633a01e 100644 --- a/Socket/TCPSSLClient.h +++ b/Socket/TCPSSLClient.h @@ -16,43 +16,40 @@ class CTCPSSLClient : public ASecureSocket { public: - explicit CTCPSSLClient(const LogFnCallback oLogger, - const OpenSSLProtocol eSSLVersion = OpenSSLProtocol::TLS, - const SettingsFlag eSettings = ALL_FLAGS); - ~CTCPSSLClient() override; - - CTCPSSLClient(const CTCPSSLClient&) = delete; - CTCPSSLClient& operator=(const CTCPSSLClient&) = delete; - - /* connect to a TCP SSL server */ - bool Connect(const std::string& strServer, const std::string& strPort); - - bool SetRcvTimeout(unsigned int timeout); - bool SetSndTimeout(unsigned int timeout); - -#ifndef WINDOWS - bool SetRcvTimeout(struct timeval timeout); - bool SetSndTimeout(struct timeval timeout); + explicit CTCPSSLClient(const LogFnCallback& oLogger, + const OpenSSLProtocol& eSSLVersion = OpenSSLProtocol::TLS, + const SettingsFlag eSettings = ALL_FLAGS); + ~CTCPSSLClient() override; + + CTCPSSLClient(const CTCPSSLClient&) = delete; + CTCPSSLClient& operator=(const CTCPSSLClient&) = delete; + + /* connect to a TCP SSL server */ + bool Connect(const std::string& strServer, const std::string& strPort); + /* disconnect from the SSL TCP server */ + void Disconnect(); + + int Receive(char* pData, size_t uSize, bool bReadFully = true) const; + /* send data to a TCP SSL server */ + int Send(const char* pData, size_t uSize) const; + int Send(const std::string& strData) const; + int Send(const std::vector& Data) const; + + bool SetRcvTimeout(unsigned int timeout); + bool SetSndTimeout(unsigned int timeout); + +#ifndef _WIN32 + bool SetRcvTimeout(struct timeval timeout); + bool SetSndTimeout(struct timeval timeout); #endif - /* disconnect from the SSL TCP server */ - bool Disconnect(); - - /* send data to a TCP SSL server */ - bool Send(const char* pData, const size_t uSize) const; - bool Send(const std::string& strData) const; - bool Send(const std::vector& Data) const; - - /* receive data from a TCP SSL server */ - bool HasPending(); - int PendingBytes(); - - int Receive(char* pData, const size_t uSize, bool bReadFully = true) const; + /* receive data from a TCP SSL server */ + bool HasPending(); + int PendingBytes(); protected: - CTCPClient m_TCPClient; - SSLSocket m_SSLConnectSocket; - + CTCPClient m_TCPClient; + SSLSocket m_SSLConnectSocket; }; #endif diff --git a/Socket/TCPSSLServer.cpp b/Socket/TCPSSLServer.cpp index 2658014..9a3324a 100644 --- a/Socket/TCPSSLServer.cpp +++ b/Socket/TCPSSLServer.cpp @@ -8,239 +8,279 @@ #include "TCPSSLServer.h" CTCPSSLServer::CTCPSSLServer(const LogFnCallback oLogger, - const std::string& strPort, - const OpenSSLProtocol eSSLVersion, - const SettingsFlag eSettings /*= ALL_FLAGS*/) - /*throw (EResolveError)*/ : - ASecureSocket(oLogger, eSSLVersion, eSettings), - m_TCPServer(oLogger, strPort, eSettings) + const std::string& strPort, + const OpenSSLProtocol eSSLVersion, + const SettingsFlag eSettings /*= ALL_FLAGS*/) /*throw (EResolveError)*/ + : ASecureSocket(oLogger, eSSLVersion, eSettings) + , m_TCPServer(oLogger, strPort, eSettings) { - } -bool CTCPSSLServer::SetRcvTimeout(SSLSocket& ClientSocket, unsigned int msec_timeout){ - return m_TCPServer.SetRcvTimeout(ClientSocket.m_SockFd, msec_timeout); +CTCPSSLServer::~CTCPSSLServer() +{ + SocketClose(m_TCPServer.m_ListenSocket); } -bool CTCPSSLServer::SetSndTimeout(SSLSocket& ClientSocket, unsigned int msec_timeout){ - return m_TCPServer.SetSndTimeout(ClientSocket.m_SockFd, msec_timeout); +bool CTCPSSLServer::SetRcvTimeout(SSLSocket& ClientSocket, unsigned int msec_timeout) +{ + return m_TCPServer.SetRcvTimeout(ClientSocket.m_SockFd, msec_timeout); } -#ifndef WINDOWS -bool CTCPSSLServer::SetRcvTimeout(SSLSocket& ClientSocket, struct timeval timeout) { +#ifndef _WIN32 +bool CTCPSSLServer::SetRcvTimeout(SSLSocket& ClientSocket, struct timeval timeout) +{ return m_TCPServer.SetRcvTimeout(ClientSocket.m_SockFd, timeout); } - -bool CTCPSSLServer::SetSndTimeout(SSLSocket& ClientSocket, struct timeval timeout){ - return m_TCPServer.SetSndTimeout(ClientSocket.m_SockFd, timeout); -} #endif -// returns the socket of the accepted client -bool CTCPSSLServer::Listen(SSLSocket& ClientSocket, size_t msec /*= ACCEPT_WAIT_INF_DELAY*/) +bool CTCPSSLServer::SetSndTimeout(SSLSocket& ClientSocket, unsigned int msec_timeout) { - if (m_TCPServer.Listen(ClientSocket.m_SockFd, msec)) - { - SetUpCtxServer(ClientSocket); - - if (ClientSocket.m_pCTXSSL == nullptr) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLServer][Error] SSL CTX failed."); - //ERR_print_errors_fp(stdout); - return false; - } - - //SSL_CTX_set_options(ClientSocket.m_pCTXSSL, SSL_OP_SINGLE_DH_USE); - //SSL_CTX_set_cert_verify_callback(ClientSocket.m_pCTXSSL, AlwaysTrueCallback, nullptr); - - /* Load server certificate into the SSL context. */ - if (!m_strSSLCertFile.empty()) - { - if (SSL_CTX_use_certificate_file(ClientSocket.m_pCTXSSL, - m_strSSLCertFile.c_str(), SSL_FILETYPE_PEM) <= 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLServer][Error] Loading cert file failed."); - //ERR_print_errors_fp(stdout); - return false; - } - } - /* Load trusted CA file. */ - if (!m_strCAFile.empty()) - { - if (!SSL_CTX_load_verify_locations(ClientSocket.m_pCTXSSL, m_strCAFile.c_str(), nullptr)) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLServer][Error] Loading CA file failed."); - - return false; - } - /* Set to require peer (client) certificate verification. */ - //SSL_CTX_set_verify(m_SSLConnectSocket.m_pCTXSSL, SSL_VERIFY_PEER, VerifyCallback); - /* Set the verification depth to 1 */ - SSL_CTX_set_verify_depth(ClientSocket.m_pCTXSSL, 1); - } - /* Load the server private-key into the SSL context. */ - if (!m_strSSLKeyFile.empty()) - { - if (SSL_CTX_use_PrivateKey_file(ClientSocket.m_pCTXSSL, - m_strSSLKeyFile.c_str(), SSL_FILETYPE_PEM) <= 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLServer][Error] Loading key file failed."); - //ERR_print_errors_fp(stdout); - return false; - } - - // verify private key - /*if (!SSL_CTX_check_private_key(ClientSocket.m_pCTXSSL)) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLServer][Error] Private key does not match the public certificate."); - - return false; - }*/ - } - - ClientSocket.m_pSSL = SSL_new(ClientSocket.m_pCTXSSL); - // set the socket directly into the SSL structure or we can use a BIO structure - SSL_set_fd(ClientSocket.m_pSSL, ClientSocket.m_SockFd); - - /* wait for a TLS/SSL client to initiate a TLS/SSL handshake */ - int iSSLErr = SSL_accept(ClientSocket.m_pSSL); - if (iSSLErr <= 0) - { - //Error occurred, log and close down ssl - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPSSLServer][Error] accept failed. (Error=%d | %s)", - iSSLErr, GetSSLErrorString(SSL_get_error(ClientSocket.m_pSSL, iSSLErr)))); - - //if (iSSLErr < 0) - // under Windows it creates problems - #ifdef LINUX - ERR_print_errors_fp(stdout); - #endif - - ShutdownSSL(ClientSocket); - - return false; - } - - /* The TLS/SSL handshake is successfully completed and a TLS/SSL connection - * has been established. Now all reads and writes must use SSL. */ - // peer_cert = SSL_get_peer_certificate(ClientSocket.m_pSSL); - return true; - } - - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPSSLServer][Error] Unable to accept an incoming TCP connection with a client."); - - return false; + return m_TCPServer.SetSndTimeout(ClientSocket.m_SockFd, msec_timeout); } -bool CTCPSSLServer::HasPending(const SSLSocket& ClientSocket) +#ifndef _WIN32 +bool CTCPSSLServer::SetSndTimeout(SSLSocket& ClientSocket, struct timeval timeout) { - int pend; - - pend = SSL_has_pending(ClientSocket.m_pSSL); - - return pend == 1; + return m_TCPServer.SetSndTimeout(ClientSocket.m_SockFd, timeout); } +#endif -int CTCPSSLServer::PendingBytes(const SSLSocket& ClientSocket) +// returns the socket of the accepted client +int CTCPSSLServer::Listen(SSLSocket& ClientSocket, size_t msec /*= ACCEPT_WAIT_INF_DELAY*/) { - int nPend; - - nPend = SSL_pending(ClientSocket.m_pSSL); - - return nPend; + int ret_val = m_TCPServer.Listen(ClientSocket.m_SockFd, msec); + if (ret_val < 0) + { + SocketLog("[ERROR]TCPSSLServer, m_TCPServer Listen failed[:Unable to accept an incoming TCP connection with a client.][:%s]", m_TCPServer.m_strPort.c_str()); + return -1; + } + + if (ret_val == 0) + { + return 0; + } + + ret_val = 0; + do + { + if (!SetUpCtxServer(ClientSocket)) + { + SocketLog("[ERROR]TCPSSLServer, SSL_CTX_new failed[:%s][%d]", m_TCPServer.m_strPort.c_str(), ClientSocket.m_SockFd); + //ERR_print_errors_fp(stdout); + break; + } + + //SSL_CTX_set_options(ClientSocket.m_pCTXSSL, SSL_OP_SINGLE_DH_USE); + //SSL_CTX_set_cert_verify_callback(ClientSocket.m_pCTXSSL, AlwaysTrueCallback, nullptr); + + /* Load server certificate into the SSL context. */ + if (!m_strSSLCertFile.empty()) + { + if (SSL_CTX_use_certificate_file(ClientSocket.m_pCTXSSL, m_strSSLCertFile.c_str(), SSL_FILETYPE_PEM) != 1) + { + SocketLog("[ERROR]TCPSSLServer, SSL_CTX_use_certificate_file failed[Loading cert file failed.][%lu:%s][%d][%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), ClientSocket.m_SockFd, m_strSSLCertFile.c_str()); + //ERR_print_errors_fp(stdout); + break; + } + } + + /* Load trusted CA file. */ + if (!m_strCAFile.empty()) + { + if (SSL_CTX_load_verify_locations(ClientSocket.m_pCTXSSL, m_strCAFile.c_str(), nullptr) != 1) + { + SocketLog("[ERROR]TCPSSLServer, SSL_CTX_load_verify_locations failed[Loading CA file failed.][%lu:%s][%d][%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), ClientSocket.m_SockFd, m_strCAFile.c_str()); + break; + } + + /* Set to require peer (client) certificate verification. */ + //SSL_CTX_set_verify(m_SSLConnectSocket.m_pCTXSSL, SSL_VERIFY_PEER, VerifyCallback); + + /* Set the verification depth to 1 */ + SSL_CTX_set_verify_depth(ClientSocket.m_pCTXSSL, 1); + } + + /* Load the server private-key into the SSL context. */ + if (!m_strSSLKeyFile.empty()) + { + if (SSL_CTX_use_PrivateKey_file(ClientSocket.m_pCTXSSL, m_strSSLKeyFile.c_str(), SSL_FILETYPE_PEM) != 1) + { + SocketLog("[ERROR]TCPSSLServer, SSL_CTX_use_PrivateKey_file failed[Loading key file failed.][%lu:%s][%d][%s]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), ClientSocket.m_SockFd, m_strSSLKeyFile.c_str()); + //ERR_print_errors_fp(stdout); + break; + } + +#if 0 + // verify private key + if (SSL_CTX_check_private_key(ClientSocket.m_pCTXSSL) != 1) + { + SocketLog("[ERROR]TCPSSLServer, SSL_CTX_check_private_key failed[Private key does not match the public certificate.][%lu:%s][%d]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), ClientSocket.m_SockFd); + break; + } +#endif + } + + ClientSocket.m_pSSL = SSL_new(ClientSocket.m_pCTXSSL); + if (ClientSocket.m_pSSL == nullptr) + { + SocketLog("[ERROR]TCPSSLServer, SSL_new failed[%lu:%s][%d]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), ClientSocket.m_SockFd); + break; + } + + // set the socket directly into the SSL structure or we can use a BIO structure + if (SSL_set_fd(ClientSocket.m_pSSL, (int)ClientSocket.m_SockFd) != 1) + { + SocketLog("[ERROR]TCPSSLServer, SSL_set_fd failed[%lu:%s][%d]", ERR_get_error(), ERR_error_string(ERR_get_error(), nullptr), ClientSocket.m_SockFd); + break; + } + + bool acceptOK = false; + do + { + /* wait for a TLS/SSL client to initiate a TLS/SSL handshake */ + int iSSLErr = SSL_accept(ClientSocket.m_pSSL); + if (iSSLErr == 1) + { + acceptOK = true; + break; + } + + //Error occurred, log and close down ssl + int iErrCode = SSL_get_error(ClientSocket.m_pSSL, iSSLErr); + if (iErrCode != SSL_ERROR_WANT_ACCEPT) + { + SocketLog("[ERROR]TCPSSLServer, SSL_accept failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), ClientSocket.m_SockFd); + // under Windows it creates problems +#ifndef _WIN32 + ERR_print_errors_fp(stdout); +#endif + break; + } + } while (1); + + if (!acceptOK) + { + break; + } + + /* The TLS/SSL handshake is successfully completed and a TLS/SSL connection + * has been established. Now all reads and writes must use SSL. */ + // peer_cert = SSL_get_peer_certificate(ClientSocket.m_pSSL); + SocketLog("[ERROR]TCPSSLServer, SSL_accept accepted[%d]", ClientSocket.m_SockFd); + return 1; + } while (0); + + Disconnect(ClientSocket); + return ret_val; } -/* When an SSL_read() operation has to be repeated because of SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, - * it must be repeated with the same arguments.*/ -int CTCPSSLServer::Receive(const SSLSocket& ClientSocket, - char* pData, - const size_t uSize, - bool bReadFully /*= true*/) const +/** + * When an SSL_read() operation has to be repeated because of SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, + * it must be repeated with the same arguments. + */ +int CTCPSSLServer::Receive(const SSLSocket& ClientSocket, char* pData, size_t uSize, bool bReadFully /*= true*/) const { - int total = 0; - do - { - int nRecvd = SSL_read(ClientSocket.m_pSSL, pData + total, uSize - total); - - if (nRecvd <= 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPSSLServer][Error] SSL_read failed (Error=%d | %s)", - nRecvd, GetSSLErrorString(SSL_get_error(ClientSocket.m_pSSL, nRecvd)))); - - //ERR_print_errors_fp(stdout); - - break; - } + if (ClientSocket.m_SockFd == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPSSLServer, SSL_read failed[not a connection to SSL server.]"); + return -1; + } + + if (pData == nullptr && uSize != 0) + { + SocketLog("[ERROR]TCPSSLServer, SSL_read failed[%d][%p:%zu]", ClientSocket.m_SockFd, pData, uSize); + return -2; + } + + int total = 0; + do + { + int nRecvd = SSL_read(ClientSocket.m_pSSL, pData + total, (int)uSize - total); + if (nRecvd <= 0) + { + int iErrCode = SSL_get_error(ClientSocket.m_pSSL, nRecvd); + if (iErrCode == SSL_ERROR_WANT_READ) + { + continue; + } + + SocketLog("[ERROR]TCPSSLServer, SSL_read failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), ClientSocket.m_SockFd); + //ERR_print_errors_fp(stdout); + break; + } - total += nRecvd; + total += nRecvd; - } while(bReadFully && (total < uSize)); + } while (bReadFully && (total < (int)uSize)); - return total; + return total; } /* When an SSL_write() operation has to be repeated because of SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, * it must be repeated with the same arguments. * When calling SSL_write() with uSize=0 bytes to be sent the behaviour is undefined. */ -bool CTCPSSLServer::Send(const SSLSocket& ClientSocket, const char* pData, const size_t uSize) const +int CTCPSSLServer::Send(const SSLSocket& ClientSocket, const char* pData, size_t uSize) const { - int total = 0; - do - { - int nSent; - - nSent = SSL_write(ClientSocket.m_pSSL, pData + total, uSize - total); - - if (nSent <= 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPSSLServer][Error] SSL_write failed (Error=%d | %s).", - nSent, GetSSLErrorString(SSL_get_error(ClientSocket.m_pSSL, nSent)))); - - return false; - } - total += nSent; - } while (total < uSize); - - return true; + if (ClientSocket.m_pSSL == nullptr || ClientSocket.m_SockFd == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPSSLServer, SSL_write failed[not a connection to SSL server.]"); + return -1; + } + + //OpenSSL 1.1.1 + if (pData == nullptr && uSize != 0) + { + SocketLog("[ERROR]TCPSSLServer, SSL_write failed[%d][%p:%zu]", ClientSocket.m_SockFd, pData, uSize); + return -2; + } + + int total = 0; + do + { + int nSent = SSL_write(ClientSocket.m_pSSL, pData + total, (int)uSize - total); + if (nSent <= 0) + { + int iErrCode = SSL_get_error(ClientSocket.m_pSSL, nSent); + if (iErrCode == SSL_ERROR_WANT_WRITE) + { + continue; + } + + SocketLog("[ERROR]TCPSSLServer, SSL_write failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), ClientSocket.m_SockFd); + return -1; + } + + total += nSent; + } while (total < (int)uSize); + + return (int)total; } -bool CTCPSSLServer::Send(const SSLSocket& ClientSocket, const std::string& strData) const +int CTCPSSLServer::Send(const SSLSocket& ClientSocket, const std::string& strData) const { - bool ret; - - ret = Send(ClientSocket, strData.c_str(), strData.length()); - - return ret; + return Send(ClientSocket, strData.c_str(), strData.length()); } -bool CTCPSSLServer::Send(const SSLSocket& ClientSocket, const std::vector& Data) const +int CTCPSSLServer::Send(const SSLSocket& ClientSocket, const std::vector& Data) const { - bool ret; - - ret = Send(ClientSocket, Data.data(), Data.size()); - - return ret; + return Send(ClientSocket, Data.data(), Data.size()); } -bool CTCPSSLServer::Disconnect(SSLSocket& ClientSocket) const +bool CTCPSSLServer::HasPending(const SSLSocket& ClientSocket) { - // send close_notify message to notify peer of the SSL closure. - ShutdownSSL(ClientSocket); - - return m_TCPServer.Disconnect(ClientSocket.m_SockFd); + return ClientSocket.HasPending(); } -CTCPSSLServer::~CTCPSSLServer() +int CTCPSSLServer::PendingBytes(const SSLSocket& ClientSocket) { + return ClientSocket.PendingBytes(); +} +void CTCPSSLServer::Disconnect(SSLSocket& ClientSocket) const +{ + // send close_notify message to notify peer of the SSL closure. + ClientSocket.Disconnect(); + //ShutdownSSL(ClientSocket); + m_TCPServer.Disconnect(ClientSocket.m_SockFd); } #endif diff --git a/Socket/TCPSSLServer.h b/Socket/TCPSSLServer.h index 69c8b3e..ab0ba71 100644 --- a/Socket/TCPSSLServer.h +++ b/Socket/TCPSSLServer.h @@ -13,46 +13,42 @@ #include "SecureSocket.h" #include "TCPServer.h" -/* private inheritance from CTCPServer is replaced with composition to avoid +/* private inheritance from CTCPServer is replaced with composition to avoid * ambiguity on the log callable object */ class CTCPSSLServer : public ASecureSocket { public: - explicit CTCPSSLServer(const LogFnCallback oLogger, - const std::string& strPort, - const OpenSSLProtocol eSSLVersion = OpenSSLProtocol::TLS, - const SettingsFlag eSettings = ALL_FLAGS) - /*throw (EResolveError)*/; + explicit CTCPSSLServer(const LogFnCallback oLogger, + const std::string& strPort, + const OpenSSLProtocol eSSLVersion = OpenSSLProtocol::TLS, + const SettingsFlag eSettings = ALL_FLAGS) /*throw (EResolveError)*/; + ~CTCPSSLServer() override; - ~CTCPSSLServer() override; + CTCPSSLServer(const CTCPSSLServer&) = delete; + CTCPSSLServer& operator=(const CTCPSSLServer&) = delete; - CTCPSSLServer(const CTCPSSLServer&) = delete; - CTCPSSLServer& operator=(const CTCPSSLServer&) = delete; + int Listen(SSLSocket& ClientSocket, size_t msec = ACCEPT_WAIT_INF_DELAY); + void Disconnect(SSLSocket& ClientSocket) const; - bool Listen(SSLSocket& ClientSocket, size_t msec = ACCEPT_WAIT_INF_DELAY); + int Receive(const SSLSocket& ClientSocket, char* pData, size_t uSize, bool bReadFully = true) const; + int Send(const SSLSocket& ClientSocket, const char* pData, size_t uSize) const; + int Send(const SSLSocket& ClientSocket, const std::string& strData) const; + int Send(const SSLSocket& ClientSocket, const std::vector& Data) const; - bool SetRcvTimeout(SSLSocket& ClientSocket, unsigned int msec_timeout); - bool SetSndTimeout(SSLSocket& ClientSocket, unsigned int timeout); - -#ifndef WINDOWS - bool SetRcvTimeout(SSLSocket& ClientSocket, struct timeval timeout); - bool SetSndTimeout(SSLSocket& ClientSocket, struct timeval timeout); -#endif - - bool HasPending(const SSLSocket& ClientSocket); - int PendingBytes(const SSLSocket& ClientSocket); - int Receive(const SSLSocket& ClientSocket, char* pData, - const size_t uSize, bool bReadFully = true) const; + bool SetRcvTimeout(SSLSocket& ClientSocket, unsigned int msec_timeout); + bool SetSndTimeout(SSLSocket& ClientSocket, unsigned int msec_timeout); - bool Send(const SSLSocket& ClientSocket, const char* pData, const size_t uSize) const; - bool Send(const SSLSocket& ClientSocket, const std::string& strData) const; - bool Send(const SSLSocket& ClientSocket, const std::vector& Data) const; +#ifndef _WIN32 + bool SetRcvTimeout(SSLSocket& ClientSocket, struct timeval timeout); + bool SetSndTimeout(SSLSocket& ClientSocket, struct timeval timeout); +#endif - bool Disconnect(SSLSocket& ClientSocket) const; + bool HasPending(const SSLSocket& ClientSocket); + int PendingBytes(const SSLSocket& ClientSocket); protected: - CTCPServer m_TCPServer; + CTCPServer m_TCPServer; }; diff --git a/Socket/TCPServer.cpp b/Socket/TCPServer.cpp index dff5eef..0a0a8a7 100644 --- a/Socket/TCPServer.cpp +++ b/Socket/TCPServer.cpp @@ -7,471 +7,370 @@ #include "TCPServer.h" CTCPServer::CTCPServer(const LogFnCallback oLogger, - /*const std::string& strAddr,*/ - const std::string& strPort, - const SettingsFlag eSettings /*= ALL_FLAGS*/) - /*throw (EResolveError)*/ : - ASocket(oLogger, eSettings), - m_ListenSocket(INVALID_SOCKET), -#ifdef WINDOWS - m_pResultAddrInfo(nullptr), -#endif - //m_strHost(strAddr), - m_strPort(strPort) { -#ifdef WINDOWS - // Resolve the server address and port - ZeroMemory(&m_HintsAddrInfo, sizeof(m_HintsAddrInfo)); - /* AF_INET is used to specify the IPv4 address family. */ - m_HintsAddrInfo.ai_family = AF_INET; - /* SOCK_STREAM is used to specify a stream socket. */ - m_HintsAddrInfo.ai_socktype = SOCK_STREAM; - /* IPPROTO_TCP is used to specify the TCP protocol. */ - m_HintsAddrInfo.ai_protocol = IPPROTO_TCP; - /* AI_PASSIVE flag indicates the caller intends to use the returned socket - * address structure in a call to the bind function.*/ - m_HintsAddrInfo.ai_flags = AI_PASSIVE; - - int iResult = getaddrinfo(nullptr, strPort.c_str(), &m_HintsAddrInfo, &m_pResultAddrInfo); - if (iResult != 0) - { - if (m_pResultAddrInfo != nullptr) - { - freeaddrinfo(m_pResultAddrInfo); - m_pResultAddrInfo = nullptr; - } - - throw EResolveError(StringFormat("[TCPServer][Error] getaddrinfo failed : %d", iResult)); - } -#else - // clear address structure - bzero((char*) &m_ServAddr, sizeof(m_ServAddr)); - - int iPort = atoi(strPort.c_str()); - - /* setup the host_addr structure for use in bind call */ - // server byte order - m_ServAddr.sin_family = AF_INET; - - // automatically be filled with current host's IP address - m_ServAddr.sin_addr.s_addr = INADDR_ANY; - //m_ServAddr.sin_addr.s_addr = inet_addr(strAddr.c_str()); // doesn't work ! - - // convert short integer value for port must be converted into network byte order - m_ServAddr.sin_port = htons(iPort); -#endif + /*const std::string& strAddr,*/ const std::string& strPort, + const SettingsFlag eSettings /*= ALL_FLAGS*/) /*throw (EResolveError)*/ + : ASocket(oLogger, eSettings) + , m_ListenSocket(INVALID_SOCKET) + //, m_strHost(strAddr) + , m_strPort(strPort) + , m_pResultAddrInfo(nullptr) + , m_HintsAddrInfo() +{ } -// Method for setting receive timeout. Can be called after Listen, using the previously created ClientSocket -bool CTCPServer::SetRcvTimeout(ASocket::Socket& ClientSocket, unsigned int msec_timeout) { -#ifndef WINDOWS - struct timeval t = ASocket::TimevalFromMsec(msec_timeout); - - return this->SetRcvTimeout(ClientSocket, t); -#else - int iErr; - - iErr = setsockopt(ClientSocket, SOL_SOCKET, SO_RCVTIMEO, (char*)&msec_timeout, sizeof(struct timeval)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::SetRcvTimeout : Socket error in SO_RCVTIMEO call to setsockopt."); - - return false; - } - - return true; -#endif +CTCPServer::~CTCPServer() +{ + SocketClose(m_ListenSocket); } -// Method for setting send timeout. Can be called after Listen, using the previously created ClientSocket -bool CTCPServer::SetSndTimeout(ASocket::Socket& ClientSocket, unsigned int msec_timeout) { -#ifndef WINDOWS - struct timeval t = ASocket::TimevalFromMsec(msec_timeout); - - return this->SetRcvTimeout(ClientSocket, t); -#else - int iErr; - - iErr = setsockopt(ClientSocket, SOL_SOCKET, SO_SNDTIMEO, (char*)&msec_timeout, sizeof(struct timeval)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::SetSndTimeout : Socket error in SO_SNDTIMEO call to setsockopt."); - - return false; - } - - return true; -#endif +bool CTCPServer::InitAddrInfo() +{ + // Resolve the server address and port + memset(&m_HintsAddrInfo, 0, sizeof(m_HintsAddrInfo)); + /* AF_INET is used to specify the IPv4 address family. */ + m_HintsAddrInfo.ai_family = AF_INET; + /* SOCK_STREAM is used to specify a stream socket. */ + m_HintsAddrInfo.ai_socktype = SOCK_STREAM; + /* IPPROTO_TCP is used to specify the TCP protocol. */ + m_HintsAddrInfo.ai_protocol = IPPROTO_TCP; + /* AI_PASSIVE flag indicates the caller intends to use the returned socket + * address structure in a call to the bind function.*/ + m_HintsAddrInfo.ai_flags = AI_PASSIVE; + + int iResult = getaddrinfo(nullptr, m_strPort.c_str(), &m_HintsAddrInfo, &m_pResultAddrInfo); + if (iResult != 0) + { + SocketLog("[ERROR]TCPServer, getaddrinfo failed[%s:%s][:%s]", iResult, GaiStrerror(iResult), m_strPort.c_str()); + return false; + } + + return true; } -#ifndef WINDOWS -bool CTCPServer::SetRcvTimeout(ASocket::Socket& ClientSocket, struct timeval Timeout) { - int iErr; - - iErr = setsockopt(ClientSocket, SOL_SOCKET, SO_RCVTIMEO, (char*)&Timeout, sizeof(struct timeval)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::SetRcvTimeout : Socket error in SO_RCVTIMEO call to setsockopt."); +// Method for setting receive timeout. Can be called after Listen, using the previously created ClientSocket +bool CTCPServer::SetRcvTimeout(Socket ClientSocket, unsigned int msec_timeout) +{ + bool ret_val = ASocket::SetRcvTimeout(ClientSocket, msec_timeout); + if (!ret_val) + { + SocketLog("[ERROR]TCPServer, setsockopt SOL_SOCKET SO_RCVTIMEO failed[%d:%s][%d][%u]", GetSocketError(), strerror(GetSocketError()), ClientSocket, msec_timeout); + } + + return ret_val; +} - return false; - } +#ifndef _WIN32 +bool CTCPServer::SetRcvTimeout(Socket ClientSocket, struct timeval timeout) +{ + bool ret_val = ASocket::SetRcvTimeout(ClientSocket, timeout); + if (!ret_val) + { + SocketLog("[ERROR]TCPServer, setsockopt SOL_SOCKET SO_RCVTIMEO failed[%d:%s][%d][%u:%u]", GetSocketError(), strerror(GetSocketError()), ClientSocket, timeout.tv_sec, timeout.tv_usec); + } - return true; + return ret_val; } +#endif -bool CTCPServer::SetSndTimeout(ASocket::Socket& ClientSocket, struct timeval Timeout) { - int iErr; - - iErr = setsockopt(ClientSocket, SOL_SOCKET, SO_SNDTIMEO, (char*) &Timeout, sizeof(struct timeval)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::SetSndTimeout : Socket error in SO_SNDTIMEO call to setsockopt."); +// Method for setting send timeout. Can be called after Listen, using the previously created ClientSocket +bool CTCPServer::SetSndTimeout(Socket ClientSocket, unsigned int msec_timeout) +{ + bool ret_val = ASocket::SetSndTimeout(ClientSocket, msec_timeout); + if (!ret_val) + { + SocketLog("[ERROR]TCPServer, setsockopt SOL_SOCKET SO_SNDTIMEO failed[%d:%s][%d][%u]", GetSocketError(), strerror(GetSocketError()), ClientSocket, msec_timeout); + } + + return ret_val; +} - return false; - } +#ifndef _WIN32 +bool CTCPServer::SetSndTimeout(Socket ClientSocket, struct timeval timeout) +{ + bool ret_val = ASocket::SetSndTimeout(ClientSocket, timeout); + if (!ret_val) + { + SocketLog("[ERROR]TCPServer, setsockopt SOL_SOCKET SO_SNDTIMEO failed[%d:%s][%d][%u:%u]", GetSocketError(), strerror(GetSocketError()), ClientSocket, timeout.tv_sec, timeout.tv_usec); + } - return true; + return ret_val; } #endif // returns the socket of the accepted client // maxRcvTime and maxSendTime define timeouts in µs for receiving and sending over the socket. Using a negative value // will deactivate the timeout. 0 will set a zero timeout. -bool CTCPServer::Listen(ASocket::Socket& ClientSocket, size_t msec /*= ACCEPT_WAIT_INF_DELAY*/) { - ClientSocket = INVALID_SOCKET; - - // creates a socket to listen for incoming client connections if it doesn't already exist - if (m_ListenSocket == INVALID_SOCKET) { -#ifdef WINDOWS - m_ListenSocket = socket(m_pResultAddrInfo->ai_family, - m_pResultAddrInfo->ai_socktype, - m_pResultAddrInfo->ai_protocol); - - if (m_ListenSocket == INVALID_SOCKET) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] socket failed : %d", WSAGetLastError())); - freeaddrinfo(m_pResultAddrInfo); - m_pResultAddrInfo = nullptr; - return false; - } - - // Allow the socket to be bound to an address that is already in use - int opt = 1; - int iErr = 0; - - iErr = setsockopt(m_ListenSocket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&opt), sizeof(int)); - if (iErr < 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::Listen : Socket error in call to setsockopt."); - - closesocket(m_ListenSocket); - freeaddrinfo(m_pResultAddrInfo); m_pResultAddrInfo = nullptr; - - m_ListenSocket = INVALID_SOCKET; - - return false; - } - - // bind the listen socket to the host address:port - int iResult = bind(m_ListenSocket, - m_pResultAddrInfo->ai_addr, - static_cast(m_pResultAddrInfo->ai_addrlen)); - - freeaddrinfo(m_pResultAddrInfo); // free memory allocated by getaddrinfo - m_pResultAddrInfo = nullptr; - - if (iResult == SOCKET_ERROR) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] bind failed : %d", WSAGetLastError())); - closesocket(m_ListenSocket); - m_ListenSocket = INVALID_SOCKET; - return false; - } -#else - - // create a socket - // socket(int domain, int type, int protocol) - m_ListenSocket = socket(AF_INET, SOCK_STREAM, 0/*IPPROTO_TCP*/); - if (m_ListenSocket < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] opening socket : %s", strerror(errno))); - - m_ListenSocket = INVALID_SOCKET; - return false; - } - - // Allow the socket to be bound to an address that is already in use - int opt = 1; - int iErr = 0; - - iErr = setsockopt(m_ListenSocket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&opt), sizeof(int)); - if (iErr < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::Listen : Socket error in SO_REUSEADDR call to setsockopt."); - - close(m_ListenSocket); - m_ListenSocket = INVALID_SOCKET; - - return false; - } - - /* - iErr = setsockopt(m_ListenSocket, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast(&opt), sizeof(int)); - if (iErr < 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::Listen : Socket error in SO_KEEPALIVE call to setsockopt."); - - close(m_ListenSocket); - m_ListenSocket = INVALID_SOCKET; - - return false; - } - */ - - // bind(int fd, struct sockaddr *local_addr, socklen_t addr_length) - // bind() passes file descriptor, the address structure, - // and the length of the address structure - // This bind() call will bind the socket to the current IP address on port, portno - int iResult = bind(m_ListenSocket, - reinterpret_cast(&m_ServAddr), - sizeof(m_ServAddr)); - if (iResult < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] bind failed : %s", strerror(errno))); - return false; - } +int CTCPServer::Listen(Socket& ClientSocket, size_t msec /*= ACCEPT_WAIT_INF_DELAY*/) +{ + ClientSocket = INVALID_SOCKET; + + bool isOK = false; + // creates a socket to listen for incoming client connections if it doesn't already exist + if (m_ListenSocket == INVALID_SOCKET) + { + if (m_pResultAddrInfo == nullptr && !InitAddrInfo()) + { + return -1; + } + + do + { + m_ListenSocket = socket(m_pResultAddrInfo->ai_family, m_pResultAddrInfo->ai_socktype, m_pResultAddrInfo->ai_protocol); + if (m_ListenSocket == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPServer, create socket failed[%d:%s]", GetSocketError(), strerror(GetSocketError())); + break; + } + + // Allow the socket to be bound to an address that is already in use + int opt = 1; + if (setsockopt(m_ListenSocket, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&opt), sizeof(opt)) == SOCKET_ERROR) + { + SocketLog("[ERROR]TCPServer, setsockopt SOL_SOCKET SO_REUSEADDR failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ListenSocket); + break; + } + +#if 0 +#ifdef SO_KEEPALIVE + if (setsockopt(m_ListenSocket, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast(&opt), sizeof(opt)) == SOCKET_ERROR) + { + SocketLog("[ERROR]TCPServer, setsockopt SOL_SOCKET SO_KEEPALIVE failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ListenSocket); + break; + } +#endif #endif - } - -#ifdef WINDOWS - sockaddr addrClient; - int iResult; - /* SOMAXCONN = allow max number of connexions in waiting */ - iResult = listen(m_ListenSocket, SOMAXCONN); - if (iResult == SOCKET_ERROR) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] listen failed : %d", WSAGetLastError())); - closesocket(m_ListenSocket); - m_ListenSocket = INVALID_SOCKET; - return false; - } - - if (msec != ACCEPT_WAIT_INF_DELAY) - { - int ret = SelectSocket(m_ListenSocket, msec); - if (ret == 0) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::Listen : Timed out."); - - return false; - } - - if (ret == -1) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::Listen : Error selecting socket."); - - return false; - } - } - - // accept client connection, the returned socket will be used for I/O operations - int iAddrLen = sizeof(addrClient); - ClientSocket = accept(m_ListenSocket, &addrClient, &iAddrLen); - if (ClientSocket == INVALID_SOCKET) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] accept failed : %d", WSAGetLastError())); - - return false; - } - - { - if (m_eSettingsFlags & ENABLE_LOG) - // TODO : a version that handles IPv6 - m_oLog( StringFormat("[TCPServer][Info] Incoming connection from '%s' port '%d'", - (addrClient.sa_family == AF_INET) ? inet_ntoa(((struct sockaddr_in*)&addrClient)->sin_addr) : "", - (addrClient.sa_family == AF_INET) ? ntohs(((struct sockaddr_in*)&addrClient)->sin_port) : 0)); - } - - //char buf1[256]; - //unsigned long len2 = 256UL; - //if (!WSAAddressToStringA(&addrClient, lenAddr, NULL, buf1, &len2)) - //if (m_eSettingsFlags & ENABLE_LOG) - //m_oLog(StringFormat("[TCPServer][Info] Connection from %s", buf1)); - -#else - // This listen() call tells the socket to listen to the incoming connections. - // The listen() function places all incoming connection into a backlog queue - // until accept() call accepts the connection. - // Here, we set the maximum size for the backlog queue to SOMAXCONN. - int iResult = listen(m_ListenSocket, SOMAXCONN); - if (iResult < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] listen failed : %s", strerror(errno))); - - return false; - } - - if (msec != ACCEPT_WAIT_INF_DELAY) { - int ret = SelectSocket(m_ListenSocket, msec); - if (ret == 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::Listen : Timed out."); - - return false; - } - - if (ret == -1) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] CTCPServer::Listen : Error selecting socket."); - - return false; - } - } - - struct sockaddr_in ClientAddr; - // The accept() call actually accepts an incoming connection - socklen_t uClientLen = sizeof(ClientAddr); - - // This accept() function will write the connecting client's address info - // into the the address structure and the size of that structure is uClientLen. - // The accept() returns a new socket file descriptor for the accepted connection. - // So, the original socket file descriptor can continue to be used - // for accepting new connections while the new socker file descriptor is used for - // communicating with the connected client. - ClientSocket = accept(m_ListenSocket, - reinterpret_cast(&ClientAddr), - &uClientLen); - - if (ClientSocket < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] accept failed : %s", strerror(errno))); - - return false; - } - - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Info] Incoming connection from '%s' port '%d'", - inet_ntoa(ClientAddr.sin_addr), ntohs(ClientAddr.sin_port))); + + // bind the listen socket to the host address:port + if (bind(m_ListenSocket, m_pResultAddrInfo->ai_addr, static_cast(m_pResultAddrInfo->ai_addrlen)) == SOCKET_ERROR) + { + SocketLog("[ERROR]TCPServer, bind failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ListenSocket); + break; + } + + // This listen() call tells the socket to listen to the incoming connections. + // The listen() function places all incoming connection into a backlog queue + // until accept() call accepts the connection. + // Here, we set the maximum size for the backlog queue to SOMAXCONN. + /* SOMAXCONN = allow max number of connexions in waiting */ + if (listen(m_ListenSocket, SOMAXCONN) == SOCKET_ERROR) + { + SocketLog("[ERROR]TCPServer, listen failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ListenSocket); + break; + } + + SocketLog("[INFO ]TCPServer, listen succeed[%d]", m_ListenSocket); + isOK = true; + } while (0); + + // free memory allocated by getaddrinfo + //if (m_pResultAddrInfo != nullptr) + { + freeaddrinfo(m_pResultAddrInfo); + m_pResultAddrInfo = nullptr; + } + + if (!isOK/* && m_ListenSocket != INVALID_SOCKET*/) + { + SocketClose(m_ListenSocket); + return -1; + } + } + + do + { + //if (msec != ACCEPT_WAIT_INF_DELAY) + { + int ret = SelectSocket(m_ListenSocket, msec); + if (ret == SOCKET_ERROR) + { + SocketLog("[ERROR]TCPServer, select failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ListenSocket); + break; + } + + if (ret == 0) + { + //SocketLog("[INFO ]TCPServer, select timeout[%d]", m_ListenSocket); + return 0; + } + } + + // This accept() function will write the connecting client's address info + // into the the address structure and the size of that structure is uClientLen. + // The accept() returns a new socket file descriptor for the accepted connection. + // So, the original socket file descriptor can continue to be used + // for accepting new connections while the new socker file descriptor is used for + // communicating with the connected client. + // accept client connection, the returned socket will be used for I/O operations + + sockaddr addrClient{}; + int iAddrLen = (int)sizeof(addrClient); + ClientSocket = accept(m_ListenSocket, &addrClient, &iAddrLen); + if (ClientSocket == INVALID_SOCKET) + { + int iErrCode = GetSocketError(); + if (SOCKET_ERR_ACCEPT_RETRIABLE(iErrCode)) + { + return 0; + } + + SocketLog("[ERROR]TCPServer, accept failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), m_ListenSocket); + break; + } + +#if 0 + { + SocketLog("[INFO ]TCPServer, Incoming connection from[%s:%u]", + (addrClient.sa_family == AF_INET) ? inet_ntoa(((struct sockaddr_in*)&addrClient)->sin_addr) : "", + (addrClient.sa_family == AF_INET) ? ntohs(((struct sockaddr_in*)&addrClient)->sin_port) : 0); + } +#endif + +#if 0 + char buf1[256] = {}; + unsigned long len2 = 256UL; + if (WSAAddressToStringA(&addrClient, lenAddr, NULL, buf1, &len2) != SOCKET_ERROR) + { + SocketLog("[INFO ]TCPServer, Connection from[%s]", buf1); + } #endif - return true; + SocketLog("[INFO ]TCPServer, client_sock be accepted[%d]", ClientSocket); + return 1; + } while (0); + + SocketClose(m_ListenSocket); + return -1; } /* ret > 0 : bytes received * ret == 0 : connection closed * ret < 0 : recv failed */ -int CTCPServer::Receive(const CTCPServer::Socket ClientSocket, - char* pData, - const size_t uSize, - bool bReadFully /*= true*/) const { - if (ClientSocket < 0 || !pData || !uSize) - return -1; - -#ifdef WINDOWS - int tries = 0; +int CTCPServer::Receive(Socket ClientSocket, char* pData, size_t uSize, bool bReadFully /*= true*/) const +{ + if (ClientSocket == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPServer, recv failed[not a connection to server.]"); + return -1; + } + + if (pData == nullptr || uSize == 0) + { + SocketLog("[ERROR]TCPServer, recv failed[%d][%p:%zu]", ClientSocket, pData, uSize); + return -2; + } + +#if 0 +#ifdef _WIN32 + int tries = 0; #endif - - int total = 0; - do { - int nRecvd = recv(ClientSocket, pData + total, uSize - total, 0); - - if (nRecvd == 0) { - // peer shut down - break; - } - -#ifdef WINDOWS - if ((nRecvd < 0) && (WSAGetLastError() == WSAENOBUFS)) - { - // On long messages, Windows recv sometimes fails with WSAENOBUFS, but - // will work if you try again. - if ((tries++ < 1000)) - { - Sleep(1); - continue; - } - - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] Socket error in call to recv."); - - break; - } #endif - total += nRecvd; - - } while (bReadFully && (total < uSize)); - - return total; -} + int total = 0; + bool isOK = true; + do + { + isOK = true; + int nRecvd = recv(ClientSocket, pData + total, (int)uSize - total, 0); + if (nRecvd == SOCKET_ERROR) + { + isOK = false; + int iErrCode = GetSocketError(); + if (SOCKET_ERR_RW_RETRIABLE(iErrCode)) + { + continue; + } +#if 0 +#ifdef _WIN32 + // On long messages, Windows recv sometimes fails with WSAENOBUFS, but + // will work if you try again. + if (WSAGetLastError() == WSAENOBUFS && (tries++ < 1000)) + { + Sleep(1); + continue; + } +#endif +#endif -bool CTCPServer::Send(const Socket ClientSocket, const char* pData, size_t uSize) const { - if (ClientSocket < 0 || !pData || !uSize) - return false; + SocketLog("[ERROR]TCPServer, recv failed[%d:%s][%d]", iErrCode, strerror(iErrCode), ClientSocket); + break; + } - int total = 0; - do { - const int flags = 0; - int nSent; + if (nRecvd == 0) + { + SocketLog("[INFO ]TCPServer, peer shut down[%d]", ClientSocket); + break; + } - nSent = send(ClientSocket, pData + total, uSize - total, flags); + total += nRecvd; - if (nSent < 0) { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog("[TCPServer][Error] Socket error in call to send."); + } while (bReadFully && (total < (int)uSize)); - return false; - } - total += nSent; - } while (total < uSize); + if (!isOK && total == 0) + { + return -1; + } - return true; + return (int)total; } -bool CTCPServer::Send(const Socket ClientSocket, const std::string& strData) const { - return Send(ClientSocket, strData.c_str(), strData.length()); +int CTCPServer::Send(const Socket ClientSocket, const char* pData, size_t uSize) const +{ + if (ClientSocket == INVALID_SOCKET) + { + SocketLog("[ERROR]TCPServer, send failed[not a connection to server.]"); + return -1; + } + + if (pData == nullptr && uSize != 0) + { + SocketLog("[ERROR]TCPServer, send failed[%d][%p:%zu]", ClientSocket, pData, uSize); + return 0; + } + + int total = 0; + do + { + int nSent = send(ClientSocket, pData + total, (int)uSize - total, 0); + if (nSent == SOCKET_ERROR) + { + int iErrCode = GetSocketError(); + if (SOCKET_ERR_RW_RETRIABLE(iErrCode)) + { + continue; + } + + SocketLog("[ERROR]TCPServer, send failed[%d:%s][%d]", iErrCode, strerror(iErrCode), ClientSocket); + return -1; + } + + total += nSent; + } while (total < (int)uSize); + + return (int)total; } -bool CTCPServer::Send(const Socket ClientSocket, const std::vector& Data) const { - return Send(ClientSocket, Data.data(), Data.size()); +int CTCPServer::Send(const Socket ClientSocket, const std::string& strData) const +{ + return Send(ClientSocket, strData.c_str(), strData.length()); } -bool CTCPServer::Disconnect(const CTCPServer::Socket ClientSocket) const { -#ifdef WINDOWS - // The shutdown function disables sends or receives on a socket. - int iResult = shutdown(ClientSocket, SD_RECEIVE); - - if (iResult == SOCKET_ERROR) - { - if (m_eSettingsFlags & ENABLE_LOG) - m_oLog(StringFormat("[TCPServer][Error] shutdown failed : %d", WSAGetLastError())); - - return false; - } - - closesocket(ClientSocket); -#else - - close(ClientSocket); - -#endif - - return true; +int CTCPServer::Send(const Socket ClientSocket, const std::vector& Data) const +{ + return Send(ClientSocket, Data.data(), Data.size()); } -CTCPServer::~CTCPServer() { -#ifdef WINDOWS - // close listen socket - closesocket(m_ListenSocket); -#else - close(m_ListenSocket); +void CTCPServer::Disconnect(Socket& ClientSocket) const +{ + if (ClientSocket != INVALID_SOCKET) + { +#if 0//defined(_WIN32) + // The shutdown function disables sends or receives on a socket. + if (shutdown(ClientSocket, SD_RECEIVE) == SOCKET_ERROR) + { + SocketLog("[ERROR]TCPServer, shutdown SD_RECEIVE failed[%d:%s][%d]", GetSocketError(), strerror(GetSocketError()), ClientSocket); + } #endif + SocketClose(ClientSocket); + } } diff --git a/Socket/TCPServer.h b/Socket/TCPServer.h index 311242a..11e409b 100644 --- a/Socket/TCPServer.h +++ b/Socket/TCPServer.h @@ -9,74 +9,54 @@ #ifndef INCLUDE_TCPSERVER_H_ #define INCLUDE_TCPSERVER_H_ -#include -#include // size_t -#include -#include // strerror, strlen, memcpy, strcpy -#include -#include -#include -#include #include #include #include "Socket.h" -#ifdef WINDOWS -#undef min -#undef max -#endif +class CTCPSSLServer; class CTCPServer : public ASocket { + friend class CTCPSSLServer; public: - explicit CTCPServer(const LogFnCallback oLogger, - /*const std::string& strAddr,*/ - const std::string& strPort, - const SettingsFlag eSettings = ALL_FLAGS) - /*throw (EResolveError)*/; - - ~CTCPServer() override; - - // copy constructor and assignment operator are disabled - CTCPServer(const CTCPServer&) = delete; - CTCPServer& operator=(const CTCPServer&) = delete; - - /* returns the socket of the accepted client, the waiting period can be set */ - bool Listen(Socket& ClientSocket, size_t msec = ACCEPT_WAIT_INF_DELAY); - - int Receive(const Socket ClientSocket, - char* pData, - const size_t uSize, - bool bReadFully = true) const; - - bool Send(const Socket ClientSocket, const char* pData, const size_t uSize) const; - bool Send(const Socket ClientSocket, const std::string& strData) const; - bool Send(const Socket ClientSocket, const std::vector& Data) const; - - bool Disconnect(const Socket ClientSocket) const; - - bool SetRcvTimeout(ASocket::Socket& ClientSocket, unsigned int msec_timeout); - bool SetSndTimeout(ASocket::Socket& ClientSocket, unsigned int msec_timeout); - -#ifndef WINDOWS - bool SetRcvTimeout(ASocket::Socket& ClientSocket, struct timeval Timeout); - bool SetSndTimeout(ASocket::Socket& ClientSocket, struct timeval Timeout); + explicit CTCPServer(const LogFnCallback oLogger, + /*const std::string& strAddr,*/ const std::string& strPort, + const SettingsFlag eSettings = ALL_FLAGS) /*throw (EResolveError)*/; + ~CTCPServer() override; + + // copy constructor and assignment operator are disabled + CTCPServer(const CTCPServer&) = delete; + CTCPServer& operator=(const CTCPServer&) = delete; + + /* returns the socket of the accepted client, the waiting period can be set */ + int Listen(Socket& ClientSocket, size_t msec = ACCEPT_WAIT_INF_DELAY); + void Disconnect(Socket& ClientSocket) const; + + int Receive(Socket ClientSocket, char* pData, size_t uSize, bool bReadFully = true) const; + int Send(Socket ClientSocket, const char* pData, size_t uSize) const; + int Send(Socket ClientSocket, const std::string& strData) const; + int Send(Socket ClientSocket, const std::vector& Data) const; + + bool SetRcvTimeout(Socket ClientSocket, unsigned int msec_timeout); + bool SetSndTimeout(Socket ClientSocket, unsigned int msec_timeout); + +#ifndef _WIN32 + bool SetRcvTimeout(Socket ClientSocket, struct timeval timeout); + bool SetSndTimeout(Socket ClientSocket, struct timeval timeout); #endif -protected: - Socket m_ListenSocket; +private: + bool InitAddrInfo(); - //std::string m_strHost; - std::string m_strPort; +protected: + Socket m_ListenSocket; - #ifdef WINDOWS - struct addrinfo* m_pResultAddrInfo; - struct addrinfo m_HintsAddrInfo; - #else - struct sockaddr_in m_ServAddr; - #endif + //std::string m_strHost; + std::string m_strPort; + struct addrinfo* m_pResultAddrInfo; + struct addrinfo m_HintsAddrInfo; }; #endif From 7abdb3f128793233cb3b05aef060b57a528ae77b Mon Sep 17 00:00:00 2001 From: cruise Date: Fri, 29 Aug 2025 15:06:15 +0800 Subject: [PATCH 2/6] corrected the decision algorithm for log switch --- Socket/Socket.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Socket/Socket.h b/Socket/Socket.h index 1869224..1ef8cad 100644 --- a/Socket/Socket.h +++ b/Socket/Socket.h @@ -158,7 +158,7 @@ class ASocket #define SocketLog(fmt, ...) \ do { \ - if (m_oLog && ((m_eSettingsFlags & ENABLE_LOG) == ENABLE_LOG)) \ + if (m_oLog && (m_eSettingsFlags & ENABLE_LOG)) \ { \ m_oLog(StringFormat(fmt, ##__VA_ARGS__)); \ } \ From f8b735272e33952877affadcdf41254574bb3c70 Mon Sep 17 00:00:00 2001 From: cruise Date: Sat, 11 Apr 2026 10:19:16 +0800 Subject: [PATCH 3/6] fixed a bug in the SocketClose function where the condition for checking the sd handle was reversed. --- Socket/Socket.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Socket/Socket.cpp b/Socket/Socket.cpp index 70e0454..c40f0c0 100644 --- a/Socket/Socket.cpp +++ b/Socket/Socket.cpp @@ -118,7 +118,7 @@ char* ASocket::GaiStrerror(int ecode) void ASocket::SocketClose(Socket& sd) { - if (sd == INVALID_SOCKET) + if (sd != INVALID_SOCKET) { #ifdef _WIN32 closesocket(sd); From 16b6e91d102ae58e4322216b665d4aede0b4b135 Mon Sep 17 00:00:00 2001 From: cruise Date: Sat, 11 Apr 2026 10:22:19 +0800 Subject: [PATCH 4/6] add the SOCKET_ERR_EPIPE macro. This macro indicates that when send < 0, and the error is EPIPE, we do not consider it an error, but leave it to recv to handle, so that recv can read the remaining unread data in the system buffer. --- Socket/Socket.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Socket/Socket.h b/Socket/Socket.h index 1ef8cad..67f544f 100644 --- a/Socket/Socket.h +++ b/Socket/Socket.h @@ -62,6 +62,9 @@ #define SOCKET_ERR_ADDR_INUSE(e) \ ((e) == EADDRINUSE) +#define SOCKET_ERR_EPIPE(e) \ + ((e) == EPIPE) + #else #define SOCKET_ERR_IS_EAGAIN(e) ((e) == WSAEWOULDBLOCK) @@ -83,6 +86,9 @@ #define SOCKET_ERR_ADDR_INUSE(e) \ ((e) == WSAEADDRINUSE) +#define SOCKET_ERR_EPIPE(e) \ + ((e) == WSAESHUTDOWN) + #endif class ASocket From adbfa7539785664b2e2e58b3d89704b4b05eba98 Mon Sep 17 00:00:00 2001 From: cruise Date: Sat, 11 Apr 2026 10:26:02 +0800 Subject: [PATCH 5/6] fixed a bug where the release order of Socket, SSL, and SSL_CTX was incorrect. --- Socket/SecureSocket.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Socket/SecureSocket.cpp b/Socket/SecureSocket.cpp index a380cbf..5fdd19d 100644 --- a/Socket/SecureSocket.cpp +++ b/Socket/SecureSocket.cpp @@ -60,8 +60,6 @@ ASecureSocket::SSLSocket::~SSLSocket() void ASecureSocket::SSLSocket::Disconnect() { - SocketClose(m_SockFd); - if (m_pSSL != nullptr) { /* send the close_notify alert to the peer. */ @@ -70,6 +68,8 @@ void ASecureSocket::SSLSocket::Disconnect() m_pSSL = nullptr; } + SocketClose(m_SockFd); + if (m_pCTXSSL != nullptr) { SSL_CTX_free(m_pCTXSSL); From 58b693f47488864a89ea4da394b7f5991dcd8d2a Mon Sep 17 00:00:00 2001 From: cruise Date: Sat, 11 Apr 2026 10:35:41 +0800 Subject: [PATCH 6/6] the specification addresses the handling of the Receive, Send, Connect, Listen, and Disconnect functions. --- Socket/TCPClient.cpp | 32 ++++++++++++++++---------- Socket/TCPSSLClient.cpp | 46 ++++++++++++++++++++++++++----------- Socket/TCPSSLServer.cpp | 51 ++++++++++++++++++++++++++++------------- Socket/TCPServer.cpp | 31 +++++++++++++++---------- 4 files changed, 107 insertions(+), 53 deletions(-) diff --git a/Socket/TCPClient.cpp b/Socket/TCPClient.cpp index 5d266e8..53a6ddc 100644 --- a/Socket/TCPClient.cpp +++ b/Socket/TCPClient.cpp @@ -180,14 +180,11 @@ int CTCPClient::Receive(char* pData, size_t uSize, bool bReadFully /*= true*/) c #endif int total = 0; - bool isOK = true; do { - isOK = true; int nRecvd = recv(m_ConnectSocket, pData + total, (int)uSize - total, 0); if (nRecvd == SOCKET_ERROR) { - isOK = false; int iErrCode = GetSocketError(); if (SOCKET_ERR_RW_RETRIABLE(iErrCode)) { @@ -204,13 +201,23 @@ int CTCPClient::Receive(char* pData, size_t uSize, bool bReadFully /*= true*/) c } #endif #endif - SocketLog("[ERROR]TCPClient, recv failed[%d:%s][%d]", iErrCode, strerror(iErrCode), m_ConnectSocket); + if (total == 0) + { + SocketLog("[ERROR]TCPClient, recv failed[%d:%s][%d]", iErrCode, strerror(iErrCode), m_ConnectSocket); + return -1; + } + break; } if (nRecvd == 0) { - SocketLog("[INFO ]TCPClient, peer shut down[%d]", m_ConnectSocket); + if (total == 0) + { + SocketLog("[INFO ]TCPClient, peer shut down[%d]", m_ConnectSocket); + return 0; + } + break; } @@ -218,12 +225,7 @@ int CTCPClient::Receive(char* pData, size_t uSize, bool bReadFully /*= true*/) c } while (bReadFully && (total < (int)uSize)); - if (!isOK && total == 0) - { - return -1; - } - - return (int)total; + return total; } int CTCPClient::Send(const char* pData, size_t uSize) const @@ -252,6 +254,12 @@ int CTCPClient::Send(const char* pData, size_t uSize) const continue; } + if (SOCKET_ERR_EPIPE(iErrCode)) + { + SocketLog("[WARN ]TCPClient, send shutdowned SHUT_WR[%d:%s][%d]", iErrCode, strerror(iErrCode), m_ConnectSocket); + break; + } + SocketLog("[ERROR]TCPClient, send failed[%d:%s][%d]", iErrCode, strerror(iErrCode), m_ConnectSocket); return -1; } @@ -259,7 +267,7 @@ int CTCPClient::Send(const char* pData, size_t uSize) const total += nSent; } while (total < (int)uSize); - return (int)total; + return total; } int CTCPClient::Send(const std::string& strData) const diff --git a/Socket/TCPSSLClient.cpp b/Socket/TCPSSLClient.cpp index be78dc7..c4360d1 100644 --- a/Socket/TCPSSLClient.cpp +++ b/Socket/TCPSSLClient.cpp @@ -132,22 +132,24 @@ bool CTCPSSLClient::Connect(const std::string& strServer, const std::string& str { /* initiate the TLS/SSL handshake with an TLS/SSL server */ int iResult = SSL_connect(m_SSLConnectSocket.m_pSSL); - if (iResult == 1) + if (iResult <= 0) { - connectOK = true; - break; - } + int iErrCode = SSL_get_error(m_SSLConnectSocket.m_pSSL, iResult); + if (iErrCode == SSL_ERROR_WANT_CONNECT || iErrCode == SSL_ERROR_WANT_WRITE || iErrCode == SSL_ERROR_WANT_READ) + { + continue; + } - int iErrCode = SSL_get_error(m_SSLConnectSocket.m_pSSL, iResult); - if (iErrCode != SSL_ERROR_WANT_CONNECT) - { // under Windows it creates problems SocketLog("[ERROR]TCPSSLClient, SSL_connect failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), m_SSLConnectSocket.m_SockFd); #ifndef _WIN32 - ERR_print_errors_fp(stdout); + //ERR_print_errors_fp(stdout); #endif break; } + + connectOK = true; + break; } while (1); if (!connectOK) @@ -203,12 +205,24 @@ int CTCPSSLClient::Receive(char* pData, size_t uSize, bool bReadFully /*= true*/ if (nRecvd <= 0) { int iErrCode = SSL_get_error(m_SSLConnectSocket.m_pSSL, nRecvd); - if (iErrCode == SSL_ERROR_WANT_READ) + if (iErrCode == SSL_ERROR_WANT_READ || iErrCode == SSL_ERROR_WANT_WRITE) { continue; } - SocketLog("[ERROR]TCPSSLClient, SSL_read failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), m_SSLConnectSocket.m_SockFd); + if (total == 0) + { + m_iSSLErrCode = iErrCode; + if (iErrCode == SSL_ERROR_ZERO_RETURN) + { + SocketLog("[INFO ]TCPSSLClient, SSL_read failed(peer shut down)[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), m_SSLConnectSocket.m_SockFd); + return 0; + } + + SocketLog("[ERROR]TCPSSLClient, SSL_read failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), m_SSLConnectSocket.m_SockFd); + return -1; + } + break; } @@ -242,11 +256,17 @@ int CTCPSSLClient::Send(const char* pData, size_t uSize) const if (nSent <= 0) { int iErrCode = SSL_get_error(m_SSLConnectSocket.m_pSSL, nSent); - if (iErrCode == SSL_ERROR_WANT_WRITE) + if (iErrCode == SSL_ERROR_WANT_WRITE || iErrCode == SSL_ERROR_WANT_READ) { continue; } + if (iErrCode == SSL_ERROR_ZERO_RETURN) + { + //SocketLog("[WARN ]TCPSSLClient, SSL_write SSL_shutdowned[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), m_SSLConnectSocket.m_SockFd); + break; + } + SocketLog("[ERROR]TCPSSLClient, SSL_write failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), m_SSLConnectSocket.m_SockFd); return -1; } @@ -254,7 +274,7 @@ int CTCPSSLClient::Send(const char* pData, size_t uSize) const total += nSent; } while (total < (int)uSize); - return (int)total; + return total; } int CTCPSSLClient::Send(const std::string& strData) const @@ -282,6 +302,6 @@ void CTCPSSLClient::Disconnect() // send close_notify message to notify peer of the SSL closure. m_SSLConnectSocket.Disconnect(); //ShutdownSSL(m_SSLConnectSocket); - m_TCPClient.Disconnect(); + //m_TCPClient.Disconnect(); } #endif diff --git a/Socket/TCPSSLServer.cpp b/Socket/TCPSSLServer.cpp index 9a3324a..992c931 100644 --- a/Socket/TCPSSLServer.cpp +++ b/Socket/TCPSSLServer.cpp @@ -18,7 +18,7 @@ CTCPSSLServer::CTCPSSLServer(const LogFnCallback oLogger, CTCPSSLServer::~CTCPSSLServer() { - SocketClose(m_TCPServer.m_ListenSocket); + //SocketClose(m_TCPServer.m_ListenSocket); } bool CTCPSSLServer::SetRcvTimeout(SSLSocket& ClientSocket, unsigned int msec_timeout) @@ -139,23 +139,25 @@ int CTCPSSLServer::Listen(SSLSocket& ClientSocket, size_t msec /*= ACCEPT_WAIT_I { /* wait for a TLS/SSL client to initiate a TLS/SSL handshake */ int iSSLErr = SSL_accept(ClientSocket.m_pSSL); - if (iSSLErr == 1) + if (iSSLErr <= 0) { - acceptOK = true; - break; - } + //Error occurred, log and close down ssl + int iErrCode = SSL_get_error(ClientSocket.m_pSSL, iSSLErr); + if (iErrCode == SSL_ERROR_WANT_ACCEPT || iErrCode == SSL_ERROR_WANT_READ || iErrCode == SSL_ERROR_WANT_WRITE) + { + continue; + } - //Error occurred, log and close down ssl - int iErrCode = SSL_get_error(ClientSocket.m_pSSL, iSSLErr); - if (iErrCode != SSL_ERROR_WANT_ACCEPT) - { SocketLog("[ERROR]TCPSSLServer, SSL_accept failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), ClientSocket.m_SockFd); // under Windows it creates problems #ifndef _WIN32 - ERR_print_errors_fp(stdout); + //ERR_print_errors_fp(stdout); #endif break; } + + acceptOK = true; + break; } while (1); if (!acceptOK) @@ -199,13 +201,24 @@ int CTCPSSLServer::Receive(const SSLSocket& ClientSocket, char* pData, size_t uS if (nRecvd <= 0) { int iErrCode = SSL_get_error(ClientSocket.m_pSSL, nRecvd); - if (iErrCode == SSL_ERROR_WANT_READ) + if (iErrCode == SSL_ERROR_WANT_READ || iErrCode == SSL_ERROR_WANT_WRITE) { continue; } - SocketLog("[ERROR]TCPSSLServer, SSL_read failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), ClientSocket.m_SockFd); - //ERR_print_errors_fp(stdout); + if (total == 0) + { + //ERR_print_errors_fp(stdout); + if (iErrCode == SSL_ERROR_ZERO_RETURN) + { + SocketLog("[INFO ]TCPSSLServer, SSL_read failed(peer shut down)[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), ClientSocket.m_SockFd); + return 0; + } + + SocketLog("[ERROR]TCPSSLServer, SSL_read failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), ClientSocket.m_SockFd); + return -1; + } + break; } @@ -241,11 +254,17 @@ int CTCPSSLServer::Send(const SSLSocket& ClientSocket, const char* pData, size_t if (nSent <= 0) { int iErrCode = SSL_get_error(ClientSocket.m_pSSL, nSent); - if (iErrCode == SSL_ERROR_WANT_WRITE) + if (iErrCode == SSL_ERROR_WANT_WRITE || iErrCode == SSL_ERROR_WANT_READ) { continue; } + if (iErrCode == SSL_ERROR_ZERO_RETURN) + { + SocketLog("[WARN ]TCPSSLServer, SSL_write SSL_shutdowned[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), ClientSocket.m_SockFd); + break; + } + SocketLog("[ERROR]TCPSSLServer, SSL_write failed[%d:%s][%d]", iErrCode, GetSSLErrorString(iErrCode), ClientSocket.m_SockFd); return -1; } @@ -253,7 +272,7 @@ int CTCPSSLServer::Send(const SSLSocket& ClientSocket, const char* pData, size_t total += nSent; } while (total < (int)uSize); - return (int)total; + return total; } int CTCPSSLServer::Send(const SSLSocket& ClientSocket, const std::string& strData) const @@ -281,6 +300,6 @@ void CTCPSSLServer::Disconnect(SSLSocket& ClientSocket) const // send close_notify message to notify peer of the SSL closure. ClientSocket.Disconnect(); //ShutdownSSL(ClientSocket); - m_TCPServer.Disconnect(ClientSocket.m_SockFd); + //m_TCPServer.Disconnect(ClientSocket.m_SockFd); } #endif diff --git a/Socket/TCPServer.cpp b/Socket/TCPServer.cpp index 0a0a8a7..0066096 100644 --- a/Socket/TCPServer.cpp +++ b/Socket/TCPServer.cpp @@ -267,14 +267,11 @@ int CTCPServer::Receive(Socket ClientSocket, char* pData, size_t uSize, bool bRe #endif int total = 0; - bool isOK = true; do { - isOK = true; int nRecvd = recv(ClientSocket, pData + total, (int)uSize - total, 0); if (nRecvd == SOCKET_ERROR) { - isOK = false; int iErrCode = GetSocketError(); if (SOCKET_ERR_RW_RETRIABLE(iErrCode)) { @@ -291,14 +288,23 @@ int CTCPServer::Receive(Socket ClientSocket, char* pData, size_t uSize, bool bRe } #endif #endif + if (total == 0) + { + SocketLog("[ERROR]TCPServer, recv failed[%d:%s][%d]", iErrCode, strerror(iErrCode), ClientSocket); + return -1; + } - SocketLog("[ERROR]TCPServer, recv failed[%d:%s][%d]", iErrCode, strerror(iErrCode), ClientSocket); break; } if (nRecvd == 0) { - SocketLog("[INFO ]TCPServer, peer shut down[%d]", ClientSocket); + if (total == 0) + { + SocketLog("[INFO ]TCPServer, peer shut down[%d]", ClientSocket); + return 0; + } + break; } @@ -306,12 +312,7 @@ int CTCPServer::Receive(Socket ClientSocket, char* pData, size_t uSize, bool bRe } while (bReadFully && (total < (int)uSize)); - if (!isOK && total == 0) - { - return -1; - } - - return (int)total; + return total; } int CTCPServer::Send(const Socket ClientSocket, const char* pData, size_t uSize) const @@ -340,6 +341,12 @@ int CTCPServer::Send(const Socket ClientSocket, const char* pData, size_t uSize) continue; } + if (SOCKET_ERR_EPIPE(iErrCode)) + { + //SocketLog("[WARN ]TCPServer, send shutdowned SHUT_WR[%d:%s][%d]", iErrCode, strerror(iErrCode), ClientSocket); + break; + } + SocketLog("[ERROR]TCPServer, send failed[%d:%s][%d]", iErrCode, strerror(iErrCode), ClientSocket); return -1; } @@ -347,7 +354,7 @@ int CTCPServer::Send(const Socket ClientSocket, const char* pData, size_t uSize) total += nSent; } while (total < (int)uSize); - return (int)total; + return total; } int CTCPServer::Send(const Socket ClientSocket, const std::string& strData) const