From 6c09d63c3d3de1e976ec901739d9d0cc5e7f6d89 Mon Sep 17 00:00:00 2001 From: Kitteh Date: Sun, 6 Jun 2021 12:56:15 +0100 Subject: [PATCH] Move socket/stream operations and handling to SocketManager.zig. --- src/SocketManager.zig | 99 ++++++++++++++++++++++++++++++++++ src/client.zig | 121 ++++++++---------------------------------- 2 files changed, 121 insertions(+), 99 deletions(-) create mode 100644 src/SocketManager.zig diff --git a/src/SocketManager.zig b/src/SocketManager.zig new file mode 100644 index 0000000..cfccf46 --- /dev/null +++ b/src/SocketManager.zig @@ -0,0 +1,99 @@ +const std = @import("std"); + +const QVariant = @import("./qtshit/types/QVariant.zig").QVariant; +const read = @import("./qtshit/read.zig"); +const write = @import("./qtshit/write.zig"); +const range = @import("./qtshit/utils/RangeIter.zig").range; +const tls = @import("./deps/iguanaTLS/src/main.zig"); + +pub const SocketManager = struct { + allocator: *std.mem.Allocator, + baseStream: *std.net.Stream, + + pub var tlsAllowed = !true; + pub var tlsConnected = !true; + + pub const TLSStream = tls.Client(std.net.Stream.Reader, std.net.Stream.Writer, tls.ciphersuites.all, false); + pub var tlsClient: TLSStream = undefined; + + pub fn deinit(s: *SocketManager) void {} + + pub fn setTLSAllowed(s: *SocketManager, value: bool) void { + tlsAllowed = value; + } + + pub fn initTLS(s: *SocketManager) !void { + if (!tlsConnected and tlsAllowed) { + var randBuf: [32]u8 = undefined; + try std.os.getrandom(&randBuf); + var rng = std.rand.DefaultCsprng.init(randBuf); + + var rand = blk: { + var seed: [std.rand.DefaultCsprng.secret_seed_length]u8 = undefined; + try std.os.getrandom(&seed); + break :blk &std.rand.DefaultCsprng.init(seed).random; + }; + + tlsClient = try tls.client_connect(.{ + .rand = rand, + .temp_allocator = s.allocator, + .reader = s.baseStream.reader(), + .writer = s.baseStream.writer(), + .cert_verifier = .none, + .ciphersuites = tls.ciphersuites.all, + }, "quassel.owo.monster"); + tlsConnected = true; + } + } + + fn _writeFrame(s: *SocketManager, writer: anytype, data: std.ArrayList(u8)) !void { + try write.writeUInt(writer, @intCast(u32, data.items.len)); + try writer.writeAll(data.items); + } + + pub fn writeFrame(s: *SocketManager, data: std.ArrayList(u8)) !void { + try s.initTLS(); + if (tlsConnected) { + var writer = tlsClient.writer(); + try s._writeFrame(writer, data); + } else { + var writer = s.baseStream.writer(); + try s._writeFrame(writer, data); + } + } + + fn _readFrame(s: *SocketManager, reader: anytype) !std.ArrayList(u8) { + var size = try read.readUInt(reader); + var data = std.ArrayList(u8).init(s.allocator); + var iter = range(u32, 0, size); + while (iter.next()) |i| { + const byte = try reader.readByte(); + try data.append(byte); + } + return data; + } + + pub fn readFrame(s: *SocketManager) !QVariant { + try s.initTLS(); + var data: std.ArrayList(u8) = undefined; + defer data.deinit(); + + if (tlsConnected) { + var reader = tlsClient.reader(); + data = try s._readFrame(reader); + } else { + var reader = s.baseStream.reader(); + data = try s._readFrame(reader); + } + + var fBS = std.io.fixedBufferStream(data.items); + return try read.readQVariant(fBS.reader(), s.allocator); + } +}; + +pub fn initSocketManager(allocator: *std.mem.Allocator, stream: *std.net.Stream) SocketManager { + return SocketManager{ + .allocator = allocator, + .baseStream = stream, + }; +} diff --git a/src/client.zig b/src/client.zig index 6be14ad..d74e836 100644 --- a/src/client.zig +++ b/src/client.zig @@ -1,18 +1,17 @@ const std = @import("std"); const BufferManager = @import("./BufferManager.zig"); +const SocketManager = @import("./SocketManager.zig"); + const read = @import("./qtshit/read.zig"); const write = @import("./qtshit/write.zig"); -const range = @import("./qtshit/utils/RangeIter.zig").range; -const QVariantType = @import("./qtshit/types/QVariant.zig").QVariant; +const QVariant = @import("./qtshit/types/QVariant.zig").QVariant; const prettyPrintQVariant = @import("./qtshit/utils/prettyPrintQVariant.zig").prettyPrintQVariant; const freeQVariant = @import("./qtshit/utils/free/freeQVariant.zig").freeQVariant; const QVariantMapToQVariantList = @import("./qtshit/utils/QVariantMapToQVariantList.zig").QVariantMapToQVariantList; const UserType = @import("./qtshit/types/UserType.zig"); -const tls = @import("./deps/iguanaTLS/src/main.zig"); - fn dumpDebug(name: []const u8, list: std.ArrayList(u8)) !void { std.debug.print("dumpDebug list len {d}\n", .{list.items.len}); @@ -28,105 +27,28 @@ fn dumpDebug(name: []const u8, list: std.ArrayList(u8)) !void { pub const Client = struct { allocator: *std.mem.Allocator, stream: *std.net.Stream, + socketManager: SocketManager.SocketManager, bufferManager: BufferManager.BufferManager, - pub var tlsAllowed = !true; - pub var tlsConnected = !true; - - pub const TLSStream = tls.Client(std.net.Stream.Reader, std.net.Stream.Writer, tls.ciphersuites.all, false); - pub var tlsClient: TLSStream = undefined; - pub fn deinit(s: *Client) void { s.bufferManager.deinit(); - } - - - pub fn initTLS(s: *Client) !void { - if (!tlsConnected and tlsAllowed) { - var randBuf: [32]u8 = undefined; - try std.os.getrandom(&randBuf); - var rng = std.rand.DefaultCsprng.init(randBuf); - - var rand = blk: { - var seed: [std.rand.DefaultCsprng.secret_seed_length]u8 = undefined; - try std.os.getrandom(&seed); - break :blk &std.rand.DefaultCsprng.init(seed).random; - }; - - tlsClient = try tls.client_connect(.{ - .rand = rand, - .temp_allocator = s.allocator, - .reader = s.stream.reader(), - .writer = s.stream.writer(), - .cert_verifier = .none, - .ciphersuites = tls.ciphersuites.all, - }, "quassel.owo.monster"); - tlsConnected = true; - } - } - - pub fn _writeFrame(s: *Client, writer: anytype, data: std.ArrayList(u8)) !void { - try write.writeUInt(writer, @intCast(u32, data.items.len)); - try writer.writeAll(data.items); - } - - pub fn writeFrame(s: *Client, data: std.ArrayList(u8)) !void { - try s.initTLS(); - if (tlsConnected) { - var writer = tlsClient.writer(); - try s._writeFrame(writer, data); - } else { - var writer = s.stream.writer(); - try s._writeFrame(writer, data); - } - } - - pub fn _readFrame(s: *Client, reader: anytype) !std.ArrayList(u8) { - var size = try read.readUInt(reader); - var data = std.ArrayList(u8).init(s.allocator); - var iter = range(u32, 0, size); - while (iter.next()) |i| { - const byte = try reader.readByte(); - try data.append(byte); - } - return data; - } - - pub fn readFrame(s: *Client) !QVariantType { - try s.initTLS(); - var data: std.ArrayList(u8) = undefined; - defer data.deinit(); - - if (tlsConnected) { - var reader = tlsClient.reader(); - data = try s._readFrame(reader); - } else { - var reader = s.stream.reader(); - data = try s._readFrame(reader); - } - - var fBS = std.io.fixedBufferStream(data.items); - return try read.readQVariant(fBS.reader(), s.allocator); + s.socketManager.deinit(); } pub fn handshake(s: *Client) !void { - //const magic = 0x42b33f00; - - //try write.writeUInt(s.stream.writer(), magic); - //try write.writeUInt(s.stream.writer(), 0x80000002); - try s.stream.writer().writeAll(&[_]u8{ 0x42, 0xb3, 0x3f, 0x01, 0x00, 0x00, 0x00, 0x01, 0x80, 0x00, 0x00, 0x00 }); - - var flags = try read.readByte(s.stream.reader()); - var extra = try read.readShort(s.stream.reader()); - var version = try read.readSignedByte(s.stream.reader()); + try s.socketManager.baseStream.writer().writeAll(&[_]u8{ 0x42, 0xb3, 0x3f, 0x01, 0x00, 0x00, 0x00, 0x01, 0x80, 0x00, 0x00, 0x00 }); + var flags = try read.readByte(s.socketManager.baseStream.reader()); + var extra = try read.readShort(s.socketManager.baseStream.reader()); + var version = try read.readSignedByte(s.socketManager.baseStream.reader()); std.debug.print("Handshake: flags={d} extra={d} version={d} \n", .{ flags, extra, version }); } + pub fn quassel_init_packet(s: *Client) !void { var data = std.ArrayList(u8).init(s.allocator); defer data.deinit(); - var map = std.StringHashMap(QVariantType).init(s.allocator); + var map = std.StringHashMap(QVariant).init(s.allocator); defer map.deinit(); try map.put("MsgType", .{ .String = "ClientInit" }); @@ -151,18 +73,18 @@ pub const Client = struct { .QVariantMap = map, }); - try s.writeFrame(data); + try s.socketManager.writeFrame(data); - var variant = try s.readFrame(); + var variant = try s.socketManager.readFrame(); defer freeQVariant(variant, s.allocator); - tlsAllowed = variant.QVariantMap.get("SupportSsl").?.Byte == 1; + s.socketManager.setTLSAllowed(variant.QVariantMap.get("SupportSsl").?.Byte == 1); } pub fn quassel_login(s: *Client, username: []const u8, password: []const u8) !void { var data = std.ArrayList(u8).init(s.allocator); defer data.deinit(); - var map = std.StringHashMap(QVariantType).init(s.allocator); + var map = std.StringHashMap(QVariant).init(s.allocator); defer map.deinit(); try map.put("MsgType", .{ .String = "ClientLogin" }); @@ -173,9 +95,9 @@ pub const Client = struct { .QVariantMap = map, }); - try s.writeFrame(data); + try s.socketManager.writeFrame(data); - var loginResponse = try s.readFrame(); + var loginResponse = try s.socketManager.readFrame(); defer freeQVariant(loginResponse, s.allocator); var loginResponseMap = loginResponse.QVariantMap; @@ -186,14 +108,14 @@ pub const Client = struct { } } - fn handle_session_init_packet(s: *Client, sessionState: std.StringHashMap(QVariantType)) !void { + fn handle_session_init_packet(s: *Client, sessionState: std.StringHashMap(QVariant)) !void { for (sessionState.get("BufferInfos").?.QVariantList) |qvar| { try s.bufferManager.addBufferInfo(qvar.UserType.BufferInfo); } } pub fn read_quassel_packet(s: *Client) !void { - var variant = try s.readFrame(); + var variant = try s.socketManager.readFrame(); defer freeQVariant(variant, s.allocator); switch (variant) { @@ -216,7 +138,7 @@ pub const Client = struct { var data = std.ArrayList(u8).init(s.allocator); defer data.deinit(); - var listItems = std.ArrayList(QVariantType).init(s.allocator); + var listItems = std.ArrayList(QVariant).init(s.allocator); defer listItems.deinit(); try listItems.append(.{ .Int = 2 }); @@ -228,7 +150,7 @@ pub const Client = struct { .QVariantList = listItems.items, }); - try s.writeFrame(data); + try s.socketManager.writeFrame(data); } }; @@ -237,5 +159,6 @@ pub fn initClient(allocator: *std.mem.Allocator, stream: *std.net.Stream) Client .allocator = allocator, .stream = stream, .bufferManager = BufferManager.initBufferManager(allocator), + .socketManager = SocketManager.initSocketManager(allocator, stream) }; }