using System.Buffers; using System.Net.WebSockets; using System.Text; using System.Threading; using System.Threading.Channels; namespace Sandbox; /// /// A WebSocket client for connecting to external services. /// /// /// Events handlers will be called on the synchronization context that Connect was called on. /// public sealed class WebSocket : IDisposable { /// /// Event handler which processes text messages from the WebSocket service. /// /// The message text that was received. public delegate void MessageReceivedHandler( string message ); /// /// Event handler which processes binary messages from the WebSocket service. /// /// The binary message data that was received. public delegate void DataReceivedHandler( Span data ); /// /// Event handler which fires when the WebSocket disconnects from the server. /// /// The close status code from the server, or 0 if there was none. See known values here: https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent /// The reason string for closing the connection. This may not be populated, may be from the server, or may be a client exception message. public delegate void DisconnectedHandler( int status, string reason ); private class Message { public WebSocketMessageType Type; public ArraySegment Data; // this must be returnable to the pool } [SkipHotload] private CancellationTokenSource _cts; [SkipHotload] private ClientWebSocket _socket; [SkipHotload] private readonly Channel _outgoing; private readonly int _maxMessageSize; private bool _dispatchedDisconnect; private bool _isShuttingDown; /// /// Returns true as long as a WebSocket connection is established. /// public bool IsConnected => _socket?.State == WebSocketState.Open; /// /// Get the sub-protocol that was negotiated during the opening handshake. /// public string SubProtocol => _socket?.SubProtocol; /// /// Event which fires when a text message is received from the server. /// public event MessageReceivedHandler OnMessageReceived; /// /// Event which fires when a binary message is received from the server. /// public event DataReceivedHandler OnDataReceived; /// /// Event which fires when the connection to the WebSocket service is lost, for any reason. /// public event DisconnectedHandler OnDisconnected; /// /// Enable or disable compression for the websocket. If the server supports it, compression will be enabled for all messages. /// Note: compression is disabled by default, and can be dangerous if you are sending secrets across the network. /// public bool EnableCompression { set { if ( value ) { _socket.Options.DangerousDeflateOptions = new WebSocketDeflateOptions { ServerContextTakeover = false, ClientMaxWindowBits = 15 }; } else { _socket.Options.DangerousDeflateOptions = null; } } } /// /// Initialized a new WebSocket client. /// /// The maximum message size to allow from the server, in bytes. Default 64 KiB. public WebSocket( int maxMessageSize = 64 * 1024 ) { if ( maxMessageSize <= 0 || maxMessageSize > 4 * 1024 * 1024 ) { throw new ArgumentOutOfRangeException( nameof( maxMessageSize ) ); } _maxMessageSize = Math.Max( maxMessageSize, 4096 ); _cts = new CancellationTokenSource(); _socket = new ClientWebSocket(); _outgoing = Channel.CreateBounded( new BoundedChannelOptions( 10 ) { SingleReader = true, SingleWriter = false, } ); // auto-disposes this when the TaskSource generation changes TaskSource.Cancellation.Register( () => { _isShuttingDown = true; Dispose(); } ); } ~WebSocket() { Dispose(); } /// /// Cleans up resources used by the WebSocket client. This will also immediately close the connection if it is currently open. /// public void Dispose() { lock ( this ) { DispatchDisconnect( WebSocketCloseStatus.Empty, "Disposed" ); _cts?.Cancel(); _cts?.Dispose(); _cts = null; _socket?.Dispose(); _socket = null; _outgoing.Writer.TryComplete(); } GC.SuppressFinalize( this ); } /// /// Add a sub-protocol to be negotiated during the WebSocket connection handshake. /// /// public void AddSubProtocol( string protocol ) { EnsureNotDisposed(); if ( _socket.State != WebSocketState.None ) { throw new InvalidOperationException( "Cannot add sub-protocols while the WebSocket is connected." ); } _socket.Options.AddSubProtocol( protocol ); } /// /// Establishes a connection to an external WebSocket service. /// /// The WebSocket URI to connect to. For example, "ws://hostname.local:1280/" for unencrypted WebSocket or "wss://hostname.local:1281/" for encrypted. /// A which allows the connection attempt to be aborted if necessary. /// A which completes when the connection is established, or throws if it failed to connect. public Task Connect( string websocketUri, CancellationToken ct = default ) => Connect( websocketUri, null, ct ); /// /// Establishes a connection to an external WebSocket service. /// /// The WebSocket URI to connect to. For example, "ws://hostname.local:1280/" for unencrypted WebSocket or "wss://hostname.local:1281/" for encrypted. /// Headers to send with the connection request. /// A which allows the connection attempt to be aborted if necessary. /// A which completes when the connection is established, or throws if it failed to connect. public async Task Connect( string websocketUri, Dictionary headers, CancellationToken ct = default ) { EnsureNotDisposed(); if ( _socket.State != WebSocketState.None ) { throw new InvalidOperationException( "Connect may only be called once per WebSocket instance." ); } var uri = ParseWebSocketUri( websocketUri ); if ( !Http.IsAllowed( uri ) ) { throw new InvalidOperationException( $"Access to '{websocketUri}' is not allowed." ); } if ( headers != null ) { foreach ( var (key, value) in headers ) { if ( !Http.IsHeaderAllowed( key ) ) { throw new InvalidOperationException( $"Not allowed to set header '{key}'." ); } _socket.Options.SetRequestHeader( key, value ); } } _socket.Options.SetRequestHeader( "User-Agent", Http.UserAgent ); _socket.Options.SetRequestHeader( "Referer", Http.Referrer ); using var linkedCt = CancellationTokenSource.CreateLinkedTokenSource( _cts.Token, ct ); await _socket.ConnectAsync( uri, linkedCt.Token ); SendLoop(); ReceiveLoop(); } /// /// Sends a text message to the WebSocket server. /// /// The message text to send. Must not be null. /// A which completes when the message was queued to be sent. public ValueTask Send( string message ) { EnsureNotDisposed(); if ( message == null ) { throw new ArgumentNullException( nameof( message ) ); } var byteCount = Encoding.UTF8.GetByteCount( message ); var buffer = ArrayPool.Shared.Rent( byteCount ); var length = Encoding.UTF8.GetBytes( message, buffer ); return _outgoing.Writer.WriteAsync( new Message { Type = WebSocketMessageType.Text, Data = new ArraySegment( buffer, 0, length ), } ); } /// /// Sends a binary message to the WebSocket server. /// /// /// The and overloads allow sending subsections of byte arrays. /// /// The message data to send. Must not be null. /// A which completes when the message was queued to be sent. public ValueTask Send( byte[] data ) { EnsureNotDisposed(); if ( data == null ) { throw new ArgumentNullException( nameof( data ) ); } return Send( data.AsSpan() ); } /// /// Sends a binary message to the WebSocket server. /// /// The message data to send. Must not be null. /// A which completes when the message was queued to be sent. public ValueTask Send( ArraySegment data ) { EnsureNotDisposed(); if ( data.Array == null ) { throw new ArgumentNullException( nameof( data ) ); } return Send( data.AsSpan() ); } /// /// Sends a binary message to the WebSocket server. /// /// The message data to send. /// A which completes when the message was queued to be sent. public ValueTask Send( Span data ) { EnsureNotDisposed(); var buffer = ArrayPool.Shared.Rent( data.Length ); data.CopyTo( buffer ); var message = new Message { Type = WebSocketMessageType.Binary, Data = new ArraySegment( buffer, 0, data.Length ), }; return _outgoing.Writer.WriteAsync( message, _cts.Token ); } private async void ReceiveLoop() { var ct = _cts.Token; while ( !ct.IsCancellationRequested ) { byte[] buffer = null; try { buffer = ArrayPool.Shared.Rent( _maxMessageSize ); WebSocketMessageType? type = null; var offset = 0; var length = 0; while ( true ) { var receiveSegment = new ArraySegment( buffer, offset, buffer.Length - length ); var result = await _socket.ReceiveAsync( receiveSegment, ct ); if ( result.MessageType == WebSocketMessageType.Close ) { Disconnect( result.CloseStatus, result.CloseStatusDescription ); return; } if ( type == null ) { type = result.MessageType; } else if ( result.MessageType != type.Value ) { throw new InvalidOperationException( "WebSocket message type changed unexpectedly" ); } offset += result.Count; length += result.Count; if ( result.EndOfMessage ) { break; } if ( length == buffer.Length ) { throw new InvalidOperationException( "WebSocket message exceeds max message size limit" ); } } DispatchReceived( type.Value, buffer, length ); } catch ( WebSocketException e ) when ( e.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely ) { /* * Conna: only disconnect here if the connection is actually closed. We can sometimes get * premature close messages but we're still connected. */ if ( !IsConnected ) { Disconnect( WebSocketCloseStatus.InvalidMessageType, "Unexpected Closure" ); return; } } catch ( Exception e ) { if ( !ct.IsCancellationRequested ) { Log.Error( e ); } // Conna: let's only call disconnect if we are connected.. if ( IsConnected ) { Disconnect( WebSocketCloseStatus.InternalServerError, e.Message ); } return; } finally { if ( buffer != null ) { ArrayPool.Shared.Return( buffer ); } } } } private void DispatchReceived( WebSocketMessageType type, byte[] buffer, int length ) { if ( _isShuttingDown ) return; var data = new Span( buffer, 0, length ); if ( type == WebSocketMessageType.Text ) { var messageText = Encoding.UTF8.GetString( data ); try { OnMessageReceived?.Invoke( messageText ); } catch ( Exception e ) { Log.Error( e ); } } else { try { OnDataReceived?.Invoke( data ); } catch ( Exception e ) { Log.Error( e ); } } } private async void SendLoop() { var ct = _cts.Token; while ( !ct.IsCancellationRequested ) { byte[] data = null; try { var message = await _outgoing.Reader.ReadAsync( ct ); data = message.Data.Array; await _socket.SendAsync( message.Data, message.Type, true, ct ); } catch ( Exception e ) { if ( !ct.IsCancellationRequested ) { Log.Error( e ); } Disconnect( WebSocketCloseStatus.ProtocolError, e.Message ); } finally { if ( data != null ) { ArrayPool.Shared.Return( data ); } } } } private void Disconnect( WebSocketCloseStatus? status, string reason ) { DispatchDisconnect( status, reason ); Dispose(); } private void DispatchDisconnect( WebSocketCloseStatus? status, string reason ) { if ( _dispatchedDisconnect ) return; _dispatchedDisconnect = true; if ( _isShuttingDown ) return; try { OnDisconnected?.Invoke( (int)status.GetValueOrDefault( 0 ), reason ); } catch ( Exception e ) { Log.Error( e ); } } private void EnsureNotDisposed() { lock ( this ) { if ( _cts == null ) { throw new ObjectDisposedException( nameof( WebSocket ) ); } } } private static Uri ParseWebSocketUri( string websocketUri ) { if ( string.IsNullOrEmpty( websocketUri ) ) { throw new ArgumentNullException( nameof( websocketUri ) ); } if ( !Uri.TryCreate( websocketUri, UriKind.Absolute, out var uri ) ) { throw new ArgumentException( "WebSocket URI is not a valid URI.", nameof( websocketUri ) ); } if ( uri.Scheme != "ws" && uri.Scheme != "wss" ) { throw new ArgumentException( "WebSocket URI must use the ws:// or wss:// scheme.", nameof( websocketUri ) ); } return uri; } }