From 28a7e7fe45dfd9ef17cc779c55af640e1ef0e1ae Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Mon, 13 Apr 2026 11:21:59 +0800 Subject: [PATCH] Basic protocol support for websocket. Websockets client can send a Protocol which the server can agree to. This isn't as fancy as it sounds. We just send a specific header on websocket handshake and then read the response header. --- src/browser/webapi/net/WebSocket.zig | 64 +++++++++++++++++++++------- 1 file changed, 48 insertions(+), 16 deletions(-) diff --git a/src/browser/webapi/net/WebSocket.zig b/src/browser/webapi/net/WebSocket.zig index aef8809a..5f0c09ac 100644 --- a/src/browser/webapi/net/WebSocket.zig +++ b/src/browser/webapi/net/WebSocket.zig @@ -54,6 +54,7 @@ _got_upgrade: bool = false, _conn: ?*http.Connection, _http_client: *HttpClient, +_req_headers: http.Headers, // buffered outgoing messages _send_queue: std.ArrayList(Message) = .empty, @@ -66,6 +67,9 @@ _recv_buffer: std.ArrayList(u8) = .empty, _close_code: u16 = 1000, _close_reason: []const u8 = "", +// negotiated protocol +_protocol: []const u8 = "", + // Event handlers _on_open: ?js.Function.Temp = null, _on_message: ?js.Function.Temp = null, @@ -84,13 +88,21 @@ pub const BinaryType = enum { arraybuffer, }; -pub fn init(url: []const u8, protocols_: ?[]const u8, page: *Page) !*WebSocket { - if (protocols_) |protocols| { - if (protocols.len > 0) { - log.warn(.not_implemented, "WS protocols", .{ .protocols = protocols }); +fn isValidProtocol(protocol: []const u8) bool { + if (protocol.len == 0) return false; + for (protocol) |c| { + // Control characters + if (c <= 31 or c == 127) return false; + // Separators per RFC 2616 + switch (c) { + '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t' => return false, + else => {}, } } + return true; +} +pub fn init(url: []const u8, protocols: [][]const u8, page: *Page) !*WebSocket { { if (url.len < 6) { return error.SyntaxError; @@ -103,6 +115,11 @@ pub fn init(url: []const u8, protocols_: ?[]const u8, page: *Page) !*WebSocket { if (std.mem.indexOfScalar(u8, url, '#') != null) { return error.SyntaxError; } + for (protocols) |protocol| { + if (!isValidProtocol(protocol)) { + return error.SyntaxError; + } + } } const arena = try page.getArena(.medium, "WebSocket"); @@ -124,12 +141,21 @@ pub fn init(url: []const u8, protocols_: ?[]const u8, page: *Page) !*WebSocket { try conn.setWriteCallback(receivedDataCallback); try conn.setHeaderCallback(receivedHeaderCallback); + var headers = try http_client.newHeaders(); + errdefer headers.deinit(); + if (protocols.len > 0) { + const header = try std.fmt.allocPrintSentinel(arena, "Sec-WebSocket-Protocol: {s}", .{try std.mem.join(arena, ", ", protocols)}, 0); + try headers.add(header); + try conn.setHeaders(&headers); + } + const self = try page._factory.eventTargetWithAllocator(arena, WebSocket{ ._page = page, ._conn = conn, ._arena = arena, ._proto = undefined, ._url = resolved_url, + ._req_headers = headers, ._http_client = http_client, }); conn.transport = .{ .websocket = self }; @@ -206,6 +232,7 @@ pub fn disconnected(self: *WebSocket, err_: ?anyerror) void { fn cleanup(self: *WebSocket) void { if (self._conn) |conn| { self._http_client.removeConn(conn); + self._req_headers.deinit(); self._conn = null; self.releaseRef(self._page._session); self._send_queue.clearRetainingCapacity(); @@ -356,6 +383,10 @@ pub fn getBinaryType(self: *const WebSocket) []const u8 { return @tagName(self._binary_type); } +pub fn getProtocol(self: *const WebSocket) []const u8 { + return self._protocol; +} + pub fn setBinaryType(self: *WebSocket, value: []const u8) void { if (std.meta.stringToEnum(BinaryType, value)) |bt| { self._binary_type = bt; @@ -653,23 +684,24 @@ fn receivedHeaderCallback(buffer: [*]const u8, header_count: usize, buf_len: usi return buf_len; } - if (self._got_upgrade) { - // dont' care about headers once we've gotten the upgrade header - return buf_len; - } - const colon = std.mem.indexOfScalarPos(u8, header, 0, ':') orelse { // weird, continue... return buf_len; }; - if (std.ascii.eqlIgnoreCase(header[0..colon], "upgrade") == false) { - return buf_len; - } - + const header_name = header[0..colon]; const value = std.mem.trim(u8, header[colon + 1 ..], " \t\r\n"); - if (std.ascii.eqlIgnoreCase(value, "websocket")) { - self._got_upgrade = true; + + if (std.ascii.eqlIgnoreCase(header_name, "upgrade")) { + if (std.ascii.eqlIgnoreCase(value, "websocket")) { + self._got_upgrade = true; + } + } else if (std.ascii.eqlIgnoreCase(header_name, "sec-websocket-protocol")) { + // TODO, we should validate this against our sent list. + self._protocol = self._arena.dupe(u8, value) catch |err| { + log.err(.websocket, "dupe protocol", .{ .err = err }); + return 0; + }; } return buf_len; @@ -713,7 +745,7 @@ pub const JsApi = struct { pub const bufferedAmount = bridge.accessor(WebSocket.getBufferedAmount, null, .{}); pub const binaryType = bridge.accessor(WebSocket.getBinaryType, WebSocket.setBinaryType, .{}); - pub const protocol = bridge.property("", .{ .template = false }); + pub const protocol = bridge.accessor(WebSocket.getProtocol, null, .{}); pub const extensions = bridge.property("", .{ .template = false }); pub const onopen = bridge.accessor(WebSocket.getOnOpen, WebSocket.setOnOpen, .{});