# BSD 2-Clause License # # Apprise - Push Notification Library. # Copyright (c) 2026, Chris Caron # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, # this list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. from __future__ import annotations import contextlib import ipaddress import select import socket import ssl import time from typing import Optional, Union from ..exception import AppriseException, AppriseInvalidData from ..logger import logger TimeoutType = Optional[ Union[float, tuple[Optional[float], Optional[float]]] ] class AppriseSocketError(AppriseException): """Raised for socket or TLS related failures.""" class SocketTransport: """ TCP client transport with optional TLS upgrade. Behaviour: - secure=False (default): plain TCP - secure=True: upgrade to TLS (immediately in connect(), or manually via start_tls()) - verify=True (default): validate certificate chain and hostname using a certifi CA bundle - verify=False: accept invalid or self-signed certs Timeout behaviour (requests-compatible): - timeout=float => (connect, read) both set to float - timeout=(connect, read) => tuple form - None => no defaults (connect/read can block indefinitely) """ def __init__( self, host: str, port: int, bind_addr: Optional[str] = None, bind_port: Optional[int] = None, secure: bool = False, verify: bool = True, timeout: TimeoutType = 10.0, retries: int = 0, ) -> None: self.host = host self.port = int(port) self.bind_addr = bind_addr self.bind_port = bind_port self.secure = bool(secure) self.verify = bool(verify) self.retries = retries self._connect_timeout, self._read_timeout = \ self._coerce_timeout(timeout) self._sock: Optional[socket.socket] = None self._rfile = None self._wfile = None self._is_tls: bool = False # True once we have successfully read or written data since the last # connect(). Used to decide whether reconnect attempts are allowed. self._had_io: bool = False self.local_addr: Optional[tuple[str, int]] = None self.remote_addr: Optional[tuple[str, int]] = None @staticmethod def _coerce_timeout( timeout: TimeoutType) -> tuple[Optional[float], Optional[float]]: """ Coerce requests-style timeout into (connect_timeout, read_timeout). """ if timeout is None: return None, None if isinstance(timeout, (int, float)): t = float(timeout) if t < 0: raise AppriseInvalidData("timeout must be >= 0") return t, t if isinstance(timeout, tuple) and len(timeout) == 2: connect_t, read_t = timeout if connect_t is not None: connect_t = float(connect_t) if connect_t < 0: raise AppriseInvalidData("connect timeout must be >= 0") if read_t is not None: read_t = float(read_t) if read_t < 0: raise AppriseInvalidData("read timeout must be >= 0") return connect_t, read_t raise AppriseInvalidData( "timeout must be None, a float, or a (connect, read) tuple" ) @property def connected(self) -> bool: return self._sock is not None @property def is_tls(self) -> bool: return self._is_tls def close(self) -> None: """Close the socket and associated file wrappers.""" try: if self._wfile is not None: with contextlib.suppress(Exception): self._wfile.flush() with contextlib.suppress(Exception): self._wfile.close() finally: self._wfile = None try: if self._rfile is not None: with contextlib.suppress(Exception): self._rfile.close() finally: self._rfile = None if self._sock is not None: try: with contextlib.suppress(Exception): self._sock.shutdown(socket.SHUT_RDWR) self._sock.close() finally: self._sock = None self._is_tls = False self._had_io = False self.local_addr = None self.remote_addr = None def _refresh_wrappers(self) -> None: """Rebuild file wrappers, required after TLS upgrade.""" if self._sock is None: self._rfile = None self._wfile = None return self._rfile = self._sock.makefile("rb", buffering=0) self._wfile = self._sock.makefile("wb", buffering=0) def can_read(self, timeout: float = 0.0) -> Optional[bool]: """Return True if readable, False if not, None if closed or error.""" if self._sock is None: return None try: r, _, x = select.select( [self._sock], [], [self._sock], float(timeout)) except OSError: self.close() return None if x: self.close() return None return bool(r) def can_write(self, timeout: float = 0.0) -> Optional[bool]: """Return True if writable, False if not, None if closed or error.""" if self._sock is None: return None try: _, w, x = \ select.select([], [self._sock], [self._sock], float(timeout)) except OSError: self.close() return None if x: self.close() return None return bool(w) def connect(self) -> None: """ Establish TCP connection, optionally upgrade to TLS immediately if secure=True. """ logger.trace( "Socket connect IN: host=%s port=%d secure=%s verify=%s", self.host, self.port, self.secure, self.verify, ) self.close() sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) if self.bind_addr is not None or self.bind_port is not None: sock.bind( (self.bind_addr or "127.0.0.1", int(self.bind_port or 0))) if self._connect_timeout is not None: sock.settimeout(self._connect_timeout) # Establish our connection sock.connect((self.host, self.port)) # We control I/O blocking explicitly with select() sock.settimeout(None) self._sock = sock self._is_tls = False self._had_io = False if self.secure: self.start_tls() self.local_addr = self._sock.getsockname() self.remote_addr = self._sock.getpeername() self._refresh_wrappers() logger.debug( "Socket connected: local=%s remote=%s tls=%s", self.local_addr, self.remote_addr, self._is_tls, ) except Exception as e: with contextlib.suppress(Exception): sock.close() self._sock = None self._had_io = False logger.debug("Socket connect exception: %s", e) raise AppriseSocketError(str(e)) from e def _server_hostname_for_tls(self) -> str: """ Determine hostname used for SNI and hostname verification. If verify=True and host is an IP address, attempt reverse DNS lookup. """ host = self.host if not self.verify: return host try: ipaddress.ip_address(host) except ValueError: return host try: name, _, _ = socket.gethostbyaddr(host) return name.rstrip(".") if name else host except Exception: return host def _build_ssl_context(self) -> ssl.SSLContext: """Build SSL context using certifi bundle when verify=True.""" # Enforce TLS 1.2+ to avoid TLSv1/TLSv1.1 negotiation. # We explicitly enforce TLS >= 1.2 for Python 3.9+ compatibility. if self.verify: import certifi ctx = ssl.create_default_context(cafile=certifi.where()) ctx.check_hostname = True ctx.verify_mode = ssl.CERT_REQUIRED else: # Still enforce modern TLS even when certificate verification is # disabled, since protocol downgrade is independent of trust. ctx = ssl.create_default_context() ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE # Enforce TLS 1.2+ to avoid TLSv1/TLSv1.1 negotiation. try: ctx.minimum_version = ssl.TLSVersion.TLSv1_2 except Exception: # Fallback for very old Python/OpenSSL combinations with contextlib.suppress(Exception): ctx.options |= ssl.OP_NO_TLSv1 with contextlib.suppress(Exception): ctx.options |= ssl.OP_NO_TLSv1_1 # Disable TLS-level compression (mitigates CRIME-style attacks). with contextlib.suppress(Exception): ctx.options |= ssl.OP_NO_COMPRESSION return ctx def start_tls(self) -> None: """Upgrade an existing TCP connection to TLS.""" if self._sock is None: raise AppriseSocketError("No active connection to upgrade") if self._is_tls: return server_hostname = self._server_hostname_for_tls() logger.trace("Starting TLS upgrade: sni=%s", server_hostname) try: ctx = self._build_ssl_context() tls_sock = ctx.wrap_socket( self._sock, server_hostname=server_hostname, ) tls_sock.setblocking(False) self._sock = tls_sock self._is_tls = True self.local_addr = self._sock.getsockname() self.remote_addr = self._sock.getpeername() self._refresh_wrappers() logger.trace( "TLS upgrade complete: local=%s remote=%s", self.local_addr, self.remote_addr, ) except ssl.SSLError as e: self.close() logger.debug("TLS negotiation exception: %s", e) raise AppriseSocketError(f"TLS negotiation failed: {e}") from e except OSError as e: self.close() logger.debug("TLS negotiation exception: %s", e) raise AppriseSocketError(str(e)) from e def _attempt_reconnect( self, retries: int, action: str, exc: Exception, ) -> bool: """ Attempt to reconnect and allow the caller to retry. Args: retries: Remaining reconnect attempts permitted (<= 0 disables). action: A short label (e.g. "read" or "write") for logging. exc: The exception that triggered the reconnect attempt. Returns: True if a reconnect was performed and the caller should retry. """ # Respect the caller's retry budget if int(retries) <= 0: return False # Only retry if we have previously completed useful I/O since the last # connect(). This prevents retrying the first failed read/write after # connect. if not self._had_io: return False logger.warning( "Socket %s failed, reconnecting and retrying", action ) logger.debug("Socket %s exception: %s", action, exc) try: self.close() self.connect() except Exception as e: logger.debug("Socket reconnect exception: %s", e) return False return True def read( self, max_bytes: int = 32768, blocking: bool = False, timeout: Optional[float] = None, retries: Optional[int] = None, ) -> bytes: """ Read up to max_bytes bytes. blocking=False: - returns immediately with available data, or b"" if none blocking=True: - waits up to timeout seconds (or instance read timeout if timeout is None), then reads once - if both are None, waits indefinitely retries: - number of reconnect attempts permitted if the socket goes stale after prior successful I/O. Defaults to None (which takes value globally passed into the class) """ if self._sock is None: return b"" # Compute retry attempts; treat retries=0 as explicit 0 retry_count = self.retries if retries is None else int(retries) attempts = max(0, retry_count) + 1 # Derive wait timeout (None means wait indefinitely) wait_timeout = \ self._read_timeout if timeout is None else float(timeout) # We manage readiness via select, socket stays non-blocking self._sock.setblocking(False) while attempts: attempts -= 1 try: if not blocking: try: data = self._sock.recv(int(max_bytes)) if data == b"": raise AppriseSocketError( "Connection lost during read") self._had_io = True return data except (BlockingIOError, ssl.SSLWantReadError, ssl.SSLWantWriteError): return b"" # blocking=True path: wait for readability, then recv if wait_timeout is None: # Wait indefinitely but periodically confirm socket health while True: ready = self.can_read(0.5) if ready is None: raise AppriseSocketError("Socket closed") if ready: break else: ready = self.can_read(wait_timeout) if not ready: return b"" # Even after select says readable, TLS may still raise # WANT_READ/WRITE. Loop until we either receive data, timeout, # or the socket closes. while True: try: data = self._sock.recv(int(max_bytes)) if data == b"": raise AppriseSocketError( "Connection lost during read") self._had_io = True return data except (ssl.SSLWantReadError, ssl.SSLWantWriteError, BlockingIOError): if wait_timeout is None: continue # Avoid busy loop if not self.can_read(min(0.25, wait_timeout)): return b"" except (AppriseSocketError, OSError, ssl.SSLError) as e: # Normalise and log logger.warning("Socket read failed") logger.debug("Socket read exception: %s", e) # Only close on hard errors; WANT_READ/WRITE handled above if isinstance(e, OSError) \ and not isinstance(e, ssl.SSLWantReadError) \ and not isinstance(e, ssl.SSLWantWriteError): self.close() err: Exception = e # Reconnect only if we've had prior useful I/O if self._attempt_reconnect( retries=attempts, action="read", exc=err, ): # In blocking mode with no timeout (wait indefinitely), # perform an immediate read attempt after reconnect. # This avoids relying solely on can_read(), and it keeps # edge cases (like stale sockets) recoverable. if blocking and wait_timeout is None \ and self._sock is not None: try: data = self._sock.recv(int(max_bytes)) if data == b"": raise AppriseSocketError( "Connection lost during read") self._had_io = True return data except (BlockingIOError, ssl.SSLWantReadError, ssl.SSLWantWriteError): # No data yet; fall back to retry loop pass continue if isinstance(err, AppriseSocketError): raise err from None raise AppriseSocketError(str(err)) from err raise AppriseSocketError("Socket read failed") def write( self, data: bytes, flush: bool = True, timeout: Optional[float] = None, retries: Optional[int] = None, ) -> int: """ Write bytes to the socket. timeout: - if None, uses instance read timeout - if both are None, blocks until complete retries: - number of reconnect attempts permitted if the socket goes stale after prior successful I/O. Defaults to None (which takes value globally passed into the class) """ if self._sock is None: raise AppriseSocketError("No active connection") if not isinstance(data, (bytes, bytearray, memoryview)): raise AppriseInvalidData("write() expects bytes-like data") # Loop-based retry avoids recursion and keeps state obvious retry_count = self.retries if retries is None else int(retries) attempts = max(0, retry_count) + 1 while attempts: attempts -= 1 view = memoryview(data) total_sent = 0 op_timeout = ( self._read_timeout if timeout is None else float(timeout) ) deadline = ( None if op_timeout is None else (time.monotonic() + op_timeout) ) try: self._sock.setblocking(deadline is None) while total_sent < len(view): if deadline is not None: remaining = deadline - time.monotonic() if remaining <= 0: raise AppriseSocketError( "Timed out during write" ) writable = self.can_write(remaining) if not writable: raise AppriseSocketError( "Timed out waiting for writable socket" ) sent = self._sock.send(view[total_sent:]) if sent <= 0: raise AppriseSocketError( "Connection lost during write" ) total_sent += sent if flush and self._wfile is not None: self._wfile.flush() if total_sent > 0: self._had_io = True return total_sent except (AppriseSocketError, OSError) as e: logger.warning("Socket write failed") logger.debug("Socket write exception: %s", e) # Normalise: any OSError implies the socket is toast if isinstance(e, OSError): self.close() if self._attempt_reconnect( retries=attempts, action="write", exc=e, ): continue if isinstance(e, AppriseSocketError): raise raise AppriseSocketError(str(e)) from e raise AppriseSocketError("Socket write failed")