/* * This file if part of nzbget * * Copyright (C) 2004 Sven Henkel * Copyright (C) 2007-2008 Andrei Prygounkov * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. * * $Revision$ * $Date$ * */ #ifdef HAVE_CONFIG_H #include #endif #ifdef WIN32 // SKIP_DEFAULT_WINDOWS_HEADERS prevents the including of , which includes "winsock.h", // but we need "winsock2.h" here (they conflicts with each other) #define SKIP_DEFAULT_WINDOWS_HEADERS #include "win32.h" #endif #include #include #include #ifdef WIN32 #include #include #else #include #include #include #include #include #include #endif #include "nzbget.h" #include "Connection.h" #include "Log.h" #include "TLS.h" static const int CONNECTION_READBUFFER_SIZE = 1024; #ifndef DISABLE_TLS bool Connection::bTLSLibInitialized = false; #endif void Connection::Init(bool bTLS) { debug("Initializing global connection data"); #ifdef WIN32 WSADATA wsaData; int err = WSAStartup(MAKEWORD(2, 0), &wsaData); if (err != 0) { error("Could not initialize socket library"); return; } if (LOBYTE(wsaData.wVersion) != 2 || HIBYTE( wsaData.wVersion ) != 0) { error("Could not initialize socket library"); WSACleanup(); return; } #endif #ifndef DISABLE_TLS if (bTLS) { debug("Initializing TLS library"); char* szErrStr; int iRes = tls_lib_init(&szErrStr); bTLSLibInitialized = iRes == TLS_EOK; if (!bTLSLibInitialized) { error("Could not initialize TLS library: %s", szErrStr ? szErrStr : "unknown error"); if (szErrStr) { free(szErrStr); } } } #endif } void Connection::Final() { debug("Finalizing global connection data"); #ifdef WIN32 WSACleanup(); #endif #ifndef DISABLE_TLS if (bTLSLibInitialized) { debug("Finalizing TLS library"); tls_lib_deinit(); } #endif } Connection::Connection(NetAddress* pNetAddress) { debug("Creating Connection"); m_pNetAddress = pNetAddress; m_eStatus = csDisconnected; m_iSocket = INVALID_SOCKET; m_iBufAvail = 0; m_iTimeout = 60; m_bSuppressErrors = true; m_szReadBuf = (char*)malloc(CONNECTION_READBUFFER_SIZE + 1); m_bAutoClose = true; #ifndef DISABLE_TLS m_pTLS = NULL; m_bTLSError = false; #endif } Connection::Connection(SOCKET iSocket, bool bAutoClose) { debug("Creating Connection"); m_pNetAddress = NULL; m_eStatus = csConnected; m_iSocket = iSocket; m_iBufAvail = 0; m_iTimeout = 60; m_bSuppressErrors = true; m_szReadBuf = (char*)malloc(CONNECTION_READBUFFER_SIZE + 1); m_bAutoClose = bAutoClose; #ifndef DISABLE_TLS m_pTLS = NULL; #endif } Connection::~Connection() { debug("Destroying Connection"); if (m_eStatus == csConnected && m_bAutoClose) { Disconnect(); } free(m_szReadBuf); #ifndef DISABLE_TLS if (m_pTLS) { free(m_pTLS); } #endif } bool Connection::Connect() { debug("Connecting"); if (m_eStatus == csConnected) return true; bool bRes = DoConnect(); if (bRes) m_eStatus = csConnected; else Connection::DoDisconnect(); return bRes; } bool Connection::Disconnect() { debug("Disconnecting"); if (m_eStatus == csDisconnected) return true; bool bRes = DoDisconnect(); m_eStatus = csDisconnected; m_iSocket = INVALID_SOCKET; m_iBufAvail = 0; return bRes; } int Connection::Bind() { debug("Binding"); if (m_eStatus == csListening) { return 0; } int iRes = DoBind(); if (iRes == 0) { m_eStatus = csListening; } return iRes; } int Connection::WriteLine(const char* pBuffer) { //debug("Connection::write(char* line)"); if (m_eStatus != csConnected) { return -1; } int iRes = DoWriteLine(pBuffer); return iRes; } int Connection::Send(const char* pBuffer, int iSize) { debug("Sending data"); if (m_eStatus != csConnected) { return -1; } int iRes = send(m_iSocket, pBuffer, iSize, 0); return iRes; } char* Connection::ReadLine(char* pBuffer, int iSize, int* pBytesRead) { if (m_eStatus != csConnected) { return NULL; } char* res = DoReadLine(pBuffer, iSize, pBytesRead); return res; } SOCKET Connection::Accept() { debug("Accepting connection"); if (m_eStatus != csListening) { return INVALID_SOCKET; } SOCKET iRes = DoAccept(); return iRes; } int Connection::Recv(char* pBuffer, int iSize) { debug("Receiving data"); memset(pBuffer, 0, iSize); int iReceived = recv(m_iSocket, pBuffer, iSize, 0); if (iReceived < 0) { ReportError("Could not receive data on socket", NULL, true, 0); } return iReceived; } bool Connection::RecvAll(char * pBuffer, int iSize) { debug("Receiving data (full buffer)"); memset(pBuffer, 0, iSize); char* pBufPtr = (char*)pBuffer; int NeedBytes = iSize; if (m_iBufAvail > 0) { int len = iSize > m_iBufAvail ? m_iBufAvail : iSize; memcpy(pBufPtr, m_szBufPtr, len); pBufPtr += len; m_szBufPtr += len; m_iBufAvail -= len; NeedBytes -= len; } // Read from the socket until nothing remains while (NeedBytes > 0) { int iReceived = recv(m_iSocket, pBufPtr, NeedBytes, 0); // Did the recv succeed? if (iReceived <= 0) { ReportError("Could not receive data on socket", NULL, true, 0); return false; } pBufPtr += iReceived; NeedBytes -= iReceived; } return true; } bool Connection::DoConnect() { debug("Do connecting"); struct addrinfo addr_hints, *addr_list, *addr; char iPortStr[sizeof(int) * 4 + 1]; //is enough to hold any converted int memset(&addr_hints, 0, sizeof(addr_hints)); addr_hints.ai_family = AF_UNSPEC; /* Allow IPv4 or IPv6 */ addr_hints.ai_socktype = SOCK_STREAM, sprintf(iPortStr, "%d", m_pNetAddress->GetPort()); int res = getaddrinfo(m_pNetAddress->GetHost(), iPortStr, &addr_hints, &addr_list); if (res != 0) { ReportError("Could not resolve hostname %s", m_pNetAddress->GetHost(), true, 0); return false; } m_iSocket = INVALID_SOCKET; for (addr = addr_list; addr != NULL; addr = addr->ai_next) { m_iSocket = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); if (m_iSocket != INVALID_SOCKET) { res = connect(m_iSocket , addr->ai_addr, addr->ai_addrlen); if (res != -1) { // Connection established break; } // Connection failed closesocket(m_iSocket); m_iSocket = INVALID_SOCKET; } } freeaddrinfo(addr_list); if (m_iSocket == INVALID_SOCKET) { ReportError("Connection to %s failed!", m_pNetAddress->GetHost(), true, 0); return false; } #ifdef WIN32 int MSecVal = m_iTimeout * 1000; int err = setsockopt(m_iSocket, SOL_SOCKET, SO_RCVTIMEO, (char*)&MSecVal, sizeof(MSecVal)); #else struct timeval TimeVal; TimeVal.tv_sec = m_iTimeout; TimeVal.tv_usec = 0; int err = setsockopt(m_iSocket, SOL_SOCKET, SO_RCVTIMEO, (char*)&TimeVal, sizeof(TimeVal)); #endif if (err != 0) { ReportError("setsockopt failed", NULL, true, 0); } return true; } bool Connection::DoDisconnect() { debug("Do disconnecting"); if (m_iSocket > 0) { closesocket(m_iSocket); m_iSocket = INVALID_SOCKET; #ifndef DISABLE_TLS if (m_pTLS) { CloseTLS(); } #endif } m_eStatus = csDisconnected; return true; } int Connection::DoWriteLine(const char* pBuffer) { //debug("Connection::doWrite()"); return send(m_iSocket, pBuffer, strlen(pBuffer), 0); } char* Connection::DoReadLine(char* pBuffer, int iSize, int* pBytesRead) { //debug( "Connection::DoReadLine()" ); char* pBufPtr = pBuffer; iSize--; // for trailing '0' int iBytesRead = 0; int iBufAvail = m_iBufAvail; // local variable is faster char* szBufPtr = m_szBufPtr; // local variable is faster while (iSize) { if (!iBufAvail) { iBufAvail = recv(m_iSocket, m_szReadBuf, CONNECTION_READBUFFER_SIZE, 0); if (iBufAvail < 0) { ReportError("Could not receive data on socket", NULL, true, 0); break; } else if (iBufAvail == 0) { break; } szBufPtr = m_szReadBuf; m_szReadBuf[iBufAvail] = '\0'; } int len = 0; char* p = (char*)memchr(szBufPtr, '\n', iBufAvail); if (p) { len = (int)(p - szBufPtr + 1); } else { len = iBufAvail; } if (len > iSize) { len = iSize; } memcpy(pBufPtr, szBufPtr, len); pBufPtr += len; szBufPtr += len; iBufAvail -= len; iBytesRead += len; iSize -= len; if (p) { break; } } *pBufPtr = '\0'; m_iBufAvail = iBufAvail > 0 ? iBufAvail : 0; // copy back to member m_szBufPtr = szBufPtr; // copy back to member if (pBytesRead) { *pBytesRead = iBytesRead; } if (pBufPtr == pBuffer) { return NULL; } return pBuffer; } int Connection::DoBind() { debug("Do binding"); struct addrinfo addr_hints, *addr_list, *addr; char iPortStr[sizeof(int) * 4 + 1]; // is enough to hold any converted int memset(&addr_hints, 0, sizeof(addr_hints)); addr_hints.ai_family = AF_UNSPEC; // Allow IPv4 or IPv6 addr_hints.ai_socktype = SOCK_STREAM, addr_hints.ai_flags = AI_PASSIVE; // For wildcard IP address sprintf(iPortStr, "%d", m_pNetAddress->GetPort()); int res = getaddrinfo(m_pNetAddress->GetHost(), iPortStr, &addr_hints, &addr_list); if (res != 0) { error( "Could not resolve hostname %s", m_pNetAddress->GetHost() ); return -1; } m_iSocket = INVALID_SOCKET; for (addr = addr_list; addr != NULL; addr = addr->ai_next) { m_iSocket = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); if (m_iSocket != INVALID_SOCKET) { int opt = 1; setsockopt(m_iSocket, SOL_SOCKET, SO_REUSEADDR, (char*)&opt, sizeof(opt)); res = bind(m_iSocket, addr->ai_addr, addr->ai_addrlen); if (res != -1) { // Connection established break; } // Connection failed closesocket(m_iSocket); m_iSocket = INVALID_SOCKET; } } freeaddrinfo(addr_list); if (m_iSocket == INVALID_SOCKET) { ReportError("Binding socket failed for %s", m_pNetAddress->GetHost(), true, 0); return -1; } if (listen(m_iSocket, 10) < 0) { ReportError("Listen on socket failed for %s", m_pNetAddress->GetHost(), true, 0); return -1; } return 0; } SOCKET Connection::DoAccept() { SOCKET iSocket = accept(GetSocket(), NULL, NULL); if (iSocket == INVALID_SOCKET && m_eStatus != csCancelled) { ReportError("Could not accept connection", NULL, true, 0); } return iSocket; } void Connection::Cancel() { debug("Cancelling connection"); if (m_iSocket != INVALID_SOCKET) { m_eStatus = csCancelled; int r = shutdown(m_iSocket, SHUT_RDWR); if (r == -1) { ReportError("Could not shutdown connection", NULL, true, 0); } } } void Connection::ReportError(const char* szMsgPrefix, const char* szMsgArg, bool PrintErrCode, int ErrCode) { #ifndef DISABLE_TLS if (m_bTLSError) { // TLS-Error was already reported m_bTLSError = false; return; } #endif char szErrPrefix[1024]; snprintf(szErrPrefix, 1024, szMsgPrefix, szMsgArg); szErrPrefix[1024-1] = '\0'; if (PrintErrCode) { if (ErrCode == 0) { #ifdef WIN32 ErrCode = WSAGetLastError(); #else ErrCode = errno; #endif } #ifdef WIN32 if (m_bSuppressErrors) { debug("%s: ErrNo %i", szErrPrefix, ErrCode); } else { error("%s: ErrNo %i", szErrPrefix, ErrCode); } #else const char* szErrMsg = hstrerror(ErrCode); if (m_bSuppressErrors) { debug("%s: ErrNo %i, %s", szErrPrefix, ErrCode, szErrMsg); } else { error("%s: ErrNo %i, %s", szErrPrefix, ErrCode, szErrMsg); } #endif } else { if (m_bSuppressErrors) { debug(szErrPrefix); } else { error(szErrPrefix); } } } #ifndef DISABLE_TLS bool Connection::CheckTLSResult(int iResultCode, char* szErrStr, const char* szErrMsgPrefix) { bool bOK = iResultCode == TLS_EOK; if (!bOK) { ReportError(szErrMsgPrefix, szErrStr ? szErrStr : "unknown error", false, 0); if (szErrStr) { free(szErrStr); } } return bOK; } bool Connection::StartTLS() { debug("Starting TLS"); if (m_pTLS) { free(m_pTLS); } m_pTLS = malloc(sizeof(tls_t)); tls_t* pTLS = (tls_t*)m_pTLS; memset(pTLS, 0, sizeof(tls_t)); tls_clear(pTLS); char* szErrStr; int iRes; iRes = tls_init(pTLS, NULL, NULL, NULL, 0, &szErrStr); if (!CheckTLSResult(iRes, szErrStr, "Could not initialize TLS-object: %s")) { return false; } debug("tls_start..."); iRes = tls_start(pTLS, (int)m_iSocket, NULL, 1, NULL, &szErrStr); debug("tls_start...%i", iRes); if (!CheckTLSResult(iRes, szErrStr, "Could not establish secure connection: %s")) { return false; } return true; } void Connection::CloseTLS() { tls_close((tls_t*)m_pTLS); free(m_pTLS); m_pTLS = NULL; } int Connection::recv(SOCKET s, char* buf, int len, int flags) { size_t iReceived = 0; if (m_pTLS) { m_bTLSError = false; char* szErrStr; int iRes = tls_getbuf((tls_t*)m_pTLS, buf, len, &iReceived, &szErrStr); if (!CheckTLSResult(iRes, szErrStr, "Could not read from TLS-socket: %s")) { m_bTLSError = true; return -1; } } else { iReceived = ::recv(s, buf, len, flags); } return iReceived; } int Connection::send(SOCKET s, const char* buf, int len, int flags) { if (m_pTLS) { m_bTLSError = false; char* szErrStr; int iRes = tls_putbuf((tls_t*)m_pTLS, buf, len, &szErrStr); if (!CheckTLSResult(iRes, szErrStr, "Could not send to TLS-socket: %s")) { m_bTLSError = true; return -1; } return 0; } else { int iRet = ::send(s, buf, len, flags); return iRet; } } #endif