From d2fdc28c85977938ea8edb89ed05fe775ba967f9 Mon Sep 17 00:00:00 2001 From: Andrey Prygunkov Date: Sun, 30 Dec 2012 15:27:38 +0000 Subject: [PATCH] refactor: reworked Connection-class: fully encapsulted sockets; better read/write methods --- BinRpc.cpp | 278 ++++++++++++++++++++-------------------------- BinRpc.h | 12 +- Connection.cpp | 90 +++++++++------ Connection.h | 20 ++-- Frontend.cpp | 16 +-- RemoteClient.cpp | 133 +++++++--------------- RemoteServer.cpp | 64 ++++------- RemoteServer.h | 5 +- WebDownloader.cpp | 2 +- WebServer.cpp | 8 +- WebServer.h | 2 - XmlRpc.cpp | 1 - XmlRpc.h | 2 - 13 files changed, 272 insertions(+), 361 deletions(-) diff --git a/BinRpc.cpp b/BinRpc.cpp index df928ae6..aa0cbc36 100644 --- a/BinRpc.cpp +++ b/BinRpc.cpp @@ -89,29 +89,27 @@ const unsigned int g_iMessageRequestSizes[] = //***************************************************************** // BinProcessor +BinRpcProcessor::BinRpcProcessor() +{ + m_MessageBase.m_iSignature = (int)NZBMESSAGE_SIGNATURE; +} + void BinRpcProcessor::Execute() { // Read the first package which needs to be a request - int iBytesReceived = recv(m_iSocket, ((char*)&m_MessageBase) + sizeof(m_MessageBase.m_iSignature), sizeof(m_MessageBase) - sizeof(m_MessageBase.m_iSignature), 0); - if (iBytesReceived < 0) + if (!m_pConnection->Recv(((char*)&m_MessageBase) + sizeof(m_MessageBase.m_iSignature), sizeof(m_MessageBase) - sizeof(m_MessageBase.m_iSignature))) { - return; - } - - // Make sure this is a nzbget request from a client - if ((int)ntohl(m_MessageBase.m_iSignature) != (int)NZBMESSAGE_SIGNATURE) - { - warn("Non-nzbget request received on port %i from %s", g_pOptions->GetControlPort(), m_szClientIP); + warn("Non-nzbget request received on port %i from %s", g_pOptions->GetControlPort(), m_pConnection->GetRemoteAddr()); return; } if (strcmp(m_MessageBase.m_szPassword, g_pOptions->GetControlPassword())) { - warn("nzbget request received on port %i from %s, but password invalid", g_pOptions->GetControlPort(), m_szClientIP); + warn("nzbget request received on port %i from %s, but password invalid", g_pOptions->GetControlPort(), m_pConnection->GetRemoteAddr()); return; } - debug("%s request received from %s", g_szMessageRequestNames[ntohl(m_MessageBase.m_iType)], m_szClientIP); + debug("%s request received from %s", g_szMessageRequestNames[ntohl(m_MessageBase.m_iType)], m_pConnection->GetRemoteAddr()); Dispatch(); } @@ -202,7 +200,7 @@ void BinRpcProcessor::Dispatch() if (command) { - command->SetSocket(m_iSocket); + command->SetConnection(m_pConnection); command->SetMessageBase(&m_MessageBase); command->Execute(); delete command; @@ -224,8 +222,8 @@ void BinCommand::SendBoolResponse(bool bSuccess, const char* szText) BoolResponse.m_iTrailingDataLength = htonl(iTextLen); // Send the request answer - send(m_iSocket, (char*) &BoolResponse, sizeof(BoolResponse), 0); - send(m_iSocket, (char*)szText, iTextLen, 0); + m_pConnection->Send((char*) &BoolResponse, sizeof(BoolResponse)); + m_pConnection->Send((char*)szText, iTextLen); } bool BinCommand::ReceiveRequest(void* pBuffer, int iSize) @@ -234,8 +232,7 @@ bool BinCommand::ReceiveRequest(void* pBuffer, int iSize) iSize -= sizeof(SNZBRequestBase); if (iSize > 0) { - int iBytesReceived = recv(m_iSocket, ((char*)pBuffer) + sizeof(SNZBRequestBase), iSize, 0); - if (iBytesReceived != iSize) + if (!m_pConnection->Recv(((char*)pBuffer) + sizeof(SNZBRequestBase), iSize)) { error("invalid request"); return false; @@ -343,57 +340,44 @@ void DownloadBinCommand::Execute() } char* pRecvBuffer = (char*)malloc(ntohl(DownloadRequest.m_iTrailingDataLength) + 1); - char* pBufPtr = pRecvBuffer; - // Read from the socket until nothing remains - int iResult = 0; - int NeedBytes = ntohl(DownloadRequest.m_iTrailingDataLength); - while (NeedBytes > 0) + if (!m_pConnection->Recv(pRecvBuffer, ntohl(DownloadRequest.m_iTrailingDataLength))) { - iResult = recv(m_iSocket, pBufPtr, NeedBytes, 0); - // Did the recv succeed? - if (iResult <= 0) - { - error("invalid request"); - break; - } - pBufPtr += iResult; - NeedBytes -= iResult; + error("invalid request"); + free(pRecvBuffer); + return; } - - if (NeedBytes == 0) + + int iPriority = ntohl(DownloadRequest.m_iPriority); + bool bAddPaused = ntohl(DownloadRequest.m_bAddPaused); + + NZBFile* pNZBFile = NZBFile::CreateFromBuffer(DownloadRequest.m_szFilename, DownloadRequest.m_szCategory, pRecvBuffer, ntohl(DownloadRequest.m_iTrailingDataLength)); + + if (pNZBFile) { - int iPriority = ntohl(DownloadRequest.m_iPriority); - bool bAddPaused = ntohl(DownloadRequest.m_bAddPaused); - - NZBFile* pNZBFile = NZBFile::CreateFromBuffer(DownloadRequest.m_szFilename, DownloadRequest.m_szCategory, pRecvBuffer, ntohl(DownloadRequest.m_iTrailingDataLength)); - - if (pNZBFile) + info("Request: Queue collection %s", DownloadRequest.m_szFilename); + + for (NZBFile::FileInfos::iterator it = pNZBFile->GetFileInfos()->begin(); it != pNZBFile->GetFileInfos()->end(); it++) { - info("Request: Queue collection %s", DownloadRequest.m_szFilename); - - for (NZBFile::FileInfos::iterator it = pNZBFile->GetFileInfos()->begin(); it != pNZBFile->GetFileInfos()->end(); it++) - { - FileInfo* pFileInfo = *it; - pFileInfo->SetPriority(iPriority); - pFileInfo->SetPaused(bAddPaused); - } - - g_pQueueCoordinator->AddNZBFileToQueue(pNZBFile, ntohl(DownloadRequest.m_bAddFirst)); - delete pNZBFile; - - char tmp[1024]; - snprintf(tmp, 1024, "Collection %s added to queue", Util::BaseFileName(DownloadRequest.m_szFilename)); - tmp[1024-1] = '\0'; - SendBoolResponse(true, tmp); - } - else - { - char tmp[1024]; - snprintf(tmp, 1024, "Download Request failed for %s", Util::BaseFileName(DownloadRequest.m_szFilename)); - tmp[1024-1] = '\0'; - SendBoolResponse(false, tmp); + FileInfo* pFileInfo = *it; + pFileInfo->SetPriority(iPriority); + pFileInfo->SetPaused(bAddPaused); } + + g_pQueueCoordinator->AddNZBFileToQueue(pNZBFile, ntohl(DownloadRequest.m_bAddFirst)); + delete pNZBFile; + + char tmp[1024]; + snprintf(tmp, 1024, "Collection %s added to queue", Util::BaseFileName(DownloadRequest.m_szFilename)); + tmp[1024-1] = '\0'; + SendBoolResponse(true, tmp); + } + else + { + char tmp[1024]; + snprintf(tmp, 1024, "Download Request failed for %s", Util::BaseFileName(DownloadRequest.m_szFilename)); + tmp[1024-1] = '\0'; + SendBoolResponse(false, tmp); } free(pRecvBuffer); @@ -633,12 +617,12 @@ void ListBinCommand::Execute() } // Send the request answer - send(m_iSocket, (char*) &ListResponse, sizeof(ListResponse), 0); + m_pConnection->Send((char*) &ListResponse, sizeof(ListResponse)); // Send the data if (bufsize > 0) { - send(m_iSocket, buf, bufsize, 0); + m_pConnection->Send(buf, bufsize); } if (buf) @@ -724,12 +708,12 @@ void LogBinCommand::Execute() LogResponse.m_iTrailingDataLength = htonl(bufsize); // Send the request answer - send(m_iSocket, (char*) &LogResponse, sizeof(LogResponse), 0); + m_pConnection->Send((char*) &LogResponse, sizeof(LogResponse)); // Send the data if (bufsize > 0) { - send(m_iSocket, buf, bufsize, 0); + m_pConnection->Send(buf, bufsize); } free(buf); @@ -760,24 +744,13 @@ void EditQueueBinCommand::Execute() } char* pBuf = (char*)malloc(iBufLength); - - // Read from the socket until nothing remains - char* pBufPtr = pBuf; - int NeedBytes = iBufLength; - int iResult = 0; - while (NeedBytes > 0) + + if (!m_pConnection->Recv(pBuf, iBufLength)) { - iResult = recv(m_iSocket, pBufPtr, NeedBytes, 0); - // Did the recv succeed? - if (iResult <= 0) - { - error("invalid request"); - break; - } - pBufPtr += iResult; - NeedBytes -= iResult; + error("invalid request"); + free(pBuf); + return; } - bool bOK = NeedBytes == 0; if (iNrIDEntries <= 0 && iNrNameEntries <= 0) { @@ -785,45 +758,44 @@ void EditQueueBinCommand::Execute() return; } - if (bOK) + char* szText = iTextLen > 0 ? pBuf : NULL; + int32_t* pIDs = (int32_t*)(pBuf + iTextLen); + char* pNames = (pBuf + iTextLen + iNrIDEntries * sizeof(int32_t)); + + IDList cIDList; + NameList cNameList; + + if (iNrIDEntries > 0) { - char* szText = iTextLen > 0 ? pBuf : NULL; - int32_t* pIDs = (int32_t*)(pBuf + iTextLen); - char* pNames = (pBuf + iTextLen + iNrIDEntries * sizeof(int32_t)); - - IDList cIDList; - NameList cNameList; - - if (iNrIDEntries > 0) + cIDList.reserve(iNrIDEntries); + for (int i = 0; i < iNrIDEntries; i++) { - cIDList.reserve(iNrIDEntries); - for (int i = 0; i < iNrIDEntries; i++) - { - cIDList.push_back(ntohl(pIDs[i])); - } + cIDList.push_back(ntohl(pIDs[i])); } + } - if (iNrNameEntries > 0) + if (iNrNameEntries > 0) + { + cNameList.reserve(iNrNameEntries); + for (int i = 0; i < iNrNameEntries; i++) { - cNameList.reserve(iNrNameEntries); - for (int i = 0; i < iNrNameEntries; i++) - { - cNameList.push_back(pNames); - pNames += strlen(pNames) + 1; - } + cNameList.push_back(pNames); + pNames += strlen(pNames) + 1; } + } - if (iAction < eRemoteEditActionPostMoveOffset) - { - bOK = g_pQueueCoordinator->GetQueueEditor()->EditList( - iNrIDEntries > 0 ? &cIDList : NULL, - iNrNameEntries > 0 ? &cNameList : NULL, - (QueueEditor::EMatchMode)iMatchMode, bSmartOrder, (QueueEditor::EEditAction)iAction, iOffset, szText); - } - else - { - bOK = g_pPrePostProcessor->QueueEditList(&cIDList, (PrePostProcessor::EEditAction)iAction, iOffset); - } + bool bOK = false; + + if (iAction < eRemoteEditActionPostMoveOffset) + { + bOK = g_pQueueCoordinator->GetQueueEditor()->EditList( + iNrIDEntries > 0 ? &cIDList : NULL, + iNrNameEntries > 0 ? &cNameList : NULL, + (QueueEditor::EMatchMode)iMatchMode, bSmartOrder, (QueueEditor::EEditAction)iAction, iOffset, szText); + } + else + { + bOK = g_pPrePostProcessor->QueueEditList(&cIDList, (PrePostProcessor::EEditAction)iAction, iOffset); } free(pBuf); @@ -926,12 +898,12 @@ void PostQueueBinCommand::Execute() PostQueueResponse.m_iTrailingDataLength = htonl(bufsize); // Send the request answer - send(m_iSocket, (char*) &PostQueueResponse, sizeof(PostQueueResponse), 0); + m_pConnection->Send((char*) &PostQueueResponse, sizeof(PostQueueResponse)); // Send the data if (bufsize > 0) { - send(m_iSocket, buf, bufsize, 0); + m_pConnection->Send(buf, bufsize); } free(buf); @@ -946,50 +918,36 @@ void WriteLogBinCommand::Execute() } char* pRecvBuffer = (char*)malloc(ntohl(WriteLogRequest.m_iTrailingDataLength) + 1); - char* pBufPtr = pRecvBuffer; - - // Read from the socket until nothing remains - int iResult = 0; - int NeedBytes = ntohl(WriteLogRequest.m_iTrailingDataLength); - pRecvBuffer[NeedBytes] = '\0'; - while (NeedBytes > 0) + + if (!m_pConnection->Recv(pRecvBuffer, ntohl(WriteLogRequest.m_iTrailingDataLength))) { - iResult = recv(m_iSocket, pBufPtr, NeedBytes, 0); - // Did the recv succeed? - if (iResult <= 0) - { - error("invalid request"); + error("invalid request"); + free(pRecvBuffer); + return; + } + + bool OK = true; + switch ((Message::EKind)ntohl(WriteLogRequest.m_iKind)) + { + case Message::mkDetail: + detail(pRecvBuffer); break; - } - pBufPtr += iResult; - NeedBytes -= iResult; - } - - if (NeedBytes == 0) - { - bool OK = true; - switch ((Message::EKind)ntohl(WriteLogRequest.m_iKind)) - { - case Message::mkDetail: - detail(pRecvBuffer); - break; - case Message::mkInfo: - info(pRecvBuffer); - break; - case Message::mkWarning: - warn(pRecvBuffer); - break; - case Message::mkError: - error(pRecvBuffer); - break; - case Message::mkDebug: - debug(pRecvBuffer); - break; - default: - OK = false; - } - SendBoolResponse(OK, OK ? "Message added to log" : "Invalid message-kind"); + case Message::mkInfo: + info(pRecvBuffer); + break; + case Message::mkWarning: + warn(pRecvBuffer); + break; + case Message::mkError: + error(pRecvBuffer); + break; + case Message::mkDebug: + debug(pRecvBuffer); + break; + default: + OK = false; } + SendBoolResponse(OK, OK ? "Message added to log" : "Invalid message-kind"); free(pRecvBuffer); } @@ -1092,12 +1050,12 @@ void HistoryBinCommand::Execute() HistoryResponse.m_iTrailingDataLength = htonl(bufsize); // Send the request answer - send(m_iSocket, (char*) &HistoryResponse, sizeof(HistoryResponse), 0); + m_pConnection->Send((char*) &HistoryResponse, sizeof(HistoryResponse)); // Send the data if (bufsize > 0) { - send(m_iSocket, buf, bufsize, 0); + m_pConnection->Send(buf, bufsize); } free(buf); @@ -1202,12 +1160,12 @@ void UrlQueueBinCommand::Execute() UrlQueueResponse.m_iTrailingDataLength = htonl(bufsize); // Send the request answer - send(m_iSocket, (char*) &UrlQueueResponse, sizeof(UrlQueueResponse), 0); + m_pConnection->Send((char*) &UrlQueueResponse, sizeof(UrlQueueResponse)); // Send the data if (bufsize > 0) { - send(m_iSocket, buf, bufsize, 0); + m_pConnection->Send(buf, bufsize); } free(buf); diff --git a/BinRpc.h b/BinRpc.h index 1f0fd51f..73d2673e 100644 --- a/BinRpc.h +++ b/BinRpc.h @@ -33,23 +33,21 @@ class BinRpcProcessor { private: - SOCKET m_iSocket; SNZBRequestBase m_MessageBase; - const char* m_szClientIP; + Connection* m_pConnection; void Dispatch(); public: + BinRpcProcessor(); void Execute(); - void SetSocket(SOCKET iSocket) { m_iSocket = iSocket; } - void SetSignature(int iSignature) { m_MessageBase.m_iSignature = iSignature; } - void SetClientIP(const char* szClientIP) { m_szClientIP = szClientIP; } + void SetConnection(Connection* pConnection) { m_pConnection = pConnection; } }; class BinCommand { protected: - SOCKET m_iSocket; + Connection* m_pConnection; SNZBRequestBase* m_pMessageBase; bool ReceiveRequest(void* pBuffer, int iSize); @@ -58,7 +56,7 @@ protected: public: virtual ~BinCommand() {} virtual void Execute() = 0; - void SetSocket(SOCKET iSocket) { m_iSocket = iSocket; } + void SetConnection(Connection* pConnection) { m_pConnection = pConnection; } void SetMessageBase(SNZBRequestBase* pMessageBase) { m_pMessageBase = pMessageBase; } }; diff --git a/Connection.cpp b/Connection.cpp index c6994972..02879fef 100644 --- a/Connection.cpp +++ b/Connection.cpp @@ -144,7 +144,6 @@ Connection::Connection(const char* szHost, int iPort, bool bTLS) m_iTimeout = 60; m_bSuppressErrors = true; m_szReadBuf = (char*)malloc(CONNECTION_READBUFFER_SIZE + 1); - m_bAutoClose = true; #ifndef DISABLE_TLS m_pTLS = NULL; m_bTLSError = false; @@ -156,22 +155,22 @@ Connection::Connection(const char* szHost, int iPort, bool bTLS) } } -Connection::Connection(SOCKET iSocket, bool bAutoClose) +Connection::Connection(SOCKET iSocket, bool bTLS) { debug("Creating Connection"); m_szHost = NULL; m_iPort = 0; - m_bTLS = false; + m_bTLS = bTLS; m_eStatus = csConnected; m_iSocket = iSocket; m_iBufAvail = 0; m_iTimeout = 60; m_bSuppressErrors = true; m_szReadBuf = (char*)malloc(CONNECTION_READBUFFER_SIZE + 1); - m_bAutoClose = bAutoClose; #ifndef DISABLE_TLS m_pTLS = NULL; + m_bTLSError = false; #endif } @@ -179,14 +178,12 @@ Connection::~Connection() { debug("Destroying Connection"); + Disconnect(); + if (m_szHost) { free(m_szHost); } - if (m_bAutoClose) - { - Disconnect(); - } free(m_szReadBuf); #ifndef DISABLE_TLS if (m_pTLS) @@ -270,18 +267,27 @@ int Connection::WriteLine(const char* pBuffer) return iRes; } -int Connection::Send(const char* pBuffer, int iSize) +bool Connection::Send(const char* pBuffer, int iSize) { debug("Sending data"); if (m_eStatus != csConnected) { - return -1; + return false; } - int iRes = send(m_iSocket, pBuffer, iSize, 0); + int iBytesSent = 0; + while (iBytesSent < iSize) + { + int iRes = send(m_iSocket, pBuffer + iBytesSent, iSize-iBytesSent, 0); + if (iRes <= 0) + { + return false; + } + iBytesSent += iRes; + } - return iRes; + return true; } char* Connection::ReadLine(char* pBuffer, int iSize, int* pBytesRead) @@ -296,21 +302,27 @@ char* Connection::ReadLine(char* pBuffer, int iSize, int* pBytesRead) return res; } -SOCKET Connection::Accept() +Connection* Connection::Accept() { debug("Accepting connection"); if (m_eStatus != csListening) { - return INVALID_SOCKET; + return NULL; } SOCKET iRes = DoAccept(); + if (iRes == INVALID_SOCKET) + { + return NULL; + } + + Connection* pCon = new Connection(iRes, m_bTLS); - return iRes; + return pCon; } -int Connection::Recv(char* pBuffer, int iSize) +int Connection::TryRecv(char* pBuffer, int iSize) { debug("Receiving data"); @@ -326,7 +338,7 @@ int Connection::Recv(char* pBuffer, int iSize) return iReceived; } -bool Connection::RecvAll(char * pBuffer, int iSize) +bool Connection::Recv(char * pBuffer, int iSize) { debug("Receiving data (full buffer)"); @@ -470,10 +482,7 @@ bool Connection::DoDisconnect() closesocket(m_iSocket); m_iSocket = INVALID_SOCKET; #ifndef DISABLE_TLS - if (m_pTLS) - { - CloseTLS(); - } + CloseTLS(); #endif } @@ -666,7 +675,7 @@ int Connection::DoBind() SOCKET Connection::DoAccept() { - SOCKET iSocket = accept(GetSocket(), NULL, NULL); + SOCKET iSocket = accept(m_iSocket, NULL, NULL); if (iSocket == INVALID_SOCKET && m_eStatus != csCancelled) { @@ -780,13 +789,12 @@ bool Connection::StartTLS() m_pTLS = malloc(sizeof(tls_t)); tls_t* pTLS = (tls_t*)m_pTLS; memset(pTLS, 0, sizeof(tls_t)); - tls_clear(pTLS); char* szErrStr; int iRes; - + iRes = tls_init(pTLS, NULL, NULL, NULL, 0, &szErrStr); - if (!CheckTLSResult(iRes, szErrStr, "Could not initialize TLS-object: %s")) + if (!CheckTLSResult(iRes, szErrStr, "Could not initialize secure connection: %s")) { return false; } @@ -804,9 +812,12 @@ bool Connection::StartTLS() void Connection::CloseTLS() { - tls_close((tls_t*)m_pTLS); - free(m_pTLS); - m_pTLS = NULL; + if (m_pTLS) + { + tls_close((tls_t*)m_pTLS); + free(m_pTLS); + m_pTLS = NULL; + } } int Connection::recv(SOCKET s, char* buf, int len, int flags) @@ -818,7 +829,7 @@ int Connection::recv(SOCKET s, char* buf, int len, int flags) m_bTLSError = false; char* szErrStr; int iRes = tls_getbuf((tls_t*)m_pTLS, buf, len, &iReceived, &szErrStr); - if (!CheckTLSResult(iRes, szErrStr, "Could not read from TLS-socket: %s")) + if (!CheckTLSResult(iRes, szErrStr, "TLS-error: %s")) { m_bTLSError = true; return -1; @@ -838,12 +849,12 @@ int Connection::send(SOCKET s, const char* buf, int len, int flags) m_bTLSError = false; char* szErrStr; int iRes = tls_putbuf((tls_t*)m_pTLS, buf, len, &szErrStr); - if (!CheckTLSResult(iRes, szErrStr, "Could not send to TLS-socket: %s")) + if (!CheckTLSResult(iRes, szErrStr, "TLS-error: %s")) { m_bTLSError = true; return -1; } - return 0; + return len; } else { @@ -902,3 +913,20 @@ unsigned int Connection::ResolveHostAddr(const char* szHost) return uaddr; } #endif + +const char* Connection::GetRemoteAddr() +{ + struct sockaddr_in PeerName; + int iPeerNameLength = sizeof(PeerName); + if (getpeername(m_iSocket, (struct sockaddr*)&PeerName, (SOCKLEN_T*) &iPeerNameLength) >= 0) + { +#ifdef WIN32 + strncpy(m_szRemoteAddr, sizeof(m_szRemoteAddr), inet_ntoa(PeerName.sin_addr)); +#else + inet_ntop(AF_INET, &PeerName.sin_addr, m_szRemoteAddr, sizeof(m_szRemoteAddr)); +#endif + } + m_szRemoteAddr[sizeof(m_szRemoteAddr)-1] = '\0'; + + return m_szRemoteAddr; +} diff --git a/Connection.h b/Connection.h index 3823b3ef..d647a097 100644 --- a/Connection.h +++ b/Connection.h @@ -43,7 +43,9 @@ public: csListening, csCancelled }; - + + typedef void TLS; + protected: char* m_szHost; int m_iPort; @@ -55,9 +57,9 @@ protected: EStatus m_eStatus; int m_iTimeout; bool m_bSuppressErrors; - bool m_bAutoClose; + char m_szRemoteAddr[20]; #ifndef DISABLE_TLS - void* m_pTLS; + TLS* m_pTLS; static bool bTLSLibInitialized; bool m_bTLSError; #endif @@ -67,6 +69,7 @@ protected: #endif #endif + Connection(SOCKET iSocket, bool bTLS); void ReportError(const char* szMsgPrefix, const char* szMsgArg, bool PrintErrCode, int herrno); virtual bool DoConnect(); virtual bool DoDisconnect(); @@ -86,29 +89,28 @@ protected: public: Connection(const char* szHost, int iPort, bool bTLS); - Connection(SOCKET iSocket, bool bAutoClose); virtual ~Connection(); static void Init(); static void Final(); bool Connect(); bool Disconnect(); int Bind(); - int Send(const char* pBuffer, int iSize); - int Recv(char* pBuffer, int iSize); - bool RecvAll(char* pBuffer, int iSize); + bool Send(const char* pBuffer, int iSize); + bool Recv(char* pBuffer, int iSize); + int TryRecv(char* pBuffer, int iSize); char* ReadLine(char* pBuffer, int iSize, int* pBytesRead); void ReadBuffer(char** pBuffer, int *iBufLen); int WriteLine(const char* pBuffer); - SOCKET Accept(); + Connection* Accept(); void Cancel(); const char* GetHost() { return m_szHost; } int GetPort() { return m_iPort; } bool GetTLS() { return m_bTLS; } - SOCKET GetSocket() { return m_iSocket; } void SetTimeout(int iTimeout) { m_iTimeout = iTimeout; } EStatus GetStatus() { return m_eStatus; } void SetSuppressErrors(bool bSuppressErrors) { m_bSuppressErrors = bSuppressErrors; } bool GetSuppressErrors() { return m_bSuppressErrors; } + const char* GetRemoteAddr(); #ifndef DISABLE_TLS bool StartTLS(); #endif diff --git a/Frontend.cpp b/Frontend.cpp index 0d1bf40b..cba79805 100644 --- a/Frontend.cpp +++ b/Frontend.cpp @@ -261,15 +261,15 @@ bool Frontend::RequestMessages() LogRequest.m_iIDFrom = 0; } - if (connection.Send((char*)(&LogRequest), sizeof(LogRequest)) < 0) + if (!connection.Send((char*)(&LogRequest), sizeof(LogRequest))) { return false; } // Now listen for the returned log SNZBLogResponse LogResponse; - int iResponseLen = connection.Recv((char*) &LogResponse, sizeof(LogResponse)); - if (iResponseLen != sizeof(LogResponse) || + bool bRead = connection.Recv((char*) &LogResponse, sizeof(LogResponse)); + if (!bRead || (int)ntohl(LogResponse.m_MessageBase.m_iSignature) != (int)NZBMESSAGE_SIGNATURE || ntohl(LogResponse.m_MessageBase.m_iStructSize) != sizeof(LogResponse)) { @@ -280,7 +280,7 @@ bool Frontend::RequestMessages() if (ntohl(LogResponse.m_iTrailingDataLength) > 0) { pBuf = (char*)malloc(ntohl(LogResponse.m_iTrailingDataLength)); - if (!connection.RecvAll(pBuf, ntohl(LogResponse.m_iTrailingDataLength))) + if (!connection.Recv(pBuf, ntohl(LogResponse.m_iTrailingDataLength))) { free(pBuf); return false; @@ -325,15 +325,15 @@ bool Frontend::RequestFileList() ListRequest.m_bFileList = htonl(m_bFileList); ListRequest.m_bServerState = htonl(m_bSummary); - if (connection.Send((char*)(&ListRequest), sizeof(ListRequest)) < 0) + if (!connection.Send((char*)(&ListRequest), sizeof(ListRequest))) { return false; } // Now listen for the returned list SNZBListResponse ListResponse; - int iResponseLen = connection.Recv((char*) &ListResponse, sizeof(ListResponse)); - if (iResponseLen != sizeof(ListResponse) || + bool bRead = connection.Recv((char*) &ListResponse, sizeof(ListResponse)); + if (!bRead || (int)ntohl(ListResponse.m_MessageBase.m_iSignature) != (int)NZBMESSAGE_SIGNATURE || ntohl(ListResponse.m_MessageBase.m_iStructSize) != sizeof(ListResponse)) { @@ -344,7 +344,7 @@ bool Frontend::RequestFileList() if (ntohl(ListResponse.m_iTrailingDataLength) > 0) { pBuf = (char*)malloc(ntohl(ListResponse.m_iTrailingDataLength)); - if (!connection.RecvAll(pBuf, ntohl(ListResponse.m_iTrailingDataLength))) + if (!connection.Recv(pBuf, ntohl(ListResponse.m_iTrailingDataLength))) { free(pBuf); return false; diff --git a/RemoteClient.cpp b/RemoteClient.cpp index a71eac0b..0c9fa0c3 100644 --- a/RemoteClient.cpp +++ b/RemoteClient.cpp @@ -137,35 +137,21 @@ bool RemoteClient::ReceiveBoolResponse() SNZBDownloadResponse BoolResponse; memset(&BoolResponse, 0, sizeof(BoolResponse)); - int iResponseLen = m_pConnection->Recv((char*)&BoolResponse, sizeof(BoolResponse)); - if (iResponseLen != sizeof(BoolResponse) || + bool bRead = m_pConnection->Recv((char*)&BoolResponse, sizeof(BoolResponse)); + if (!bRead || (int)ntohl(BoolResponse.m_MessageBase.m_iSignature) != (int)NZBMESSAGE_SIGNATURE || ntohl(BoolResponse.m_MessageBase.m_iStructSize) != sizeof(BoolResponse)) { - if (iResponseLen < 0) - { - printf("No response received (timeout)\n"); - } - else - { - printf("Invalid response received: either not nzbget-server or wrong server version\n"); - } + printf("No response or invalid response (timeout, not nzbget-server or wrong nzbget-server version)\n"); return false; } int iTextLen = ntohl(BoolResponse.m_iTrailingDataLength); char* buf = (char*)malloc(iTextLen); - iResponseLen = m_pConnection->Recv(buf, iTextLen); - if (iResponseLen != iTextLen) + bRead = m_pConnection->Recv(buf, iTextLen); + if (!bRead) { - if (iResponseLen < 0) - { - printf("No response received (timeout)\n"); - } - else - { - printf("Invalid response received: either not nzbget-server or wrong server version\n"); - } + printf("No response or invalid response (timeout, not nzbget-server or wrong nzbget-server version)\n"); free(buf); return false; } @@ -208,7 +194,7 @@ bool RemoteClient::RequestServerDownload(const char* szFilename, const char* szC } DownloadRequest.m_szCategory[NZBREQUESTFILENAMESIZE-1] = '\0'; - if (m_pConnection->Send((char*)(&DownloadRequest), sizeof(DownloadRequest)) < 0) + if (!m_pConnection->Send((char*)(&DownloadRequest), sizeof(DownloadRequest))) { perror("m_pConnection->Send"); OK = false; @@ -332,7 +318,7 @@ bool RemoteClient::RequestServerList(bool bFiles, bool bGroups, const char* szPa ListRequest.m_szPattern[NZBREQUESTFILENAMESIZE-1] = '\0'; } - if (m_pConnection->Send((char*)(&ListRequest), sizeof(ListRequest)) < 0) + if (!m_pConnection->Send((char*)(&ListRequest), sizeof(ListRequest))) { perror("m_pConnection->Send"); return false; @@ -342,19 +328,12 @@ bool RemoteClient::RequestServerList(bool bFiles, bool bGroups, const char* szPa // Now listen for the returned list SNZBListResponse ListResponse; - int iResponseLen = m_pConnection->Recv((char*) &ListResponse, sizeof(ListResponse)); - if (iResponseLen != sizeof(ListResponse) || + bool bRead = m_pConnection->Recv((char*) &ListResponse, sizeof(ListResponse)); + if (!bRead || (int)ntohl(ListResponse.m_MessageBase.m_iSignature) != (int)NZBMESSAGE_SIGNATURE || ntohl(ListResponse.m_MessageBase.m_iStructSize) != sizeof(ListResponse)) { - if (iResponseLen < 0) - { - printf("No response received (timeout)\n"); - } - else - { - printf("Invalid response received: either not nzbget-server or wrong server version\n"); - } + printf("No response or invalid response (timeout, not nzbget-server or wrong nzbget-server version)\n"); return false; } @@ -362,7 +341,7 @@ bool RemoteClient::RequestServerList(bool bFiles, bool bGroups, const char* szPa if (ntohl(ListResponse.m_iTrailingDataLength) > 0) { pBuf = (char*)malloc(ntohl(ListResponse.m_iTrailingDataLength)); - if (!m_pConnection->RecvAll(pBuf, ntohl(ListResponse.m_iTrailingDataLength))) + if (!m_pConnection->Recv(pBuf, ntohl(ListResponse.m_iTrailingDataLength))) { free(pBuf); return false; @@ -700,7 +679,7 @@ bool RemoteClient::RequestServerLog(int iLines) LogRequest.m_iLines = htonl(iLines); LogRequest.m_iIDFrom = 0; - if (m_pConnection->Send((char*)(&LogRequest), sizeof(LogRequest)) < 0) + if (!m_pConnection->Send((char*)(&LogRequest), sizeof(LogRequest))) { perror("m_pConnection->Send"); return false; @@ -710,19 +689,12 @@ bool RemoteClient::RequestServerLog(int iLines) // Now listen for the returned log SNZBLogResponse LogResponse; - int iResponseLen = m_pConnection->Recv((char*) &LogResponse, sizeof(LogResponse)); - if (iResponseLen != sizeof(LogResponse) || + bool bRead = m_pConnection->Recv((char*) &LogResponse, sizeof(LogResponse)); + if (!bRead || (int)ntohl(LogResponse.m_MessageBase.m_iSignature) != (int)NZBMESSAGE_SIGNATURE || ntohl(LogResponse.m_MessageBase.m_iStructSize) != sizeof(LogResponse)) { - if (iResponseLen < 0) - { - printf("No response received (timeout)\n"); - } - else - { - printf("Invalid response received: either not nzbget-server or wrong server version\n"); - } + printf("No response or invalid response (timeout, not nzbget-server or wrong nzbget-server version)\n"); return false; } @@ -730,7 +702,7 @@ bool RemoteClient::RequestServerLog(int iLines) if (ntohl(LogResponse.m_iTrailingDataLength) > 0) { pBuf = (char*)malloc(ntohl(LogResponse.m_iTrailingDataLength)); - if (!m_pConnection->RecvAll(pBuf, ntohl(LogResponse.m_iTrailingDataLength))) + if (!m_pConnection->Recv(pBuf, ntohl(LogResponse.m_iTrailingDataLength))) { free(pBuf); return false; @@ -793,7 +765,7 @@ bool RemoteClient::RequestServerPauseUnpause(bool bPause, eRemotePauseUnpauseAct PauseUnpauseRequest.m_bPause = htonl(bPause); PauseUnpauseRequest.m_iAction = htonl(iAction); - if (m_pConnection->Send((char*)(&PauseUnpauseRequest), sizeof(PauseUnpauseRequest)) < 0) + if (!m_pConnection->Send((char*)(&PauseUnpauseRequest), sizeof(PauseUnpauseRequest))) { perror("m_pConnection->Send"); m_pConnection->Disconnect(); @@ -814,7 +786,7 @@ bool RemoteClient::RequestServerSetDownloadRate(float fRate) InitMessageBase(&SetDownloadRateRequest.m_MessageBase, eRemoteRequestSetDownloadRate, sizeof(SetDownloadRateRequest)); SetDownloadRateRequest.m_iDownloadRate = htonl((unsigned int)(fRate * 1024)); - if (m_pConnection->Send((char*)(&SetDownloadRateRequest), sizeof(SetDownloadRateRequest)) < 0) + if (!m_pConnection->Send((char*)(&SetDownloadRateRequest), sizeof(SetDownloadRateRequest))) { perror("m_pConnection->Send"); m_pConnection->Disconnect(); @@ -834,7 +806,7 @@ bool RemoteClient::RequestServerDumpDebug() SNZBDumpDebugRequest DumpDebugInfo; InitMessageBase(&DumpDebugInfo.m_MessageBase, eRemoteRequestDumpDebug, sizeof(DumpDebugInfo)); - if (m_pConnection->Send((char*)(&DumpDebugInfo), sizeof(DumpDebugInfo)) < 0) + if (!m_pConnection->Send((char*)(&DumpDebugInfo), sizeof(DumpDebugInfo))) { perror("m_pConnection->Send"); m_pConnection->Disconnect(); @@ -919,7 +891,7 @@ bool RemoteClient::RequestServerEditQueue(eRemoteEditAction iAction, int iOffset } bool OK = false; - if (m_pConnection->Send((char*)(&EditQueueRequest), sizeof(EditQueueRequest)) < 0) + if (!m_pConnection->Send((char*)(&EditQueueRequest), sizeof(EditQueueRequest))) { perror("m_pConnection->Send"); } @@ -943,7 +915,7 @@ bool RemoteClient::RequestServerShutdown() SNZBShutdownRequest ShutdownRequest; InitMessageBase(&ShutdownRequest.m_MessageBase, eRemoteRequestShutdown, sizeof(ShutdownRequest)); - bool OK = m_pConnection->Send((char*)(&ShutdownRequest), sizeof(ShutdownRequest)) >= 0; + bool OK = m_pConnection->Send((char*)(&ShutdownRequest), sizeof(ShutdownRequest)); if (OK) { OK = ReceiveBoolResponse(); @@ -964,7 +936,7 @@ bool RemoteClient::RequestServerReload() SNZBReloadRequest ReloadRequest; InitMessageBase(&ReloadRequest.m_MessageBase, eRemoteRequestReload, sizeof(ReloadRequest)); - bool OK = m_pConnection->Send((char*)(&ReloadRequest), sizeof(ReloadRequest)) >= 0; + bool OK = m_pConnection->Send((char*)(&ReloadRequest), sizeof(ReloadRequest)); if (OK) { OK = ReceiveBoolResponse(); @@ -985,7 +957,7 @@ bool RemoteClient::RequestServerVersion() SNZBVersionRequest VersionRequest; InitMessageBase(&VersionRequest.m_MessageBase, eRemoteRequestVersion, sizeof(VersionRequest)); - bool OK = m_pConnection->Send((char*)(&VersionRequest), sizeof(VersionRequest)) >= 0; + bool OK = m_pConnection->Send((char*)(&VersionRequest), sizeof(VersionRequest)); if (OK) { OK = ReceiveBoolResponse(); @@ -1006,7 +978,7 @@ bool RemoteClient::RequestPostQueue() SNZBPostQueueRequest PostQueueRequest; InitMessageBase(&PostQueueRequest.m_MessageBase, eRemoteRequestPostQueue, sizeof(PostQueueRequest)); - if (m_pConnection->Send((char*)(&PostQueueRequest), sizeof(PostQueueRequest)) < 0) + if (!m_pConnection->Send((char*)(&PostQueueRequest), sizeof(PostQueueRequest))) { perror("m_pConnection->Send"); return false; @@ -1016,19 +988,12 @@ bool RemoteClient::RequestPostQueue() // Now listen for the returned list SNZBPostQueueResponse PostQueueResponse; - int iResponseLen = m_pConnection->Recv((char*) &PostQueueResponse, sizeof(PostQueueResponse)); - if (iResponseLen != sizeof(PostQueueResponse) || + bool bRead = m_pConnection->Recv((char*) &PostQueueResponse, sizeof(PostQueueResponse)); + if (!bRead || (int)ntohl(PostQueueResponse.m_MessageBase.m_iSignature) != (int)NZBMESSAGE_SIGNATURE || ntohl(PostQueueResponse.m_MessageBase.m_iStructSize) != sizeof(PostQueueResponse)) { - if (iResponseLen < 0) - { - printf("No response received (timeout)\n"); - } - else - { - printf("Invalid response received: either not nzbget-server or wrong server version\n"); - } + printf("No response or invalid response (timeout, not nzbget-server or wrong nzbget-server version)\n"); return false; } @@ -1036,7 +1001,7 @@ bool RemoteClient::RequestPostQueue() if (ntohl(PostQueueResponse.m_iTrailingDataLength) > 0) { pBuf = (char*)malloc(ntohl(PostQueueResponse.m_iTrailingDataLength)); - if (!m_pConnection->RecvAll(pBuf, ntohl(PostQueueResponse.m_iTrailingDataLength))) + if (!m_pConnection->Recv(pBuf, ntohl(PostQueueResponse.m_iTrailingDataLength))) { free(pBuf); return false; @@ -1097,7 +1062,7 @@ bool RemoteClient::RequestWriteLog(int iKind, const char* szText) int iLength = strlen(szText) + 1; WriteLogRequest.m_iTrailingDataLength = htonl(iLength); - if (m_pConnection->Send((char*)(&WriteLogRequest), sizeof(WriteLogRequest)) < 0) + if (!m_pConnection->Send((char*)(&WriteLogRequest), sizeof(WriteLogRequest))) { perror("m_pConnection->Send"); return false; @@ -1118,7 +1083,7 @@ bool RemoteClient::RequestScan(bool bSyncMode) ScanRequest.m_bSyncMode = htonl(bSyncMode); - bool OK = m_pConnection->Send((char*)(&ScanRequest), sizeof(ScanRequest)) >= 0; + bool OK = m_pConnection->Send((char*)(&ScanRequest), sizeof(ScanRequest)); if (OK) { OK = ReceiveBoolResponse(); @@ -1139,7 +1104,7 @@ bool RemoteClient::RequestHistory() SNZBHistoryRequest HistoryRequest; InitMessageBase(&HistoryRequest.m_MessageBase, eRemoteRequestHistory, sizeof(HistoryRequest)); - if (m_pConnection->Send((char*)(&HistoryRequest), sizeof(HistoryRequest)) < 0) + if (!m_pConnection->Send((char*)(&HistoryRequest), sizeof(HistoryRequest))) { perror("m_pConnection->Send"); return false; @@ -1149,19 +1114,12 @@ bool RemoteClient::RequestHistory() // Now listen for the returned list SNZBHistoryResponse HistoryResponse; - int iResponseLen = m_pConnection->Recv((char*) &HistoryResponse, sizeof(HistoryResponse)); - if (iResponseLen != sizeof(HistoryResponse) || + bool bRead = m_pConnection->Recv((char*) &HistoryResponse, sizeof(HistoryResponse)); + if (!bRead || (int)ntohl(HistoryResponse.m_MessageBase.m_iSignature) != (int)NZBMESSAGE_SIGNATURE || ntohl(HistoryResponse.m_MessageBase.m_iStructSize) != sizeof(HistoryResponse)) { - if (iResponseLen < 0) - { - printf("No response received (timeout)\n"); - } - else - { - printf("Invalid response received: either not nzbget-server or wrong server version\n"); - } + printf("No response or invalid response (timeout, not nzbget-server or wrong nzbget-server version)\n"); return false; } @@ -1169,7 +1127,7 @@ bool RemoteClient::RequestHistory() if (ntohl(HistoryResponse.m_iTrailingDataLength) > 0) { pBuf = (char*)malloc(ntohl(HistoryResponse.m_iTrailingDataLength)); - if (!m_pConnection->RecvAll(pBuf, ntohl(HistoryResponse.m_iTrailingDataLength))) + if (!m_pConnection->Recv(pBuf, ntohl(HistoryResponse.m_iTrailingDataLength))) { free(pBuf); return false; @@ -1257,7 +1215,7 @@ bool RemoteClient::RequestServerDownloadUrl(const char* szURL, const char* szNZB } DownloadUrlRequest.m_szNZBFilename[NZBREQUESTFILENAMESIZE-1] = '\0'; - bool OK = m_pConnection->Send((char*)(&DownloadUrlRequest), sizeof(DownloadUrlRequest)) >= 0; + bool OK = m_pConnection->Send((char*)(&DownloadUrlRequest), sizeof(DownloadUrlRequest)); if (OK) { OK = ReceiveBoolResponse(); @@ -1278,7 +1236,7 @@ bool RemoteClient::RequestUrlQueue() SNZBUrlQueueRequest UrlQueueRequest; InitMessageBase(&UrlQueueRequest.m_MessageBase, eRemoteRequestUrlQueue, sizeof(UrlQueueRequest)); - if (m_pConnection->Send((char*)(&UrlQueueRequest), sizeof(UrlQueueRequest)) < 0) + if (!m_pConnection->Send((char*)(&UrlQueueRequest), sizeof(UrlQueueRequest))) { perror("m_pConnection->Send"); return false; @@ -1288,19 +1246,12 @@ bool RemoteClient::RequestUrlQueue() // Now listen for the returned list SNZBUrlQueueResponse UrlQueueResponse; - int iResponseLen = m_pConnection->Recv((char*) &UrlQueueResponse, sizeof(UrlQueueResponse)); - if (iResponseLen != sizeof(UrlQueueResponse) || + bool bRead = m_pConnection->Recv((char*) &UrlQueueResponse, sizeof(UrlQueueResponse)); + if (!bRead || (int)ntohl(UrlQueueResponse.m_MessageBase.m_iSignature) != (int)NZBMESSAGE_SIGNATURE || ntohl(UrlQueueResponse.m_MessageBase.m_iStructSize) != sizeof(UrlQueueResponse)) { - if (iResponseLen < 0) - { - printf("No response received (timeout)\n"); - } - else - { - printf("Invalid response received: either not nzbget-server or wrong server version\n"); - } + printf("No response or invalid response (timeout, not nzbget-server or wrong nzbget-server version)\n"); return false; } @@ -1308,7 +1259,7 @@ bool RemoteClient::RequestUrlQueue() if (ntohl(UrlQueueResponse.m_iTrailingDataLength) > 0) { pBuf = (char*)malloc(ntohl(UrlQueueResponse.m_iTrailingDataLength)); - if (!m_pConnection->RecvAll(pBuf, ntohl(UrlQueueResponse.m_iTrailingDataLength))) + if (!m_pConnection->Recv(pBuf, ntohl(UrlQueueResponse.m_iTrailingDataLength))) { free(pBuf); return false; diff --git a/RemoteServer.cpp b/RemoteServer.cpp index 2692fe5b..66bf999d 100644 --- a/RemoteServer.cpp +++ b/RemoteServer.cpp @@ -47,6 +47,7 @@ #include "WebServer.h" #include "Log.h" #include "Options.h" +#include "Util.h" extern Options* g_pOptions; @@ -86,13 +87,13 @@ void RemoteServer::Run() bBind = m_pConnection->Bind() == 0; } - // Accept connections and store the "new" socket value - SOCKET iSocket = INVALID_SOCKET; + // Accept connections and store the new Connection + Connection* pAcceptedConnection = NULL; if (bBind) { - iSocket = m_pConnection->Accept(); + pAcceptedConnection = m_pConnection->Accept(); } - if (!bBind || iSocket == INVALID_SOCKET) + if (!bBind || pAcceptedConnection == NULL) { // Remote server could not bind or accept connection, waiting 1/2 sec and try again if (IsStopped()) @@ -107,7 +108,7 @@ void RemoteServer::Run() RequestProcessor* commandThread = new RequestProcessor(); commandThread->SetAutoDestroy(true); - commandThread->SetSocket(iSocket); + commandThread->SetConnection(pAcceptedConnection); commandThread->Start(); } if (m_pConnection) @@ -134,32 +135,22 @@ void RemoteServer::Stop() //***************************************************************** // RequestProcessor +RequestProcessor::~RequestProcessor() +{ + m_pConnection->Disconnect(); + delete m_pConnection; +} + void RequestProcessor::Run() { - // Read the first 4 bytes to determine request type bool bOK = false; - int iSignature = 0; - int iBytesReceived = recv(m_iSocket, (char*)&iSignature, sizeof(iSignature), 0); - if (iBytesReceived < 0) - { - return; - } - // Info - connection received -#ifdef WIN32 - char* ip = NULL; -#else - char ip[20]; -#endif - struct sockaddr_in PeerName; - int iPeerNameLength = sizeof(PeerName); - if (getpeername(m_iSocket, (struct sockaddr*)&PeerName, (SOCKLEN_T*) &iPeerNameLength) >= 0) + // Read the first 4 bytes to determine request type + int iSignature = 0; + if (!m_pConnection->Recv((char*)&iSignature, 4)) { -#ifdef WIN32 - ip = inet_ntoa(PeerName.sin_addr); -#else - inet_ntop(AF_INET, &PeerName.sin_addr, ip, sizeof(ip)); -#endif + warn("Non-nzbget request received on port %i from %s", g_pOptions->GetControlPort(), m_pConnection->GetRemoteAddr()); + return; } if ((int)ntohl(iSignature) == (int)NZBMESSAGE_SIGNATURE) @@ -167,9 +158,7 @@ void RequestProcessor::Run() // binary request received bOK = true; BinRpcProcessor processor; - processor.SetSocket(m_iSocket); - processor.SetSignature(iSignature); - processor.SetClientIP(ip); + processor.SetConnection(m_pConnection); processor.Execute(); } else if (!strncmp((char*)&iSignature, "POST", 4) || @@ -177,9 +166,8 @@ void RequestProcessor::Run() !strncmp((char*)&iSignature, "OPTI", 4)) { // HTTP request received - Connection con(m_iSocket, false); char szBuffer[1024]; - if (con.ReadLine(szBuffer, sizeof(szBuffer), NULL)) + if (m_pConnection->ReadLine(szBuffer, sizeof(szBuffer), NULL)) { WebProcessor::EHttpMethod eHttpMethod = WebProcessor::hmGet; char* szUrl = szBuffer; @@ -201,8 +189,7 @@ void RequestProcessor::Run() debug("url: %s", szUrl); WebProcessor processor; - processor.SetConnection(&con); - processor.SetClientIP(ip); + processor.SetConnection(m_pConnection); processor.SetUrl(szUrl); processor.SetHttpMethod(eHttpMethod); processor.Execute(); @@ -210,15 +197,8 @@ void RequestProcessor::Run() } } - if (!bOK && iBytesReceived > 0) + if (!bOK) { - warn("Non-nzbget request received on port %i from %s", g_pOptions->GetControlPort(), ip); + warn("Non-nzbget request received on port %i from %s", g_pOptions->GetControlPort(), m_pConnection->GetRemoteAddr()); } - - if (!bOK && iBytesReceived == 0) - { - debug("empty request received on port %i from %s", g_pOptions->GetControlPort(), ip); - } - - closesocket(m_iSocket); } diff --git a/RemoteServer.h b/RemoteServer.h index 7cb531df..3f34d9bf 100644 --- a/RemoteServer.h +++ b/RemoteServer.h @@ -45,11 +45,12 @@ public: class RequestProcessor : public Thread { private: - SOCKET m_iSocket; + Connection* m_pConnection; public: + ~RequestProcessor(); virtual void Run(); - void SetSocket(SOCKET iSocket) { m_iSocket = iSocket; }; + void SetConnection(Connection* pConnection) { m_pConnection = pConnection; } }; #endif diff --git a/WebDownloader.cpp b/WebDownloader.cpp index bd6963bc..1e7c750f 100644 --- a/WebDownloader.cpp +++ b/WebDownloader.cpp @@ -390,7 +390,7 @@ WebDownloader::EStatus WebDownloader::DownloadBody() m_pConnection->ReadBuffer(&szBuffer, &iLen); if (iLen == 0) { - iLen = m_pConnection->Recv(szLineBuf, LineBufSize); + iLen = m_pConnection->TryRecv(szLineBuf, LineBufSize); szBuffer = szLineBuf; } diff --git a/WebServer.cpp b/WebServer.cpp index ea13fd05..2567b669 100644 --- a/WebServer.cpp +++ b/WebServer.cpp @@ -59,7 +59,6 @@ static const int MAX_UNCOMPRESSED_SIZE = 500; WebProcessor::WebProcessor() { m_pConnection = NULL; - m_szClientIP = NULL; m_szRequest = NULL; m_szUrl = NULL; m_szOrigin = NULL; @@ -199,7 +198,7 @@ void WebProcessor::Execute() if (pw) *pw++ = '\0'; if (strcmp(szAuthInfo, "nzbget") || strcmp(pw, g_pOptions->GetControlPassword())) { - warn("request received on port %i from %s, but password invalid", g_pOptions->GetControlPort(), m_szClientIP); + warn("request received on port %i from %s, but password invalid", g_pOptions->GetControlPort(), m_pConnection->GetRemoteAddr()); SendAuthResponse(); return; } @@ -210,7 +209,7 @@ void WebProcessor::Execute() m_szRequest = (char*)malloc(iContentLen + 1); m_szRequest[iContentLen] = '\0'; - if (!m_pConnection->RecvAll(m_szRequest, iContentLen)) + if (!m_pConnection->Recv(m_szRequest, iContentLen)) { free(m_szRequest); error("invalid-request: could not read data"); @@ -219,7 +218,7 @@ void WebProcessor::Execute() debug("Request=%s", m_szRequest); } - debug("request received from %s", m_szClientIP); + debug("request received from %s", m_pConnection->GetRemoteAddr()); Dispatch(); } @@ -236,7 +235,6 @@ void WebProcessor::Dispatch() { XmlRpcProcessor processor; processor.SetRequest(m_szRequest); - processor.SetClientIP(m_szClientIP); processor.SetHttpMethod(m_eHttpMethod == hmGet ? XmlRpcProcessor::hmGet : XmlRpcProcessor::hmPost); processor.SetUrl(m_szUrl); processor.Execute(); diff --git a/WebServer.h b/WebServer.h index a2570d24..019b7bbd 100644 --- a/WebServer.h +++ b/WebServer.h @@ -40,7 +40,6 @@ public: private: Connection* m_pConnection; - const char* m_szClientIP; char* m_szRequest; char* m_szUrl; EHttpMethod m_eHttpMethod; @@ -63,7 +62,6 @@ public: void SetConnection(Connection* pConnection) { m_pConnection = pConnection; } void SetUrl(const char* szUrl); void SetHttpMethod(EHttpMethod eHttpMethod) { m_eHttpMethod = eHttpMethod; } - void SetClientIP(const char* szClientIP) { m_szClientIP = szClientIP; } }; #endif diff --git a/XmlRpc.cpp b/XmlRpc.cpp index 2906ebdc..4b9e42d4 100644 --- a/XmlRpc.cpp +++ b/XmlRpc.cpp @@ -93,7 +93,6 @@ void StringBuilder::Append(const char* szStr) XmlRpcProcessor::XmlRpcProcessor() { - m_szClientIP = NULL; m_szRequest = NULL; m_eProtocol = rpUndefined; m_eHttpMethod = hmPost; diff --git a/XmlRpc.h b/XmlRpc.h index 96b59d93..9ac1949f 100644 --- a/XmlRpc.h +++ b/XmlRpc.h @@ -61,7 +61,6 @@ public: }; private: - const char* m_szClientIP; char* m_szRequest; const char* m_szContentType; ERpcProtocol m_eProtocol; @@ -80,7 +79,6 @@ public: void Execute(); void SetHttpMethod(EHttpMethod eHttpMethod) { m_eHttpMethod = eHttpMethod; } void SetUrl(const char* szUrl); - void SetClientIP(const char* szClientIP) { m_szClientIP = szClientIP; } void SetRequest(char* szRequest) { m_szRequest = szRequest; } const char* GetResponse() { return m_cResponse.GetBuffer(); } const char* GetContentType() { return m_szContentType; }