Basic protocol support for websocket.

Websockets client can send a Protocol which the server can agree to. This isn't
as fancy as it sounds. We just send a specific header on websocket handshake
and then read the response header.
This commit is contained in:
Karl Seguin
2026-04-13 11:21:59 +08:00
parent dc8e917084
commit 28a7e7fe45

View File

@@ -54,6 +54,7 @@ _got_upgrade: bool = false,
_conn: ?*http.Connection,
_http_client: *HttpClient,
_req_headers: http.Headers,
// buffered outgoing messages
_send_queue: std.ArrayList(Message) = .empty,
@@ -66,6 +67,9 @@ _recv_buffer: std.ArrayList(u8) = .empty,
_close_code: u16 = 1000,
_close_reason: []const u8 = "",
// negotiated protocol
_protocol: []const u8 = "",
// Event handlers
_on_open: ?js.Function.Temp = null,
_on_message: ?js.Function.Temp = null,
@@ -84,13 +88,21 @@ pub const BinaryType = enum {
arraybuffer,
};
pub fn init(url: []const u8, protocols_: ?[]const u8, page: *Page) !*WebSocket {
if (protocols_) |protocols| {
if (protocols.len > 0) {
log.warn(.not_implemented, "WS protocols", .{ .protocols = protocols });
fn isValidProtocol(protocol: []const u8) bool {
if (protocol.len == 0) return false;
for (protocol) |c| {
// Control characters
if (c <= 31 or c == 127) return false;
// Separators per RFC 2616
switch (c) {
'(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t' => return false,
else => {},
}
}
return true;
}
pub fn init(url: []const u8, protocols: [][]const u8, page: *Page) !*WebSocket {
{
if (url.len < 6) {
return error.SyntaxError;
@@ -103,6 +115,11 @@ pub fn init(url: []const u8, protocols_: ?[]const u8, page: *Page) !*WebSocket {
if (std.mem.indexOfScalar(u8, url, '#') != null) {
return error.SyntaxError;
}
for (protocols) |protocol| {
if (!isValidProtocol(protocol)) {
return error.SyntaxError;
}
}
}
const arena = try page.getArena(.medium, "WebSocket");
@@ -124,12 +141,21 @@ pub fn init(url: []const u8, protocols_: ?[]const u8, page: *Page) !*WebSocket {
try conn.setWriteCallback(receivedDataCallback);
try conn.setHeaderCallback(receivedHeaderCallback);
var headers = try http_client.newHeaders();
errdefer headers.deinit();
if (protocols.len > 0) {
const header = try std.fmt.allocPrintSentinel(arena, "Sec-WebSocket-Protocol: {s}", .{try std.mem.join(arena, ", ", protocols)}, 0);
try headers.add(header);
try conn.setHeaders(&headers);
}
const self = try page._factory.eventTargetWithAllocator(arena, WebSocket{
._page = page,
._conn = conn,
._arena = arena,
._proto = undefined,
._url = resolved_url,
._req_headers = headers,
._http_client = http_client,
});
conn.transport = .{ .websocket = self };
@@ -206,6 +232,7 @@ pub fn disconnected(self: *WebSocket, err_: ?anyerror) void {
fn cleanup(self: *WebSocket) void {
if (self._conn) |conn| {
self._http_client.removeConn(conn);
self._req_headers.deinit();
self._conn = null;
self.releaseRef(self._page._session);
self._send_queue.clearRetainingCapacity();
@@ -356,6 +383,10 @@ pub fn getBinaryType(self: *const WebSocket) []const u8 {
return @tagName(self._binary_type);
}
pub fn getProtocol(self: *const WebSocket) []const u8 {
return self._protocol;
}
pub fn setBinaryType(self: *WebSocket, value: []const u8) void {
if (std.meta.stringToEnum(BinaryType, value)) |bt| {
self._binary_type = bt;
@@ -653,23 +684,24 @@ fn receivedHeaderCallback(buffer: [*]const u8, header_count: usize, buf_len: usi
return buf_len;
}
if (self._got_upgrade) {
// dont' care about headers once we've gotten the upgrade header
return buf_len;
}
const colon = std.mem.indexOfScalarPos(u8, header, 0, ':') orelse {
// weird, continue...
return buf_len;
};
if (std.ascii.eqlIgnoreCase(header[0..colon], "upgrade") == false) {
return buf_len;
}
const header_name = header[0..colon];
const value = std.mem.trim(u8, header[colon + 1 ..], " \t\r\n");
if (std.ascii.eqlIgnoreCase(value, "websocket")) {
self._got_upgrade = true;
if (std.ascii.eqlIgnoreCase(header_name, "upgrade")) {
if (std.ascii.eqlIgnoreCase(value, "websocket")) {
self._got_upgrade = true;
}
} else if (std.ascii.eqlIgnoreCase(header_name, "sec-websocket-protocol")) {
// TODO, we should validate this against our sent list.
self._protocol = self._arena.dupe(u8, value) catch |err| {
log.err(.websocket, "dupe protocol", .{ .err = err });
return 0;
};
}
return buf_len;
@@ -713,7 +745,7 @@ pub const JsApi = struct {
pub const bufferedAmount = bridge.accessor(WebSocket.getBufferedAmount, null, .{});
pub const binaryType = bridge.accessor(WebSocket.getBinaryType, WebSocket.setBinaryType, .{});
pub const protocol = bridge.property("", .{ .template = false });
pub const protocol = bridge.accessor(WebSocket.getProtocol, null, .{});
pub const extensions = bridge.property("", .{ .template = false });
pub const onopen = bridge.accessor(WebSocket.getOnOpen, WebSocket.setOnOpen, .{});