From 0385adfa5b6ec639e49a33f0d2a64f45ec26e7fb Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 4 Jun 2026 09:20:13 -0400 Subject: [PATCH 1/6] ql docs: add QLV2 overview and protocol reference --- QLV2_overview.md | 67 +++++++++ QL_V2.md | 376 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 443 insertions(+) create mode 100644 QLV2_overview.md create mode 100644 QL_V2.md diff --git a/QLV2_overview.md b/QLV2_overview.md new file mode 100644 index 0000000..cf267c2 --- /dev/null +++ b/QLV2_overview.md @@ -0,0 +1,67 @@ +# QuantumLink V2 + +QLV2 is designed around the shortcomings of QLV1. + +QLV1 worked, but it treated each message too much like a standalone encrypted blob. That made routing hard, pairing clunky, repeated messages expensive, reliability awkward, and application encoding too baked into the protocol. + +QLV2 moves QuantumLink toward authenticated encrypted sessions with reliable duplex byte streams. + +## Problems And Solutions + +### The protocol assumed specific product roles + +**Problem:** QLV1 was effectively shaped around the Prime/Envoy relationship. It was not a generic protocol for arbitrary peers; the wire model and pairing flow assumed who the participants were and what direction the relationship moved in. + +**Solution:** QLV2 models peers generically. A peer has a `PeerBundle` representing its public identity, including its QID, independent of whether it is Passport, Envoy, KeyOS, a mobile app, or something else. + +### Messages were not routable + +**Problem:** QLV1 messages were not cheaply introspectable. A peer or transport adapter could not look at a message and know where it was supposed to go. + +**Solution:** QLV2 records have public, but verified, routing headers. Known-peer records expose sender and recipient QIDs, and session records expose enough metadata to route them to the right session. + +This lets a single connection, like BLE, multiplex different senders and recipients on both ends. For example, iPhone/Android can have multiple apps using the same BLE connection, while KeyOS can still know which sender produced a record and which destination app should receive it. + +This also opens the door to QL-level packet forwarding, because peers can cheaply inspect a record and forward it to the right destination without needing to understand or decrypt the payload. + +### Encryption overhead was too high + +**Problem:** The minimum payload size for a QLV1 message is about 6.6KB. + +**Solution:** QLV2 amortizes the KEM and encryption setup cost into a session. Handshake messages are still large, but steady-state message overhead drops to roughly `35..42` bytes, depending on varint encoding size. + +### Key compromise could expose old messages + +**Problem:** QLV1 had no built-in key rotation model. If a long-term key was compromised, old messages were at risk. + +**Solution:** QLV2 uses Noise-style session handshakes so every session gets unique encryption keys. Compromising one session does not automatically compromise future sessions. + +### Reliability lived above the protocol + +**Problem:** QLV1 was fundamentally unreliable, so reliability had to be rebuilt by each higher-level API. Any message flow that needed dependable delivery had to invent its own retry/reliability behavior in userspace. + +**Solution:** QLV2 is built around reliable streams. One session can carry many duplex byte streams, and reliability is solved once at the QL layer instead of repeatedly in user/application space. + +### Encoding was baked into the protocol + +**Problem:** QLV1 was tied to a specific CBOR codec, even when the protocol did not really need that. This added serialization/deserialization cost and extra memory pressure. + +**Solution:** QLV2 uses binary framing internally and exposes byte streams to user space. Since QL is fundamentally moving bytes, the implementation can use zero-copy byte views where possible instead of copying payloads through a serialization layer. Applications can still layer whatever encoding they want on top: JSON, CBOR, XML, or something else. That choice belongs above QL. + +### RPC patterns had to be reinvented per API + +**Problem:** QLV1 mixed protocol transport with application workflow shape. If a feature needed request/response behavior, progress updates, downloads, uploads, or subscriptions, that behavior had to be manually modeled in its message types. The protocol did not provide reusable workflow primitives, so every feature had to encode its own control flow. + +**Solution:** QLV2 makes reliable byte streams the primitive. `ql-rpc` sits above QLV2 and gives those streams common RPC modalities: request/response, notification, upload, download, subscription, and duplex. QL stays focused on peer identity, session establishment, encryption, routing, and reliable byte transport. + +### Pairing lived in userspace + +**Problem:** QLV1 did not treat peer establishment as a protocol concern. Pairing had to be implemented as an application message flow, which made it directional, informal, and hard to generalize beyond the original Prime/Envoy convention. + +**Solution:** QLV2 lifts peer establishment into the protocol. Pairing becomes one protocol-supported way to establish a session, not a one-off userspace convention. + +QLV2 has different session establishment modalities for different starting states: + +- `XX`: first-time pairing using a small out-of-band `PairingToken` +- `IK`: the initiator already knows the responder +- `KK`: both peers already know each other diff --git a/QL_V2.md b/QL_V2.md new file mode 100644 index 0000000..0062c7e --- /dev/null +++ b/QL_V2.md @@ -0,0 +1,376 @@ +# QuantumLink V2 + +QuantumLink V2 is a peer-to-peer protocol for authenticated encrypted sessions carrying multiplexed duplex byte streams. + +It operates on whole QL records. Packetization, fragmentation, batching, and reassembly belong to the transport adapter, not to QLv2 itself. + +## Design goals +1. [Ephemeral peer sessions](#handshake): short-lived keys for encryption +2. [Forward secrecy](#security-properties): losing a long-term private key does not reveal old session data +3. [Minimal authenticated header](#record-and-frame-wire-format): keep routing visible, but authenticated +4. [QL-level reliability](#acknowledgment-and-retransmission): `ack` means received, decrypted, and accepted +5. [Duplex byte streams](#streams): avoid cross-stream head-of-line blocking and keep backpressure local +6. [Efficient wire format](#record-and-frame-wire-format): keep steady-state traffic compact +7. [Hardware-backed cryptography](#security-properties): allow platform-specific crypto implementations +8. Shared core state machine: keep implementation consistent across platforms + +## Non-goals + +QLv2 is not: + +- a packet framing format +- a generic reliability layer for arbitrary raw datagrams +- a globally ordered message bus + +## Core terms + +- `peer`: one QLv2 endpoint +- `QID`: a stable 16-byte peer identifier +- `peer bundle`: public peer information: `version`, `qid`, `capabilities`, and ML-KEM public key +- `pairing token`: an out-of-band secret that authorizes an `XX` pairing attempt +- `pairing_id`: the visible identifier derived from a pairing token and carried on `XX` records +- `session`: one live encrypted channel with directional keys and directional connection IDs +- `record`: one complete QLv2 wire unit +- `frame`: one logical item inside a session record +- `stream`: one duplex byte stream inside a session +- `route_id`: the application route carried once on the first initiator `StreamData` frame for a stream +- `stream origin`: the peer that opened the stream +- `origin lane`: bytes sent by the stream origin +- `return lane`: bytes sent back toward the stream origin + +## Record And Frame Wire Format + +QLv2 has two record types: + +- `handshake record`: used only during setup +- `session record`: used after the handshake completes + +Handshake records are large because they carry ML-KEM material. Session records are small and can carry multiple frames, including frames for different streams. + +All whole-record sizes below include the outer 2-byte record header: `version` plus `record type`. + +QLv2 uses QUIC-style variable-length integers for several steady-state fields. A varint is 1, 2, 4, or 8 bytes and can represent values in the range `0..2^62-1`. This keeps small values compact while allowing very large record and stream number spaces. + +Today, varints are used for: + +- session record `seq` +- `Ack.largest_acked` +- `Ack.block_count` +- `Ack.first_range_len` +- `Ack.gap` +- `Ack.range_len` +- `StreamData.stream_id` +- `StreamData.offset` +- `StreamData.route_id` when present +- `StreamData.bytes_len` +- `StreamWindow.stream_id` +- `StreamWindow.maximum_offset` +- `StreamClose.stream_id` + +### Handshake records + +QLv2 has two routed known-peer handshakes and one pairing handshake: + +- `IK` and `KK` carry a visible `sender` and `recipient` QID +- `XX` carries a visible `pairing_id` + +#### IK + +Used when the initiator already knows the responder bundle. + +| Record | Size | Purpose | +| --- | ---: | --- | +| `IK1` | 4785 bytes | start a handshake toward a known responder | +| `IK2` | 3195 bytes | complete `IK` and establish the session | + +#### KK + +Used when both peers already know each other. + +| Record | Size | Purpose | +| --- | ---: | --- | +| `KK1` | 3179 bytes | start a handshake between already-known peers | +| `KK2` | 3195 bytes | complete `KK` and establish the session | + +#### XX + +Used when the initiator has received an out of band pairing token, and neither peer knows each other. + +| Record | Size | Purpose | +| --- | ---: | --- | +| `XX1` | 1595 bytes | start pairing | +| `XX2` | 3201 bytes | send responder static identity and ciphertext | +| `XX3` | 3217 bytes | send initiator static identity and ciphertext | +| `XX4` | 1611 bytes | complete `XX` and establish the session | + +### Session records + +`session record size = 35..42 + sum(frame sizes)` + +There is no explicit AEAD nonce on the wire. The record `seq` is used to derive the nonce. + +| Fixed part | Size | Purpose | +| --- | ---: | --- | +| version | 1 byte | protocol version | +| record type | 1 byte | identifies a session record | +| `connection_id` | 16 bytes | route the record to the current session | +| `seq` | 1..8 bytes | varint record identity for ack and retransmit | +| AEAD auth tag | 16 bytes | authenticate the encrypted body | +| fixed overhead total | 35..42 bytes | overhead before any frames | + +The visible session header is authenticated as AEAD AAD but is not encrypted. + +### Session frames + +| Frame | Size | Purpose | +| --- | ---: | --- | +| `Ping` | 1 byte | keep the session alive when idle | +| `Unpair` | 1 byte | forget the currently bound peer and abort the session | +| `Ack` | `4+` bytes | acknowledge received session records with ACK ranges | +| `StreamWindow` | `3..17` bytes | extend per-stream send credit | +| `StreamClose` | `5..12` bytes | abort one stream lane or both lanes | +| `Close` | 3 bytes | close the whole session | +| `StreamData` | `5..34 + payload_len` bytes | carry stream bytes, optional opener route, and optional `fin` | + +`StreamData` is the main steady-state frame: + +`1 kind + varint(stream_id) + varint(offset) + 1 flags + optional varint(route_id) + varint(bytes_len) + payload_len` + +The flags byte carries: + +- `fin` +- `header present` + +Some useful minimum whole-record sizes for single-frame records: + +| Record | Size | Meaning | +| --- | ---: | --- | +| `Ping` only | 36 bytes | idle keepalive | +| `Unpair` only | 36 bytes | peer unpair | +| `Ack` only | 39 bytes | smallest selective ack | +| `Close` only | 38 bytes | session shutdown | +| empty `StreamData` without route header | 40 bytes | empty data or empty `fin` on an existing stream | +| empty opener `StreamData` with a 1-byte `route_id` | 41 bytes | open a new stream without payload bytes | + +## Handshake + +QLv2 currently supports three Noise-style handshake patterns: + +- `IK`: 2 messages, initiator already knows the responder bundle +- `KK`: 2 messages, both peers already know each other +- `XX`: 4 messages, peers authenticate through an out-of-band pairing token and exchange static identity during the handshake + +The handshake covers peer authentication and session establishment. + +Each successful handshake does five things: + +1. authenticate which peer we are talking to +2. derive a fresh transmit key and receive key +3. derive a directional transmit `connection_id` and receive `connection_id` +4. bind transport parameters into the transcript +5. produce a `handshake_hash` for the completed exchange + +Today the only transport parameter is: + +- initial per-stream receive window + +Future transport parameters could include session-wide byte credit or record-size limits. + +Each handshake attempt carries: + +- `handshake_id`: identifies one attempt and lets stale replies be ignored +- transport parameters + +`valid_until` is not currently part of the wire format. Handshake attempts instead expire by local timer. + +### Pattern summary + +- `IK` lets the responder learn the initiator during handshake completion. The initiator still needs the responder bundle before it can start. +- `KK` requires both peers to already know each other. +- `XX` requires the responder to be armed for pairing and to recognize the visible `pairing_id` derived from the expected pairing token. + +### Handshake rules + +- attempts are identified by `handshake_id` +- handshake messages are not retransmitted in place +- simultaneous starts must converge deterministically +- if `IK` and `KK` race, `IK` wins +- same-pattern races break ties by ordering the initial ephemeral public keys +- `XX` requires out-of-band authorization and uses visible `pairing_id` for lookup + +### Session establishment points + +- `IK` and `KK` complete after message 2 (1 RT) +- `XX` completes after 4 messages (2 RTT) + +## Session Model + +After the handshake, peers exchange encrypted session records. + +Each session record has: + +- one visible `connection_id` +- one visible `seq` +- one encrypted body containing one or more frames + +One session record may carry: + +- only control frames +- only stream data +- a mixture of frames for multiple streams + +This is the core steady-state model: records are the encrypted transport unit, frames are the logical items inside them. + +## Acknowledgment And Retransmission + +`Ack` is record-level, not stream-level. + +An `Ack` means the peer: + +- received that session record +- decrypted it with the current session key +- accepted its `seq` + +The ACK wire format is range-based, not bitmap-based. It carries: + +- `largest_acked` +- `block_count` +- `first_range_len` +- zero or more `(gap, range_len)` blocks + +Ranges are encoded from highest sequence numbers down to lowest sequence numbers. + +Receivers track a recent accepted record window so they can: + +- reject duplicates +- ignore records that are too old +- emit selective ACK ranges + +Pending ACK state is also range-based. If there are too many disjoint ranges, older low ranges may be dropped. An emitted ACK may also be truncated by the remaining record budget. + +Retransmission works at the frame level: + +- every emitted session record gets a fresh `seq` +- retransmit timers start only after the local transport confirms that it accepted the write +- if a record is considered lost, the FSM restores its frames +- those frames are packed into a new record with a new `seq` + +QLv2 does not resend the same logical record identity. + +There is no explicit `Nack` frame. Loss is inferred from timeout or from later ACK state that no longer includes a record. + +Pure ACK-only records are fire-and-forget: they are not themselves retransmitted. + +Example: + +`seq = 10` + +| Frame | Contents | +| --- | --- | +| `StreamData` | `stream_id=4 offset=0 bytes="hello"` | + +The sender receives more bytes for that stream before `seq = 10` is acked: + +| Pending new frame | Contents | +| --- | --- | +| `StreamData` | `stream_id=4 offset=5 bytes=" world"` | + +If `seq = 10` is considered lost, its frame is restored and packed again with a new record sequence: + +`seq = 11` + +| Frame | Contents | +| --- | --- | +| `StreamData` | `stream_id=4 offset=0 bytes="hello"` | +| `StreamData` | `stream_id=4 offset=5 bytes=" world"` | + +## Streams + +Streams are the application primitive. + +A stream has two independent lanes: + +- origin lane +- return lane + +Important properties: + +- either peer can open a stream +- stream IDs are split by parity derived from QID ordering, so both peers can open streams without collision +- stream IDs increase monotonically within each parity namespace and must not repeat within a session +- ordering is preserved within a stream lane +- different streams can make progress independently +- record loss on one stream does not block unrelated streams + +There is no separate open frame. + +Locally, opening a stream allocates: + +- a new `stream_id` +- an application `route_id` + +On the wire, the stream opener carries that `route_id` once, in the first initiator `StreamData` frame at `offset = 0`, using the optional `StreamHeader`. + +`StreamData` carries: + +- `stream_id` +- `offset` +- optional `StreamHeader { route_id }` +- `fin` +- bytes + +`StreamHeader` is only valid on the first initiator `StreamData` frame for a stream, at `offset = 0`. + +`fin` is graceful completion of one lane. It says "no more bytes on this lane" without aborting the other lane. + +## Flow Control + +Flow control is per stream. + +During the handshake, each peer advertises an initial per-stream receive window. That becomes the initial send credit the remote peer can use on each stream. + +`StreamWindow` extends that credit by advertising a larger absolute `maximum_offset`. + +In practice, a stream is writable only when both are true: + +- local send buffering has room +- peer-advertised stream credit allows more bytes + +Receive credit advances when the local application commits read bytes, not merely when bytes become readable. That is when the FSM emits a `StreamWindow` update. + +## Close And Liveness + +`StreamClose` aborts a stream early. Semantically it can target: + +- the origin lane +- the return lane +- both lanes + +`Close` aborts the whole session. + +`Unpair` is stronger than `Close`: + +- it forgets the currently bound peer locally +- it aborts the active session immediately +- it may emit one final outbound `Unpair` frame +- reconnect does not resume until a peer is paired again + +Idle sessions may send `Ping`. The peer does not answer with another ping; normal record acknowledgment is enough. + +Sessions also have local timers for: + +- handshake timeout +- delayed ack emission +- session record retransmit timeout +- keepalive ping interval +- peer silence timeout + +If peer silence exceeds the configured timeout, the session closes with timeout. + +## Security Properties + +The current handshake family is ML-KEM-based and post-quantum focused. + +Session payloads are encrypted and authenticated. The session header stays visible so the receiver can route the record, but it is still authenticated as AEAD AAD. + +QLv2 also provides forward secrecy in the following sense: even if an attacker later obtains a peer's long-term ML-KEM private key, they still cannot decrypt messages from earlier completed sessions. From 71702fbfe3489494f8ba2d44d566f0a0b5060f0e Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 4 Jun 2026 09:22:32 -0400 Subject: [PATCH 2/6] ql-wire: add wire-format definitions --- Cargo.lock | 197 ++++ Cargo.toml | 10 +- ql-wire/Cargo.toml | 27 + ql-wire/src/bytes.rs | 175 ++++ ql-wire/src/codec.rs | 245 +++++ ql-wire/src/crypto.rs | 47 + ql-wire/src/encrypted/ack.rs | 453 +++++++++ ql-wire/src/encrypted/builder.rs | 172 ++++ ql-wire/src/encrypted/close.rs | 55 ++ ql-wire/src/encrypted/mod.rs | 178 ++++ ql-wire/src/encrypted/route_id.rs | 55 ++ ql-wire/src/encrypted/stream_close.rs | 116 +++ ql-wire/src/encrypted/stream_data.rs | 135 +++ ql-wire/src/encrypted/stream_id.rs | 35 + ql-wire/src/encrypted/stream_window.rs | 29 + ql-wire/src/encrypted_message.rs | 97 ++ ql-wire/src/error.rs | 33 + ql-wire/src/handshake/ik.rs | 376 ++++++++ ql-wire/src/handshake/kk.rs | 352 +++++++ ql-wire/src/handshake/meta.rs | 48 + ql-wire/src/handshake/mod.rs | 590 ++++++++++++ ql-wire/src/handshake/pairing.rs | 83 ++ ql-wire/src/handshake/transport_params.rs | 38 + ql-wire/src/handshake/xx.rs | 612 +++++++++++++ ql-wire/src/header.rs | 121 +++ ql-wire/src/identity.rs | 192 ++++ ql-wire/src/lib.rs | 45 + ql-wire/src/nonce.rs | 13 + ql-wire/src/pq.rs | 159 ++++ ql-wire/src/qid.rs | 44 + ql-wire/src/record.rs | 254 +++++ ql-wire/src/testing.rs | 181 ++++ ql-wire/src/tests.rs | 1017 +++++++++++++++++++++ ql-wire/src/varint.rs | 181 ++++ 34 files changed, 6364 insertions(+), 1 deletion(-) create mode 100644 ql-wire/Cargo.toml create mode 100644 ql-wire/src/bytes.rs create mode 100644 ql-wire/src/codec.rs create mode 100644 ql-wire/src/crypto.rs create mode 100644 ql-wire/src/encrypted/ack.rs create mode 100644 ql-wire/src/encrypted/builder.rs create mode 100644 ql-wire/src/encrypted/close.rs create mode 100644 ql-wire/src/encrypted/mod.rs create mode 100644 ql-wire/src/encrypted/route_id.rs create mode 100644 ql-wire/src/encrypted/stream_close.rs create mode 100644 ql-wire/src/encrypted/stream_data.rs create mode 100644 ql-wire/src/encrypted/stream_id.rs create mode 100644 ql-wire/src/encrypted/stream_window.rs create mode 100644 ql-wire/src/encrypted_message.rs create mode 100644 ql-wire/src/error.rs create mode 100644 ql-wire/src/handshake/ik.rs create mode 100644 ql-wire/src/handshake/kk.rs create mode 100644 ql-wire/src/handshake/meta.rs create mode 100644 ql-wire/src/handshake/mod.rs create mode 100644 ql-wire/src/handshake/pairing.rs create mode 100644 ql-wire/src/handshake/transport_params.rs create mode 100644 ql-wire/src/handshake/xx.rs create mode 100644 ql-wire/src/header.rs create mode 100644 ql-wire/src/identity.rs create mode 100644 ql-wire/src/lib.rs create mode 100644 ql-wire/src/nonce.rs create mode 100644 ql-wire/src/pq.rs create mode 100644 ql-wire/src/qid.rs create mode 100644 ql-wire/src/record.rs create mode 100644 ql-wire/src/testing.rs create mode 100644 ql-wire/src/tests.rs create mode 100644 ql-wire/src/varint.rs diff --git a/Cargo.lock b/Cargo.lock index f144305..c2c3b23 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -532,6 +532,17 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "core-models" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "657f625ff361906f779745d08375ae3cc9fef87a35fba5f22874cf773010daf4" +dependencies = [ + "hax-lib", + "pastey", + "rand 0.9.2", +] + [[package]] name = "cpufeatures" version = "0.2.17" @@ -1081,6 +1092,43 @@ version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "hax-lib" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "543f93241d32b3f00569201bfce9d7a93c92c6421b23c77864ac929dc947b9fc" +dependencies = [ + "hax-lib-macros", + "num-bigint", + "num-traits", +] + +[[package]] +name = "hax-lib-macros" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8755751e760b11021765bb04cb4a6c4e24742688d9f3aa14c2079638f537b0f" +dependencies = [ + "hax-lib-macros-types", + "proc-macro-error2", + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "hax-lib-macros-types" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f177c9ae8ea456e2f71ff3c1ea47bf4464f772a05133fcbba56cd5ba169035a2" +dependencies = [ + "proc-macro2", + "quote", + "serde", + "serde_json", + "uuid", +] + [[package]] name = "hermit-abi" version = "0.5.2" @@ -1356,6 +1404,84 @@ version = "0.2.175" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" +[[package]] +name = "libcrux-aesgcm" +version = "0.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99f2a019dab4097585a7d4f5b9deebe46cd1e628b16a5bc4cb0ce35e1da334e6" +dependencies = [ + "libcrux-intrinsics", + "libcrux-platform", + "libcrux-secrets", + "libcrux-traits", +] + +[[package]] +name = "libcrux-intrinsics" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1b5db005ff8001e026b73a6842ee81bbef8ec5ff0e1915a67ae65fd2a9fafa5" +dependencies = [ + "core-models", + "hax-lib", +] + +[[package]] +name = "libcrux-ml-kem" +version = "0.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aca7de713c6dddcf7aaf76e8ef9dc0097c8d7ce23a8eadf04c8761734714e184" +dependencies = [ + "hax-lib", + "libcrux-intrinsics", + "libcrux-platform", + "libcrux-secrets", + "libcrux-sha3", + "libcrux-traits", + "rand 0.9.2", + "tls_codec", +] + +[[package]] +name = "libcrux-platform" +version = "0.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d9e21d7ed31a92ac539bd69a8c970b183ee883872d2d19ce27036e24cb8ecc4" +dependencies = [ + "libc", +] + +[[package]] +name = "libcrux-secrets" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ce650f3041b44ba40d4263852347d007cd2cd9d1cc856a6f6c8b2e10c3fd40b" +dependencies = [ + "hax-lib", +] + +[[package]] +name = "libcrux-sha3" +version = "0.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c50f6e04a184511b782c5cc1eb6a227c6d36f2c935e93d698655a93a99696b5" +dependencies = [ + "hax-lib", + "libcrux-intrinsics", + "libcrux-platform", + "libcrux-traits", +] + +[[package]] +name = "libcrux-traits" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e4fa89f3f5e34b47f928b22b1b78395a0d4ec23b1f583db635f128159d65f" +dependencies = [ + "libcrux-secrets", + "rand 0.9.2", +] + [[package]] name = "libm" version = "0.2.15" @@ -1469,6 +1595,16 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-bigint-dig" version = "0.8.6" @@ -1625,6 +1761,12 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pastey" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ee67f1008b1ba2321834326597b8e186293b049a023cdef258527550b9935b4" + [[package]] name = "pbkdf2" version = "0.12.2" @@ -1810,6 +1952,28 @@ dependencies = [ "elliptic-curve", ] +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "proc-macro2" version = "1.0.101" @@ -1863,6 +2027,17 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "ql-wire" +version = "0.1.0" +dependencies = [ + "bytes", + "getrandom 0.2.16", + "libcrux-aesgcm", + "libcrux-ml-kem", + "sha2", +] + [[package]] name = "quantum-link-macros" version = "0.1.0" @@ -2436,6 +2611,27 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tls_codec" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de2e01245e2bb89d6f05801c564fa27624dbd7b1846859876c7dad82e90bf6b" +dependencies = [ + "tls_codec_derive", + "zeroize", +] + +[[package]] +name = "tls_codec_derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2e76690929402faae40aebdda620a2c0e25dd6d3b9afe48867dfd95991f4bd" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "tokio" version = "1.47.1" @@ -2517,6 +2713,7 @@ version = "1.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" dependencies = [ + "getrandom 0.3.3", "js-sys", "wasm-bindgen", ] diff --git a/Cargo.toml b/Cargo.toml index 0fd0e75..8aad910 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,12 @@ [workspace] resolver = "2" -members = ["api", "backup-shard", "btp", "quantum-link-macros"] +members = [ + "api", + "backup-shard", + "btp", + "ql-wire", + "quantum-link-macros", +] [workspace.package] homepage = "https://github.com/Foundation-Devices/foundation-api" @@ -14,6 +20,7 @@ dcbor = { version = "0.23.3" } gstp = { version = "0.11.0" } chrono = "0.4" +bytes = "1" getrandom = { version = "0.2" } insta = { version = "1.43.2" } thiserror = { version = "2" } @@ -24,6 +31,7 @@ backup-shard = { path = "backup-shard" } btp = { path = "btp" } foundation-api = { path = "api" } quantum-link-macros = { path = "quantum-link-macros" } +ql-wire = { path = "ql-wire" } [patch.crates-io] pqcrypto-traits = { git = "https://github.com/Foundation-Devices/pqcrypto", rev = "ebadf71214f67cb970242fa1053b4acb65767737" } diff --git a/ql-wire/Cargo.toml b/ql-wire/Cargo.toml new file mode 100644 index 0000000..399846c --- /dev/null +++ b/ql-wire/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "ql-wire" +version = "0.1.0" +edition = "2021" +description = "QuantumLink protocol wire format" +license = "Proprietary" + +[features] +test-utils = [ + "dep:getrandom", + "dep:libcrux-aesgcm", + "dep:libcrux-ml-kem", + "dep:sha2", +] + +[dependencies] +bytes = { workspace = true } +getrandom = { workspace = true, optional = true } +libcrux-aesgcm = { version = "0.0.7", optional = true } +libcrux-ml-kem = { version = "0.0.7", optional = true } +sha2 = { version = "0.10", optional = true } + +[dev-dependencies] +getrandom = { workspace = true } +libcrux-aesgcm = "0.0.7" +libcrux-ml-kem = "0.0.7" +sha2 = "0.10" diff --git a/ql-wire/src/bytes.rs b/ql-wire/src/bytes.rs new file mode 100644 index 0000000..9fecf5e --- /dev/null +++ b/ql-wire/src/bytes.rs @@ -0,0 +1,175 @@ +use core::ops::{Deref, DerefMut}; + +use bytes::{Buf, Bytes}; + +/// A mutable or immutable byte slice owner used by the wire parser. +pub trait ByteSlice: Deref + Sized { + /// Splits the current byte view at `mid`. + /// + /// Returns `Err(self)` when `mid` is out of bounds. + fn split_at(self, mid: usize) -> Result<(Self, Self), Self>; +} + +/// A mutable reference to bytes. +pub trait ByteSliceMut: ByteSlice + DerefMut {} + +impl ByteSliceMut for B where B: ByteSlice + DerefMut {} + +impl ByteSlice for &[u8] { + #[inline] + fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { + if mid <= self.len() { + Ok(<[u8]>::split_at(self, mid)) + } else { + Err(self) + } + } +} + +impl ByteSlice for &mut [u8] { + #[inline] + fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { + if mid <= self.len() { + Ok(<[u8]>::split_at_mut(self, mid)) + } else { + Err(self) + } + } +} + +impl ByteSlice for Bytes { + #[inline] + fn split_at(self, mid: usize) -> Result<(Self, Self), Self> { + if mid <= self.len() { + Ok((self.slice(..mid), self.slice(mid..))) + } else { + Err(self) + } + } +} + +/// A byte container that can expose a replayable [`Buf`] view for encoding. +pub trait BufView { + type Buf<'a>: Buf + where + Self: 'a; + + fn buf(&self) -> Self::Buf<'_>; + + fn is_empty(&self) -> bool { + self.buf().remaining() == 0 + } +} + +impl BufView for &T { + type Buf<'a> + = T::Buf<'a> + where + Self: 'a; + + fn buf(&self) -> Self::Buf<'_> { + (*self).buf() + } +} + +impl BufView for &mut T { + type Buf<'a> + = T::Buf<'a> + where + Self: 'a; + + fn buf(&self) -> Self::Buf<'_> { + (**self).buf() + } +} + +impl BufView for [u8] { + type Buf<'a> + = &'a [u8] + where + Self: 'a; + + fn buf(&self) -> Self::Buf<'_> { + self + } +} + +impl BufView for [u8; N] { + type Buf<'a> + = &'a [u8] + where + Self: 'a; + + fn buf(&self) -> Self::Buf<'_> { + self.as_slice() + } +} + +impl BufView for Vec { + type Buf<'a> + = &'a [u8] + where + Self: 'a; + + fn buf(&self) -> Self::Buf<'_> { + self.as_slice() + } +} + +impl BufView for Bytes { + type Buf<'a> + = &'a [u8] + where + Self: 'a; + + fn buf(&self) -> Self::Buf<'_> { + self.as_ref() + } +} + +#[cfg(test)] +mod tests { + use bytes::Buf; + + use super::{BufView, ByteSlice, ByteSliceMut}; + + #[test] + fn shared_slice_split_at() { + let bytes: &[u8] = b"abcdef"; + let (left, right) = ByteSlice::split_at(bytes, 2).unwrap(); + assert_eq!(left, b"ab"); + assert_eq!(right, b"cdef"); + } + + #[test] + fn mutable_slice_split_at() { + let mut bytes = *b"abcdef"; + let (left, right) = ByteSlice::split_at(&mut bytes[..], 2).unwrap(); + assert_eq!(left, b"ab"); + assert_eq!(right, b"cdef"); + } + + #[test] + fn mutable_split_trait_is_implemented() { + fn assert_split_mut(_value: T) {} + + let mut bytes = [0u8; 4]; + assert_split_mut(&mut bytes[..]); + } + + #[test] + fn split_at_rejects_out_of_bounds_index() { + let bytes: &[u8] = b"abcdef"; + assert!(ByteSlice::split_at(bytes, 7).is_err()); + } + + #[test] + fn slice_buf_view_is_contiguous() { + let bytes: &[u8] = b"abcdef"; + let mut buf = bytes.buf(); + assert_eq!(buf.remaining(), 6); + assert_eq!(buf.chunk(), b"abcdef"); + buf.advance(6); + assert!(!buf.has_remaining()); + } +} diff --git a/ql-wire/src/codec.rs b/ql-wire/src/codec.rs new file mode 100644 index 0000000..0245ef6 --- /dev/null +++ b/ql-wire/src/codec.rs @@ -0,0 +1,245 @@ +use bytes::BufMut; + +use crate::{ByteSlice, WireError}; + +pub trait WireEncode { + fn encoded_len(&self) -> usize; + + fn encode(&self, out: &mut W); + + fn encode_vec(&self) -> Vec { + let mut out = Vec::with_capacity(self.encoded_len()); + self.encode(&mut out); + debug_assert_eq!(out.len(), self.encoded_len()); + out + } +} + +pub trait WireDecode: Sized { + fn decode(reader: &mut Reader) -> Result; + + fn decode_bytes(bytes: B) -> Result { + let mut reader = Reader::new(bytes); + Self::decode(&mut reader) + } + + fn decode_exact(bytes: B) -> Result { + let mut reader = Reader::new(bytes); + let value = Self::decode(&mut reader)?; + if reader.is_empty() { + Ok(value) + } else { + Err(WireError::InvalidPayload) + } + } +} + +impl WireDecode for [u8; N] { + fn decode(reader: &mut Reader) -> Result { + let bytes = reader.take_bytes(N)?; + let mut out = [0u8; N]; + out.copy_from_slice(&bytes); + Ok(out) + } +} + +impl WireEncode for [u8; N] { + fn encoded_len(&self) -> usize { + N + } + + fn encode(&self, out: &mut W) { + out.put_slice(self); + } +} + +impl WireDecode for Box<[u8; N]> { + fn decode(reader: &mut Reader) -> Result { + let bytes = reader.take_bytes(N)?; + let mut out = Self::new_uninit(); + let src = bytes.as_ptr(); + let dst = out.as_mut_ptr().cast::(); + // SAFETY: `take_bytes(N)` guarantees the source has exactly `N` bytes. + unsafe { + std::ptr::copy_nonoverlapping(src, dst, N); + Ok(out.assume_init()) + } + } +} + +impl WireEncode for Box<[u8; N]> { + fn encoded_len(&self) -> usize { + N + } + + fn encode(&self, out: &mut W) { + out.put_slice(self.as_ref()); + } +} + +impl WireEncode for [u8] { + fn encoded_len(&self) -> usize { + self.len() + } + + fn encode(&self, out: &mut W) { + out.put_slice(self); + } +} + +impl WireDecode for u8 { + fn decode(reader: &mut Reader) -> Result { + Ok(reader.take_bytes(1)?[0]) + } +} + +impl WireEncode for u8 { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u8(*self); + } +} + +impl WireDecode for u16 { + fn decode(reader: &mut Reader) -> Result { + Ok(Self::from_be_bytes(reader.decode()?)) + } +} + +impl WireEncode for u16 { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u16(*self); + } +} + +impl WireDecode for u32 { + fn decode(reader: &mut Reader) -> Result { + Ok(Self::from_be_bytes(reader.decode()?)) + } +} + +impl WireEncode for u32 { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u32(*self); + } +} + +impl WireDecode for u64 { + fn decode(reader: &mut Reader) -> Result { + Ok(Self::from_be_bytes(reader.decode()?)) + } +} + +impl WireEncode for u64 { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u64(*self); + } +} + +impl WireDecode for bool { + fn decode(reader: &mut Reader) -> Result { + match reader.decode::()? { + 0 => Ok(false), + 1 => Ok(true), + _ => Err(WireError::InvalidPayload), + } + } +} + +impl WireEncode for bool { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u8(u8::from(*self)); + } +} + +impl WireEncode for Option { + fn encoded_len(&self) -> usize { + 1 + self.as_ref().map_or(0, WireEncode::encoded_len) + } + + fn encode(&self, out: &mut W) { + match self { + None => out.put_u8(0), + Some(inner) => { + out.put_u8(1); + inner.encode(out); + } + } + } +} + +impl> WireDecode for Option { + fn decode(reader: &mut Reader) -> Result { + match reader.decode::()? { + 0 => Ok(None), + 1 => Ok(Some(reader.decode::()?)), + _ => Err(WireError::InvalidPayload), + } + } +} + +#[derive(Clone)] +pub struct Reader { + remaining: Option, +} + +impl Reader { + pub fn new(bytes: B) -> Self { + Self { + remaining: Some(bytes), + } + } + + pub fn is_empty(&self) -> bool { + self.remaining.as_ref().unwrap().is_empty() + } + + pub fn remaining_len(&self) -> usize { + self.remaining.as_ref().unwrap().len() + } + + pub fn take_bytes(&mut self, len: usize) -> Result { + let remaining = self.remaining.take().unwrap(); + match remaining.split_at(len) { + Ok((head, tail)) => { + self.remaining = Some(tail); + Ok(head) + } + Err(remaining) => { + self.remaining = Some(remaining); + Err(WireError::InvalidPayload) + } + } + } + + pub fn take_rest(&mut self) -> B { + self.take_bytes(self.remaining_len()).unwrap() + } + + #[inline] + pub fn decode(&mut self) -> Result + where + T: WireDecode, + { + T::decode(self) + } +} diff --git a/ql-wire/src/crypto.rs b/ql-wire/src/crypto.rs new file mode 100644 index 0000000..96ace38 --- /dev/null +++ b/ql-wire/src/crypto.rs @@ -0,0 +1,47 @@ +use crate::{ + MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, SessionKey, + ENCRYPTED_MESSAGE_AUTH_SIZE, +}; + +pub trait QlRandom { + fn fill_random_bytes(&self, out: &mut [u8]); +} + +pub trait QlHash { + fn sha256(&self, parts: &[&[u8]]) -> [u8; 32]; +} + +pub trait QlAead { + fn aes256_gcm_encrypt( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; + + fn aes256_gcm_decrypt( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], + ) -> bool; +} + +pub trait QlKem { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair; + + fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey); + + fn mlkem_decapsulate( + &self, + private_key: &MlKemPrivateKey, + ciphertext: &MlKemCiphertext, + ) -> SessionKey; +} + +pub trait QlCrypto: QlRandom + QlHash + QlAead + QlKem {} + +impl QlCrypto for T where T: QlRandom + QlHash + QlAead + QlKem {} diff --git a/ql-wire/src/encrypted/ack.rs b/ql-wire/src/encrypted/ack.rs new file mode 100644 index 0000000..2eb34b3 --- /dev/null +++ b/ql-wire/src/encrypted/ack.rs @@ -0,0 +1,453 @@ +use std::{fmt, ops::RangeInclusive}; + +use crate::{codec, ByteSlice, RecordSeq, VarInt, WireEncode, WireError}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RecordAck { + largest_acked: RecordSeq, + first_range_len: VarInt, + blocks: Box<[RecordAckBlock]>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RecordAckBlock { + pub gap: VarInt, + pub range_len: VarInt, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RecordAckRangeError { + Empty, + InvertedRange, + NotCanonical, +} + +impl RecordAck { + /// Build a record ACK from canonical ranges ordered from highest to lowest sequence number. + /// + /// Ranges must be: + /// - non-empty + /// - individually valid (`start <= end`) + /// - strictly descending + /// - separated by at least one missing sequence number + pub fn from_ranges(ranges: I) -> Result + where + I: IntoIterator>, + { + let mut builder = RecordAckBuilder::new(); + for range in ranges { + let pushed = builder.try_push_range(range, usize::MAX)?; + if !pushed { + unreachable!("record ack should fit inside usize::MAX"); + } + } + builder.build() + } + + pub fn ranges(&self) -> RecordAckRangeIter<'_> { + RecordAckRangeIter { + largest_acked: self.largest_acked.into_inner(), + first_range_len: Some(self.first_range_len), + previous_start: None, + blocks: self.blocks.iter(), + } + } + + pub fn contains(&self, seq: u64) -> bool { + let Ok(seq) = RecordSeq::from_u64(seq) else { + return false; + }; + self.ranges().any(|range| range.contains(&seq)) + } + + fn block_count_len(block_count: usize) -> usize { + VarInt::try_from(block_count).unwrap().encoded_len() + } +} + +impl RecordAckBlock { + fn encoded_len(&self) -> usize { + self.gap.encoded_len() + self.range_len.encoded_len() + } +} + +impl fmt::Display for RecordAckRangeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Empty => f.write_str("record ack requires at least one acknowledged range"), + Self::InvertedRange => { + f.write_str("record ack range start must be less than or equal to end") + } + Self::NotCanonical => f.write_str( + "record ack ranges must be passed in descending, disjoint order with a gap between adjacent ranges", + ), + } + } +} + +impl std::error::Error for RecordAckRangeError {} + +pub struct RecordAckRangeIter<'a> { + largest_acked: u64, + first_range_len: Option, + previous_start: Option, + blocks: std::slice::Iter<'a, RecordAckBlock>, +} + +impl Iterator for RecordAckRangeIter<'_> { + type Item = RangeInclusive; + + fn next(&mut self) -> Option { + if let Some(first_range_len) = self.first_range_len.take() { + let end = self.largest_acked; + let start = end - first_range_len.into_inner(); + self.previous_start = Some(start); + return Some(RecordSeq::from_u64(start).unwrap()..=RecordSeq::from_u64(end).unwrap()); + } + + let block = self.blocks.next()?; + let previous_start = self + .previous_start + .expect("first ack range is always yielded"); + // gap is encoded as missing_count - 1, so decoding steps back by gap + 2. + let end = previous_start - block.gap.into_inner() - 2; + let start = end - block.range_len.into_inner(); + self.previous_start = Some(start); + Some(RecordSeq::from_u64(start).unwrap()..=RecordSeq::from_u64(end).unwrap()) + } +} + +impl WireEncode for RecordAck { + fn encoded_len(&self) -> usize { + self.largest_acked.encoded_len() + + Self::block_count_len(self.blocks.len()) + + self.first_range_len.encoded_len() + + self + .blocks + .iter() + .map(RecordAckBlock::encoded_len) + .sum::() + } + + fn encode(&self, out: &mut W) { + self.largest_acked.encode(out); + VarInt::try_from(self.blocks.len()).unwrap().encode(out); + self.first_range_len.encode(out); + for block in &self.blocks { + block.gap.encode(out); + block.range_len.encode(out); + } + } +} + +impl codec::WireDecode for RecordAck { + fn decode(reader: &mut codec::Reader) -> Result { + let largest_acked = reader.decode()?; + let block_count = usize::try_from(reader.decode::()?.into_inner()) + .map_err(|_| WireError::InvalidPayload)?; + let first_range_len = reader.decode::()?; + let mut blocks = Vec::with_capacity(block_count); + for _ in 0..block_count { + blocks.push(RecordAckBlock { + gap: reader.decode::()?, + range_len: reader.decode::()?, + }); + } + + let ack = Self { + largest_acked, + first_range_len, + blocks: blocks.into_boxed_slice(), + }; + + // validate + { + let mut previous_start = ack + .largest_acked + .into_inner() + .checked_sub(ack.first_range_len.into_inner()) + .ok_or(WireError::InvalidPayload)?; + + for block in &ack.blocks { + let end = previous_start + .checked_sub( + block + .gap + .into_inner() + .checked_add(2) + .ok_or(WireError::InvalidPayload)?, + ) + .ok_or(WireError::InvalidPayload)?; + previous_start = end + .checked_sub(block.range_len.into_inner()) + .ok_or(WireError::InvalidPayload)?; + } + } + Ok(ack) + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct RecordAckBuilder { + largest_acked: Option, + first_range_len: Option, + blocks: Vec, + previous_start: Option, + wire_len: usize, +} + +impl RecordAckBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn try_push_range( + &mut self, + range: RangeInclusive, + max_wire_size: usize, + ) -> Result { + let start = range.start().into_inner(); + let end = range.end().into_inner(); + if start > end { + return Err(RecordAckRangeError::InvertedRange); + } + + let range_len = VarInt::from_u64(end - start).unwrap(); + if let Some(previous_start) = self.previous_start { + if end.saturating_add(1) >= previous_start { + return Err(RecordAckRangeError::NotCanonical); + } + + let gap = previous_start + .checked_sub(end) + .and_then(|delta| delta.checked_sub(2)) + .expect("canonical ack ranges stay separated by at least one sequence"); + let block = RecordAckBlock { + gap: VarInt::from_u64(gap).unwrap(), + range_len, + }; + let current_block_count_len = RecordAck::block_count_len(self.blocks.len()); + let next_block_count_len = RecordAck::block_count_len(self.blocks.len() + 1); + let next_wire_len = self.wire_len + + (next_block_count_len - current_block_count_len) + + block.encoded_len(); + if next_wire_len > max_wire_size { + return Ok(false); + } + + self.previous_start = Some(start); + self.wire_len = next_wire_len; + self.blocks.push(block); + return Ok(true); + } + + let largest_acked = RecordSeq::from_u64(end).unwrap(); + let wire_len = + largest_acked.encoded_len() + RecordAck::block_count_len(0) + range_len.encoded_len(); + if wire_len > max_wire_size { + return Ok(false); + } + + self.largest_acked = Some(largest_acked); + self.first_range_len = Some(range_len); + self.previous_start = Some(start); + self.wire_len = wire_len; + Ok(true) + } + + pub fn build(self) -> Result { + let Some(largest_acked) = self.largest_acked else { + return Err(RecordAckRangeError::Empty); + }; + + Ok(RecordAck { + largest_acked, + first_range_len: self.first_range_len.unwrap(), + blocks: self.blocks.into_boxed_slice(), + }) + } +} +#[cfg(test)] +mod tests { + use std::ops::RangeInclusive; + + use super::{RecordAck, RecordAckBlock, RecordAckBuilder, RecordAckRangeError}; + use crate::{RecordSeq, VarInt, WireDecode, WireEncode, WireError}; + + fn seq(value: u64) -> RecordSeq { + RecordSeq::from_u64(value).unwrap() + } + + fn ack_range(start: u64, end: u64) -> RangeInclusive { + seq(start)..=seq(end) + } + + fn varint(value: u64) -> VarInt { + VarInt::from_u64(value).unwrap() + } + + #[test] + fn encode_decode_round_trip() { + let ack = + RecordAck::from_ranges([ack_range(95, 100), ack_range(90, 92), ack_range(80, 80)]) + .unwrap(); + let encoded = ack.encode_vec(); + + assert_eq!(RecordAck::decode_exact(encoded.as_slice()).unwrap(), ack); + } + + #[test] + fn wire_fields_match_gap_encoding() { + let ack = + RecordAck::from_ranges([ack_range(95, 100), ack_range(90, 92), ack_range(80, 80)]) + .unwrap(); + + assert_eq!(ack.largest_acked, seq(100)); + assert_eq!(ack.first_range_len, varint(5)); + assert_eq!( + ack.blocks.as_ref(), + &[ + RecordAckBlock { + gap: varint(1), + range_len: varint(2), + }, + RecordAckBlock { + gap: varint(8), + range_len: varint(0), + } + ] + ); + } + + #[test] + fn builder_matches_from_ranges() { + let mut builder = RecordAckBuilder::new(); + assert!(builder + .try_push_range(ack_range(95, 100), usize::MAX) + .unwrap()); + assert!(builder + .try_push_range(ack_range(90, 92), usize::MAX) + .unwrap()); + assert!(builder + .try_push_range(ack_range(80, 80), usize::MAX) + .unwrap()); + + assert_eq!( + builder.build().unwrap(), + RecordAck::from_ranges([ack_range(95, 100), ack_range(90, 92), ack_range(80, 80)]) + .unwrap() + ); + } + + #[test] + fn builder_stops_when_budget_is_exhausted() { + let first_only = RecordAck::from_ranges([ack_range(95, 100)]).unwrap(); + let mut builder = RecordAckBuilder::new(); + + assert!(builder + .try_push_range(ack_range(95, 100), first_only.encoded_len()) + .unwrap()); + assert!(!builder + .try_push_range(ack_range(90, 92), first_only.encoded_len()) + .unwrap()); + assert_eq!(builder.build().unwrap(), first_only); + } + + #[test] + fn builder_rejects_non_canonical_ranges() { + let mut builder = RecordAckBuilder::new(); + assert!(builder + .try_push_range(ack_range(95, 100), usize::MAX) + .unwrap()); + assert_eq!( + builder.try_push_range(ack_range(90, 95), usize::MAX), + Err(RecordAckRangeError::NotCanonical) + ); + } + + #[test] + fn rejects_unsorted_ranges() { + assert_eq!( + RecordAck::from_ranges([ack_range(90, 92), ack_range(95, 100)]), + Err(RecordAckRangeError::NotCanonical) + ); + } + + #[test] + fn rejects_touching_ranges() { + assert_eq!( + RecordAck::from_ranges([ack_range(10, 12), ack_range(7, 9)]), + Err(RecordAckRangeError::NotCanonical) + ); + } + + #[test] + fn rejects_overlapping_ranges() { + assert_eq!( + RecordAck::from_ranges([ack_range(10, 12), ack_range(8, 11)]), + Err(RecordAckRangeError::NotCanonical) + ); + } + + #[test] + fn contains_matches_range_membership() { + let ack = RecordAck::from_ranges([ + ack_range(150, 163), + ack_range(105, 110), + ack_range(100, 100), + ]) + .unwrap(); + + assert!(ack.contains(100)); + assert!(ack.contains(107)); + assert!(ack.contains(163)); + assert!(!ack.contains(99)); + assert!(!ack.contains(104)); + assert!(!ack.contains(164)); + } + + #[test] + fn empty_ack_is_rejected() { + assert_eq!(RecordAck::from_ranges([]), Err(RecordAckRangeError::Empty)); + } + + #[test] + fn inverted_range_is_rejected() { + assert_eq!( + RecordAck::from_ranges([ack_range(5, 4)]), + Err(RecordAckRangeError::InvertedRange) + ); + } + + #[test] + fn decode_rejects_underflowing_ack_blocks() { + let encoded = vec![ + 42, // largest_acked + 1, // block_count + 0, // first_range_len + 41, // gap: implies a missing run larger than largest_acked + 0, // range_len + ]; + + assert_eq!( + RecordAck::decode_exact(encoded.as_slice()), + Err(WireError::InvalidPayload) + ); + } + + #[test] + fn decode_rejects_truncated_payload() { + assert_eq!( + RecordAck::decode_exact(&[][..]), + Err(WireError::InvalidPayload) + ); + + let encoded = RecordAck::from_ranges([ack_range(42, 42)]) + .unwrap() + .encode_vec(); + assert_eq!( + RecordAck::decode_exact(&encoded[..encoded.len() - 1]), + Err(WireError::InvalidPayload) + ); + } +} diff --git a/ql-wire/src/encrypted/builder.rs b/ql-wire/src/encrypted/builder.rs new file mode 100644 index 0000000..4293323 --- /dev/null +++ b/ql-wire/src/encrypted/builder.rs @@ -0,0 +1,172 @@ +use bytes::BufMut; + +use super::{RecordAck, SessionClose, SessionFrame, StreamClose, StreamData, StreamWindow}; +use crate::{ + BufView, ConnectionId, Nonce, QlCrypto, RecordSeq, RecordType, SessionHeader, SessionKey, + WireEncode, QL_WIRE_VERSION, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionRecordBuilder { + seq: RecordSeq, + prefix_len: usize, + max_capacity: usize, + bytes: Vec, +} + +impl SessionRecordBuilder { + pub const MIN_CAPACITY: usize = 1 + + 1 + + ConnectionId::SIZE + + RecordSeq::MAX_ENCODED_LEN + + crate::ENCRYPTED_MESSAGE_AUTH_SIZE; + + pub fn new(seq: RecordSeq, max_capacity: usize) -> Self { + let prefix_len = + 1 + 1 + ConnectionId::SIZE + seq.encoded_len() + crate::ENCRYPTED_MESSAGE_AUTH_SIZE; + assert!(max_capacity >= prefix_len); + Self { + seq, + prefix_len, + max_capacity, + bytes: Vec::new(), + } + } + + pub fn seq(&self) -> RecordSeq { + self.seq + } + + pub fn prefix_len(&self) -> usize { + self.prefix_len + } + + pub fn max_capacity(&self) -> usize { + self.max_capacity + } + + pub fn len(&self) -> usize { + self.bytes.len().saturating_sub(self.prefix_len) + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn remaining_capacity(&self) -> usize { + self.max_capacity + .saturating_sub(self.bytes.len().max(self.prefix_len)) + } + + pub fn bytes(&self) -> &[u8] { + self.bytes.get(self.prefix_len..).unwrap_or_default() + } + + pub fn push_ping(&mut self) -> bool { + self.push_empty_frame(super::SessionFrameKind::Ping) + } + + pub fn push_unpair(&mut self) -> bool { + self.push_empty_frame(super::SessionFrameKind::Unpair) + } + + pub fn push_ack(&mut self, ack: &RecordAck) -> bool { + self.push_frame_payload(super::SessionFrameKind::Ack, ack) + } + + pub fn push_stream_data(&mut self, frame: &StreamData) -> bool { + self.push_frame_payload(super::SessionFrameKind::StreamData, frame) + } + + pub fn push_stream_window(&mut self, frame: &StreamWindow) -> bool { + self.push_frame_payload(super::SessionFrameKind::StreamWindow, frame) + } + + pub fn push_stream_close(&mut self, frame: &StreamClose) -> bool { + self.push_frame_payload(super::SessionFrameKind::StreamClose, frame) + } + + pub fn push_close(&mut self, close: &SessionClose) -> bool { + self.push_frame_payload(super::SessionFrameKind::Close, close) + } + + pub fn push_frame(&mut self, frame: &SessionFrame) -> bool { + match frame { + SessionFrame::Ping => self.push_ping(), + SessionFrame::Unpair => self.push_unpair(), + SessionFrame::Ack(frame) => self.push_ack(frame), + SessionFrame::StreamData(frame) => self.push_stream_data(frame), + SessionFrame::StreamWindow(frame) => self.push_stream_window(frame), + SessionFrame::StreamClose(frame) => self.push_stream_close(frame), + SessionFrame::Close(close) => self.push_close(close), + } + } + + pub fn encrypt( + mut self, + crypto: &impl QlCrypto, + connection_id: ConnectionId, + session_key: &SessionKey, + ) -> Vec { + self.ensure_prefix_capacity(0); + let header = SessionHeader { + connection_id, + seq: self.seq, + }; + let aad = header.aad(); + let nonce = Nonce::from_counter(self.seq.into_inner()); + let auth = crypto.aes256_gcm_encrypt( + session_key, + &nonce, + &aad, + &mut self.bytes[self.prefix_len..], + ); + + let mut prefix = &mut self.bytes[..self.prefix_len]; + prefix[0] = QL_WIRE_VERSION; + prefix[1] = RecordType::Session as u8; + prefix = &mut prefix[2..]; + header.encode(&mut prefix); + auth.encode(&mut prefix); + debug_assert!(prefix.is_empty()); + self.bytes + } + + fn push_wire_size(&mut self, wire_size: usize, encode: impl FnOnce(&mut Vec)) -> bool { + if !self.can_push_len(wire_size) { + return false; + } + self.ensure_prefix_capacity(wire_size); + let start = self.bytes.len(); + encode(&mut self.bytes); + debug_assert_eq!(self.bytes.len(), start + wire_size); + true + } + + fn push_empty_frame(&mut self, kind: super::SessionFrameKind) -> bool { + self.push_wire_size(1, |out| out.put_u8(kind as u8)) + } + + fn push_frame_payload( + &mut self, + kind: super::SessionFrameKind, + payload: &T, + ) -> bool { + let payload_wire_size = payload.encoded_len(); + self.push_wire_size(1 + payload_wire_size, |out| { + out.put_u8(kind as u8); + payload.encode(out); + }) + } + + fn can_push_len(&self, len: usize) -> bool { + len <= self.remaining_capacity() + } + + fn ensure_prefix_capacity(&mut self, additional_body_len: usize) { + if self.bytes.is_empty() { + self.bytes.reserve(self.prefix_len + additional_body_len); + self.bytes.resize(self.prefix_len, 0); + } + } +} diff --git a/ql-wire/src/encrypted/close.rs b/ql-wire/src/encrypted/close.rs new file mode 100644 index 0000000..e0860d7 --- /dev/null +++ b/ql-wire/src/encrypted/close.rs @@ -0,0 +1,55 @@ +use crate::{codec, codec::Reader, ByteSlice, WireEncode, WireError}; + +/// closes the whole session immediately with a close code. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionClose { + pub code: SessionCloseCode, +} + +impl SessionClose { + pub const WIRE_SIZE: usize = size_of::(); +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct SessionCloseCode(pub u16); + +impl SessionCloseCode { + pub const CANCELLED: Self = Self(0); + pub const PROTOCOL: Self = Self(1); + pub const TIMEOUT: Self = Self(2); +} + +impl WireEncode for SessionCloseCode { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for SessionCloseCode { + fn decode(reader: &mut Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +impl codec::WireDecode for SessionClose { + fn decode(reader: &mut Reader) -> Result { + Ok(Self { + code: reader.decode()?, + }) + } +} + +impl WireEncode for SessionClose { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.code.encode(out); + } +} diff --git a/ql-wire/src/encrypted/mod.rs b/ql-wire/src/encrypted/mod.rs new file mode 100644 index 0000000..563f9de --- /dev/null +++ b/ql-wire/src/encrypted/mod.rs @@ -0,0 +1,178 @@ +use crate::{ + codec, encrypted_message::EncryptedMessage, BufView, ByteSlice, Nonce, QlCrypto, Reader, + SessionHeader, SessionKey, WireDecode, WireEncode, WireError, +}; + +mod ack; +mod builder; +mod close; +mod route_id; +mod stream_close; +mod stream_data; +mod stream_id; +mod stream_window; + +pub use ack::*; +pub use builder::*; +pub use close::*; +pub use route_id::*; +pub use stream_close::*; +pub use stream_data::*; +pub use stream_id::*; +pub use stream_window::*; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionFrame { + // todo: do we need ping as explicit frame? + Ping, + Unpair, + Ack(RecordAck), + StreamData(StreamData), + StreamWindow(StreamWindow), + StreamClose(StreamClose), + Close(SessionClose), +} + +impl WireDecode for SessionFrame { + fn decode(reader: &mut Reader) -> Result { + let kind = reader.decode::()?; + let frame = match kind { + SessionFrameKind::Ping => Self::Ping, + SessionFrameKind::Unpair => Self::Unpair, + SessionFrameKind::Ack => Self::Ack(reader.decode::()?), + SessionFrameKind::StreamData => Self::StreamData(reader.decode::>()?), + SessionFrameKind::StreamWindow => Self::StreamWindow(reader.decode::()?), + SessionFrameKind::StreamClose => Self::StreamClose(reader.decode::()?), + SessionFrameKind::Close => Self::Close(reader.decode::()?), + }; + Ok(frame) + } +} + +impl SessionFrame { + fn kind(&self) -> SessionFrameKind { + match self { + Self::Ping => SessionFrameKind::Ping, + Self::Unpair => SessionFrameKind::Unpair, + Self::Ack(_) => SessionFrameKind::Ack, + Self::StreamData(_) => SessionFrameKind::StreamData, + Self::StreamWindow(_) => SessionFrameKind::StreamWindow, + Self::StreamClose(_) => SessionFrameKind::StreamClose, + Self::Close(_) => SessionFrameKind::Close, + } + } +} + +impl SessionFrame { + pub fn into_owned(self) -> SessionFrame> { + match self { + Self::Ping => SessionFrame::Ping, + Self::Unpair => SessionFrame::Unpair, + Self::Ack(frame) => SessionFrame::Ack(frame), + Self::StreamData(frame) => SessionFrame::StreamData(frame.into_owned()), + Self::StreamWindow(frame) => SessionFrame::StreamWindow(frame), + Self::StreamClose(frame) => SessionFrame::StreamClose(frame), + Self::Close(frame) => SessionFrame::Close(frame), + } + } +} + +impl WireEncode for SessionFrame { + fn encoded_len(&self) -> usize { + 1 + match self { + Self::Ping | Self::Unpair => 0, + Self::Ack(frame) => frame.encoded_len(), + Self::StreamData(frame) => frame.encoded_len(), + Self::StreamWindow(frame) => frame.encoded_len(), + Self::StreamClose(frame) => frame.encoded_len(), + Self::Close(frame) => frame.encoded_len(), + } + } + + fn encode(&self, out: &mut W) { + out.put_u8(self.kind() as u8); + match self { + Self::Ping | Self::Unpair => {} + Self::Ack(frame) => frame.encode(out), + Self::StreamData(frame) => frame.encode(out), + Self::StreamWindow(frame) => frame.encode(out), + Self::StreamClose(frame) => frame.encode(out), + Self::Close(frame) => frame.encode(out), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum SessionFrameKind { + Ping = 1, + Ack = 2, + StreamData = 3, + StreamWindow = 4, + StreamClose = 5, + Close = 6, + Unpair = 7, +} + +impl TryFrom for SessionFrameKind { + type Error = WireError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::Ping), + 2 => Ok(Self::Ack), + 3 => Ok(Self::StreamData), + 4 => Ok(Self::StreamWindow), + 5 => Ok(Self::StreamClose), + 6 => Ok(Self::Close), + 7 => Ok(Self::Unpair), + _ => Err(WireError::InvalidPayload), + } + } +} + +impl codec::WireDecode for SessionFrameKind { + fn decode(reader: &mut codec::Reader) -> Result { + reader.decode::()?.try_into() + } +} + +pub fn parse_session_frames(bytes: B) -> SessionFrameIter { + SessionFrameIter { + reader: Reader::new(bytes), + } +} + +pub fn decode_session_frames(bytes: &[u8]) -> Result>>, WireError> { + parse_session_frames(bytes) + .map(|frame| frame.map(SessionFrame::into_owned)) + .collect() +} + +#[derive(Clone)] +pub struct SessionFrameIter { + reader: Reader, +} + +impl Iterator for SessionFrameIter { + type Item = Result, WireError>; + + fn next(&mut self) -> Option { + if self.reader.is_empty() { + None + } else { + Some(self.reader.decode::>()) + } + } +} + +pub fn decrypt_record>( + crypto: &impl QlCrypto, + header: &SessionHeader, + encrypted: EncryptedMessage, + session_key: &SessionKey, +) -> Result { + let aad = header.aad(); + let nonce = Nonce::from_counter(header.seq.into_inner()); + encrypted.decrypt_in_place(crypto, session_key, &nonce, &aad) +} diff --git a/ql-wire/src/encrypted/route_id.rs b/ql-wire/src/encrypted/route_id.rs new file mode 100644 index 0000000..6b91a52 --- /dev/null +++ b/ql-wire/src/encrypted/route_id.rs @@ -0,0 +1,55 @@ +use crate::{ByteSlice, Reader, VarInt, VarIntBoundsExceeded, WireDecode, WireEncode, WireError}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct RouteId(pub VarInt); + +impl RouteId { + pub const MAX_ENCODED_LEN: usize = VarInt::MAX_SIZE; + + pub const fn from_u32(value: u32) -> Self { + Self(VarInt::from_u32(value)) + } + + pub fn from_u64(value: u64) -> Result { + Ok(Self(VarInt::from_u64(value)?)) + } + + pub const fn into_inner(self) -> u64 { + self.0.into_inner() + } +} + +impl WireEncode for RouteId { + fn encoded_len(&self) -> usize { + self.0.size() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl WireDecode for RouteId { + fn decode(reader: &mut Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +impl From for RouteId { + fn from(value: VarInt) -> Self { + Self(value) + } +} + +impl From for RouteId { + fn from(value: u32) -> Self { + Self::from_u32(value) + } +} + +impl std::fmt::Display for RouteId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} diff --git a/ql-wire/src/encrypted/stream_close.rs b/ql-wire/src/encrypted/stream_close.rs new file mode 100644 index 0000000..20ddb87 --- /dev/null +++ b/ql-wire/src/encrypted/stream_close.rs @@ -0,0 +1,116 @@ +use super::StreamId; +use crate::{codec, ByteSlice, WireEncode, WireError}; + +/// aborts one or both lanes of a stream with a close code +/// +/// stream origin is the peer that opened the stream +/// origin lane carries bytes sent by the stream origin +/// return lane carries bytes sent back toward the stream origin +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamClose { + pub stream_id: StreamId, + pub target: CloseTarget, + pub code: StreamCloseCode, +} + +impl StreamClose {} + +impl WireEncode for StreamClose { + fn encoded_len(&self) -> usize { + self.stream_id.encoded_len() + self.target.encoded_len() + self.code.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.stream_id.encode(out); + self.target.encode(out); + self.code.encode(out); + } +} + +impl codec::WireDecode for StreamClose { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + stream_id: reader.decode()?, + target: reader.decode()?, + code: reader.decode()?, + }) + } +} + +/// selects which stream lane a [`StreamClose`] applies to +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum CloseTarget { + /// close the lane sent by the stream origin + Origin = 1, + /// close the lane sent back toward the stream origin + Return = 2, + /// close both stream lanes + Both = 3, +} + +impl CloseTarget { + pub const fn to_wire(self) -> u8 { + self as u8 + } +} + +impl WireEncode for CloseTarget { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + self.to_wire().encode(out); + } +} + +impl TryFrom for CloseTarget { + type Error = WireError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::Origin), + 2 => Ok(Self::Return), + 3 => Ok(Self::Both), + _ => Err(WireError::InvalidPayload), + } + } +} + +impl codec::WireDecode for CloseTarget { + fn decode(reader: &mut codec::Reader) -> Result { + reader.decode::()?.try_into() + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct StreamCloseCode(pub u16); + +impl StreamCloseCode { + /// the stream was aborted intentionally before graceful completion + pub const CANCELLED: Self = Self(0); +} + +impl codec::WireDecode for StreamCloseCode { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +impl WireEncode for StreamCloseCode { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl std::fmt::Display for StreamCloseCode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} diff --git a/ql-wire/src/encrypted/stream_data.rs b/ql-wire/src/encrypted/stream_data.rs new file mode 100644 index 0000000..9174fe5 --- /dev/null +++ b/ql-wire/src/encrypted/stream_data.rs @@ -0,0 +1,135 @@ +use bytes::Buf; + +use super::{RouteId, StreamId}; +use crate::{codec, BufView, ByteSlice, VarInt, WireDecode, WireEncode, WireError}; + +/// carries bytes for a stream and may finish that sending direction. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamData { + pub stream_id: StreamId, + pub offset: VarInt, + pub header: Option, + pub fin: bool, + pub bytes: B, +} + +impl StreamData { + pub const MIN_WIRE_SIZE: usize = StreamId::MAX_ENCODED_LEN + + VarInt::MAX_SIZE + + size_of::() + + StreamHeader::MAX_WIRE_SIZE + + VarInt::MAX_SIZE; +} + +impl WireDecode for StreamData { + fn decode(reader: &mut codec::Reader) -> Result { + let stream_id = reader.decode()?; + let offset: VarInt = reader.decode()?; + let flags = reader.decode::()?; + let fin = (flags & flag::FIN) != 0; + let has_header = (flags & flag::HEADER) != 0; + let header = if has_header { + Some(reader.decode()?) + } else { + None + }; + let bytes_len = usize::try_from(reader.decode::()?.into_inner()) + .map_err(|_| WireError::InvalidPayload)?; + + Ok(Self { + stream_id, + offset, + header, + fin, + bytes: reader.take_bytes(bytes_len)?, + }) + } +} + +impl StreamData { + pub fn into_owned(self) -> StreamData> + where + B: ByteSlice, + { + StreamData { + stream_id: self.stream_id, + offset: self.offset, + header: self.header, + fin: self.fin, + bytes: self.bytes.to_vec(), + } + } +} + +impl WireEncode for StreamData { + fn encoded_len(&self) -> usize { + let bytes = self.bytes.buf(); + let bytes_len = bytes.remaining(); + self.stream_id.encoded_len() + + self.offset.encoded_len() + + size_of::() + + self.header.as_ref().map_or(0, WireEncode::encoded_len) + + VarInt::try_from(bytes_len).unwrap().encoded_len() + + bytes_len + } + + fn encode(&self, out: &mut W) { + debug_assert!( + self.offset.into_inner() == 0 || self.header.is_none(), + "stream header is only valid at offset 0" + ); + + self.stream_id.encode(out); + self.offset.encode(out); + let mut flags = 0; + if self.fin { + flags |= flag::FIN; + } + if self.header.is_some() { + flags |= flag::HEADER; + } + flags.encode(out); + if let Some(header) = &self.header { + header.encode(out); + } + let mut bytes = self.bytes.buf(); + VarInt::try_from(bytes.remaining()).unwrap().encode(out); + while bytes.has_remaining() { + let chunk = bytes.chunk(); + out.put_slice(chunk); + bytes.advance(chunk.len()); + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct StreamHeader { + pub route_id: RouteId, +} + +impl StreamHeader { + pub const MAX_WIRE_SIZE: usize = RouteId::MAX_ENCODED_LEN; +} + +impl WireDecode for StreamHeader { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + route_id: reader.decode()?, + }) + } +} + +impl WireEncode for StreamHeader { + fn encoded_len(&self) -> usize { + self.route_id.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.route_id.encode(out); + } +} + +mod flag { + pub const FIN: u8 = 0x01; + pub const HEADER: u8 = 0x02; +} diff --git a/ql-wire/src/encrypted/stream_id.rs b/ql-wire/src/encrypted/stream_id.rs new file mode 100644 index 0000000..0700225 --- /dev/null +++ b/ql-wire/src/encrypted/stream_id.rs @@ -0,0 +1,35 @@ +use crate::{ByteSlice, Reader, VarInt, WireDecode, WireEncode, WireError}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct StreamId(pub VarInt); + +impl StreamId { + pub const MAX_ENCODED_LEN: usize = VarInt::MAX_SIZE; + + pub const fn into_inner(self) -> u64 { + self.0.into_inner() + } +} + +impl WireEncode for StreamId { + fn encoded_len(&self) -> usize { + self.0.size() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl WireDecode for StreamId { + fn decode(reader: &mut Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +impl std::fmt::Display for StreamId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} diff --git a/ql-wire/src/encrypted/stream_window.rs b/ql-wire/src/encrypted/stream_window.rs new file mode 100644 index 0000000..6a2274f --- /dev/null +++ b/ql-wire/src/encrypted/stream_window.rs @@ -0,0 +1,29 @@ +use super::StreamId; +use crate::{codec, ByteSlice, VarInt, WireEncode, WireError}; + +/// advertises the highest byte offset the peer may send on a stream. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamWindow { + pub stream_id: StreamId, + pub maximum_offset: VarInt, +} + +impl WireEncode for StreamWindow { + fn encoded_len(&self) -> usize { + self.stream_id.encoded_len() + self.maximum_offset.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.stream_id.encode(out); + self.maximum_offset.encode(out); + } +} + +impl codec::WireDecode for StreamWindow { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + stream_id: reader.decode()?, + maximum_offset: reader.decode()?, + }) + } +} diff --git a/ql-wire/src/encrypted_message.rs b/ql-wire/src/encrypted_message.rs new file mode 100644 index 0000000..9e11d3d --- /dev/null +++ b/ql-wire/src/encrypted_message.rs @@ -0,0 +1,97 @@ +use crate::{ + codec, ByteSlice, Nonce, QlCrypto, SessionKey, WireDecode, WireEncode, WireError, + ENCRYPTED_MESSAGE_AUTH_SIZE, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EncryptedMessage { + pub auth: [u8; ENCRYPTED_MESSAGE_AUTH_SIZE], + pub ciphertext: B, +} + +impl EncryptedMessage { + pub const AUTH_SIZE: usize = ENCRYPTED_MESSAGE_AUTH_SIZE; + pub const HEADER_LEN: usize = Self::AUTH_SIZE; + + pub fn into_owned(self) -> EncryptedMessage> + where + B: ByteSlice, + { + EncryptedMessage { + auth: self.auth, + ciphertext: self.ciphertext.to_vec(), + } + } +} + +impl WireDecode for EncryptedMessage { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + auth: reader.decode()?, + ciphertext: reader.take_rest(), + }) + } +} + +impl> EncryptedMessage { + pub fn decrypt( + &self, + crypto: &impl QlCrypto, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + ) -> Result, WireError> { + let mut plaintext = self.ciphertext.as_ref().to_vec(); + if !crypto.aes256_gcm_decrypt(key, nonce, aad, &mut plaintext, &self.auth) { + return Err(WireError::DecryptFailed); + } + Ok(plaintext) + } +} + +impl> WireEncode for EncryptedMessage { + fn encoded_len(&self) -> usize { + Self::HEADER_LEN + self.ciphertext.as_ref().len() + } + + fn encode(&self, out: &mut W) { + self.auth.encode(out); + self.ciphertext.as_ref().encode(out); + } +} + +impl> EncryptedMessage { + pub fn decrypt_in_place( + mut self, + crypto: &impl QlCrypto, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + ) -> Result { + let ciphertext = self.ciphertext.as_mut(); + if !crypto.aes256_gcm_decrypt(key, nonce, aad, ciphertext, &self.auth) { + return Err(WireError::DecryptFailed); + } + Ok(self.ciphertext) + } +} + +impl EncryptedMessage> { + pub fn encrypt( + crypto: &impl QlCrypto, + key: &SessionKey, + mut plaintext: Vec, + nonce: &Nonce, + aad: &[u8], + ) -> Self { + let auth = crypto.aes256_gcm_encrypt(key, nonce, aad, &mut plaintext); + Self { + auth, + ciphertext: plaintext, + } + } + + pub fn decode(bytes: &[u8]) -> Result { + Ok(EncryptedMessage::decode_exact(bytes)?.into_owned()) + } +} diff --git a/ql-wire/src/error.rs b/ql-wire/src/error.rs new file mode 100644 index 0000000..8da1eec --- /dev/null +++ b/ql-wire/src/error.rs @@ -0,0 +1,33 @@ +use core::fmt; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum WireError { + InvalidPayload, + InvalidHandshakeHeader, + InvalidHandshakeMeta, + InvalidPairingId, + InvalidRemoteBundle, + InvalidTransportParams, + Expired, + DecryptFailed, + InvalidState, +} + +impl fmt::Display for WireError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let message = match self { + Self::InvalidPayload => "invalid payload", + Self::InvalidHandshakeHeader => "invalid handshake header", + Self::InvalidHandshakeMeta => "invalid handshake meta", + Self::InvalidPairingId => "invalid pairing id", + Self::InvalidRemoteBundle => "invalid remote bundle", + Self::InvalidTransportParams => "invalid transport params", + Self::Expired => "expired", + Self::DecryptFailed => "decryption failed", + Self::InvalidState => "invalid state", + }; + f.write_str(message) + } +} + +impl std::error::Error for WireError {} diff --git a/ql-wire/src/handshake/ik.rs b/ql-wire/src/handshake/ik.rs new file mode 100644 index 0000000..628e30e --- /dev/null +++ b/ql-wire/src/handshake/ik.rs @@ -0,0 +1,376 @@ +use super::{ + decrypt_mlkem_ciphertext, decrypt_peer_bundle, encrypt_mlkem_ciphertext, encrypt_peer_bundle, + finalize_handshake, generate_ephemeral_keypair, init_ik_symmetric, initialize_handshake_meta, + mix_hash_ephemeral, mix_hash_routed_handshake, require_handshake_meta, + EncryptedMlKemCiphertext, EncryptedPeerBundle, EphemeralKeyPair, EphemeralPublicKey, + FinalizedHandshake, HandshakeHeader, Role, SymmetricState, TransportParams, +}; +use crate::{ + codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, + QlIdentity, WireEncode, WireError, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Ik1 { + pub header: HandshakeHeader, + pub meta: HandshakeMeta, + pub transport_params: TransportParams, + pub skem_ciphertext: MlKemCiphertext, + pub ephemeral: EphemeralPublicKey, + pub static_bundle: EncryptedPeerBundle, +} + +impl codec::WireDecode for Ik1 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + transport_params: reader.decode()?, + skem_ciphertext: reader.decode()?, + ephemeral: reader.decode()?, + static_bundle: reader.decode()?, + }) + } +} + +impl WireEncode for Ik1 { + fn encoded_len(&self) -> usize { + HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + + MlKemCiphertext::SIZE + + EphemeralPublicKey::WIRE_SIZE + + self.static_bundle.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.transport_params.encode(out); + self.skem_ciphertext.encode(out); + self.ephemeral.encode(out); + self.static_bundle.encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Ik2 { + pub header: HandshakeHeader, + pub meta: HandshakeMeta, + pub transport_params: TransportParams, + pub ekem_ciphertext: MlKemCiphertext, + pub skem_ciphertext: EncryptedMlKemCiphertext, +} + +impl Ik2 { + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + + MlKemCiphertext::SIZE + + EncryptedMlKemCiphertext::WIRE_SIZE; +} + +impl codec::WireDecode for Ik2 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + transport_params: reader.decode()?, + ekem_ciphertext: reader.decode()?, + skem_ciphertext: reader.decode()?, + }) + } +} + +impl WireEncode for Ik2 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.transport_params.encode(out); + self.ekem_ciphertext.encode(out); + self.skem_ciphertext.encode(out); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum IkStep { + Send1, + Recv1, + Send2, + Recv2, + Done, +} + +#[derive(Debug, Clone)] +pub struct IkHandshake { + role: Role, + step: IkStep, + symmetric: SymmetricState, + local: QlIdentity, + remote_bundle: Option, + local_ephemeral: Option, + remote_ephemeral: Option, + handshake_meta: Option, + local_transport_params: TransportParams, + remote_transport_params: Option, +} + +impl IkHandshake { + pub fn new_initiator( + crypto: &impl QlCrypto, + local: QlIdentity, + remote_bundle: PeerBundle, + local_transport_params: TransportParams, + ) -> Self { + let symmetric = init_ik_symmetric(crypto, &remote_bundle); + Self { + role: Role::Initiator, + step: IkStep::Send1, + symmetric, + local, + remote_bundle: Some(remote_bundle), + local_ephemeral: None, + remote_ephemeral: None, + handshake_meta: None, + local_transport_params, + remote_transport_params: None, + } + } + + pub fn new_responder( + crypto: &impl QlCrypto, + local: QlIdentity, + expected_remote: Option, + local_transport_params: TransportParams, + ) -> Self { + let symmetric = init_ik_symmetric(crypto, &local.bundle()); + Self { + role: Role::Responder, + step: IkStep::Recv1, + symmetric, + local, + remote_bundle: expected_remote, + local_ephemeral: None, + remote_ephemeral: None, + handshake_meta: None, + local_transport_params, + remote_transport_params: None, + } + } + + pub fn is_finished(&self) -> bool { + self.step == IkStep::Done + } + + fn outbound_header(&self) -> Result { + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; + Ok(HandshakeHeader { + sender: self.local.qid, + recipient: remote_bundle.qid, + }) + } + + fn ensure_inbound_recipient(&self, header: HandshakeHeader) -> Result<(), WireError> { + if header.recipient == self.local.qid { + Ok(()) + } else { + Err(WireError::InvalidPayload) + } + } + + fn ensure_known_remote_sender(&self, header: HandshakeHeader) -> Result<(), WireError> { + if let Some(remote_bundle) = self.remote_bundle.as_ref() { + if header.sender != remote_bundle.qid { + return Err(WireError::InvalidPayload); + } + } + Ok(()) + } + + pub fn write_1( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != IkStep::Send1 { + return Err(WireError::InvalidState); + } + initialize_handshake_meta(&mut self.handshake_meta, meta)?; + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; + let header = self.outbound_header()?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Ik1, + meta, + self.local_transport_params, + ); + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); + self.symmetric.mix_hash(crypto, skem_ciphertext.as_bytes()); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + let local_ephemeral = generate_ephemeral_keypair(crypto); + let public = local_ephemeral.public(); + mix_hash_ephemeral(&mut self.symmetric, crypto, &public); + + let static_bundle = encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; + + self.local_ephemeral = Some(local_ephemeral); + self.step = IkStep::Recv2; + Ok(Ik1 { + header, + meta, + transport_params: self.local_transport_params, + skem_ciphertext, + ephemeral: public, + static_bundle, + }) + } + + pub fn write_2( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != IkStep::Send2 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), meta)?; + let header = self.outbound_header()?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Ik2, + meta, + self.local_transport_params, + ); + let remote_ephemeral = self + .remote_ephemeral + .clone() + .ok_or(WireError::InvalidState)?; + let (ekem_ciphertext, ekem_secret) = + crypto.mlkem_encapsulate(&remote_ephemeral.mlkem_public_key); + self.symmetric.mix_hash(crypto, ekem_ciphertext.as_bytes()); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); + let skem_ciphertext = + encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.step = IkStep::Done; + Ok(Ik2 { + header, + meta, + transport_params: self.local_transport_params, + ekem_ciphertext, + skem_ciphertext, + }) + } + + pub fn read_1(&mut self, crypto: &impl QlCrypto, message: &Ik1) -> Result<(), WireError> { + if self.step != IkStep::Recv1 { + return Err(WireError::InvalidState); + } + initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; + self.ensure_inbound_recipient(message.header)?; + self.ensure_known_remote_sender(message.header)?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Ik1, + message.meta, + message.transport_params, + ); + self.symmetric + .mix_hash(crypto, message.skem_ciphertext.as_bytes()); + let skem_secret = + crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &message.skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); + self.remote_ephemeral = Some(message.ephemeral.clone()); + + let remote_bundle = + decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; + if remote_bundle.qid != message.header.sender { + return Err(WireError::InvalidPayload); + } + match self.remote_bundle.as_ref() { + Some(expected) if expected != &remote_bundle => { + return Err(WireError::InvalidPayload); + } + Some(_) => {} + None => self.remote_bundle = Some(remote_bundle), + } + self.remote_transport_params = Some(message.transport_params); + self.step = IkStep::Send2; + Ok(()) + } + + pub fn read_2(&mut self, crypto: &impl QlCrypto, message: &Ik2) -> Result<(), WireError> { + if self.step != IkStep::Recv2 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; + self.ensure_inbound_recipient(message.header)?; + self.ensure_known_remote_sender(message.header)?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Ik2, + message.meta, + message.transport_params, + ); + let local_ephemeral = self + .local_ephemeral + .as_ref() + .ok_or(WireError::InvalidState)?; + self.symmetric + .mix_hash(crypto, message.ekem_ciphertext.as_bytes()); + let ekem_secret = + crypto.mlkem_decapsulate(&local_ephemeral.mlkem.private, &message.ekem_ciphertext); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let skem_ciphertext = + decrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &message.skem_ciphertext)?; + let skem_secret = crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.remote_transport_params = Some(message.transport_params); + self.step = IkStep::Done; + Ok(()) + } + + pub fn finalize(self, crypto: &impl QlCrypto) -> Result { + if !self.is_finished() { + return Err(WireError::InvalidState); + } + let remote_bundle = self.remote_bundle.ok_or(WireError::InvalidState)?; + let remote_transport_params = self + .remote_transport_params + .ok_or(WireError::InvalidState)?; + Ok(finalize_handshake( + crypto, + &self.symmetric, + self.role, + remote_bundle, + remote_transport_params, + )) + } +} diff --git a/ql-wire/src/handshake/kk.rs b/ql-wire/src/handshake/kk.rs new file mode 100644 index 0000000..2ad5ee2 --- /dev/null +++ b/ql-wire/src/handshake/kk.rs @@ -0,0 +1,352 @@ +use super::{ + decrypt_mlkem_ciphertext, encrypt_mlkem_ciphertext, finalize_handshake, + generate_ephemeral_keypair, init_kk_symmetric, initialize_handshake_meta, mix_hash_ephemeral, + mix_hash_routed_handshake, require_handshake_meta, EncryptedMlKemCiphertext, EphemeralKeyPair, + EphemeralPublicKey, FinalizedHandshake, HandshakeHeader, Role, SymmetricState, TransportParams, +}; +use crate::{ + codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PeerBundle, QlCrypto, + QlIdentity, WireEncode, WireError, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Kk1 { + pub header: HandshakeHeader, + pub meta: HandshakeMeta, + pub transport_params: TransportParams, + pub skem_ciphertext: MlKemCiphertext, + pub ephemeral: EphemeralPublicKey, +} + +impl Kk1 { + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + + MlKemCiphertext::SIZE + + EphemeralPublicKey::WIRE_SIZE; +} + +impl codec::WireDecode for Kk1 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + transport_params: reader.decode()?, + skem_ciphertext: reader.decode()?, + ephemeral: reader.decode()?, + }) + } +} + +impl WireEncode for Kk1 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.transport_params.encode(out); + self.skem_ciphertext.encode(out); + self.ephemeral.encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Kk2 { + pub header: HandshakeHeader, + pub meta: HandshakeMeta, + pub transport_params: TransportParams, + pub ekem_ciphertext: MlKemCiphertext, + pub skem_ciphertext: EncryptedMlKemCiphertext, +} + +impl Kk2 { + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + TransportParams::WIRE_SIZE + + MlKemCiphertext::SIZE + + EncryptedMlKemCiphertext::WIRE_SIZE; +} + +impl codec::WireDecode for Kk2 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + transport_params: reader.decode()?, + ekem_ciphertext: reader.decode()?, + skem_ciphertext: reader.decode()?, + }) + } +} + +impl WireEncode for Kk2 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.transport_params.encode(out); + self.ekem_ciphertext.encode(out); + self.skem_ciphertext.encode(out); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum KkStep { + Send1, + Recv1, + Send2, + Recv2, + Done, +} + +#[derive(Debug, Clone)] +pub struct KkHandshake { + role: Role, + step: KkStep, + symmetric: SymmetricState, + local: QlIdentity, + remote_bundle: PeerBundle, + local_ephemeral: Option, + remote_ephemeral: Option, + handshake_meta: Option, + local_transport_params: TransportParams, + remote_transport_params: Option, +} + +impl KkHandshake { + pub fn new_initiator( + crypto: &impl QlCrypto, + local: QlIdentity, + remote_bundle: PeerBundle, + local_transport_params: TransportParams, + ) -> Self { + let symmetric = init_kk_symmetric(crypto, &local.bundle(), &remote_bundle); + Self { + role: Role::Initiator, + step: KkStep::Send1, + symmetric, + local, + remote_bundle, + local_ephemeral: None, + remote_ephemeral: None, + handshake_meta: None, + local_transport_params, + remote_transport_params: None, + } + } + + pub fn new_responder( + crypto: &impl QlCrypto, + local: QlIdentity, + remote_bundle: PeerBundle, + local_transport_params: TransportParams, + ) -> Self { + let symmetric = init_kk_symmetric(crypto, &remote_bundle, &local.bundle()); + Self { + role: Role::Responder, + step: KkStep::Recv1, + symmetric, + local, + remote_bundle, + local_ephemeral: None, + remote_ephemeral: None, + handshake_meta: None, + local_transport_params, + remote_transport_params: None, + } + } + + pub fn is_finished(&self) -> bool { + self.step == KkStep::Done + } + + fn outbound_header(&self) -> HandshakeHeader { + HandshakeHeader { + sender: self.local.qid, + recipient: self.remote_bundle.qid, + } + } + + fn inbound_header(&self) -> HandshakeHeader { + HandshakeHeader { + sender: self.remote_bundle.qid, + recipient: self.local.qid, + } + } + + fn ensure_inbound_header(&self, header: HandshakeHeader) -> Result<(), WireError> { + if header == self.inbound_header() { + Ok(()) + } else { + Err(WireError::InvalidPayload) + } + } + + pub fn write_1( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != KkStep::Send1 { + return Err(WireError::InvalidState); + } + initialize_handshake_meta(&mut self.handshake_meta, meta)?; + let header = self.outbound_header(); + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Kk1, + meta, + self.local_transport_params, + ); + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&self.remote_bundle.mlkem_public_key); + self.symmetric + .encrypt_and_hash(crypto, skem_ciphertext.as_bytes())?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + let local_ephemeral = generate_ephemeral_keypair(crypto); + let public = local_ephemeral.public(); + mix_hash_ephemeral(&mut self.symmetric, crypto, &public); + + self.local_ephemeral = Some(local_ephemeral); + self.step = KkStep::Recv2; + Ok(Kk1 { + header, + meta, + transport_params: self.local_transport_params, + skem_ciphertext, + ephemeral: public, + }) + } + + pub fn write_2( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != KkStep::Send2 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), meta)?; + let header = self.outbound_header(); + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Kk2, + meta, + self.local_transport_params, + ); + let remote_ephemeral = self + .remote_ephemeral + .clone() + .ok_or(WireError::InvalidState)?; + let (ekem_ciphertext, ekem_secret) = + crypto.mlkem_encapsulate(&remote_ephemeral.mlkem_public_key); + self.symmetric.mix_hash(crypto, ekem_ciphertext.as_bytes()); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&self.remote_bundle.mlkem_public_key); + let skem_ciphertext = + encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.step = KkStep::Done; + Ok(Kk2 { + header, + meta, + transport_params: self.local_transport_params, + ekem_ciphertext, + skem_ciphertext, + }) + } + + pub fn read_1(&mut self, crypto: &impl QlCrypto, message: &Kk1) -> Result<(), WireError> { + if self.step != KkStep::Recv1 { + return Err(WireError::InvalidState); + } + initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; + self.ensure_inbound_header(message.header)?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Kk1, + message.meta, + message.transport_params, + ); + self.symmetric + .decrypt_and_hash(crypto, message.skem_ciphertext.as_bytes())?; + let skem_secret = + crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &message.skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); + self.remote_ephemeral = Some(message.ephemeral.clone()); + self.remote_transport_params = Some(message.transport_params); + self.step = KkStep::Send2; + Ok(()) + } + + pub fn read_2(&mut self, crypto: &impl QlCrypto, message: &Kk2) -> Result<(), WireError> { + if self.step != KkStep::Recv2 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; + self.ensure_inbound_header(message.header)?; + mix_hash_routed_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Kk2, + message.meta, + message.transport_params, + ); + let local_ephemeral = self + .local_ephemeral + .as_ref() + .ok_or(WireError::InvalidState)?; + self.symmetric + .mix_hash(crypto, message.ekem_ciphertext.as_bytes()); + let ekem_secret = + crypto.mlkem_decapsulate(&local_ephemeral.mlkem.private, &message.ekem_ciphertext); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let skem_ciphertext = + decrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &message.skem_ciphertext)?; + let skem_secret = crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.remote_transport_params = Some(message.transport_params); + self.step = KkStep::Done; + Ok(()) + } + + pub fn finalize(self, crypto: &impl QlCrypto) -> Result { + if !self.is_finished() { + return Err(WireError::InvalidState); + } + let remote_transport_params = self + .remote_transport_params + .ok_or(WireError::InvalidState)?; + Ok(finalize_handshake( + crypto, + &self.symmetric, + self.role, + self.remote_bundle, + remote_transport_params, + )) + } +} diff --git a/ql-wire/src/handshake/meta.rs b/ql-wire/src/handshake/meta.rs new file mode 100644 index 0000000..8cb0cf9 --- /dev/null +++ b/ql-wire/src/handshake/meta.rs @@ -0,0 +1,48 @@ +use crate::{codec, ByteSlice, WireEncode, WireError}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct HandshakeId(pub u32); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct HandshakeMeta { + pub handshake_id: HandshakeId, +} + +impl codec::WireDecode for HandshakeId { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +impl WireEncode for HandshakeId { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl HandshakeMeta { + pub const WIRE_SIZE: usize = size_of::(); +} + +impl WireEncode for HandshakeMeta { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.handshake_id.encode(out); + } +} + +impl codec::WireDecode for HandshakeMeta { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + handshake_id: reader.decode()?, + }) + } +} diff --git a/ql-wire/src/handshake/mod.rs b/ql-wire/src/handshake/mod.rs new file mode 100644 index 0000000..a9b7cf8 --- /dev/null +++ b/ql-wire/src/handshake/mod.rs @@ -0,0 +1,590 @@ +use crate::{ + codec, ByteSlice, ConnectionId, HandshakeKind, MlKemCiphertext, MlKemKeyPair, MlKemPublicKey, + Nonce, PeerBundle, QlCrypto, SessionKey, WireDecode, WireEncode, WireError, + ENCRYPTED_MESSAGE_AUTH_SIZE, QID, +}; + +mod ik; +mod kk; +mod meta; +mod pairing; +mod transport_params; +mod xx; + +pub use ik::{Ik1, Ik2, IkHandshake}; +pub use kk::{Kk1, Kk2, KkHandshake}; +pub use meta::{HandshakeId, HandshakeMeta}; +pub use pairing::{PairingId, PairingToken}; +pub use transport_params::TransportParams; +pub use xx::{Xx1, Xx2, Xx3, Xx4, XxHandshake}; + +const SHA256_BLOCK_LEN: usize = 64; +const PROTOCOL_IK: &[u8] = b"ql-wire:pq-ik:v1"; +const PROTOCOL_KK: &[u8] = b"ql-wire:pq-kk:v1"; +const PROTOCOL_XX: &[u8] = b"ql-wire:pq-xx:v1"; +const CONNECTION_ID_DOMAIN: &[u8] = b"ql-wire:conn-id:v1"; +const HANDSHAKE_PREAMBLE_DOMAIN: &[u8] = b"ql-wire:handshake-preamble:v1"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct HandshakeHeader { + pub sender: QID, + pub recipient: QID, +} + +impl HandshakeHeader { + pub const WIRE_SIZE: usize = QID::SIZE * 2; +} + +impl WireEncode for HandshakeHeader { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.sender.encode(out); + self.recipient.encode(out); + } +} + +impl codec::WireDecode for HandshakeHeader { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + sender: reader.decode()?, + recipient: reader.decode()?, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EphemeralPublicKey { + pub mlkem_public_key: MlKemPublicKey, +} + +impl EphemeralPublicKey { + pub const WIRE_SIZE: usize = MlKemPublicKey::SIZE; +} + +impl WireEncode for EphemeralPublicKey { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.mlkem_public_key.encode(out); + } +} + +impl codec::WireDecode for EphemeralPublicKey { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + mlkem_public_key: reader.decode()?, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EncryptedMlKemCiphertext(pub Box<[u8; Self::WIRE_SIZE]>); + +impl EncryptedMlKemCiphertext { + pub const WIRE_SIZE: usize = MlKemCiphertext::SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; + + pub fn new(data: Box<[u8; Self::WIRE_SIZE]>) -> Self { + Self(data) + } + + pub fn as_bytes(&self) -> &[u8; Self::WIRE_SIZE] { + self.0.as_ref() + } +} + +impl WireEncode for EncryptedMlKemCiphertext { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.0.as_ref().encode(out); + } +} + +impl codec::WireDecode for EncryptedMlKemCiphertext { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self::new(reader.decode()?)) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EncryptedPeerBundle(pub Box<[u8]>); + +impl EncryptedPeerBundle { + pub const MAX_WIRE_SIZE: usize = PeerBundle::MAX_WIRE_SIZE + ENCRYPTED_MESSAGE_AUTH_SIZE; + + pub fn as_bytes(&self) -> &[u8] { + self.0.as_ref() + } +} + +impl WireEncode for EncryptedPeerBundle { + fn encoded_len(&self) -> usize { + self.0.len() + } + + fn encode(&self, out: &mut W) { + self.as_bytes().encode(out); + } +} + +impl codec::WireDecode for EncryptedPeerBundle { + fn decode(reader: &mut codec::Reader) -> Result { + let data = reader.take_rest(); + if data.len() > Self::MAX_WIRE_SIZE { + return Err(WireError::InvalidPayload); + } + Ok(Self(data.to_vec().into_boxed_slice())) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FinalizedHandshake { + pub tx_key: SessionKey, + pub rx_key: SessionKey, + pub tx_connection_id: ConnectionId, + pub rx_connection_id: ConnectionId, + pub handshake_hash: [u8; 32], + pub remote_bundle: PeerBundle, + /// Transport parameters advertised by the remote peer + pub remote_transport_params: TransportParams, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Role { + Initiator, + Responder, +} + +#[derive(Debug, Clone)] +struct EphemeralKeyPair { + mlkem: MlKemKeyPair, +} + +impl EphemeralKeyPair { + fn public(&self) -> EphemeralPublicKey { + EphemeralPublicKey { + mlkem_public_key: self.mlkem.public.clone(), + } + } +} + +#[derive(Debug, Clone)] +struct CipherState { + key: Option, + nonce: u64, +} + +impl CipherState { + fn new() -> Self { + Self { + key: None, + nonce: 0, + } + } + + fn initialize_key(&mut self, key: SessionKey) { + self.key = Some(key); + self.nonce = 0; + } + + fn has_key(&self) -> bool { + self.key.is_some() + } + + fn encrypt( + &mut self, + crypto: &impl QlCrypto, + aad: &[u8], + plaintext: &[u8], + ) -> Result, WireError> { + let key = self.key.as_ref().ok_or(WireError::InvalidState)?; + let nonce = Nonce::from_counter(self.nonce); + let mut ciphertext = Vec::with_capacity(plaintext.len() + ENCRYPTED_MESSAGE_AUTH_SIZE); + ciphertext.extend_from_slice(plaintext); + let auth = crypto.aes256_gcm_encrypt(key, &nonce, aad, &mut ciphertext); + self.nonce = self.nonce.wrapping_add(1); + ciphertext.extend_from_slice(&auth); + Ok(ciphertext) + } + + fn decrypt( + &mut self, + crypto: &impl QlCrypto, + aad: &[u8], + ciphertext: &[u8], + ) -> Result, WireError> { + if ciphertext.len() < ENCRYPTED_MESSAGE_AUTH_SIZE { + return Err(WireError::InvalidPayload); + } + let split = ciphertext.len() - ENCRYPTED_MESSAGE_AUTH_SIZE; + let (ciphertext, auth) = ciphertext.split_at(split); + let mut plaintext = ciphertext.to_vec(); + let key = self.key.as_ref().ok_or(WireError::InvalidState)?; + let nonce = Nonce::from_counter(self.nonce); + let mut auth_tag = [0u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; + auth_tag.copy_from_slice(auth); + if !crypto.aes256_gcm_decrypt(key, &nonce, aad, &mut plaintext, &auth_tag) { + return Err(WireError::DecryptFailed); + } + self.nonce = self.nonce.wrapping_add(1); + Ok(plaintext) + } +} + +#[derive(Debug, Clone)] +struct SymmetricState { + chaining_key: [u8; 32], + handshake_hash: [u8; 32], + cipher: CipherState, +} + +impl SymmetricState { + fn new(crypto: &impl QlCrypto, protocol_name: &[u8]) -> Self { + let h = crypto.sha256(&[protocol_name]); + Self { + chaining_key: h, + handshake_hash: h, + cipher: CipherState::new(), + } + } + + fn mix_hash(&mut self, crypto: &impl QlCrypto, data: &[u8]) { + self.handshake_hash = crypto.sha256(&[&self.handshake_hash, data]); + } + + fn mix_key(&mut self, crypto: &impl QlCrypto, input_key_material: &[u8]) { + let (chaining_key, cipher_key) = hkdf2(crypto, &self.chaining_key, input_key_material); + self.chaining_key = chaining_key; + self.cipher.initialize_key(cipher_key); + } + + fn mix_key_and_hash(&mut self, crypto: &impl QlCrypto, input_key_material: &[u8]) { + let (chaining_key, hash_input, cipher_key) = + hkdf3(crypto, &self.chaining_key, input_key_material); + self.chaining_key = chaining_key; + self.mix_hash(crypto, &hash_input); + self.cipher.initialize_key(cipher_key); + } + + fn encrypt_and_hash( + &mut self, + crypto: &impl QlCrypto, + plaintext: &[u8], + ) -> Result, WireError> { + if self.cipher.has_key() { + let ciphertext = self + .cipher + .encrypt(crypto, &self.handshake_hash, plaintext)?; + self.mix_hash(crypto, &ciphertext); + Ok(ciphertext) + } else { + self.mix_hash(crypto, plaintext); + Ok(plaintext.to_vec()) + } + } + + fn decrypt_and_hash( + &mut self, + crypto: &impl QlCrypto, + ciphertext: &[u8], + ) -> Result, WireError> { + if self.cipher.has_key() { + let plaintext = self + .cipher + .decrypt(crypto, &self.handshake_hash, ciphertext)?; + self.mix_hash(crypto, ciphertext); + Ok(plaintext) + } else { + self.mix_hash(crypto, ciphertext); + Ok(ciphertext.to_vec()) + } + } + + fn split_for_role(&self, crypto: &impl QlCrypto, role: Role) -> (SessionKey, SessionKey) { + let temp_key = hmac_sha256(crypto, &self.chaining_key, &[&[]]); + let k1 = SessionKey::from_data(hmac_sha256(crypto, &temp_key, &[&[1]])); + let k2 = SessionKey::from_data(hmac_sha256(crypto, &temp_key, &[k1.as_bytes(), &[2]])); + match role { + Role::Initiator => (k1, k2), + Role::Responder => (k2, k1), + } + } +} + +fn init_kk_symmetric( + crypto: &impl QlCrypto, + initiator_bundle: &PeerBundle, + responder_bundle: &PeerBundle, +) -> SymmetricState { + let mut symmetric = SymmetricState::new(crypto, PROTOCOL_KK); + symmetric.mix_hash(crypto, &initiator_bundle.encode_vec()); + symmetric.mix_hash(crypto, &responder_bundle.encode_vec()); + symmetric +} + +fn init_ik_symmetric(crypto: &impl QlCrypto, responder_bundle: &PeerBundle) -> SymmetricState { + let mut symmetric = SymmetricState::new(crypto, PROTOCOL_IK); + symmetric.mix_hash(crypto, &responder_bundle.encode_vec()); + symmetric +} + +fn init_xx_symmetric(crypto: &impl QlCrypto) -> SymmetricState { + SymmetricState::new(crypto, PROTOCOL_XX) +} + +fn mix_psk_pairing_token( + symmetric: &mut SymmetricState, + crypto: &impl QlCrypto, + pairing_token: PairingToken, +) { + symmetric.mix_key_and_hash(crypto, &pairing_token.psk(crypto)); +} + +fn generate_ephemeral_keypair(crypto: &impl QlCrypto) -> EphemeralKeyPair { + EphemeralKeyPair { + mlkem: crypto.mlkem_generate_keypair(), + } +} + +fn mix_hash_ephemeral( + symmetric: &mut SymmetricState, + crypto: &impl QlCrypto, + public: &EphemeralPublicKey, +) { + symmetric.mix_hash(crypto, public.mlkem_public_key.as_bytes()); +} + +fn mix_hash_routed_handshake( + symmetric: &mut SymmetricState, + crypto: &impl QlCrypto, + header: HandshakeHeader, + kind: HandshakeKind, + meta: HandshakeMeta, + transport_params: TransportParams, +) { + mix_hash_handshake_preamble( + symmetric, + crypto, + &header.encode_vec(), + kind, + meta, + transport_params, + ); +} + +fn mix_hash_pairing_handshake( + symmetric: &mut SymmetricState, + crypto: &impl QlCrypto, + header: HandshakeHeader, + kind: HandshakeKind, + meta: HandshakeMeta, + pairing_id: PairingId, + transport_params: TransportParams, +) { + let mut preamble = header.encode_vec(); + pairing_id.encode(&mut preamble); + mix_hash_handshake_preamble(symmetric, crypto, &preamble, kind, meta, transport_params); +} + +fn mix_hash_handshake_preamble( + symmetric: &mut SymmetricState, + crypto: &impl QlCrypto, + header: &[u8], + kind: HandshakeKind, + meta: HandshakeMeta, + transport_params: TransportParams, +) { + symmetric.mix_hash(crypto, HANDSHAKE_PREAMBLE_DOMAIN); + symmetric.mix_hash(crypto, header); + symmetric.mix_hash(crypto, &[kind as u8]); + symmetric.mix_hash(crypto, &meta.encode_vec()); + symmetric.mix_hash(crypto, &transport_params.encode_vec()); +} + +fn initialize_handshake_meta( + expected: &mut Option, + meta: HandshakeMeta, +) -> Result<(), WireError> { + match expected { + Some(stored) if *stored != meta => Err(WireError::InvalidHandshakeMeta), + Some(_) => Ok(()), + None => { + *expected = Some(meta); + Ok(()) + } + } +} + +fn require_handshake_meta( + expected: Option<&HandshakeMeta>, + meta: HandshakeMeta, +) -> Result<(), WireError> { + match expected { + Some(stored) if *stored == meta => Ok(()), + _ => Err(WireError::InvalidHandshakeMeta), + } +} + +fn initialize_transport_params( + expected: &mut Option, + transport_params: TransportParams, +) -> Result<(), WireError> { + match expected { + Some(stored) if *stored != transport_params => Err(WireError::InvalidTransportParams), + Some(_) => Ok(()), + None => { + *expected = Some(transport_params); + Ok(()) + } + } +} + +fn require_transport_params( + expected: Option<&TransportParams>, + transport_params: TransportParams, +) -> Result<(), WireError> { + match expected { + Some(stored) if *stored == transport_params => Ok(()), + _ => Err(WireError::InvalidTransportParams), + } +} + +fn encrypt_peer_bundle( + crypto: &impl QlCrypto, + symmetric: &mut SymmetricState, + bundle: &PeerBundle, +) -> Result { + let ciphertext = symmetric.encrypt_and_hash(crypto, &bundle.encode_vec())?; + Ok(EncryptedPeerBundle(ciphertext.into_boxed_slice())) +} + +fn decrypt_peer_bundle( + crypto: &impl QlCrypto, + symmetric: &mut SymmetricState, + bundle: &EncryptedPeerBundle, +) -> Result { + let plaintext = symmetric.decrypt_and_hash(crypto, bundle.as_bytes())?; + let bundle = PeerBundle::decode_exact(plaintext.as_slice())?; + if !bundle.qid_matches_public_key(crypto) { + return Err(WireError::InvalidRemoteBundle); + } + Ok(bundle) +} + +fn encrypt_mlkem_ciphertext( + crypto: &impl QlCrypto, + symmetric: &mut SymmetricState, + ciphertext: &MlKemCiphertext, +) -> Result { + let encrypted = symmetric.encrypt_and_hash(crypto, ciphertext.as_bytes())?; + let out: Box<[u8; EncryptedMlKemCiphertext::WIRE_SIZE]> = + encrypted.try_into().map_err(|_| WireError::InvalidState)?; + Ok(EncryptedMlKemCiphertext::new(out)) +} + +fn decrypt_mlkem_ciphertext( + crypto: &impl QlCrypto, + symmetric: &mut SymmetricState, + ciphertext: &EncryptedMlKemCiphertext, +) -> Result { + let plaintext = symmetric.decrypt_and_hash(crypto, ciphertext.as_bytes())?; + let out: Box<[u8; MlKemCiphertext::SIZE]> = plaintext + .try_into() + .map_err(|_| WireError::InvalidPayload)?; + Ok(MlKemCiphertext::new(out)) +} + +fn finalize_handshake( + crypto: &impl QlCrypto, + symmetric: &SymmetricState, + role: Role, + remote_bundle: PeerBundle, + remote_transport_params: TransportParams, +) -> FinalizedHandshake { + let handshake_hash = symmetric.handshake_hash; + let (tx_key, rx_key) = symmetric.split_for_role(crypto, role); + let (initiator_rx, responder_rx) = derive_connection_ids(crypto, &handshake_hash); + let (tx_connection_id, rx_connection_id) = match role { + Role::Initiator => (responder_rx, initiator_rx), + Role::Responder => (initiator_rx, responder_rx), + }; + FinalizedHandshake { + tx_key, + rx_key, + tx_connection_id, + rx_connection_id, + handshake_hash, + remote_bundle, + remote_transport_params, + } +} + +fn derive_connection_ids( + crypto: &impl QlCrypto, + handshake_hash: &[u8; 32], +) -> (ConnectionId, ConnectionId) { + let initiator = crypto.sha256(&[CONNECTION_ID_DOMAIN, handshake_hash, b"initiator-rx"]); + let responder = crypto.sha256(&[CONNECTION_ID_DOMAIN, handshake_hash, b"responder-rx"]); + let mut initiator_rx = [0u8; ConnectionId::SIZE]; + let mut responder_rx = [0u8; ConnectionId::SIZE]; + initiator_rx.copy_from_slice(&initiator[..ConnectionId::SIZE]); + responder_rx.copy_from_slice(&responder[..ConnectionId::SIZE]); + ( + ConnectionId::from_data(initiator_rx), + ConnectionId::from_data(responder_rx), + ) +} + +fn hkdf2( + crypto: &impl QlCrypto, + chaining_key: &[u8; 32], + input_key_material: &[u8], +) -> ([u8; 32], SessionKey) { + let temp_key = hmac_sha256(crypto, chaining_key, &[input_key_material]); + let out1 = hmac_sha256(crypto, &temp_key, &[&[1]]); + let out2 = hmac_sha256(crypto, &temp_key, &[&out1, &[2]]); + (out1, SessionKey::from_data(out2)) +} + +fn hkdf3( + crypto: &impl QlCrypto, + chaining_key: &[u8; 32], + input_key_material: &[u8], +) -> ([u8; 32], [u8; 32], SessionKey) { + let temp_key = hmac_sha256(crypto, chaining_key, &[input_key_material]); + let out1 = hmac_sha256(crypto, &temp_key, &[&[1]]); + let out2 = hmac_sha256(crypto, &temp_key, &[&out1, &[2]]); + let out3 = hmac_sha256(crypto, &temp_key, &[&out2, &[3]]); + (out1, out2, SessionKey::from_data(out3)) +} + +fn hmac_sha256(crypto: &impl QlCrypto, key: &[u8], parts: &[&[u8]]) -> [u8; 32] { + let mut key_block = [0u8; SHA256_BLOCK_LEN]; + if key.len() > SHA256_BLOCK_LEN { + key_block[..32].copy_from_slice(&crypto.sha256(&[key])); + } else { + key_block[..key.len()].copy_from_slice(key); + } + + let mut ipad = [0x36u8; SHA256_BLOCK_LEN]; + let mut opad = [0x5cu8; SHA256_BLOCK_LEN]; + for (dst, src) in ipad.iter_mut().zip(key_block.iter()) { + *dst ^= *src; + } + for (dst, src) in opad.iter_mut().zip(key_block.iter()) { + *dst ^= *src; + } + + let mut inner_parts: Vec<&[u8]> = Vec::with_capacity(parts.len() + 1); + inner_parts.push(&ipad); + inner_parts.extend_from_slice(parts); + let inner = crypto.sha256(&inner_parts); + crypto.sha256(&[&opad, &inner]) +} diff --git a/ql-wire/src/handshake/pairing.rs b/ql-wire/src/handshake/pairing.rs new file mode 100644 index 0000000..237f066 --- /dev/null +++ b/ql-wire/src/handshake/pairing.rs @@ -0,0 +1,83 @@ +use std::fmt::{self, Display, Formatter}; + +use crate::{codec, ByteSlice, QlCrypto, WireEncode, WireError}; + +const PAIRING_ID_DOMAIN: &[u8] = b"ql-wire:pairing-id:v1"; +const PAIRING_PSK_DOMAIN: &[u8] = b"ql-wire:pairing-psk:v1"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct PairingToken(pub [u8; Self::SIZE]); + +impl PairingToken { + pub const SIZE: usize = 16; + + pub fn id(&self, crypto: &impl QlCrypto) -> PairingId { + let hash = crypto.sha256(&[PAIRING_ID_DOMAIN, &self.0]); + let mut id = [0u8; PairingId::SIZE]; + id.copy_from_slice(&hash[..PairingId::SIZE]); + PairingId(id) + } + + pub(super) fn psk(&self, crypto: &impl QlCrypto) -> [u8; 32] { + crypto.sha256(&[PAIRING_PSK_DOMAIN, &self.0]) + } +} + +impl Display for PairingToken { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + for byte in self.0 { + write!(f, "{byte:02x}")?; + } + Ok(()) + } +} + +impl WireEncode for PairingToken { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for PairingToken { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct PairingId(pub [u8; Self::SIZE]); + +impl PairingId { + pub const SIZE: usize = 16; +} + +impl Display for PairingId { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + for byte in self.0 { + write!(f, "{byte:02x}")?; + } + Ok(()) + } +} + +impl WireEncode for PairingId { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for PairingId { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} diff --git a/ql-wire/src/handshake/transport_params.rs b/ql-wire/src/handshake/transport_params.rs new file mode 100644 index 0000000..bfd0d42 --- /dev/null +++ b/ql-wire/src/handshake/transport_params.rs @@ -0,0 +1,38 @@ +use crate::{codec, ByteSlice, WireEncode, WireError}; + +/// Session parameters advertised in the handshake +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TransportParams { + /// Initial per-stream receive credit granted to the remote peer + pub initial_stream_receive_window: u32, +} + +impl TransportParams { + pub const WIRE_SIZE: usize = size_of::(); +} + +impl WireEncode for TransportParams { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.initial_stream_receive_window.encode(out); + } +} + +impl Default for TransportParams { + fn default() -> Self { + Self { + initial_stream_receive_window: 16 * 1024, + } + } +} + +impl codec::WireDecode for TransportParams { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + initial_stream_receive_window: reader.decode()?, + }) + } +} diff --git a/ql-wire/src/handshake/xx.rs b/ql-wire/src/handshake/xx.rs new file mode 100644 index 0000000..0b6452d --- /dev/null +++ b/ql-wire/src/handshake/xx.rs @@ -0,0 +1,612 @@ +use super::{ + decrypt_mlkem_ciphertext, decrypt_peer_bundle, encrypt_mlkem_ciphertext, encrypt_peer_bundle, + finalize_handshake, generate_ephemeral_keypair, init_xx_symmetric, initialize_handshake_meta, + initialize_transport_params, mix_hash_ephemeral, mix_hash_pairing_handshake, + mix_psk_pairing_token, require_handshake_meta, require_transport_params, + EncryptedMlKemCiphertext, EncryptedPeerBundle, EphemeralKeyPair, EphemeralPublicKey, + FinalizedHandshake, HandshakeHeader, Role, SymmetricState, TransportParams, +}; +use crate::{ + codec, ByteSlice, HandshakeKind, HandshakeMeta, MlKemCiphertext, PairingId, PairingToken, + PeerBundle, QlCrypto, QlIdentity, WireEncode, WireError, QID, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Xx1 { + pub header: HandshakeHeader, + pub meta: HandshakeMeta, + pub pairing_id: PairingId, + pub transport_params: TransportParams, + pub ephemeral: EphemeralPublicKey, +} + +impl Xx1 { + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + PairingId::SIZE + + TransportParams::WIRE_SIZE + + EphemeralPublicKey::WIRE_SIZE; +} + +impl codec::WireDecode for Xx1 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + pairing_id: reader.decode()?, + transport_params: reader.decode()?, + ephemeral: reader.decode()?, + }) + } +} + +impl WireEncode for Xx1 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.pairing_id.encode(out); + self.transport_params.encode(out); + self.ephemeral.encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Xx2 { + pub header: HandshakeHeader, + pub meta: HandshakeMeta, + pub pairing_id: PairingId, + pub transport_params: TransportParams, + pub ekem_ciphertext: MlKemCiphertext, + pub static_bundle: EncryptedPeerBundle, +} + +impl codec::WireDecode for Xx2 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + pairing_id: reader.decode()?, + transport_params: reader.decode()?, + ekem_ciphertext: reader.decode()?, + static_bundle: reader.decode()?, + }) + } +} + +impl WireEncode for Xx2 { + fn encoded_len(&self) -> usize { + HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + PairingId::SIZE + + TransportParams::WIRE_SIZE + + MlKemCiphertext::SIZE + + self.static_bundle.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.pairing_id.encode(out); + self.transport_params.encode(out); + self.ekem_ciphertext.encode(out); + self.static_bundle.encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Xx3 { + pub header: HandshakeHeader, + pub meta: HandshakeMeta, + pub pairing_id: PairingId, + pub transport_params: TransportParams, + pub skem_ciphertext: EncryptedMlKemCiphertext, + pub static_bundle: EncryptedPeerBundle, +} + +impl codec::WireDecode for Xx3 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + pairing_id: reader.decode()?, + transport_params: reader.decode()?, + skem_ciphertext: reader.decode()?, + static_bundle: reader.decode()?, + }) + } +} + +impl WireEncode for Xx3 { + fn encoded_len(&self) -> usize { + HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + PairingId::SIZE + + TransportParams::WIRE_SIZE + + EncryptedMlKemCiphertext::WIRE_SIZE + + self.static_bundle.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.pairing_id.encode(out); + self.transport_params.encode(out); + self.skem_ciphertext.encode(out); + self.static_bundle.encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Xx4 { + pub header: HandshakeHeader, + pub meta: HandshakeMeta, + pub pairing_id: PairingId, + pub transport_params: TransportParams, + pub skem_ciphertext: EncryptedMlKemCiphertext, +} + +impl Xx4 { + pub const WIRE_SIZE: usize = HandshakeHeader::WIRE_SIZE + + HandshakeMeta::WIRE_SIZE + + PairingId::SIZE + + TransportParams::WIRE_SIZE + + EncryptedMlKemCiphertext::WIRE_SIZE; +} + +impl codec::WireDecode for Xx4 { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + meta: reader.decode()?, + pairing_id: reader.decode()?, + transport_params: reader.decode()?, + skem_ciphertext: reader.decode()?, + }) + } +} + +impl WireEncode for Xx4 { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.meta.encode(out); + self.pairing_id.encode(out); + self.transport_params.encode(out); + self.skem_ciphertext.encode(out); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum XxStep { + Send1, + Recv1, + Send2, + Recv2, + Send3, + Recv3, + Send4, + Recv4, + Done, +} + +#[derive(Debug, Clone)] +pub struct XxHandshake { + role: Role, + step: XxStep, + symmetric: SymmetricState, + local: QlIdentity, + remote_qid: QID, + pairing_token: PairingToken, + remote_bundle: Option, + local_ephemeral: Option, + remote_ephemeral: Option, + handshake_meta: Option, + local_transport_params: TransportParams, + remote_transport_params: Option, +} + +impl XxHandshake { + pub fn new_initiator( + crypto: &impl QlCrypto, + local: QlIdentity, + remote_qid: QID, + pairing_token: PairingToken, + local_transport_params: TransportParams, + ) -> Self { + Self { + role: Role::Initiator, + step: XxStep::Send1, + symmetric: init_xx_symmetric(crypto), + local, + remote_qid, + pairing_token, + remote_bundle: None, + local_ephemeral: None, + remote_ephemeral: None, + handshake_meta: None, + local_transport_params, + remote_transport_params: None, + } + } + + pub fn new_responder( + crypto: &impl QlCrypto, + local: QlIdentity, + remote_qid: QID, + pairing_token: PairingToken, + local_transport_params: TransportParams, + ) -> Self { + Self { + role: Role::Responder, + step: XxStep::Recv1, + symmetric: init_xx_symmetric(crypto), + local, + remote_qid, + pairing_token, + remote_bundle: None, + local_ephemeral: None, + remote_ephemeral: None, + handshake_meta: None, + local_transport_params, + remote_transport_params: None, + } + } + + pub fn is_finished(&self) -> bool { + self.step == XxStep::Done + } + + pub fn pairing_token(&self) -> PairingToken { + self.pairing_token + } + + pub fn pairing_id(&self, crypto: &impl QlCrypto) -> PairingId { + self.pairing_token.id(crypto) + } + + pub fn remote_qid(&self) -> QID { + self.remote_qid + } + + pub fn remote_bundle(&self) -> Option<&PeerBundle> { + self.remote_bundle.as_ref() + } + + fn header(&self) -> HandshakeHeader { + HandshakeHeader { + sender: self.local.qid, + recipient: self.remote_qid, + } + } + + fn ensure_inbound_header( + &self, + crypto: &impl QlCrypto, + header: HandshakeHeader, + pairing_id: PairingId, + ) -> Result<(), WireError> { + if header.sender != self.remote_qid || header.recipient != self.local.qid { + return Err(WireError::InvalidHandshakeHeader); + } + if pairing_id != self.pairing_token.id(crypto) { + return Err(WireError::InvalidPairingId); + } + Ok(()) + } + + fn ensure_remote_bundle(&self, bundle: &PeerBundle) -> Result<(), WireError> { + if bundle.qid == self.remote_qid { + Ok(()) + } else { + Err(WireError::InvalidRemoteBundle) + } + } + + pub fn write_1( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != XxStep::Send1 { + return Err(WireError::InvalidState); + } + initialize_handshake_meta(&mut self.handshake_meta, meta)?; + let header = self.header(); + let pairing_id = self.pairing_token.id(crypto); + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx1, + meta, + pairing_id, + self.local_transport_params, + ); + mix_psk_pairing_token(&mut self.symmetric, crypto, self.pairing_token); + + let local_ephemeral = generate_ephemeral_keypair(crypto); + let ephemeral = local_ephemeral.public(); + mix_hash_ephemeral(&mut self.symmetric, crypto, &ephemeral); + + self.local_ephemeral = Some(local_ephemeral); + self.step = XxStep::Recv2; + Ok(Xx1 { + header, + meta, + pairing_id, + transport_params: self.local_transport_params, + ephemeral, + }) + } + + pub fn read_1(&mut self, crypto: &impl QlCrypto, message: &Xx1) -> Result<(), WireError> { + if self.step != XxStep::Recv1 { + return Err(WireError::InvalidState); + } + initialize_handshake_meta(&mut self.handshake_meta, message.meta)?; + self.ensure_inbound_header(crypto, message.header, message.pairing_id)?; + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Xx1, + message.meta, + message.pairing_id, + message.transport_params, + ); + mix_psk_pairing_token(&mut self.symmetric, crypto, self.pairing_token); + mix_hash_ephemeral(&mut self.symmetric, crypto, &message.ephemeral); + + self.remote_ephemeral = Some(message.ephemeral.clone()); + initialize_transport_params(&mut self.remote_transport_params, message.transport_params)?; + self.step = XxStep::Send2; + Ok(()) + } + + pub fn write_2( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != XxStep::Send2 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), meta)?; + let header = self.header(); + let pairing_id = self.pairing_token.id(crypto); + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx2, + meta, + pairing_id, + self.local_transport_params, + ); + + let remote_ephemeral = self + .remote_ephemeral + .as_ref() + .ok_or(WireError::InvalidState)?; + let (ekem_ciphertext, ekem_secret) = + crypto.mlkem_encapsulate(&remote_ephemeral.mlkem_public_key); + self.symmetric.mix_hash(crypto, ekem_ciphertext.as_bytes()); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let static_bundle = encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; + + self.step = XxStep::Recv3; + Ok(Xx2 { + header, + meta, + pairing_id, + transport_params: self.local_transport_params, + ekem_ciphertext, + static_bundle, + }) + } + + pub fn read_2(&mut self, crypto: &impl QlCrypto, message: &Xx2) -> Result<(), WireError> { + if self.step != XxStep::Recv2 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; + self.ensure_inbound_header(crypto, message.header, message.pairing_id)?; + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Xx2, + message.meta, + message.pairing_id, + message.transport_params, + ); + + let local_ephemeral = self + .local_ephemeral + .as_ref() + .ok_or(WireError::InvalidState)?; + self.symmetric + .mix_hash(crypto, message.ekem_ciphertext.as_bytes()); + let ekem_secret = + crypto.mlkem_decapsulate(&local_ephemeral.mlkem.private, &message.ekem_ciphertext); + self.symmetric.mix_key(crypto, ekem_secret.as_bytes()); + + let remote_bundle = + decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; + self.ensure_remote_bundle(&remote_bundle)?; + self.remote_bundle = Some(remote_bundle); + initialize_transport_params(&mut self.remote_transport_params, message.transport_params)?; + self.step = XxStep::Send3; + Ok(()) + } + + pub fn write_3( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != XxStep::Send3 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), meta)?; + let header = self.header(); + let pairing_id = self.pairing_token.id(crypto); + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx3, + meta, + pairing_id, + self.local_transport_params, + ); + + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); + let skem_ciphertext = + encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + let static_bundle = encrypt_peer_bundle(crypto, &mut self.symmetric, &self.local.bundle())?; + + self.step = XxStep::Recv4; + Ok(Xx3 { + header, + meta, + pairing_id, + transport_params: self.local_transport_params, + skem_ciphertext, + static_bundle, + }) + } + + pub fn read_3(&mut self, crypto: &impl QlCrypto, message: &Xx3) -> Result<(), WireError> { + if self.step != XxStep::Recv3 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; + self.ensure_inbound_header(crypto, message.header, message.pairing_id)?; + require_transport_params( + self.remote_transport_params.as_ref(), + message.transport_params, + )?; + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Xx3, + message.meta, + message.pairing_id, + message.transport_params, + ); + + let skem_ciphertext = + decrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &message.skem_ciphertext)?; + let skem_secret = crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + let remote_bundle = + decrypt_peer_bundle(crypto, &mut self.symmetric, &message.static_bundle)?; + self.ensure_remote_bundle(&remote_bundle)?; + self.remote_bundle = Some(remote_bundle); + self.step = XxStep::Send4; + Ok(()) + } + + pub fn write_4( + &mut self, + crypto: &impl QlCrypto, + meta: HandshakeMeta, + ) -> Result { + if self.step != XxStep::Send4 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), meta)?; + let header = self.header(); + let pairing_id = self.pairing_token.id(crypto); + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + header, + HandshakeKind::Xx4, + meta, + pairing_id, + self.local_transport_params, + ); + + let remote_bundle = self.remote_bundle.as_ref().ok_or(WireError::InvalidState)?; + let (skem_ciphertext, skem_secret) = + crypto.mlkem_encapsulate(&remote_bundle.mlkem_public_key); + let skem_ciphertext = + encrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &skem_ciphertext)?; + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.step = XxStep::Done; + Ok(Xx4 { + header, + meta, + pairing_id, + transport_params: self.local_transport_params, + skem_ciphertext, + }) + } + + pub fn read_4(&mut self, crypto: &impl QlCrypto, message: &Xx4) -> Result<(), WireError> { + if self.step != XxStep::Recv4 { + return Err(WireError::InvalidState); + } + require_handshake_meta(self.handshake_meta.as_ref(), message.meta)?; + self.ensure_inbound_header(crypto, message.header, message.pairing_id)?; + require_transport_params( + self.remote_transport_params.as_ref(), + message.transport_params, + )?; + mix_hash_pairing_handshake( + &mut self.symmetric, + crypto, + message.header, + HandshakeKind::Xx4, + message.meta, + message.pairing_id, + message.transport_params, + ); + + let skem_ciphertext = + decrypt_mlkem_ciphertext(crypto, &mut self.symmetric, &message.skem_ciphertext)?; + let skem_secret = crypto.mlkem_decapsulate(&self.local.mlkem_private_key, &skem_ciphertext); + self.symmetric + .mix_key_and_hash(crypto, skem_secret.as_bytes()); + + self.step = XxStep::Done; + Ok(()) + } + + pub fn finalize(self, crypto: &impl QlCrypto) -> Result { + if !self.is_finished() { + return Err(WireError::InvalidState); + } + let remote_bundle = self.remote_bundle.ok_or(WireError::InvalidState)?; + let remote_transport_params = self + .remote_transport_params + .ok_or(WireError::InvalidState)?; + Ok(finalize_handshake( + crypto, + &self.symmetric, + self.role, + remote_bundle, + remote_transport_params, + )) + } +} diff --git a/ql-wire/src/header.rs b/ql-wire/src/header.rs new file mode 100644 index 0000000..88764c0 --- /dev/null +++ b/ql-wire/src/header.rs @@ -0,0 +1,121 @@ +use ::bytes::BufMut; + +use crate::{ + codec, ByteSlice, VarInt, VarIntBoundsExceeded, WireEncode, WireError, QL_WIRE_VERSION, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct SessionHeader { + pub connection_id: ConnectionId, + pub seq: RecordSeq, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct RecordSeq(pub VarInt); + +impl RecordSeq { + pub const MAX_ENCODED_LEN: usize = VarInt::MAX_SIZE; + + pub const fn from_u32(value: u32) -> Self { + Self(VarInt::from_u32(value)) + } + + pub fn from_u64(value: u64) -> Result { + Ok(Self(VarInt::from_u64(value)?)) + } + + pub const fn into_inner(self) -> u64 { + self.0.into_inner() + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct ConnectionId(pub [u8; Self::SIZE]); + +impl ConnectionId { + pub const SIZE: usize = 16; + + pub const fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(data) + } + + pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { + &self.0 + } +} + +impl codec::WireDecode for RecordSeq { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} + +impl WireEncode for RecordSeq { + fn encoded_len(&self) -> usize { + self.0.size() + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for ConnectionId { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self::from_data(reader.decode()?)) + } +} + +impl WireEncode for ConnectionId { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl SessionHeader { + pub const MAX_ENCODED_LEN: usize = ConnectionId::SIZE + RecordSeq::MAX_ENCODED_LEN; + const AAD_DOMAIN: &[u8] = b"ql-wire:session-aad:v1"; + const AAD_RECORD_KIND_SESSION: u8 = 1; + + pub fn aad(&self) -> Vec { + let aad_len = Self::AAD_DOMAIN.len() + + size_of::() + + size_of::() + + ConnectionId::SIZE + + self.seq.encoded_len(); + let mut aad = Vec::with_capacity(aad_len); + aad.put_slice(Self::AAD_DOMAIN); + aad.put_u8(QL_WIRE_VERSION); + aad.put_u8(Self::AAD_RECORD_KIND_SESSION); + self.connection_id.encode(&mut aad); + self.seq.encode(&mut aad); + debug_assert_eq!(aad.len(), aad_len); + aad + } +} + +impl WireEncode for SessionHeader { + fn encoded_len(&self) -> usize { + ConnectionId::SIZE + self.seq.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.connection_id.encode(out); + self.seq.encode(out); + } +} + +impl codec::WireDecode for SessionHeader { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + connection_id: reader.decode()?, + seq: reader.decode()?, + }) + } +} diff --git a/ql-wire/src/identity.rs b/ql-wire/src/identity.rs new file mode 100644 index 0000000..bdc54b2 --- /dev/null +++ b/ql-wire/src/identity.rs @@ -0,0 +1,192 @@ +use std::ops::Deref; + +use crate::{ + codec, ByteSlice, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, QlCrypto, QlHash, VarInt, + WireEncode, WireError, QID, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PeerBundle { + pub version: u16, + pub qid: QID, + pub capabilities: u32, + pub mlkem_public_key: MlKemPublicKey, + pub name: QlName, +} + +impl PeerBundle { + pub const VERSION: u16 = 1; + pub const FIXED_WIRE_SIZE: usize = + size_of::() + QID::SIZE + size_of::() + MlKemPublicKey::SIZE; + pub const MAX_WIRE_SIZE: usize = Self::FIXED_WIRE_SIZE + VarInt::MAX_SIZE + QlName::MAX_LEN; + + pub fn qid_matches_public_key(&self, crypto: &impl QlHash) -> bool { + self.qid.matches_public_key(crypto, &self.mlkem_public_key) + } +} + +impl WireEncode for PeerBundle { + fn encoded_len(&self) -> usize { + Self::FIXED_WIRE_SIZE + self.name.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.version.encode(out); + self.qid.encode(out); + self.capabilities.encode(out); + self.mlkem_public_key.encode(out); + self.name.encode(out); + } +} + +impl codec::WireDecode for PeerBundle { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + version: reader.decode()?, + qid: reader.decode()?, + capabilities: reader.decode()?, + mlkem_public_key: reader.decode()?, + name: reader.decode()?, + }) + } +} + +#[derive(Debug, Clone)] +pub struct QlIdentity { + pub qid: QID, + pub mlkem_private_key: MlKemPrivateKey, + pub mlkem_public_key: MlKemPublicKey, + pub capabilities: u32, + pub name: QlName, +} + +impl QlIdentity { + pub const FIXED_WIRE_SIZE: usize = + QID::SIZE + MlKemPrivateKey::SIZE + MlKemPublicKey::SIZE + size_of::(); + pub const MAX_WIRE_SIZE: usize = Self::FIXED_WIRE_SIZE + VarInt::MAX_SIZE + QlName::MAX_LEN; + + pub fn new( + crypto: &impl QlHash, + mlkem_private_key: MlKemPrivateKey, + mlkem_public_key: MlKemPublicKey, + name: impl Into, + ) -> Result { + let name = QlName::new(name)?; + let qid = QID::derive(crypto, &mlkem_public_key); + Ok(Self { + qid, + mlkem_private_key, + mlkem_public_key, + capabilities: 0, + name, + }) + } + + #[must_use] + pub fn with_capabilities(mut self, capabilities: u32) -> Self { + self.capabilities = capabilities; + self + } + + pub fn with_name(mut self, name: impl Into) -> Result { + self.name = QlName::new(name)?; + Ok(self) + } + + pub fn bundle(&self) -> PeerBundle { + PeerBundle { + version: PeerBundle::VERSION, + qid: self.qid, + capabilities: self.capabilities, + mlkem_public_key: self.mlkem_public_key.clone(), + name: self.name.clone(), + } + } +} + +impl WireEncode for QlIdentity { + fn encoded_len(&self) -> usize { + Self::FIXED_WIRE_SIZE + self.name.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.qid.encode(out); + self.mlkem_private_key.as_bytes().encode(out); + self.mlkem_public_key.encode(out); + self.capabilities.encode(out); + self.name.encode(out); + } +} + +impl codec::WireDecode for QlIdentity { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + qid: reader.decode()?, + mlkem_private_key: MlKemPrivateKey::new(reader.decode()?), + mlkem_public_key: reader.decode()?, + capabilities: reader.decode()?, + name: reader.decode()?, + }) + } +} + +pub fn generate_identity( + crypto: &impl QlCrypto, + name: impl Into, +) -> Result { + let MlKemKeyPair { + private: mlkem_private_key, + public: mlkem_public_key, + } = crypto.mlkem_generate_keypair(); + QlIdentity::new(crypto, mlkem_private_key, mlkem_public_key, name) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QlName(String); + +impl QlName { + pub const MAX_LEN: usize = 256; + + pub fn new(name: impl Into) -> Result { + let name = name.into(); + if name.is_empty() || name.len() > Self::MAX_LEN { + return Err(WireError::InvalidPayload); + } + Ok(Self(name)) + } +} + +impl Deref for QlName { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl WireEncode for QlName { + fn encoded_len(&self) -> usize { + let len = VarInt::try_from(self.0.len()).unwrap(); + len.encoded_len() + self.0.len() + } + + fn encode(&self, out: &mut W) { + VarInt::try_from(self.0.len()) + .expect("identity name length fits in varint") + .encode(out); + self.0.as_bytes().encode(out); + } +} + +impl codec::WireDecode for QlName { + fn decode(reader: &mut codec::Reader) -> Result { + let len = usize::try_from(reader.decode::()?.into_inner()) + .map_err(|_| WireError::InvalidPayload)?; + if len == 0 || len > Self::MAX_LEN { + return Err(WireError::InvalidPayload); + } + let bytes = reader.take_bytes(len)?; + let name = std::str::from_utf8(&bytes).map_err(|_| WireError::InvalidPayload)?; + QlName::new(name) + } +} diff --git a/ql-wire/src/lib.rs b/ql-wire/src/lib.rs new file mode 100644 index 0000000..1713745 --- /dev/null +++ b/ql-wire/src/lib.rs @@ -0,0 +1,45 @@ +//! +//! QuantumLink protocol wire format +//! + +#![allow(clippy::too_many_arguments)] + +mod bytes; +mod codec; +mod crypto; +mod encrypted; +mod encrypted_message; +mod error; +mod handshake; +mod header; +mod identity; +mod nonce; +mod pq; +mod qid; +mod record; +#[cfg(any(feature = "test-utils", test))] +mod testing; +mod varint; + +pub use bytes::*; +pub use codec::*; +pub use crypto::*; +pub use encrypted::*; +pub use encrypted_message::*; +pub use error::*; +pub use handshake::*; +pub use header::*; +pub use identity::*; +pub use nonce::*; +pub use pq::*; +pub use qid::*; +pub use record::*; +#[cfg(any(feature = "test-utils", test))] +pub use testing::*; +pub use varint::*; + +pub const QL_WIRE_VERSION: u8 = 1; +pub const ENCRYPTED_MESSAGE_AUTH_SIZE: usize = 16; + +#[cfg(test)] +mod tests; diff --git a/ql-wire/src/nonce.rs b/ql-wire/src/nonce.rs new file mode 100644 index 0000000..c7e6d79 --- /dev/null +++ b/ql-wire/src/nonce.rs @@ -0,0 +1,13 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct Nonce(pub [u8; Self::SIZE]); + +impl Nonce { + pub const SIZE: usize = 12; + + pub fn from_counter(counter: u64) -> Self { + let mut nonce = [0u8; Self::SIZE]; + nonce[4..].copy_from_slice(&counter.to_le_bytes()); + Self(nonce) + } +} diff --git a/ql-wire/src/pq.rs b/ql-wire/src/pq.rs new file mode 100644 index 0000000..327ef7c --- /dev/null +++ b/ql-wire/src/pq.rs @@ -0,0 +1,159 @@ +use crate::{codec, ByteSlice, WireEncode, WireError}; + +pub const ML_KEM_SUITE_TAG: &[u8] = b"ml-kem-1024"; + +// ql-wire fixes the protocol to ML-KEM-1024 on the wire, but the host +// platform is free to satisfy QlKem with any backend that produces the same +// serialized sizes. +const ML_KEM_1024_SHARED_SECRET_SIZE: usize = 32; +const ML_KEM_1024_PUBLIC_KEY_SIZE: usize = 1568; +const ML_KEM_1024_PRIVATE_KEY_SIZE: usize = 3168; +const ML_KEM_1024_CIPHERTEXT_SIZE: usize = 1568; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SessionKey([u8; Self::SIZE]); + +impl SessionKey { + pub const SIZE: usize = ML_KEM_1024_SHARED_SECRET_SIZE; + + pub const fn from_data(data: [u8; Self::SIZE]) -> Self { + Self(data) + } + + pub const fn data(&self) -> &[u8; Self::SIZE] { + &self.0 + } + + pub const fn as_bytes(&self) -> &[u8; Self::SIZE] { + &self.0 + } +} + +impl AsRef<[u8]> for SessionKey { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl Drop for SessionKey { + fn drop(&mut self) { + self.0.fill(0); + } +} + +impl WireEncode for SessionKey { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for SessionKey { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self::from_data(reader.decode()?)) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MlKemPublicKey(Box<[u8; MlKemPublicKey::SIZE]>); + +impl MlKemPublicKey { + pub const SIZE: usize = ML_KEM_1024_PUBLIC_KEY_SIZE; + + pub fn new(data: Box<[u8; Self::SIZE]>) -> Self { + Self(data) + } + + pub fn as_bytes(&self) -> &[u8; Self::SIZE] { + self.0.as_ref() + } +} + +impl Drop for MlKemPublicKey { + fn drop(&mut self) { + self.0.as_mut().fill(0); + } +} + +impl codec::WireDecode for MlKemPublicKey { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self::new(reader.decode()?)) + } +} + +impl WireEncode for MlKemPublicKey { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.as_ref().encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MlKemPrivateKey(Box<[u8; MlKemPrivateKey::SIZE]>); + +impl MlKemPrivateKey { + pub const SIZE: usize = ML_KEM_1024_PRIVATE_KEY_SIZE; + + pub fn new(data: Box<[u8; Self::SIZE]>) -> Self { + Self(data) + } + + pub fn as_bytes(&self) -> &[u8; Self::SIZE] { + self.0.as_ref() + } +} + +impl Drop for MlKemPrivateKey { + fn drop(&mut self) { + self.0.as_mut().fill(0); + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MlKemCiphertext(Box<[u8; MlKemCiphertext::SIZE]>); + +impl MlKemCiphertext { + pub const SIZE: usize = ML_KEM_1024_CIPHERTEXT_SIZE; + + pub fn new(data: Box<[u8; Self::SIZE]>) -> Self { + Self(data) + } + + pub fn as_bytes(&self) -> &[u8; Self::SIZE] { + self.0.as_ref() + } +} + +impl Drop for MlKemCiphertext { + fn drop(&mut self) { + self.0.as_mut().fill(0); + } +} + +impl codec::WireDecode for MlKemCiphertext { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self::new(reader.decode()?)) + } +} + +impl WireEncode for MlKemCiphertext { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.as_ref().encode(out); + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MlKemKeyPair { + pub private: MlKemPrivateKey, + pub public: MlKemPublicKey, +} diff --git a/ql-wire/src/qid.rs b/ql-wire/src/qid.rs new file mode 100644 index 0000000..55c6684 --- /dev/null +++ b/ql-wire/src/qid.rs @@ -0,0 +1,44 @@ +use crate::{codec, ByteSlice, MlKemPublicKey, QlHash, WireEncode, WireError, ML_KEM_SUITE_TAG}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct QID(pub [u8; Self::SIZE]); + +impl QID { + pub const SIZE: usize = 16; + + pub fn derive(crypto: &impl QlHash, mlkem_public_key: &MlKemPublicKey) -> Self { + let digest = crypto.sha256(&[ + b"quantum-link qid v1", + ML_KEM_SUITE_TAG, + mlkem_public_key.as_bytes(), + ]); + let mut qid = [0u8; Self::SIZE]; + qid.copy_from_slice(&digest[..Self::SIZE]); + Self(qid) + } + + pub fn matches_public_key( + &self, + crypto: &impl QlHash, + mlkem_public_key: &MlKemPublicKey, + ) -> bool { + *self == Self::derive(crypto, mlkem_public_key) + } +} + +impl WireEncode for QID { + fn encoded_len(&self) -> usize { + Self::SIZE + } + + fn encode(&self, out: &mut W) { + self.0.encode(out); + } +} + +impl codec::WireDecode for QID { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self(reader.decode()?)) + } +} diff --git a/ql-wire/src/record.rs b/ql-wire/src/record.rs new file mode 100644 index 0000000..163a1bf --- /dev/null +++ b/ql-wire/src/record.rs @@ -0,0 +1,254 @@ +use crate::{ + codec, + encrypted_message::EncryptedMessage, + handshake::{Ik1, Ik2, Kk1, Kk2, Xx1, Xx2, Xx3, Xx4}, + ByteSlice, SessionHeader, WireDecode, WireEncode, WireError, QL_WIRE_VERSION, +}; + +pub fn encode_record(out: &mut W, record_type: RecordType, body: &T) +where + W: bytes::BufMut + ?Sized, + T: WireEncode + ?Sized, +{ + RecordHeader { + version: QL_WIRE_VERSION, + record_type, + } + .encode(out); + body.encode(out); +} + +pub fn encode_record_vec(record_type: RecordType, body: &T) -> Vec { + let mut out = Vec::with_capacity(RecordHeader::WIRE_SIZE + body.encoded_len()); + encode_record(&mut out, record_type, body); + out +} + +pub fn decode_record(bytes: B) -> Result<(RecordHeader, T), WireError> +where + T: WireDecode, + B: ByteSlice, +{ + let mut reader = codec::Reader::new(bytes); + Ok((reader.decode()?, reader.decode()?)) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RecordHeader { + pub version: u8, + pub record_type: RecordType, +} + +impl RecordHeader { + pub const WIRE_SIZE: usize = size_of::() + size_of::(); +} + +impl WireDecode for RecordHeader { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + version: reader.decode()?, + record_type: reader.decode()?, + }) + } +} + +impl WireEncode for RecordHeader { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + out.put_u8(self.version); + self.record_type.encode(out); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum RecordType { + Handshake = 1, + Session = 2, +} + +impl TryFrom for RecordType { + type Error = WireError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::Handshake), + 2 => Ok(Self::Session), + _ => Err(WireError::InvalidPayload), + } + } +} + +impl WireDecode for RecordType { + fn decode(reader: &mut codec::Reader) -> Result { + reader.decode::()?.try_into() + } +} + +impl WireEncode for RecordType { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u8(*self as u8); + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum QlHandshakeRecord { + Ik1(Ik1), + Ik2(Ik2), + Kk1(Kk1), + Kk2(Kk2), + Xx1(Xx1), + Xx2(Xx2), + Xx3(Xx3), + Xx4(Xx4), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum HandshakeKind { + Ik1 = 1, + Ik2 = 2, + Kk1 = 3, + Kk2 = 4, + Xx1 = 5, + Xx2 = 6, + Xx3 = 7, + Xx4 = 8, +} + +impl TryFrom for HandshakeKind { + type Error = WireError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::Ik1), + 2 => Ok(Self::Ik2), + 3 => Ok(Self::Kk1), + 4 => Ok(Self::Kk2), + 5 => Ok(Self::Xx1), + 6 => Ok(Self::Xx2), + 7 => Ok(Self::Xx3), + 8 => Ok(Self::Xx4), + _ => Err(WireError::InvalidPayload), + } + } +} + +impl WireDecode for HandshakeKind { + fn decode(reader: &mut codec::Reader) -> Result { + reader.decode::()?.try_into() + } +} + +impl WireEncode for HandshakeKind { + fn encoded_len(&self) -> usize { + size_of::() + } + + fn encode(&self, out: &mut W) { + out.put_u8(*self as u8); + } +} + +impl QlHandshakeRecord { + pub fn kind(&self) -> HandshakeKind { + match self { + Self::Ik1(_) => HandshakeKind::Ik1, + Self::Ik2(_) => HandshakeKind::Ik2, + Self::Kk1(_) => HandshakeKind::Kk1, + Self::Kk2(_) => HandshakeKind::Kk2, + Self::Xx1(_) => HandshakeKind::Xx1, + Self::Xx2(_) => HandshakeKind::Xx2, + Self::Xx3(_) => HandshakeKind::Xx3, + Self::Xx4(_) => HandshakeKind::Xx4, + } + } +} + +impl WireEncode for QlHandshakeRecord { + fn encoded_len(&self) -> usize { + self.kind().encoded_len() + + match self { + Self::Ik1(message) => message.encoded_len(), + Self::Ik2(message) => message.encoded_len(), + Self::Kk1(message) => message.encoded_len(), + Self::Kk2(message) => message.encoded_len(), + Self::Xx1(message) => message.encoded_len(), + Self::Xx2(message) => message.encoded_len(), + Self::Xx3(message) => message.encoded_len(), + Self::Xx4(message) => message.encoded_len(), + } + } + + fn encode(&self, out: &mut W) { + self.kind().encode(out); + match self { + Self::Ik1(message) => message.encode(out), + Self::Ik2(message) => message.encode(out), + Self::Kk1(message) => message.encode(out), + Self::Kk2(message) => message.encode(out), + Self::Xx1(message) => message.encode(out), + Self::Xx2(message) => message.encode(out), + Self::Xx3(message) => message.encode(out), + Self::Xx4(message) => message.encode(out), + } + } +} + +impl WireDecode for QlHandshakeRecord { + fn decode(reader: &mut codec::Reader) -> Result { + let kind = reader.decode::()?; + match kind { + HandshakeKind::Ik1 => Ok(Self::Ik1(reader.decode()?)), + HandshakeKind::Ik2 => Ok(Self::Ik2(reader.decode()?)), + HandshakeKind::Kk1 => Ok(Self::Kk1(reader.decode()?)), + HandshakeKind::Kk2 => Ok(Self::Kk2(reader.decode()?)), + HandshakeKind::Xx1 => Ok(Self::Xx1(reader.decode()?)), + HandshakeKind::Xx2 => Ok(Self::Xx2(reader.decode()?)), + HandshakeKind::Xx3 => Ok(Self::Xx3(reader.decode()?)), + HandshakeKind::Xx4 => Ok(Self::Xx4(reader.decode()?)), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct QlSessionRecord { + pub header: SessionHeader, + pub payload: EncryptedMessage, +} + +impl> WireEncode for QlSessionRecord { + fn encoded_len(&self) -> usize { + self.header.encoded_len() + self.payload.encoded_len() + } + + fn encode(&self, out: &mut W) { + self.header.encode(out); + self.payload.encode(out); + } +} + +impl QlSessionRecord { + pub fn into_owned(self) -> QlSessionRecord> { + QlSessionRecord { + header: self.header, + payload: self.payload.into_owned(), + } + } +} + +impl WireDecode for QlSessionRecord { + fn decode(reader: &mut codec::Reader) -> Result { + Ok(Self { + header: reader.decode()?, + payload: reader.decode()?, + }) + } +} diff --git a/ql-wire/src/testing.rs b/ql-wire/src/testing.rs new file mode 100644 index 0000000..a1223c1 --- /dev/null +++ b/ql-wire/src/testing.rs @@ -0,0 +1,181 @@ +use libcrux_aesgcm::AesGcm256Key; +use libcrux_ml_kem::mlkem1024; +use sha2::{Digest, Sha256}; + +use crate::{ + MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey, Nonce, QlAead, QlCrypto, + QlHash, QlIdentity, QlKem, QlRandom, SessionKey, ENCRYPTED_MESSAGE_AUTH_SIZE, +}; + +#[derive(Debug, Default, Clone, Copy)] +pub struct SoftwareCrypto; + +#[derive(Debug, Default, Clone, Copy)] +pub struct NoopCrypto; + +pub fn test_identities(crypto: &impl QlCrypto) -> (QlIdentity, QlIdentity) { + ( + crate::generate_identity(crypto, "alice").unwrap(), + crate::generate_identity(crypto, "bob").unwrap(), + ) +} + +impl QlRandom for SoftwareCrypto { + fn fill_random_bytes(&self, out: &mut [u8]) { + getrandom::getrandom(out).unwrap(); + } +} + +impl QlHash for SoftwareCrypto { + fn sha256(&self, parts: &[&[u8]]) -> [u8; 32] { + let mut hasher = Sha256::new(); + for part in parts { + hasher.update(part); + } + hasher.finalize().into() + } +} + +impl QlAead for SoftwareCrypto { + fn aes256_gcm_encrypt( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE] { + let key: AesGcm256Key = (*key.data()).into(); + let plaintext = buffer.to_vec(); + let mut auth = [0u8; ENCRYPTED_MESSAGE_AUTH_SIZE]; + key.encrypt( + buffer, + (&mut auth).into(), + (&nonce.0).into(), + aad, + &plaintext, + ) + .unwrap(); + auth + } + + fn aes256_gcm_decrypt( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], + ) -> bool { + let key: AesGcm256Key = (*key.data()).into(); + let ciphertext = buffer.to_vec(); + key.decrypt(buffer, (&nonce.0).into(), aad, &ciphertext, auth_tag.into()) + .is_ok() + } +} + +impl QlKem for SoftwareCrypto { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair { + let key_pair = mlkem1024::generate_key_pair(random_array(self)); + let mut public = [0u8; MlKemPublicKey::SIZE]; + public.copy_from_slice(key_pair.pk()); + let mut private = [0u8; MlKemPrivateKey::SIZE]; + private.copy_from_slice(key_pair.sk()); + + MlKemKeyPair { + private: MlKemPrivateKey::new(Box::new(private)), + public: MlKemPublicKey::new(Box::new(public)), + } + } + + fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { + let public_key = public_key.as_bytes().into(); + let (ciphertext_value, shared_value) = + mlkem1024::encapsulate(&public_key, random_array(self)); + let mut ciphertext = [0u8; MlKemCiphertext::SIZE]; + ciphertext.copy_from_slice(ciphertext_value.as_slice()); + let mut shared = [0u8; SessionKey::SIZE]; + shared.copy_from_slice(shared_value.as_slice()); + ( + MlKemCiphertext::new(Box::new(ciphertext)), + SessionKey::from_data(shared), + ) + } + + fn mlkem_decapsulate( + &self, + private_key: &MlKemPrivateKey, + ciphertext: &MlKemCiphertext, + ) -> SessionKey { + let private_key = private_key.as_bytes().into(); + let ciphertext = ciphertext.as_bytes().into(); + let shared = mlkem1024::decapsulate(&private_key, &ciphertext); + let mut out = [0u8; SessionKey::SIZE]; + out.copy_from_slice(shared.as_slice()); + SessionKey::from_data(out) + } +} + +impl QlRandom for NoopCrypto { + fn fill_random_bytes(&self, out: &mut [u8]) { + out.fill(0); + } +} + +impl QlHash for NoopCrypto { + fn sha256(&self, _parts: &[&[u8]]) -> [u8; 32] { + [0; 32] + } +} + +impl QlAead for NoopCrypto { + fn aes256_gcm_encrypt( + &self, + _key: &SessionKey, + _nonce: &Nonce, + _aad: &[u8], + _buffer: &mut [u8], + ) -> [u8; ENCRYPTED_MESSAGE_AUTH_SIZE] { + [0; ENCRYPTED_MESSAGE_AUTH_SIZE] + } + + fn aes256_gcm_decrypt( + &self, + _key: &SessionKey, + _nonce: &Nonce, + _aad: &[u8], + _buffer: &mut [u8], + _auth_tag: &[u8; ENCRYPTED_MESSAGE_AUTH_SIZE], + ) -> bool { + false + } +} + +impl QlKem for NoopCrypto { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair { + MlKemKeyPair { + private: MlKemPrivateKey::new(Box::new([0; MlKemPrivateKey::SIZE])), + public: MlKemPublicKey::new(Box::new([0; MlKemPublicKey::SIZE])), + } + } + + fn mlkem_encapsulate(&self, _public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { + ( + MlKemCiphertext::new(Box::new([0; MlKemCiphertext::SIZE])), + SessionKey::from_data([0; SessionKey::SIZE]), + ) + } + + fn mlkem_decapsulate( + &self, + _private_key: &MlKemPrivateKey, + _ciphertext: &MlKemCiphertext, + ) -> SessionKey { + SessionKey::from_data([0; SessionKey::SIZE]) + } +} + +fn random_array(crypto: &impl QlRandom) -> [u8; L] { + let mut out = [0u8; L]; + crypto.fill_random_bytes(&mut out); + out +} diff --git a/ql-wire/src/tests.rs b/ql-wire/src/tests.rs new file mode 100644 index 0000000..a09a36b --- /dev/null +++ b/ql-wire/src/tests.rs @@ -0,0 +1,1017 @@ +use std::ops::RangeInclusive; + +use super::*; + +fn decode_handshake_record(bytes: &[u8]) -> QlHandshakeRecord { + decode_record(bytes).unwrap().1 +} + +fn decode_session_record(bytes: &[u8]) -> QlSessionRecord> { + let (_, record) = decode_record::, _>(bytes).unwrap(); + record.into_owned() +} + +fn qid(byte: u8) -> QID { + QID([byte; QID::SIZE]) +} + +fn varint(value: u64) -> VarInt { + VarInt::from_u64(value).unwrap() +} + +fn record_seq(value: u64) -> RecordSeq { + RecordSeq(varint(value)) +} + +fn record_ack_range(start: u64, end: u64) -> RangeInclusive { + record_seq(start)..=record_seq(end) +} + +fn stream_id(value: u64) -> StreamId { + StreamId(varint(value)) +} + +fn handshake_meta(id: u32) -> HandshakeMeta { + HandshakeMeta { + handshake_id: HandshakeId(id), + } +} + +fn handshake_transport_params(window: u32) -> TransportParams { + TransportParams { + initial_stream_receive_window: window, + } +} + +fn handshake_header(sender: u8, recipient: u8) -> HandshakeHeader { + HandshakeHeader { + sender: qid(sender), + recipient: qid(recipient), + } +} + +fn pairing_token(byte: u8) -> PairingToken { + PairingToken([byte; PairingToken::SIZE]) +} + +fn pairing_id(byte: u8) -> PairingId { + PairingId([byte; PairingId::SIZE]) +} + +fn xx_header(sender: u8, recipient: u8) -> HandshakeHeader { + HandshakeHeader { + sender: qid(sender), + recipient: qid(recipient), + } +} + +fn encrypt_record( + crypto: &impl QlCrypto, + header: SessionHeader, + session_key: &SessionKey, + body: &[SessionFrame>], +) -> QlSessionRecord> { + let mut builder = SessionRecordBuilder::new(header.seq, usize::MAX); + for frame in body { + let pushed = builder.push_frame(frame); + debug_assert!(pushed); + } + decode_session_record( + builder + .encrypt(crypto, header.connection_id, session_key) + .as_slice(), + ) +} + +#[test] +fn peer_bundle_round_trip() { + let crypto = SoftwareCrypto; + let identity = generate_identity(&crypto, "alice") + .unwrap() + .with_capabilities(0x55aa_33cc); + let bundle = identity.bundle(); + + let encoded = bundle.encode_vec(); + let decoded = PeerBundle::decode_exact(encoded.as_slice()).unwrap(); + + assert_eq!(decoded, bundle); + assert_eq!(&*decoded.name, "alice"); +} + +#[test] +fn identity_name_validation() { + assert_eq!( + QlName::new("a".repeat(QlName::MAX_LEN)).unwrap().len(), + QlName::MAX_LEN + ); + assert!(matches!(QlName::new(""), Err(WireError::InvalidPayload))); + assert!(matches!( + QlName::new("a".repeat(QlName::MAX_LEN + 1)), + Err(WireError::InvalidPayload) + )); +} + +#[test] +fn qid_derives_from_mlkem_public_key() { + let crypto = SoftwareCrypto; + let public_key = MlKemPublicKey::new(Box::new([42; MlKemPublicKey::SIZE])); + let qid = QID::derive(&crypto, &public_key); + + let digest = crypto.sha256(&[ + b"quantum-link qid v1", + ML_KEM_SUITE_TAG, + public_key.as_bytes(), + ]); + let mut expected = [0u8; QID::SIZE]; + expected.copy_from_slice(&digest[..QID::SIZE]); + + assert_eq!(qid, QID(expected)); + assert!(qid.matches_public_key(&crypto, &public_key)); +} + +#[test] +fn qid_changes_when_mlkem_public_key_changes() { + let crypto = SoftwareCrypto; + let first = MlKemPublicKey::new(Box::new([1; MlKemPublicKey::SIZE])); + let second = MlKemPublicKey::new(Box::new([2; MlKemPublicKey::SIZE])); + + assert_ne!(QID::derive(&crypto, &first), QID::derive(&crypto, &second)); +} + +#[test] +fn peer_bundle_detects_tampered_qid() { + let crypto = SoftwareCrypto; + let identity = generate_identity(&crypto, "alice").unwrap(); + let mut bundle = identity.bundle(); + + bundle.qid = qid(9); + + assert!(!bundle.qid_matches_public_key(&crypto)); +} + +#[test] +fn handshake_record_round_trip_supports_ik_kk_and_xx() { + let ik = QlHandshakeRecord::Ik1(Ik1 { + header: handshake_header(1, 2), + meta: handshake_meta(1), + transport_params: handshake_transport_params(65_536), + skem_ciphertext: MlKemCiphertext::new(Box::new([7; MlKemCiphertext::SIZE])), + ephemeral: EphemeralPublicKey { + mlkem_public_key: MlKemPublicKey::new(Box::new([9; MlKemPublicKey::SIZE])), + }, + static_bundle: EncryptedPeerBundle(vec![13; 64].into_boxed_slice()), + }); + let ik_encoded = encode_record_vec(RecordType::Handshake, &ik); + assert_eq!( + RecordHeader::decode_bytes(ik_encoded.as_slice()).unwrap(), + RecordHeader { + version: QL_WIRE_VERSION, + record_type: RecordType::Handshake, + } + ); + assert_eq!(decode_handshake_record(ik_encoded.as_slice()), ik); + + let kk = QlHandshakeRecord::Kk1(Kk1 { + header: handshake_header(1, 2), + meta: handshake_meta(2), + transport_params: handshake_transport_params(131_072), + skem_ciphertext: MlKemCiphertext::new(Box::new([11; MlKemCiphertext::SIZE])), + ephemeral: EphemeralPublicKey { + mlkem_public_key: MlKemPublicKey::new(Box::new([15; MlKemPublicKey::SIZE])), + }, + }); + let kk_encoded = encode_record_vec(RecordType::Handshake, &kk); + assert_eq!( + RecordHeader::decode_bytes(kk_encoded.as_slice()).unwrap(), + RecordHeader { + version: QL_WIRE_VERSION, + record_type: RecordType::Handshake, + } + ); + assert_eq!(decode_handshake_record(kk_encoded.as_slice()), kk); + + let xx = QlHandshakeRecord::Xx1(Xx1 { + header: xx_header(1, 2), + meta: handshake_meta(3), + pairing_id: pairing_id(3), + transport_params: handshake_transport_params(196_608), + ephemeral: EphemeralPublicKey { + mlkem_public_key: MlKemPublicKey::new(Box::new([17; MlKemPublicKey::SIZE])), + }, + }); + let xx_encoded = encode_record_vec(RecordType::Handshake, &xx); + assert_eq!( + RecordHeader::decode_bytes(xx_encoded.as_slice()).unwrap(), + RecordHeader { + version: QL_WIRE_VERSION, + record_type: RecordType::Handshake, + } + ); + assert_eq!(decode_handshake_record(xx_encoded.as_slice()), xx); +} + +#[test] +fn ik_handshake_rejects_tampered_handshake_meta() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator, + responder.bundle(), + TransportParams::default(), + ); + let mut responder_state = + IkHandshake::new_responder(&crypto, responder, None, TransportParams::default()); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(77)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let mut m2 = responder_state + .write_2(&crypto, handshake_meta(77)) + .unwrap(); + m2.meta.handshake_id = HandshakeId(78); + + assert_eq!( + initiator_state.read_2(&crypto, &m2), + Err(WireError::InvalidHandshakeMeta) + ); +} + +#[test] +fn kk_handshake_rejects_tampered_handshake_header() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let mut initiator_state = KkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + TransportParams::default(), + ); + let mut responder_state = KkHandshake::new_responder( + &crypto, + responder, + initiator.bundle(), + TransportParams::default(), + ); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(88)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let mut m2 = responder_state + .write_2(&crypto, handshake_meta(88)) + .unwrap(); + m2.header = handshake_header(9, 1); + + assert_eq!( + initiator_state.read_2(&crypto, &m2), + Err(WireError::InvalidPayload) + ); +} + +#[test] +fn ik_handshake_rejects_tampered_transport_params() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator, + responder.bundle(), + handshake_transport_params(4096), + ); + let mut responder_state = + IkHandshake::new_responder(&crypto, responder, None, handshake_transport_params(8192)); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(89)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let mut m2 = responder_state + .write_2(&crypto, handshake_meta(89)) + .unwrap(); + m2.transport_params.initial_stream_receive_window += 1; + + assert_eq!( + initiator_state.read_2(&crypto, &m2), + Err(WireError::DecryptFailed) + ); +} + +#[test] +fn ik_handshake_rejects_tampered_handshake_header() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator, + responder.bundle(), + TransportParams::default(), + ); + let mut responder_state = + IkHandshake::new_responder(&crypto, responder, None, TransportParams::default()); + + let mut m1 = initiator_state + .write_1(&crypto, handshake_meta(90)) + .unwrap(); + m1.header.sender = qid(9); + + assert_eq!( + responder_state.read_1(&crypto, &m1), + Err(WireError::DecryptFailed) + ); +} + +#[test] +fn ik_handshake_rejects_bound_remote_bundle_mismatch() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + let bogus = generate_identity(&crypto, "bogus").unwrap(); + + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator, + responder.bundle(), + TransportParams::default(), + ); + let mut responder_state = IkHandshake::new_responder( + &crypto, + responder, + Some(bogus.bundle()), + TransportParams::default(), + ); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(91)) + .unwrap(); + + assert_eq!( + responder_state.read_1(&crypto, &m1), + Err(WireError::InvalidPayload) + ); +} + +#[test] +fn ik_handshake_round_trip_derives_matching_transport_and_learns_remote() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let initiator_params = handshake_transport_params(4096); + let responder_params = handshake_transport_params(8192); + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + initiator_params, + ); + let mut responder_state = + IkHandshake::new_responder(&crypto, responder.clone(), None, responder_params); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(11)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let m2 = responder_state + .write_2(&crypto, handshake_meta(11)) + .unwrap(); + initiator_state.read_2(&crypto, &m2).unwrap(); + + let initiator_final = initiator_state.finalize(&crypto).unwrap(); + let responder_final = responder_state.finalize(&crypto).unwrap(); + + assert_eq!( + initiator_final.handshake_hash, + responder_final.handshake_hash + ); + assert_eq!(initiator_final.tx_key, responder_final.rx_key); + assert_eq!(initiator_final.rx_key, responder_final.tx_key); + assert_eq!( + initiator_final.tx_connection_id, + responder_final.rx_connection_id + ); + assert_eq!( + initiator_final.rx_connection_id, + responder_final.tx_connection_id + ); + assert_eq!(initiator_final.remote_bundle, responder.bundle()); + assert_eq!(responder_final.remote_bundle, initiator.bundle()); + assert_eq!(initiator_final.remote_transport_params, responder_params); + assert_eq!(responder_final.remote_transport_params, initiator_params); +} + +#[test] +fn ik_handshake_round_trip_derives_matching_transport_with_bound_responder() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let initiator_params = handshake_transport_params(16_384); + let responder_params = handshake_transport_params(32_768); + let mut initiator_state = IkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + initiator_params, + ); + let mut responder_state = IkHandshake::new_responder( + &crypto, + responder.clone(), + Some(initiator.bundle()), + responder_params, + ); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(12)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let m2 = responder_state + .write_2(&crypto, handshake_meta(12)) + .unwrap(); + initiator_state.read_2(&crypto, &m2).unwrap(); + + let initiator_final = initiator_state.finalize(&crypto).unwrap(); + let responder_final = responder_state.finalize(&crypto).unwrap(); + + assert_eq!( + initiator_final.handshake_hash, + responder_final.handshake_hash + ); + assert_eq!(initiator_final.tx_key, responder_final.rx_key); + assert_eq!(initiator_final.rx_key, responder_final.tx_key); + assert_eq!( + initiator_final.tx_connection_id, + responder_final.rx_connection_id + ); + assert_eq!( + initiator_final.rx_connection_id, + responder_final.tx_connection_id + ); + assert_eq!(initiator_final.remote_bundle, responder.bundle()); + assert_eq!(responder_final.remote_bundle, initiator.bundle()); + assert_eq!(initiator_final.remote_transport_params, responder_params); + assert_eq!(responder_final.remote_transport_params, initiator_params); +} + +#[test] +fn kk_handshake_round_trip_derives_matching_transport() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let initiator_params = handshake_transport_params(24_576); + let responder_params = handshake_transport_params(49_152); + let mut initiator_state = KkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + initiator_params, + ); + let mut responder_state = KkHandshake::new_responder( + &crypto, + responder.clone(), + initiator.bundle(), + responder_params, + ); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(21)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let m2 = responder_state + .write_2(&crypto, handshake_meta(21)) + .unwrap(); + initiator_state.read_2(&crypto, &m2).unwrap(); + + let initiator_final = initiator_state.finalize(&crypto).unwrap(); + let responder_final = responder_state.finalize(&crypto).unwrap(); + + assert_eq!( + initiator_final.handshake_hash, + responder_final.handshake_hash + ); + assert_eq!(initiator_final.tx_key, responder_final.rx_key); + assert_eq!(initiator_final.rx_key, responder_final.tx_key); + assert_eq!( + initiator_final.tx_connection_id, + responder_final.rx_connection_id + ); + assert_eq!( + initiator_final.rx_connection_id, + responder_final.tx_connection_id + ); + assert_eq!(initiator_final.remote_bundle, responder.bundle()); + assert_eq!(responder_final.remote_bundle, initiator.bundle()); + assert_eq!(initiator_final.remote_transport_params, responder_params); + assert_eq!(responder_final.remote_transport_params, initiator_params); +} + +#[test] +fn kk_handshake_rejects_tampered_transport_params() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let mut initiator_state = KkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + handshake_transport_params(12288), + ); + let mut responder_state = KkHandshake::new_responder( + &crypto, + responder, + initiator.bundle(), + handshake_transport_params(24576), + ); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(22)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let mut m2 = responder_state + .write_2(&crypto, handshake_meta(22)) + .unwrap(); + m2.transport_params.initial_stream_receive_window += 1; + + assert_eq!( + initiator_state.read_2(&crypto, &m2), + Err(WireError::DecryptFailed) + ); +} + +#[test] +fn xx_handshake_rejects_tampered_pairing_id() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + let token = pairing_token(7); + + let mut initiator_state = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.qid, + token, + TransportParams::default(), + ); + let mut responder_state = XxHandshake::new_responder( + &crypto, + responder, + initiator.qid, + token, + TransportParams::default(), + ); + + let mut m1 = initiator_state + .write_1(&crypto, handshake_meta(31)) + .unwrap(); + m1.pairing_id = pairing_id(8); + + assert_eq!( + responder_state.read_1(&crypto, &m1), + Err(WireError::InvalidPairingId) + ); +} + +#[test] +fn xx_handshake_rejects_tampered_sender_or_recipient() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + let token = pairing_token(7); + + let mut initiator_state = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.qid, + token, + TransportParams::default(), + ); + let mut responder_state = XxHandshake::new_responder( + &crypto, + responder.clone(), + initiator.qid, + token, + TransportParams::default(), + ); + + let mut m1 = initiator_state + .write_1(&crypto, handshake_meta(31)) + .unwrap(); + m1.header.sender = responder.qid; + + assert_eq!( + responder_state.read_1(&crypto, &m1), + Err(WireError::InvalidHandshakeHeader) + ); + + let mut initiator_state = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.qid, + token, + TransportParams::default(), + ); + let mut responder_state = XxHandshake::new_responder( + &crypto, + responder.clone(), + initiator.qid, + token, + TransportParams::default(), + ); + + let mut m1 = initiator_state + .write_1(&crypto, handshake_meta(31)) + .unwrap(); + m1.header.recipient = initiator.qid; + + assert_eq!( + responder_state.read_1(&crypto, &m1), + Err(WireError::InvalidHandshakeHeader) + ); +} + +#[test] +fn xx_handshake_rejects_repeated_transport_param_change() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + let token = pairing_token(9); + + let mut initiator_state = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.qid, + token, + handshake_transport_params(12_288), + ); + let mut responder_state = XxHandshake::new_responder( + &crypto, + responder, + initiator.qid, + token, + handshake_transport_params(24_576), + ); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(32)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let m2 = responder_state + .write_2(&crypto, handshake_meta(32)) + .unwrap(); + initiator_state.read_2(&crypto, &m2).unwrap(); + + let mut m3 = initiator_state + .write_3(&crypto, handshake_meta(32)) + .unwrap(); + m3.transport_params.initial_stream_receive_window += 1; + + assert_eq!( + responder_state.read_3(&crypto, &m3), + Err(WireError::InvalidTransportParams) + ); +} + +#[test] +fn xx_handshake_round_trip_derives_matching_transport_and_learns_remote() { + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + let token = pairing_token(10); + + let initiator_params = handshake_transport_params(28_672); + let responder_params = handshake_transport_params(57_344); + let mut initiator_state = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.qid, + token, + initiator_params, + ); + let mut responder_state = XxHandshake::new_responder( + &crypto, + responder.clone(), + initiator.qid, + token, + responder_params, + ); + + assert_eq!(initiator_state.pairing_token(), token); + assert_eq!(responder_state.pairing_token(), token); + assert_eq!(initiator_state.pairing_id(&crypto), token.id(&crypto)); + assert_eq!(responder_state.pairing_id(&crypto), token.id(&crypto)); + assert!(initiator_state.remote_bundle().is_none()); + assert!(responder_state.remote_bundle().is_none()); + + let m1 = initiator_state + .write_1(&crypto, handshake_meta(33)) + .unwrap(); + responder_state.read_1(&crypto, &m1).unwrap(); + + let m2 = responder_state + .write_2(&crypto, handshake_meta(33)) + .unwrap(); + initiator_state.read_2(&crypto, &m2).unwrap(); + assert_eq!(initiator_state.remote_bundle(), Some(&responder.bundle())); + assert!(responder_state.remote_bundle().is_none()); + + let m3 = initiator_state + .write_3(&crypto, handshake_meta(33)) + .unwrap(); + responder_state.read_3(&crypto, &m3).unwrap(); + assert_eq!(responder_state.remote_bundle(), Some(&initiator.bundle())); + + let m4 = responder_state + .write_4(&crypto, handshake_meta(33)) + .unwrap(); + initiator_state.read_4(&crypto, &m4).unwrap(); + + let initiator_final = initiator_state.finalize(&crypto).unwrap(); + let responder_final = responder_state.finalize(&crypto).unwrap(); + + assert_eq!( + initiator_final.handshake_hash, + responder_final.handshake_hash + ); + assert_eq!(initiator_final.tx_key, responder_final.rx_key); + assert_eq!(initiator_final.rx_key, responder_final.tx_key); + assert_eq!( + initiator_final.tx_connection_id, + responder_final.rx_connection_id + ); + assert_eq!( + initiator_final.rx_connection_id, + responder_final.tx_connection_id + ); + assert_eq!(initiator_final.remote_bundle, responder.bundle()); + assert_eq!(responder_final.remote_bundle, initiator.bundle()); + assert_eq!(initiator_final.remote_transport_params, responder_params); + assert_eq!(responder_final.remote_transport_params, initiator_params); +} + +#[test] +fn encrypted_session_record_round_trip_uses_connection_id_header() { + let crypto = SoftwareCrypto; + let header = SessionHeader { + connection_id: ConnectionId::from_data([0x44; ConnectionId::SIZE]), + seq: record_seq(11), + }; + let body = vec![ + SessionFrame::Ping, + SessionFrame::Unpair, + SessionFrame::Ack( + RecordAck::from_ranges([record_ack_range(20, 23), record_ack_range(12, 13)]).unwrap(), + ), + SessionFrame::StreamWindow(StreamWindow { + stream_id: stream_id(9), + maximum_offset: varint(65_536), + }), + SessionFrame::StreamData(StreamData { + stream_id: stream_id(9), + offset: varint(1024), + header: None, + bytes: b"hello".to_vec(), + fin: true, + }), + SessionFrame::StreamClose(StreamClose { + stream_id: stream_id(9), + target: CloseTarget::Both, + code: StreamCloseCode::CANCELLED, + }), + SessionFrame::Close(SessionClose { + code: SessionCloseCode::TIMEOUT, + }), + ]; + let session_key = SessionKey::from_data([7; SessionKey::SIZE]); + let record = encrypt_record(&crypto, header, &session_key, &body); + + let bytes = encode_record_vec(RecordType::Session, &record); + assert_eq!( + RecordHeader::decode_bytes(bytes.as_slice()).unwrap(), + RecordHeader { + version: QL_WIRE_VERSION, + record_type: RecordType::Session, + } + ); + let decoded = decode_session_record(bytes.as_slice()); + assert_eq!(decoded.header, header); + let encrypted = decoded.payload; + + let decrypted = + encrypted::decrypt_record(&crypto, &header, encrypted.clone(), &session_key).unwrap(); + assert_eq!(decode_session_frames(&decrypted).unwrap(), body); + + let wrong_header = SessionHeader { + connection_id: ConnectionId::from_data([0x99; ConnectionId::SIZE]), + seq: header.seq, + }; + assert_eq!( + encrypted::decrypt_record(&crypto, &wrong_header, encrypted.clone(), &session_key), + Err(WireError::DecryptFailed) + ); + + let wrong_seq_header = SessionHeader { + connection_id: header.connection_id, + seq: record_seq(header.seq.into_inner() + 1), + }; + assert_eq!( + encrypted::decrypt_record(&crypto, &wrong_seq_header, encrypted, &session_key), + Err(WireError::DecryptFailed) + ); +} + +#[test] +fn session_varint_fields_expand_at_expected_boundaries() { + let short_header = SessionHeader { + connection_id: ConnectionId::from_data([0x11; ConnectionId::SIZE]), + seq: record_seq(63), + }; + let long_header = SessionHeader { + connection_id: ConnectionId::from_data([0x11; ConnectionId::SIZE]), + seq: record_seq(64), + }; + + assert_eq!(short_header.encode_vec().len(), ConnectionId::SIZE + 1); + assert_eq!(long_header.encode_vec().len(), ConnectionId::SIZE + 2); + + let frame = StreamData { + stream_id: stream_id(64), + offset: varint(16_384), + header: None, + fin: true, + bytes: b"abc".to_vec(), + }; + let encoded = frame.encode_vec(); + + assert_eq!( + StreamData::decode_exact(encoded.as_slice()) + .unwrap() + .into_owned(), + frame + ); +} + +#[test] +fn protocol_record_size_breakdown() { + fn print_size(label: &str, size: usize) { + println!("{label:<32}: {size} bytes"); + } + + let crypto = SoftwareCrypto; + let (initiator, responder) = test_identities(&crypto); + + let mut ik_initiator = IkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + TransportParams::default(), + ); + let mut ik_responder = + IkHandshake::new_responder(&crypto, responder.clone(), None, TransportParams::default()); + + let ik1 = ik_initiator.write_1(&crypto, handshake_meta(101)).unwrap(); + ik_responder.read_1(&crypto, &ik1).unwrap(); + + let ik2 = ik_responder.write_2(&crypto, handshake_meta(101)).unwrap(); + ik_initiator.read_2(&crypto, &ik2).unwrap(); + + let ik1 = QlHandshakeRecord::Ik1(ik1); + let ik2 = QlHandshakeRecord::Ik2(ik2); + + let mut kk_initiator = KkHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.bundle(), + TransportParams::default(), + ); + let mut kk_responder = KkHandshake::new_responder( + &crypto, + responder.clone(), + initiator.bundle(), + TransportParams::default(), + ); + + let kk1 = kk_initiator.write_1(&crypto, handshake_meta(201)).unwrap(); + kk_responder.read_1(&crypto, &kk1).unwrap(); + + let kk2 = kk_responder.write_2(&crypto, handshake_meta(201)).unwrap(); + kk_initiator.read_2(&crypto, &kk2).unwrap(); + + let kk1 = QlHandshakeRecord::Kk1(kk1); + let kk2 = QlHandshakeRecord::Kk2(kk2); + + let token = pairing_token(0x42); + let mut xx_initiator = XxHandshake::new_initiator( + &crypto, + initiator.clone(), + responder.qid, + token, + TransportParams::default(), + ); + let mut xx_responder = XxHandshake::new_responder( + &crypto, + responder.clone(), + initiator.qid, + token, + TransportParams::default(), + ); + + let xx1 = xx_initiator.write_1(&crypto, handshake_meta(301)).unwrap(); + xx_responder.read_1(&crypto, &xx1).unwrap(); + + let xx2 = xx_responder.write_2(&crypto, handshake_meta(301)).unwrap(); + xx_initiator.read_2(&crypto, &xx2).unwrap(); + + let xx3 = xx_initiator.write_3(&crypto, handshake_meta(301)).unwrap(); + xx_responder.read_3(&crypto, &xx3).unwrap(); + + let xx4 = xx_responder.write_4(&crypto, handshake_meta(301)).unwrap(); + xx_initiator.read_4(&crypto, &xx4).unwrap(); + + let xx1 = QlHandshakeRecord::Xx1(xx1); + let xx2 = QlHandshakeRecord::Xx2(xx2); + let xx3 = QlHandshakeRecord::Xx3(xx3); + let xx4 = QlHandshakeRecord::Xx4(xx4); + + let session = ik_initiator.finalize(&crypto).unwrap(); + let session_ping = encrypt_record( + &crypto, + SessionHeader { + connection_id: session.tx_connection_id, + seq: record_seq(1), + }, + &session.tx_key, + &[SessionFrame::Ping], + ); + let session_ack = encrypt_record( + &crypto, + SessionHeader { + connection_id: session.tx_connection_id, + seq: record_seq(2), + }, + &session.tx_key, + &[SessionFrame::Ack( + RecordAck::from_ranges([record_ack_range(6, 6), record_ack_range(1, 2)]).unwrap(), + )], + ); + let session_unpair = encrypt_record( + &crypto, + SessionHeader { + connection_id: session.tx_connection_id, + seq: record_seq(3), + }, + &session.tx_key, + &[SessionFrame::Unpair], + ); + let session_stream_empty = encrypt_record( + &crypto, + SessionHeader { + connection_id: session.tx_connection_id, + seq: record_seq(4), + }, + &session.tx_key, + &[SessionFrame::StreamData(StreamData { + stream_id: stream_id(1), + offset: varint(0), + header: None, + fin: false, + bytes: Vec::new(), + })], + ); + let session_close = encrypt_record( + &crypto, + SessionHeader { + connection_id: session.tx_connection_id, + seq: record_seq(5), + }, + &session.tx_key, + &[SessionFrame::Close(SessionClose { + code: SessionCloseCode::PROTOCOL, + })], + ); + + print_size("ql-wire peer bundle", initiator.bundle().encode_vec().len()); + print_size("ql-wire mlkem public key", MlKemPublicKey::SIZE); + print_size("ql-wire mlkem ciphertext", MlKemCiphertext::SIZE); + print_size("ql-wire pq ik1", ik1.encode_vec().len()); + print_size("ql-wire pq ik2", ik2.encode_vec().len()); + print_size("ql-wire pq kk1", kk1.encode_vec().len()); + print_size("ql-wire pq kk2", kk2.encode_vec().len()); + print_size("ql-wire pq xx1", xx1.encode_vec().len()); + print_size("ql-wire pq xx2", xx2.encode_vec().len()); + print_size("ql-wire pq xx3", xx3.encode_vec().len()); + print_size("ql-wire pq xx4", xx4.encode_vec().len()); + print_size("ql-wire session ping", session_ping.encode_vec().len()); + print_size("ql-wire session ack", session_ack.encode_vec().len()); + print_size("ql-wire session unpair", session_unpair.encode_vec().len()); + print_size( + "ql-wire session stream empty", + session_stream_empty.encode_vec().len(), + ); + print_size("ql-wire session close", session_close.encode_vec().len()); +} diff --git a/ql-wire/src/varint.rs b/ql-wire/src/varint.rs new file mode 100644 index 0000000..7a39bd1 --- /dev/null +++ b/ql-wire/src/varint.rs @@ -0,0 +1,181 @@ +use core::fmt; + +use bytes::BufMut; + +use crate::{ByteSlice, Reader, WireDecode, WireEncode, WireError}; + +/// An integer less than 2^62 encoded with QUIC variable-length integer rules. +#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct VarInt(pub(crate) u64); + +impl VarInt { + /// The largest representable value. + pub const MAX: Self = Self((1u64 << 62) - 1); + /// The largest encoded value length. + pub const MAX_SIZE: usize = 8; + pub const MIN_SIZE: usize = 1; + + /// Construct a `VarInt` infallibly from a `u32`. + pub const fn from_u32(x: u32) -> Self { + Self(x as u64) + } + + /// Construct a `VarInt` from a `u64`. + pub fn from_u64(x: u64) -> Result { + if x < (1u64 << 62) { + Ok(Self(x)) + } else { + Err(VarIntBoundsExceeded) + } + } + + /// Create a `VarInt` without checking the bounds. + /// + /// # Safety + /// + /// `x` must be less than 2^62. + pub const unsafe fn from_u64_unchecked(x: u64) -> Self { + Self(x) + } + + /// Extract the inner integer value. + pub const fn into_inner(self) -> u64 { + self.0 + } + + /// Return the number of bytes required to encode this value. + pub const fn size(self) -> usize { + let x = self.0; + if x < (1u64 << 6) { + 1 + } else if x < (1u64 << 14) { + 2 + } else if x < (1u64 << 30) { + 4 + } else { + 8 + } + } +} + +impl From for u64 { + fn from(value: VarInt) -> Self { + value.0 + } +} + +impl From for VarInt { + fn from(value: u8) -> Self { + Self(value.into()) + } +} + +impl From for VarInt { + fn from(value: u16) -> Self { + Self(value.into()) + } +} + +impl From for VarInt { + fn from(value: u32) -> Self { + Self(value.into()) + } +} + +impl TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + + fn try_from(value: u64) -> Result { + Self::from_u64(value) + } +} + +impl TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + + fn try_from(value: u128) -> Result { + Self::from_u64(value.try_into().map_err(|_| VarIntBoundsExceeded)?) + } +} + +impl TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + + fn try_from(value: usize) -> Result { + Self::from_u64(value as u64) + } +} + +impl fmt::Debug for VarInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl fmt::Display for VarInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct VarIntBoundsExceeded; + +impl fmt::Display for VarIntBoundsExceeded { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("value too large for varint encoding") + } +} + +impl std::error::Error for VarIntBoundsExceeded {} + +impl WireDecode for VarInt { + fn decode(reader: &mut Reader) -> Result { + let first = reader.decode::()?; + let tag = first >> 6; + let first = first & 0b0011_1111; + let value = match tag { + 0b00 => u64::from(first), + 0b01 => { + let mut buf = [0; 2]; + buf[0] = first; + buf[1] = reader.decode()?; + u64::from(u16::from_be_bytes(buf)) + } + 0b10 => { + let mut buf = [0; 4]; + buf[0] = first; + buf[1..].copy_from_slice(&reader.take_bytes(3)?); + u64::from(u32::from_be_bytes(buf)) + } + 0b11 => { + let mut buf = [0; 8]; + buf[0] = first; + buf[1..].copy_from_slice(&reader.take_bytes(7)?); + u64::from_be_bytes(buf) + } + _ => unreachable!(), + }; + + // SAFETY: the decoded value is guaranteed to fit in the 62-bit varint range. + Ok(unsafe { Self::from_u64_unchecked(value) }) + } +} + +impl WireEncode for VarInt { + fn encoded_len(&self) -> usize { + self.size() + } + + #[allow(clippy::cast_possible_truncation)] + fn encode(&self, out: &mut W) { + let x = self.into_inner(); + match self.size() { + 1 => out.put_u8(x as u8), + 2 => out.put_u16((0b01 << 14) | x as u16), + 4 => out.put_u32((0b10 << 30) | x as u32), + 8 => out.put_u64((0b11 << 62) | x), + _ => unreachable!("malformed varint"), + } + } +} From 1a09d0a5bd6ca907a18d0fa74a54aff13c9cf2ae Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 4 Jun 2026 09:24:20 -0400 Subject: [PATCH 3/6] ql-fsm: add synchronous protocol state machine --- Cargo.lock | 144 ++- Cargo.toml | 2 + ql-fsm/Cargo.toml | 15 + ql-fsm/src/error.rs | 124 +++ ql-fsm/src/fsm.rs | 267 +++++ ql-fsm/src/handshake/ik.rs | 119 +++ ql-fsm/src/handshake/kk.rs | 118 +++ ql-fsm/src/handshake/mod.rs | 140 +++ ql-fsm/src/handshake/xx.rs | 207 ++++ ql-fsm/src/lib.rs | 334 ++++++ ql-fsm/src/pairing.rs | 38 + ql-fsm/src/session/ack_tracker.rs | 266 +++++ ql-fsm/src/session/mod.rs | 1041 +++++++++++++++++++ ql-fsm/src/session/range_set.rs | 221 ++++ ql-fsm/src/session/remote_stream_history.rs | 60 ++ ql-fsm/src/session/state.rs | 140 +++ ql-fsm/src/session/stream_ops.rs | 147 +++ ql-fsm/src/session/stream_parity.rs | 44 + ql-fsm/src/session/stream_rx.rs | 428 ++++++++ ql-fsm/src/session/stream_tx.rs | 579 +++++++++++ ql-fsm/src/session/tests.rs | 869 ++++++++++++++++ ql-fsm/src/session/tracked.rs | 29 + ql-fsm/src/state.rs | 139 +++ ql-fsm/src/tests/handshake.rs | 388 +++++++ ql-fsm/src/tests/mod.rs | 351 +++++++ ql-fsm/src/tests/proptest.rs | 1001 ++++++++++++++++++ ql-fsm/src/tests/session.rs | 532 ++++++++++ 27 files changed, 7741 insertions(+), 2 deletions(-) create mode 100644 ql-fsm/Cargo.toml create mode 100644 ql-fsm/src/error.rs create mode 100644 ql-fsm/src/fsm.rs create mode 100644 ql-fsm/src/handshake/ik.rs create mode 100644 ql-fsm/src/handshake/kk.rs create mode 100644 ql-fsm/src/handshake/mod.rs create mode 100644 ql-fsm/src/handshake/xx.rs create mode 100644 ql-fsm/src/lib.rs create mode 100644 ql-fsm/src/pairing.rs create mode 100644 ql-fsm/src/session/ack_tracker.rs create mode 100644 ql-fsm/src/session/mod.rs create mode 100644 ql-fsm/src/session/range_set.rs create mode 100644 ql-fsm/src/session/remote_stream_history.rs create mode 100644 ql-fsm/src/session/state.rs create mode 100644 ql-fsm/src/session/stream_ops.rs create mode 100644 ql-fsm/src/session/stream_parity.rs create mode 100644 ql-fsm/src/session/stream_rx.rs create mode 100644 ql-fsm/src/session/stream_tx.rs create mode 100644 ql-fsm/src/session/tests.rs create mode 100644 ql-fsm/src/session/tracked.rs create mode 100644 ql-fsm/src/state.rs create mode 100644 ql-fsm/src/tests/handshake.rs create mode 100644 ql-fsm/src/tests/mod.rs create mode 100644 ql-fsm/src/tests/proptest.rs create mode 100644 ql-fsm/src/tests/session.rs diff --git a/Cargo.lock b/Cargo.lock index c2c3b23..016c88c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -293,6 +293,21 @@ dependencies = [ "thiserror", ] +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + [[package]] name = "bitcoin-io" version = "0.1.3" @@ -326,9 +341,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.9.3" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34efbcccd345379ca2868b2b2c9d3782e9cc58ba87bc7d79d5b53d9c9ae6f25d" +checksum = "84d7ced0ae9557296835c32bf1b1e02b44c746701f898460fb000d7eaa84f00a" [[package]] name = "blake2" @@ -820,6 +835,22 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "fastrand" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" + [[package]] name = "ff" version = "0.13.1" @@ -878,6 +909,12 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -1488,6 +1525,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "litemap" version = "0.8.0" @@ -1983,6 +2026,25 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proptest" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b45fcc2344c680f5025fe57779faef368840d0bd1f42f216291f0dc4ace4744" +dependencies = [ + "bit-set", + "bit-vec", + "bitflags", + "num-traits", + "rand 0.9.2", + "rand_chacha 0.9.0", + "rand_xorshift", + "regex-syntax", + "rusty-fork", + "tempfile", + "unarray", +] + [[package]] name = "provenance-mark" version = "0.16.0" @@ -2027,6 +2089,16 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "ql-fsm" +version = "0.1.0" +dependencies = [ + "bytes", + "indexmap", + "proptest", + "ql-wire", +] + [[package]] name = "ql-wire" version = "0.1.0" @@ -2047,6 +2119,12 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + [[package]] name = "quote" version = "1.0.40" @@ -2130,6 +2208,15 @@ dependencies = [ "getrandom 0.3.3", ] +[[package]] +name = "rand_xorshift" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a" +dependencies = [ + "rand_core 0.9.3", +] + [[package]] name = "rand_xoshiro" version = "0.6.0" @@ -2262,12 +2349,37 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + [[package]] name = "rustversion" version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" +[[package]] +name = "rusty-fork" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc6bf79ff24e648f6da1f8d1f011e9cac26491b619e6b9280f2b47f1774e6ee2" +dependencies = [ + "fnv", + "quick-error", + "tempfile", + "wait-timeout", +] + [[package]] name = "ryu" version = "1.0.20" @@ -2557,6 +2669,19 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "tempfile" +version = "3.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +dependencies = [ + "fastrand", + "getrandom 0.3.3", + "once_cell", + "rustix", + "windows-sys", +] + [[package]] name = "thiserror" version = "2.0.17" @@ -2652,6 +2777,12 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + [[package]] name = "unicode-ident" version = "1.0.18" @@ -2724,6 +2855,15 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "wait-timeout" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ac3b126d3914f9849036f826e054cbabdc8519970b8998ddaf3b5bd3c65f11" +dependencies = [ + "libc", +] + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index 8aad910..84dde3e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "api", "backup-shard", "btp", + "ql-fsm", "ql-wire", "quantum-link-macros", ] @@ -31,6 +32,7 @@ backup-shard = { path = "backup-shard" } btp = { path = "btp" } foundation-api = { path = "api" } quantum-link-macros = { path = "quantum-link-macros" } +ql-fsm = { path = "ql-fsm" } ql-wire = { path = "ql-wire" } [patch.crates-io] diff --git a/ql-fsm/Cargo.toml b/ql-fsm/Cargo.toml new file mode 100644 index 0000000..be937f0 --- /dev/null +++ b/ql-fsm/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "ql-fsm" +version = "0.1.0" +edition = "2021" +description = "QuantumLink Sans-IO protocol finite state machine" +license = "Proprietary" + +[dependencies] +bytes = { workspace = true } +indexmap = "2" +ql-wire = { workspace = true } + +[dev-dependencies] +proptest = "1.6" +ql-wire = { workspace = true, features = ["test-utils"] } diff --git a/ql-fsm/src/error.rs b/ql-fsm/src/error.rs new file mode 100644 index 0000000..9bf2a91 --- /dev/null +++ b/ql-fsm/src/error.rs @@ -0,0 +1,124 @@ +use std::{ + error::Error, + fmt::{Display, Formatter}, +}; + +use ql_wire::{PairingId, WireError}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ReceiveError { + InvalidRecordHeader(WireError), + InvalidRecordVersion, + InvalidHandshakeRecord(WireError), + InvalidSessionRecord(WireError), + InvalidSessionConnectionId, + InvalidSessionPayload(WireError), + InvalidIkHandshake(WireError), + InvalidKkHandshake(WireError), + InvalidXxHandshake(WireError), + InvalidRemoteBundle, + InvalidQid, + NoPeer, + NoSession, + NotPairingMode, + InvalidPairingId { + expected: PairingId, + actual: PairingId, + }, +} + +impl Display for ReceiveError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::InvalidRecordHeader(error) => write!(f, "invalid record header: {error}"), + Self::InvalidRecordVersion => f.write_str("invalid record version"), + Self::InvalidHandshakeRecord(error) => { + write!(f, "invalid handshake record: {error}") + } + Self::InvalidSessionRecord(error) => write!(f, "invalid session record: {error}"), + Self::InvalidSessionConnectionId => f.write_str("invalid session connection id"), + Self::InvalidSessionPayload(error) => write!(f, "invalid session payload: {error}"), + Self::InvalidIkHandshake(error) => write!(f, "invalid ik handshake: {error}"), + Self::InvalidKkHandshake(error) => write!(f, "invalid kk handshake: {error}"), + Self::InvalidXxHandshake(error) => write!(f, "invalid xx handshake: {error}"), + Self::InvalidRemoteBundle => f.write_str("invalid remote bundle"), + Self::InvalidQid => f.write_str("invalid qid"), + Self::NoPeer => f.write_str("no bound peer"), + Self::NoSession => f.write_str("no active session"), + Self::NotPairingMode => f.write_str("not in pairing mode"), + Self::InvalidPairingId { expected, actual } => { + write!( + f, + "invalid pairing id: expected {expected}, actual {actual}" + ) + } + } + } +} + +impl std::error::Error for ReceiveError {} + +impl From for ReceiveError { + fn from(_: NoSessionError) -> Self { + Self::NoSession + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct NoPeerError; + +impl Display for NoPeerError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str("no peer bound") + } +} + +impl Error for NoPeerError {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct NoSessionError; + +impl Display for NoSessionError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "no session") + } +} + +impl Error for NoSessionError {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamError { + MissingStream, + NotWritable, + NoSession, +} + +impl Display for StreamError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let message = match self { + Self::MissingStream => "missing stream", + Self::NotWritable => "stream is not writable", + Self::NoSession => "no session", + }; + f.write_str(message) + } +} + +impl Error for StreamError {} + +impl From for StreamError { + fn from(_: NoSessionError) -> Self { + Self::NoSession + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct CommitReadError; + +impl Display for CommitReadError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "invalid read commit") + } +} + +impl Error for CommitReadError {} diff --git a/ql-fsm/src/fsm.rs b/ql-fsm/src/fsm.rs new file mode 100644 index 0000000..036a336 --- /dev/null +++ b/ql-fsm/src/fsm.rs @@ -0,0 +1,267 @@ +use std::{collections::VecDeque, time::Instant}; + +use bytes::Bytes; +use ql_wire::{self as wire, QlCrypto, RouteId, SessionCloseCode, StreamId, WireDecode}; + +use crate::{ + handshake, + session::{self, SessionEvent, TerminalFrame}, + state::LinkState, + Event, NoPeerError, NoSessionError, OutboundWrite, QlFsm, ReceiveError, StreamError, WriteId, +}; + +pub struct EventSink<'a> { + events: &'a mut VecDeque, + termination: Option, +} + +impl<'a> EventSink<'a> { + fn new(events: &'a mut VecDeque) -> Self { + Self { + events, + termination: None, + } + } +} + +impl session::EventSink for EventSink<'_> { + fn emit(&mut self, event: SessionEvent) { + match event { + SessionEvent::Unpaired => { + self.termination = Some(TerminalFrame::Unpair); + } + SessionEvent::Opened { + stream_id, + route_id, + } => { + self.events.push_back(Event::Opened { + stream_id, + route_id, + }); + } + SessionEvent::Readable(stream_id) => { + self.events.push_back(Event::Readable(stream_id)); + } + SessionEvent::Writable(stream_id) => { + self.events.push_back(Event::Writable(stream_id)); + } + SessionEvent::Finished(stream_id) => { + self.events.push_back(Event::Finished(stream_id)); + } + SessionEvent::OutboundFinished(stream_id) => { + self.events.push_back(Event::OutboundFinished(stream_id)); + } + SessionEvent::Closed(frame) => { + self.events.push_back(Event::Closed(frame)); + } + SessionEvent::WritableClosed(frame) => { + self.events.push_back(Event::WritableClosed(frame)); + } + SessionEvent::SessionClosed(close) => { + self.termination = Some(TerminalFrame::Close(close.clone())); + self.events.push_back(Event::SessionClosed(close)); + } + } + } +} + +pub fn handle_bind_peer(fsm: &mut QlFsm, peer: ql_wire::PeerBundle) { + fsm.state.handshake = None; + fsm.state.link = LinkState::Idle; + fsm.state.peer = Some(peer); +} + +pub fn unpair(fsm: &mut QlFsm) { + let had_peer = fsm.state.peer.is_some(); + fsm.state.handshake = None; + fsm.state.armed_pairing_token = None; + + if let Some(conn) = fsm.state.link.connected_mut() { + let mut emit = EventSink::new(&mut fsm.events); + conn.session.unpair(&mut emit); + } else { + fsm.state.link = LinkState::Idle; + } + + if had_peer { + emit_peer_status(fsm, crate::PeerStatus::Unpaired); + } + fsm.state.peer = None; +} + +pub fn handle_disarm_pairing(fsm: &mut QlFsm) { + fsm.state.armed_pairing_token = None; + handshake::handle_disarm_pairing(fsm); +} + +pub fn handle_connect_xx(fsm: &mut QlFsm, invite: crate::PairingInvite, crypto: &impl QlCrypto) { + handshake::handle_connect_xx(fsm, invite, crypto); +} + +pub fn handle_connect_ik(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + handshake::handle_connect_ik(fsm, crypto) +} + +pub fn handle_connect_kk(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + handshake::handle_connect_kk(fsm, crypto) +} + +pub fn receive( + fsm: &mut QlFsm, + mut bytes: Vec, + crypto: &impl QlCrypto, +) -> Result<(), ReceiveError> { + let mut reader = wire::Reader::new(bytes.as_mut_slice()); + let header = + wire::RecordHeader::decode(&mut reader).map_err(ReceiveError::InvalidRecordHeader)?; + + if header.version != wire::QL_WIRE_VERSION { + return Err(ReceiveError::InvalidRecordVersion); + } + + match header.record_type { + wire::RecordType::Handshake => { + let record = wire::QlHandshakeRecord::decode(&mut reader) + .map_err(ReceiveError::InvalidHandshakeRecord)?; + handshake::handle_handshake_record(fsm, crypto, &record) + } + wire::RecordType::Session => { + let termination = { + let QlFsm { state, events, .. } = fsm; + let conn = state.link.connected_mut_or_err()?; + let (decrypt_len, seq) = { + let record = wire::QlSessionRecord::decode(&mut reader) + .map_err(ReceiveError::InvalidSessionRecord)?; + if record.header.connection_id != conn.transport.rx_connection_id { + return Err(ReceiveError::InvalidSessionConnectionId); + } + let payload = wire::decrypt_record( + crypto, + &record.header, + record.payload, + &conn.transport.rx_key, + ) + .map_err(ReceiveError::InvalidSessionPayload)?; + (payload.len(), record.header.seq) + }; + + let len = bytes.len(); + let plaintext = Bytes::from(bytes).slice(len - decrypt_len..); + let frames = wire::parse_session_frames(plaintext); + + let mut emit = EventSink::new(events); + conn.session.receive(state.now, seq, frames, &mut emit); + emit.termination + }; + + if matches!(termination, Some(TerminalFrame::Unpair)) { + if fsm.state.peer.is_some() { + emit_peer_status(fsm, crate::PeerStatus::Unpaired); + } + fsm.state.handshake = None; + fsm.state.armed_pairing_token = None; + fsm.state.peer = None; + } + Ok(()) + } + } +} + +pub fn on_timer(fsm: &mut QlFsm) { + handshake::handle_timer(fsm); + + let QlFsm { state, events, .. } = fsm; + let Some(conn) = state.link.connected_mut() else { + return; + }; + + let mut emit = EventSink::new(events); + conn.session.on_timer(state.now, &mut emit); +} + +pub fn next_deadline(fsm: &QlFsm) -> Option { + [ + handshake::next_handshake_deadline(fsm), + fsm.state + .link + .connected() + .and_then(|state| state.session.next_deadline()), + ] + .into_iter() + .flatten() + .min() +} + +pub fn take_next_write(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Option { + if let Some(record) = fsm.state.handshake.take() { + let record = wire::encode_record_vec(ql_wire::RecordType::Handshake, &record); + return Some(OutboundWrite { + record, + write_id: None, + }); + } + + let QlFsm { state, .. } = fsm; + let conn = state.link.connected_mut()?; + + let (write_id, builder) = conn.session.take_next_write(state.now)?; + let record = builder.encrypt( + crypto, + conn.transport.tx_connection_id, + &conn.transport.tx_key, + ); + if conn.session.is_closed() && matches!(fsm.state.link, LinkState::Connected(_)) { + fsm.state.link = LinkState::Idle; + emit_peer_status(fsm, fsm.state.link.status()); + } + Some(OutboundWrite { + record, + write_id: write_id.map(WriteId), + }) +} + +pub fn complete_write(fsm: &mut QlFsm, write_id: WriteId, success: bool) { + let QlFsm { state, .. } = fsm; + if let Some(conn) = state.link.connected_mut() { + conn.session.complete_write(state.now, write_id.0, success); + } +} + +pub fn close_session(fsm: &mut QlFsm, code: SessionCloseCode) { + let QlFsm { state, events, .. } = fsm; + let Some(conn) = state.link.connected_mut() else { + return; + }; + let mut emit = EventSink::new(events); + conn.session.close(code, &mut emit); +} + +pub fn open_stream( + fsm: &mut QlFsm, + route_id: RouteId, +) -> Result, NoSessionError> { + let QlFsm { state, events, .. } = fsm; + let conn = state.link.connected_mut_or_err()?; + let inner = conn.session.open_stream(route_id, EventSink::new(events))?; + Ok(crate::StreamOps { inner }) +} + +pub fn stream(fsm: &mut QlFsm, stream_id: StreamId) -> Result, StreamError> { + let QlFsm { state, events, .. } = fsm; + let conn = state.link.connected_mut_or_err()?; + let inner = conn.session.stream(stream_id, EventSink::new(events))?; + Ok(crate::StreamOps { inner }) +} + +pub fn queue_ping(fsm: &mut QlFsm) -> Result<(), NoSessionError> { + let conn = fsm.state.link.connected_mut_or_err()?; + conn.session.queue_ping() +} + +pub fn poll_event(fsm: &mut QlFsm) -> Option { + fsm.events.pop_front() +} + +pub fn emit_peer_status(fsm: &mut QlFsm, status: crate::PeerStatus) { + fsm.events.push_back(Event::PeerStatusChanged(status)); +} diff --git a/ql-fsm/src/handshake/ik.rs b/ql-fsm/src/handshake/ik.rs new file mode 100644 index 0000000..7e6ebd1 --- /dev/null +++ b/ql-fsm/src/handshake/ik.rs @@ -0,0 +1,119 @@ +use ql_wire::{self as wire, Ik1, Ik2, PeerBundle, QlCrypto, QlHandshakeRecord}; + +use super::{ + emit_peer_status, enqueue_handshake, finish_handshake, reset_connected_session_if_needed, +}; +use crate::{ + state::{IkInitiatorState, LinkState, SessionTransport}, + QlFsm, ReceiveError, +}; + +pub fn start_initiator(fsm: &mut QlFsm, crypto: &impl QlCrypto, peer: PeerBundle) { + let meta = super::next_handshake_meta(fsm); + let mut handshake = wire::IkHandshake::new_initiator( + crypto, + fsm.identity.clone(), + peer, + super::local_transport_params(fsm), + ); + let message = handshake.write_1(crypto, meta).unwrap(); + + fsm.state.link = LinkState::IkInitiator(IkInitiatorState { + handshake_id: meta.handshake_id, + initial_ephemeral: message.ephemeral.clone(), + handshake, + deadline: fsm.state.now + fsm.config.handshake_timeout, + }); + enqueue_handshake(fsm, QlHandshakeRecord::Ik1(message)); + emit_peer_status(fsm, fsm.state.link.status()); +} + +pub fn handle_ik1( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Ik1, +) -> Result<(), ReceiveError> { + if should_ignore_inbound(fsm, message) { + return Ok(()); + } + if message.header.recipient != fsm.identity.qid { + return Err(ReceiveError::InvalidQid); + } + if let Some(peer) = fsm.state.peer.as_ref() { + if message.header.sender != peer.qid { + return Err(ReceiveError::InvalidQid); + } + } + + reset_connected_session_if_needed(fsm); + + let mut handshake = wire::IkHandshake::new_responder( + crypto, + fsm.identity.clone(), + fsm.state.peer.clone(), + super::local_transport_params(fsm), + ); + handshake + .read_1(crypto, message) + .map_err(ReceiveError::InvalidIkHandshake)?; + let outbound = handshake + .write_2(crypto, message.meta) + .map_err(ReceiveError::InvalidIkHandshake)?; + let (transport, remote_bundle) = SessionTransport::from_finalized( + handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidIkHandshake)?, + ); + finish_handshake(fsm, transport, remote_bundle)?; + fsm.state.handshake = None; + enqueue_handshake(fsm, QlHandshakeRecord::Ik2(outbound)); + Ok(()) +} + +pub fn handle_ik2( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Ik2, +) -> Result<(), ReceiveError> { + { + let LinkState::IkInitiator(state) = &mut fsm.state.link else { + return Ok(()); + }; + + if message.meta.handshake_id != state.handshake_id { + return Ok(()); + } + + state + .handshake + .read_2(crypto, message) + .map_err(ReceiveError::InvalidIkHandshake)?; + } + + let LinkState::IkInitiator(state) = fsm.state.link.take() else { + unreachable!("active IK initiator was checked above"); + }; + let (transport, remote_bundle) = SessionTransport::from_finalized( + state + .handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidIkHandshake)?, + ); + finish_handshake(fsm, transport, remote_bundle) +} + +pub fn should_ignore_inbound(fsm: &QlFsm, message: &Ik1) -> bool { + match &fsm.state.link { + LinkState::Idle + | LinkState::Connected(_) + | LinkState::KkInitiator(_) + | LinkState::XxInitiator(_) + | LinkState::XxResponder(_) => false, + LinkState::IkInitiator(state) => { + if fsm.state.peer.as_ref().map(|peer| peer.qid) != Some(message.header.sender) { + return false; + } + super::local_start_wins(&state.initial_ephemeral, &message.ephemeral) + } + } +} diff --git a/ql-fsm/src/handshake/kk.rs b/ql-fsm/src/handshake/kk.rs new file mode 100644 index 0000000..e78c8a6 --- /dev/null +++ b/ql-fsm/src/handshake/kk.rs @@ -0,0 +1,118 @@ +use ql_wire::{self as wire, Kk1, Kk2, PeerBundle, QlCrypto, QlHandshakeRecord}; + +use super::{ + emit_peer_status, enqueue_handshake, finish_handshake, reset_connected_session_if_needed, +}; +use crate::{ + state::{KkInitiatorState, LinkState, SessionTransport}, + QlFsm, ReceiveError, +}; + +pub fn start_initiator(fsm: &mut QlFsm, crypto: &impl QlCrypto, peer: PeerBundle) { + let meta = super::next_handshake_meta(fsm); + let mut handshake = wire::KkHandshake::new_initiator( + crypto, + fsm.identity.clone(), + peer, + super::local_transport_params(fsm), + ); + let message = handshake.write_1(crypto, meta).unwrap(); + + fsm.state.link = LinkState::KkInitiator(KkInitiatorState { + handshake_id: meta.handshake_id, + initial_ephemeral: message.ephemeral.clone(), + handshake, + deadline: fsm.state.now + fsm.config.handshake_timeout, + }); + enqueue_handshake(fsm, QlHandshakeRecord::Kk1(message)); + emit_peer_status(fsm, fsm.state.link.status()); +} + +pub fn handle_kk1( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Kk1, +) -> Result<(), ReceiveError> { + if should_ignore_inbound(fsm, message) { + return Ok(()); + } + + let Some(peer) = fsm.state.peer.clone() else { + return Err(ReceiveError::NoPeer); + }; + if message.header.recipient != fsm.identity.qid || message.header.sender != peer.qid { + return Err(ReceiveError::InvalidQid); + } + + reset_connected_session_if_needed(fsm); + + let mut handshake = wire::KkHandshake::new_responder( + crypto, + fsm.identity.clone(), + peer, + super::local_transport_params(fsm), + ); + handshake + .read_1(crypto, message) + .map_err(ReceiveError::InvalidKkHandshake)?; + let outbound = handshake + .write_2(crypto, message.meta) + .map_err(ReceiveError::InvalidKkHandshake)?; + let (transport, remote_bundle) = SessionTransport::from_finalized( + handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidKkHandshake)?, + ); + finish_handshake(fsm, transport, remote_bundle)?; + fsm.state.handshake = None; + enqueue_handshake(fsm, QlHandshakeRecord::Kk2(outbound)); + Ok(()) +} + +pub fn handle_kk2( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Kk2, +) -> Result<(), ReceiveError> { + { + let LinkState::KkInitiator(state) = &mut fsm.state.link else { + return Ok(()); + }; + + if message.meta.handshake_id != state.handshake_id { + return Ok(()); + } + + state + .handshake + .read_2(crypto, message) + .map_err(ReceiveError::InvalidKkHandshake)?; + } + + let LinkState::KkInitiator(state) = fsm.state.link.take() else { + unreachable!("active KK initiator was checked above"); + }; + let (transport, remote_bundle) = SessionTransport::from_finalized( + state + .handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidKkHandshake)?, + ); + finish_handshake(fsm, transport, remote_bundle) +} + +pub fn should_ignore_inbound(fsm: &QlFsm, message: &Kk1) -> bool { + match &fsm.state.link { + LinkState::Idle + | LinkState::Connected(_) + | LinkState::XxInitiator(_) + | LinkState::XxResponder(_) => false, + LinkState::IkInitiator(_) => true, + LinkState::KkInitiator(state) => { + if fsm.state.peer.as_ref().map(|peer| peer.qid) != Some(message.header.sender) { + return false; + } + super::local_start_wins(&state.initial_ephemeral, &message.ephemeral) + } + } +} diff --git a/ql-fsm/src/handshake/mod.rs b/ql-fsm/src/handshake/mod.rs new file mode 100644 index 0000000..1881f66 --- /dev/null +++ b/ql-fsm/src/handshake/mod.rs @@ -0,0 +1,140 @@ +mod ik; +mod kk; +mod xx; + +use ql_wire::{self as wire, EphemeralPublicKey, HandshakeMeta, QlCrypto, QlHandshakeRecord}; + +use crate::{ + fsm::emit_peer_status, + session::{SessionConfig, SessionFsm, StreamParity}, + state::{ConnectedState, LinkState, SessionTransport}, + Event, NoPeerError, QlFsm, ReceiveError, +}; + +pub fn handle_connect_ik(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + let peer = fsm.state.peer.clone().ok_or(NoPeerError)?; + prepare_for_outbound_connect(fsm); + ik::start_initiator(fsm, crypto, peer); + Ok(()) +} + +pub fn handle_connect_kk(fsm: &mut QlFsm, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + let peer = fsm.state.peer.clone().ok_or(NoPeerError)?; + prepare_for_outbound_connect(fsm); + kk::start_initiator(fsm, crypto, peer); + Ok(()) +} + +pub fn handle_connect_xx(fsm: &mut QlFsm, invite: crate::PairingInvite, crypto: &impl QlCrypto) { + prepare_for_outbound_connect(fsm); + xx::start_initiator(fsm, crypto, invite.token, invite.qid); +} + +pub fn next_handshake_meta(fsm: &mut QlFsm) -> HandshakeMeta { + let handshake_id = wire::HandshakeId(fsm.state.next_control_id); + fsm.state.next_control_id = fsm.state.next_control_id.wrapping_add(1); + HandshakeMeta { handshake_id } +} + +pub fn enqueue_handshake(fsm: &mut QlFsm, record: QlHandshakeRecord) { + debug_assert!(fsm.state.handshake.is_none()); + fsm.state.handshake = Some(record); +} + +pub fn handle_disarm_pairing(fsm: &mut QlFsm) { + xx::disarm_pairing(fsm); +} + +fn local_transport_params(fsm: &QlFsm) -> wire::TransportParams { + wire::TransportParams { + initial_stream_receive_window: fsm.config.session_stream_receive_buffer_size, + } +} + +pub fn prepare_for_outbound_connect(fsm: &mut QlFsm) { + fsm.state.handshake = None; + reset_connected_session_if_needed(fsm); +} + +pub fn handle_handshake_record( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + record: &QlHandshakeRecord, +) -> Result<(), ReceiveError> { + match record { + QlHandshakeRecord::Ik1(message) => ik::handle_ik1(fsm, crypto, message), + QlHandshakeRecord::Ik2(message) => ik::handle_ik2(fsm, crypto, message), + QlHandshakeRecord::Kk1(message) => kk::handle_kk1(fsm, crypto, message), + QlHandshakeRecord::Kk2(message) => kk::handle_kk2(fsm, crypto, message), + QlHandshakeRecord::Xx1(message) => xx::handle_xx1(fsm, crypto, message), + QlHandshakeRecord::Xx2(message) => xx::handle_xx2(fsm, crypto, message), + QlHandshakeRecord::Xx3(message) => xx::handle_xx3(fsm, crypto, message), + QlHandshakeRecord::Xx4(message) => xx::handle_xx4(fsm, crypto, message), + } +} + +pub fn handle_timer(fsm: &mut QlFsm) { + let Some(deadline) = fsm.state.link.handshake_deadline() else { + return; + }; + if deadline > fsm.state.now { + return; + } + + fsm.state.link = LinkState::Idle; + fsm.state.handshake = None; + emit_peer_status(fsm, fsm.state.link.status()); +} + +pub fn next_handshake_deadline(fsm: &QlFsm) -> Option { + fsm.state.link.handshake_deadline() +} + +pub fn finish_handshake( + fsm: &mut QlFsm, + transport: SessionTransport, + remote_bundle: wire::PeerBundle, +) -> Result<(), ReceiveError> { + let qid = remote_bundle.qid; + if let Some(peer) = fsm.state.peer.as_ref() { + if peer != &remote_bundle { + return Err(ReceiveError::InvalidRemoteBundle); + } + } else { + fsm.state.peer = Some(remote_bundle); + fsm.events.push_back(Event::NewPeer); + } + + let config = &fsm.config; + let session = SessionFsm::new( + SessionConfig { + local_parity: StreamParity::for_local(fsm.identity.qid, qid), + record_max_size: config.session_record_max_size, + ack_delay: config.session_record_ack_delay, + retransmit_timeout: config.session_record_retransmit_timeout, + keepalive_interval: config.session_keepalive_interval, + peer_timeout: config.session_peer_timeout, + stream_send_buffer_size: config.session_stream_send_buffer_size, + stream_receive_buffer_size: config.session_stream_receive_buffer_size, + accepted_record_window: config.session_accepted_record_window, + pending_ack_range_limit: config.session_pending_ack_range_limit, + initial_peer_stream_receive_window: transport + .remote_transport_params + .initial_stream_receive_window, + }, + fsm.state.now, + ); + fsm.state.link = LinkState::Connected(ConnectedState { transport, session }); + emit_peer_status(fsm, fsm.state.link.status()); + Ok(()) +} + +pub fn reset_connected_session_if_needed(fsm: &mut QlFsm) { + if matches!(fsm.state.link, LinkState::Connected(_)) { + fsm.state.link = LinkState::Idle; + } +} + +fn local_start_wins(local: &EphemeralPublicKey, inbound: &EphemeralPublicKey) -> bool { + local.mlkem_public_key.as_bytes() <= inbound.mlkem_public_key.as_bytes() +} diff --git a/ql-fsm/src/handshake/xx.rs b/ql-fsm/src/handshake/xx.rs new file mode 100644 index 0000000..c9a289e --- /dev/null +++ b/ql-fsm/src/handshake/xx.rs @@ -0,0 +1,207 @@ +use ql_wire::{self as wire, PairingToken, QlCrypto, QlHandshakeRecord, Xx1, Xx2, Xx3, Xx4, QID}; + +use super::{ + emit_peer_status, enqueue_handshake, finish_handshake, reset_connected_session_if_needed, +}; +use crate::{ + state::{LinkState, SessionTransport, XxInitiatorState, XxResponderState}, + QlFsm, ReceiveError, +}; + +pub fn start_initiator( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + token: PairingToken, + remote_qid: QID, +) { + let meta = super::next_handshake_meta(fsm); + let mut handshake = wire::XxHandshake::new_initiator( + crypto, + fsm.identity.clone(), + remote_qid, + token, + super::local_transport_params(fsm), + ); + let message = handshake.write_1(crypto, meta).unwrap(); + + fsm.state.link = LinkState::XxInitiator(XxInitiatorState { + handshake_id: meta.handshake_id, + initial_ephemeral: message.ephemeral.clone(), + handshake, + deadline: fsm.state.now + fsm.config.handshake_timeout, + }); + enqueue_handshake(fsm, QlHandshakeRecord::Xx1(message)); + emit_peer_status(fsm, fsm.state.link.status()); +} + +pub fn handle_xx1( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Xx1, +) -> Result<(), ReceiveError> { + if should_ignore_inbound(fsm, crypto, message) { + return Ok(()); + } + match fsm.state.armed_pairing_token { + Some(expected) if expected.id(crypto) != message.pairing_id => { + Err(ReceiveError::InvalidPairingId { + expected: expected.id(crypto), + actual: message.pairing_id, + }) + } + Some(_) + if message.header.recipient != fsm.identity.qid + || message.header.sender == fsm.identity.qid => + { + Err(ReceiveError::InvalidQid) + } + Some(token) => { + reset_connected_session_if_needed(fsm); + + let mut handshake = wire::XxHandshake::new_responder( + crypto, + fsm.identity.clone(), + message.header.sender, + token, + super::local_transport_params(fsm), + ); + handshake + .read_1(crypto, message) + .map_err(ReceiveError::InvalidXxHandshake)?; + let outbound = handshake + .write_2(crypto, message.meta) + .map_err(ReceiveError::InvalidXxHandshake)?; + fsm.state.link = LinkState::XxResponder(XxResponderState { + handshake, + handshake_meta: message.meta, + deadline: fsm.state.now + fsm.config.handshake_timeout, + }); + fsm.state.handshake = None; + enqueue_handshake(fsm, QlHandshakeRecord::Xx2(outbound)); + Ok(()) + } + None => Err(ReceiveError::NotPairingMode), + } +} + +pub fn handle_xx2( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Xx2, +) -> Result<(), ReceiveError> { + { + let LinkState::XxInitiator(state) = &mut fsm.state.link else { + return Ok(()); + }; + + if message.meta.handshake_id != state.handshake_id { + return Ok(()); + } + + state + .handshake + .read_2(crypto, message) + .map_err(ReceiveError::InvalidXxHandshake)?; + let outbound = state + .handshake + .write_3(crypto, message.meta) + .map_err(ReceiveError::InvalidXxHandshake)?; + fsm.state.handshake = None; + enqueue_handshake(fsm, QlHandshakeRecord::Xx3(outbound)); + } + + Ok(()) +} + +pub fn handle_xx3( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Xx3, +) -> Result<(), ReceiveError> { + let LinkState::XxResponder(state) = &mut fsm.state.link else { + return Ok(()); + }; + + if message.meta.handshake_id != state.handshake_meta.handshake_id { + return Ok(()); + } + + state + .handshake + .read_3(crypto, message) + .map_err(ReceiveError::InvalidXxHandshake)?; + let handshake_meta = state.handshake_meta; + let LinkState::XxResponder(mut state) = fsm.state.link.take() else { + unreachable!("active XX responder was checked above"); + }; + let outbound = state + .handshake + .write_4(crypto, handshake_meta) + .map_err(ReceiveError::InvalidXxHandshake)?; + fsm.state.handshake = None; + enqueue_handshake(fsm, QlHandshakeRecord::Xx4(outbound)); + let (transport, remote_bundle) = SessionTransport::from_finalized( + state + .handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidXxHandshake)?, + ); + finish_handshake(fsm, transport, remote_bundle) +} + +pub fn handle_xx4( + fsm: &mut QlFsm, + crypto: &impl QlCrypto, + message: &Xx4, +) -> Result<(), ReceiveError> { + { + let LinkState::XxInitiator(state) = &mut fsm.state.link else { + return Ok(()); + }; + + if message.meta.handshake_id != state.handshake_id { + return Ok(()); + } + + state + .handshake + .read_4(crypto, message) + .map_err(ReceiveError::InvalidXxHandshake)?; + } + + let LinkState::XxInitiator(state) = fsm.state.link.take() else { + unreachable!("active XX initiator was checked above"); + }; + let (transport, remote_bundle) = SessionTransport::from_finalized( + state + .handshake + .finalize(crypto) + .map_err(ReceiveError::InvalidXxHandshake)?, + ); + finish_handshake(fsm, transport, remote_bundle) +} + +pub fn disarm_pairing(fsm: &mut QlFsm) { + if matches!(fsm.state.link, LinkState::XxResponder(_)) { + fsm.state.link = LinkState::Idle; + fsm.state.handshake = None; + } +} + +pub fn should_ignore_inbound(fsm: &QlFsm, crypto: &impl QlCrypto, message: &Xx1) -> bool { + match &fsm.state.link { + LinkState::Idle | LinkState::Connected(_) => false, + LinkState::IkInitiator(_) | LinkState::KkInitiator(_) | LinkState::XxResponder(_) => true, + LinkState::XxInitiator(state) => { + if state.handshake.pairing_id(crypto) != message.pairing_id { + return false; + } + if message.header.recipient != fsm.identity.qid + || message.header.sender != state.handshake.remote_qid() + { + return false; + } + super::local_start_wins(&state.initial_ephemeral, &message.ephemeral) + } + } +} diff --git a/ql-fsm/src/lib.rs b/ql-fsm/src/lib.rs new file mode 100644 index 0000000..3067efd --- /dev/null +++ b/ql-fsm/src/lib.rs @@ -0,0 +1,334 @@ +//! sync finite state machine for QuantumLink protocol +//! +//! a caller drives `QlFsm` inside its own event loop +//! +//! inputs to that loop usually include +//! - app actions like `bind_peer`, `connect_ik`, `connect_kk`, `connect_xx`, `open_stream`, or +//! `stream` +//! - inbound transport bytes passed to `receive` +//! - a deadline expiring, handled by calling `on_timer` +//! - transport write results passed to `complete_write` +//! +//! outputs from `QlFsm` are +//! - outbound session and handshake records from `take_next_write` +//! - queued `QlFsmEvent`s returned by `poll_event` after `connect_ik`, `connect_kk`, +//! `connect_xx`, `receive`, and `on_timer` +//! +//! call `next_deadline` after handling current inputs and any queued outputs +//! use it to decide how long the outer loop can wait before `on_timer` must run +//! another input may arrive before that deadline, which is fine + +mod error; +mod fsm; +mod handshake; +mod pairing; +mod session; +pub(crate) mod state; +#[cfg(test)] +mod tests; + +use std::{ + collections::VecDeque, + time::{Duration, Instant}, +}; + +pub use bytes::Bytes; +pub use error::*; +pub use pairing::PairingInvite; +use ql_wire::{ + PairingToken, PeerBundle, QlCrypto, QlIdentity, RouteId, SessionClose, SessionCloseCode, + StreamClose, StreamId, +}; +pub use session::{SessionEvent, StreamReadIter, StreamWriter}; + +use crate::state::{LinkState, QlFsmState}; + +/// connection state for the bound peer +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PeerStatus { + /// no active encrypted session + Disconnected, + /// we are driving the handshake + Initiator, + /// the encrypted session is up + Connected, + /// the bound peer was forgotten immediately + /// + /// unpair is abortive and best-effort. the binding is removed immediately + /// and one final write may remain: a record containing only `SessionFrame::Unpair` + Unpaired, +} + +/// events emitted by `QlFsm` +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Event { + /// a peer was learned during handshake completion + NewPeer, + /// the peer changed lifecycle state + PeerStatusChanged(PeerStatus), + /// a stream was opened + Opened { + stream_id: StreamId, + route_id: RouteId, + }, + /// a stream has bytes ready to read + Readable(StreamId), + /// a stream has room for more local writes + Writable(StreamId), + /// the peer finished writing this stream and no more bytes remain to read + Finished(StreamId), + /// our local FIN was acknowledged by the peer at the session layer + OutboundFinished(StreamId), + /// a stream was closed + Closed(StreamClose), + /// local writes on this stream are closed + WritableClosed(StreamClose), + /// the encrypted session was closed + /// + /// session close is abortive and best-effort. the session ends immediately + /// one final write remains: a record containing only `SessionFrame::Close` + /// the FSM does not wait for an ack for that record + SessionClosed(SessionClose), +} + +/// handle for a session write returned by `QlFsm::take_next_write` +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct WriteId(pub(crate) u64); + +/// outbound record produced by `QlFsm` +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct OutboundWrite { + /// wire bytes to hand to the transport + pub record: Vec, + /// write handle that must be completed exactly once + pub write_id: Option, +} + +pub struct StreamOps<'a> { + inner: session::StreamOps<'a, fsm::EventSink<'a>>, +} + +impl StreamOps<'_> { + /// returns this stream's identifier + pub fn stream_id(&self) -> StreamId { + self.inner.stream_id() + } + + /// returns the readable stream bytes as owned `Bytes` views without consuming them + pub fn read(&self) -> StreamReadIter<'_> { + self.inner.read() + } + + /// returns how many bytes can be read from the stream + pub fn readable_bytes(&self) -> usize { + self.inner.readable_bytes() + } + + /// marks previously read bytes as consumed + pub fn commit_read(&mut self, len: usize) -> Result<(), CommitReadError> { + self.inner.commit_read(len) + } + + /// returns a writer if the local write side is still open + pub fn writer(&mut self) -> Option> { + self.inner.writer() + } + + /// closes the origin lane, return lane, or both lanes of the stream + pub fn close(&mut self, target: ql_wire::CloseTarget, code: ql_wire::StreamCloseCode) { + self.inner.close(target, code); + } +} + +/// timing and buffering knobs for `QlFsm` +#[derive(Debug, Clone, Copy)] +pub struct QlFsmConfig { + /// overall time limit for one handshake attempt + pub handshake_timeout: Duration, + /// delay before sending a pure record ack + pub session_record_ack_delay: Duration, + /// how long to wait before resending unacked session records + pub session_record_retransmit_timeout: Duration, + /// idle delay before sending a keepalive ping + pub session_keepalive_interval: Duration, + /// how long to wait before declaring the peer dead + pub session_peer_timeout: Duration, + /// maximum total wire size for one session record, including header and auth tag + pub session_record_max_size: usize, + /// maximum bytes buffered locally for one stream send side + pub session_stream_send_buffer_size: usize, + /// maximum bytes buffered locally for one stream receive side + pub session_stream_receive_buffer_size: u32, + /// how many accepted record sequence numbers to retain for duplicate detection + pub session_accepted_record_window: u64, + /// maximum disjoint pending ACK ranges to retain before dropping the oldest low ranges + pub session_pending_ack_range_limit: usize, +} + +impl Default for QlFsmConfig { + fn default() -> Self { + let s = session::SessionConfig::default(); + Self { + handshake_timeout: Duration::from_secs(5), + session_record_ack_delay: s.ack_delay, + session_record_retransmit_timeout: s.retransmit_timeout, + session_keepalive_interval: s.keepalive_interval, + session_peer_timeout: s.peer_timeout, + session_record_max_size: s.record_max_size, + session_stream_send_buffer_size: s.stream_send_buffer_size, + session_stream_receive_buffer_size: s.stream_receive_buffer_size, + session_accepted_record_window: s.accepted_record_window, + session_pending_ack_range_limit: s.pending_ack_range_limit, + } + } +} + +/// synchronous driver for peer binding, handshake, and encrypted streams +pub struct QlFsm { + config: QlFsmConfig, + identity: QlIdentity, + state: QlFsmState, + events: VecDeque, +} + +impl QlFsm { + /// creates a new `QlFsm` + pub fn new(config: QlFsmConfig, identity: QlIdentity, now: Instant) -> Self { + Self { + config, + identity, + state: QlFsmState { + next_control_id: 1, + peer: None, + armed_pairing_token: None, + handshake: None, + link: LinkState::Idle, + now, + }, + events: VecDeque::new(), + } + } + + /// binds the remote peer + pub fn bind_peer(&mut self, peer: PeerBundle) { + fsm::handle_bind_peer(self, peer); + } + + /// returns the currently bound peer, if any + pub fn peer(&self) -> Option<&PeerBundle> { + self.state.peer.as_ref() + } + + /// arms acceptance of inbound xx pairings for a single token + pub fn arm_pairing(&mut self, token: PairingToken) { + self.state.armed_pairing_token = Some(token); + } + + pub fn pairing_token(&self) -> Option<&PairingToken> { + self.state.armed_pairing_token.as_ref() + } + + /// disarms inbound xx pairing and rejects any in-flight inbound xx responder state + pub fn disarm_pairing(&mut self) { + fsm::handle_disarm_pairing(self); + } + + /// starts an outbound xx handshake using a pairing invite + pub fn connect_xx(&mut self, now: Instant, invite: PairingInvite, crypto: &impl QlCrypto) { + self.state.now = now; + fsm::handle_connect_xx(self, invite, crypto); + } + + /// starts an IK handshake with the currently bound peer + pub fn connect_ik(&mut self, now: Instant, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + self.state.now = now; + fsm::handle_connect_ik(self, crypto) + } + + /// starts a KK handshake with the currently bound peer + pub fn connect_kk(&mut self, now: Instant, crypto: &impl QlCrypto) -> Result<(), NoPeerError> { + self.state.now = now; + fsm::handle_connect_kk(self, crypto) + } + + /// handles one inbound wire message + pub fn receive( + &mut self, + now: Instant, + bytes: Vec, + crypto: &impl QlCrypto, + ) -> Result<(), ReceiveError> { + self.state.now = now; + fsm::receive(self, bytes, crypto) + } + + /// returns the next queued event, if any + pub fn poll_event(&mut self) -> Option { + fsm::poll_event(self) + } + + /// advances time-based state + pub fn on_timer(&mut self, now: Instant) { + self.state.now = now; + fsm::on_timer(self); + } + + /// returns the next timer deadline, if any + pub fn next_deadline(&self) -> Option { + fsm::next_deadline(self) + } + + pub fn has_shutdown_work(&self) -> bool { + self.state + .link + .connected() + .is_some_and(|state| state.session.has_shutdown_work()) + } + + /// returns the next outbound record + /// + /// if `write_id` is `Some`, call `complete_write` exactly once + /// + /// if it is `None`, the record is fire-and-forget + pub fn take_next_write( + &mut self, + now: Instant, + crypto: &impl QlCrypto, + ) -> Option { + self.state.now = now; + fsm::take_next_write(self, crypto) + } + + /// completes a `SessionWriteId` from `take_next_write` with the transport outcome + /// + /// call this at most once for each returned `SessionWriteId` + pub fn complete_write(&mut self, now: Instant, write_id: WriteId, success: bool) { + self.state.now = now; + fsm::complete_write(self, write_id, success); + } + + /// closes the current encrypted session locally + pub fn close_session(&mut self, code: SessionCloseCode) { + fsm::close_session(self, code); + } + + /// forgets the bound peer locally and may emit one final outbound `SessionFrame::Unpair` + pub fn unpair(&mut self) { + fsm::unpair(self); + } + + /// opens a new outgoing stream + pub fn open_stream(&mut self, route_id: RouteId) -> Result, NoSessionError> { + fsm::open_stream(self, route_id) + } + + /// returns a facade for an open stream + pub fn stream(&mut self, stream_id: StreamId) -> Result, StreamError> { + fsm::stream(self, stream_id) + } + + /// queues a ping on the active session + pub fn queue_ping(&mut self) -> Result<(), NoSessionError> { + fsm::queue_ping(self) + } +} diff --git a/ql-fsm/src/pairing.rs b/ql-fsm/src/pairing.rs new file mode 100644 index 0000000..4b8361b --- /dev/null +++ b/ql-fsm/src/pairing.rs @@ -0,0 +1,38 @@ +use ql_wire::{ByteSlice, PairingToken, Reader, WireDecode, WireEncode, WireError, QID}; + +/// Out-of-band invite consumed by the initiator of an XX pairing +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct PairingInvite { + pub qid: QID, + pub token: PairingToken, +} + +impl PairingInvite { + pub const VERSION: u8 = 1; + pub const WIRE_SIZE: usize = size_of::() + QID::SIZE + PairingToken::SIZE; +} + +impl WireEncode for PairingInvite { + fn encoded_len(&self) -> usize { + Self::WIRE_SIZE + } + + fn encode(&self, out: &mut W) { + Self::VERSION.encode(out); + self.qid.encode(out); + self.token.encode(out); + } +} + +impl WireDecode for PairingInvite { + fn decode(reader: &mut Reader) -> Result { + if reader.decode::()? != Self::VERSION { + return Err(WireError::InvalidPayload); + } + + Ok(Self { + qid: reader.decode()?, + token: reader.decode()?, + }) + } +} diff --git a/ql-fsm/src/session/ack_tracker.rs b/ql-fsm/src/session/ack_tracker.rs new file mode 100644 index 0000000..a75b5c6 --- /dev/null +++ b/ql-fsm/src/session/ack_tracker.rs @@ -0,0 +1,266 @@ +use std::{ops::RangeInclusive, time::Instant}; + +use ql_wire::{RecordAck, RecordAckBuilder, RecordSeq}; + +use super::range_set::RangeSet; + +#[derive(Debug, Clone)] +pub struct AckTracker { + accepted_records: RangeSet, + pending_ack: RangeSet, + ack_state: AckState, + accepted_record_window: u64, + pending_ack_range_limit: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PendingAck { + pub ack: RecordAck, + pub due_at: Instant, + pub includes_all_pending: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReceiveOutcome { + New, + Duplicate, + TooOld, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum AckState { + Idle, + Dirty { due_at: Instant }, +} + +impl AckTracker { + pub fn new(accepted_record_window: u64, pending_ack_range_limit: usize) -> Self { + Self { + accepted_records: RangeSet::new(), + pending_ack: RangeSet::new(), + ack_state: AckState::Idle, + accepted_record_window: accepted_record_window.max(1), + pending_ack_range_limit: pending_ack_range_limit.max(1), + } + } + + pub fn insert(&mut self, seq: RecordSeq) -> ReceiveOutcome { + let seq = seq.into_inner(); + let largest_accepted = self.accepted_records.max(); + if largest_accepted.is_some_and(|largest| seq < self.accepted_cutoff(largest)) { + return ReceiveOutcome::TooOld; + } + if self.accepted_records.contains(seq) { + self.pending_ack.insert(single_range(seq)); + self.trim_pending_ack_ranges(); + return ReceiveOutcome::Duplicate; + } + + self.accepted_records.insert(single_range(seq)); + self.trim_accepted_records(); + + self.pending_ack.insert(single_range(seq)); + self.trim_pending_ack_ranges(); + + ReceiveOutcome::New + } + + pub fn ack_deadline(&self) -> Option { + match self.ack_state { + AckState::Idle => None, + AckState::Dirty { due_at } => Some(due_at), + } + } + + pub fn schedule_ack(&mut self, due_at: Instant) { + self.ack_state = match self.ack_state { + AckState::Dirty { due_at: old } => AckState::Dirty { + due_at: due_at.min(old), + }, + AckState::Idle => AckState::Dirty { due_at }, + }; + } + + pub fn pending_ack(&self, max_wire_size: usize) -> Option { + let due_at = self.ack_deadline()?; + if max_wire_size == 0 || self.pending_ack.range_count() == 0 { + return None; + } + + let total_range_count = self.pending_ack.range_count(); + let mut ack = RecordAckBuilder::new(); + let mut selected_range_count = 0usize; + + for range in self.pending_ack.iter_rev() { + let pushed = ack + .try_push_range(to_ack_range(range), max_wire_size) + .unwrap(); + if !pushed { + break; + } + selected_range_count += 1; + } + + (selected_range_count != 0).then(|| PendingAck { + ack: ack.build().unwrap(), + due_at, + includes_all_pending: total_range_count == selected_range_count, + }) + } + + pub fn on_ack_emitted(&mut self, pending_ack: &PendingAck) { + self.retire_acked_ranges(&pending_ack.ack); + if pending_ack.includes_all_pending || self.pending_ack.range_count() == 0 { + self.ack_state = AckState::Idle; + } + } + + pub fn retire_acked_ranges(&mut self, ack: &RecordAck) { + for range in ack.ranges() { + self.pending_ack.remove(from_ack_range(range)); + } + if self.pending_ack.range_count() == 0 { + self.ack_state = AckState::Idle; + } + } + + pub fn clear_ack_state(&mut self) { + self.ack_state = AckState::Idle; + } + + pub fn restore_acked_ranges(&mut self, ack: &RecordAck, due_at: Instant) { + for range in ack.ranges() { + self.pending_ack.insert(from_ack_range(range)); + } + self.trim_pending_ack_ranges(); + self.schedule_ack(due_at); + } + + fn accepted_cutoff(&self, largest_accepted: u64) -> u64 { + largest_accepted + .saturating_add(1) + .saturating_sub(self.accepted_record_window) + } + + fn trim_accepted_records(&mut self) { + let Some(largest_accepted) = self.accepted_records.max() else { + return; + }; + let cutoff = self.accepted_cutoff(largest_accepted); + self.accepted_records.remove(0..cutoff); + } + + fn trim_pending_ack_ranges(&mut self) { + while self.pending_ack.range_count() > self.pending_ack_range_limit { + self.pending_ack.pop_min(); + } + } +} + +fn single_range(seq: u64) -> std::ops::Range { + seq..seq.checked_add(1).unwrap() +} + +fn to_ack_range(range: std::ops::Range) -> RangeInclusive { + let end = range.end.checked_sub(1).unwrap(); + RecordSeq::from_u64(range.start).unwrap()..=RecordSeq::from_u64(end).unwrap() +} + +fn from_ack_range(range: RangeInclusive) -> std::ops::Range { + let start = range.start().into_inner(); + let end = range.end().into_inner().checked_add(1).unwrap(); + start..end +} + +#[cfg(test)] +mod tests { + use std::time::{Duration, Instant}; + + use ql_wire::RecordSeq; + + use super::{AckTracker, PendingAck, ReceiveOutcome}; + + fn seq(value: u64) -> RecordSeq { + RecordSeq::from_u64(value).unwrap() + } + + fn ack_ranges(pending_ack: &PendingAck) -> Vec<(u64, u64)> { + pending_ack + .ack + .ranges() + .map(|range| (range.start().into_inner(), range.end().into_inner())) + .collect() + } + + #[test] + fn contiguous_records_emit_one_ack_range() { + let now = Instant::now(); + let mut ack_tracker = AckTracker::new(128, 8); + + assert_eq!(ack_tracker.insert(seq(10)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(11)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(12)), ReceiveOutcome::New); + + ack_tracker.schedule_ack(now); + let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); + assert_eq!(ack_ranges(&pending_ack), vec![(10, 12)]); + } + + #[test] + fn sparse_records_emit_descending_ack_ranges() { + let now = Instant::now(); + let mut ack_tracker = AckTracker::new(128, 8); + + assert_eq!(ack_tracker.insert(seq(10)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(15)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(16)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(12)), ReceiveOutcome::New); + + ack_tracker.schedule_ack(now + Duration::from_millis(5)); + let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); + assert_eq!(ack_ranges(&pending_ack), vec![(15, 16), (12, 12), (10, 10)]); + } + + #[test] + fn accepted_record_window_evicts_old_sequences() { + let mut ack_tracker = AckTracker::new(4, 8); + + assert_eq!(ack_tracker.insert(seq(10)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(15)), ReceiveOutcome::New); + + assert_eq!(ack_tracker.insert(seq(10)), ReceiveOutcome::TooOld); + } + + #[test] + fn pending_ack_range_limit_drops_oldest_low_ranges() { + let now = Instant::now(); + let mut ack_tracker = AckTracker::new(128, 2); + + assert_eq!(ack_tracker.insert(seq(1)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(3)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(5)), ReceiveOutcome::New); + + ack_tracker.schedule_ack(now); + let pending_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); + assert_eq!(ack_ranges(&pending_ack), vec![(5, 5), (3, 3)]); + } + + #[test] + fn retire_acked_ranges_removes_only_exact_snapshot() { + let now = Instant::now(); + let mut ack_tracker = AckTracker::new(128, 8); + + assert_eq!(ack_tracker.insert(seq(1)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(3)), ReceiveOutcome::New); + assert_eq!(ack_tracker.insert(seq(5)), ReceiveOutcome::New); + ack_tracker.schedule_ack(now); + + let first_ack = ack_tracker.pending_ack(4).unwrap(); + assert_eq!(ack_ranges(&first_ack), vec![(5, 5)]); + ack_tracker.on_ack_emitted(&first_ack); + ack_tracker.retire_acked_ranges(&first_ack.ack); + + let second_ack = ack_tracker.pending_ack(usize::MAX).unwrap(); + assert_eq!(ack_ranges(&second_ack), vec![(3, 3), (1, 1)]); + } +} diff --git a/ql-fsm/src/session/mod.rs b/ql-fsm/src/session/mod.rs new file mode 100644 index 0000000..5518775 --- /dev/null +++ b/ql-fsm/src/session/mod.rs @@ -0,0 +1,1041 @@ +pub use self::{state::TerminalFrame, stream_ops::*, stream_parity::*, stream_rx::*}; + +mod ack_tracker; +mod range_set; +mod remote_stream_history; +mod state; +mod stream_ops; +mod stream_parity; +mod stream_rx; +mod stream_tx; +mod tracked; + +#[cfg(test)] +mod tests; + +use std::time::{Duration, Instant}; + +use bytes::Bytes; +use indexmap::IndexMap; +use ql_wire::{ + CloseTarget, RecordAck, RecordSeq, RouteId, SessionClose, SessionCloseCode, SessionFrame, + SessionRecordBuilder, StreamClose, StreamData, StreamHeader, StreamId, StreamWindow, VarInt, + WireError, +}; + +use self::{ + ack_tracker::{AckTracker, PendingAck, ReceiveOutcome}, + remote_stream_history::RemoteStreamHistory, + state::{InboundState, OutboundState, SessionPhase, SessionState, StreamRole, StreamState}, + stream_tx::StreamTxRange, + tracked::{TrackedFrame, TrackedRecord, TrackedStreamData}, +}; +use crate::{NoSessionError, StreamError}; + +#[derive(Debug, Clone, Copy)] +pub struct SessionConfig { + pub local_parity: StreamParity, + pub record_max_size: usize, + pub ack_delay: Duration, + pub retransmit_timeout: Duration, + pub keepalive_interval: Duration, + pub peer_timeout: Duration, + pub stream_send_buffer_size: usize, + pub stream_receive_buffer_size: u32, + pub initial_peer_stream_receive_window: u32, + pub accepted_record_window: u64, + pub pending_ack_range_limit: usize, +} + +impl Default for SessionConfig { + fn default() -> Self { + Self { + local_parity: StreamParity::Even, + record_max_size: 8 * 1024, + ack_delay: Duration::from_millis(5), + retransmit_timeout: Duration::from_millis(150), + keepalive_interval: Duration::from_secs(10), + peer_timeout: Duration::from_secs(30), + stream_send_buffer_size: 16 * 1024, + stream_receive_buffer_size: 16 * 1024, + initial_peer_stream_receive_window: 16 * 1024, + accepted_record_window: 4096, + pending_ack_range_limit: 64, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionEvent { + Opened { + stream_id: StreamId, + route_id: RouteId, + }, + Readable(StreamId), + Writable(StreamId), + Finished(StreamId), + OutboundFinished(StreamId), + Closed(StreamClose), + WritableClosed(StreamClose), + SessionClosed(SessionClose), + Unpaired, +} + +pub trait EventSink { + fn emit(&mut self, event: SessionEvent); +} + +impl EventSink for F +where + F: FnMut(SessionEvent), +{ + fn emit(&mut self, event: SessionEvent) { + self(event); + } +} + +pub struct SessionFsm { + config: SessionConfig, + state: SessionState, +} + +impl SessionFsm { + pub fn new(mut config: SessionConfig, now: Instant) -> Self { + config.record_max_size = config + .record_max_size + .max(SessionRecordBuilder::MIN_CAPACITY); + config.stream_send_buffer_size = config.stream_send_buffer_size.max(1); + config.stream_receive_buffer_size = config.stream_receive_buffer_size.max(1); + config.accepted_record_window = config.accepted_record_window.max(1); + config.pending_ack_range_limit = config.pending_ack_range_limit.max(1); + Self { + config, + state: SessionState { + last_activity_at: now, + last_inbound_at: now, + phase: SessionPhase::Open, + next_stream_ordinal: 0, + next_record_seq: RecordSeq::from_u32(0), + next_write_id: 0, + tracked_records: IndexMap::default(), + ack_tracker: AckTracker::new( + config.accepted_record_window, + config.pending_ack_range_limit, + ), + pending_ping: false, + streams: IndexMap::default(), + next_stream_index: 0, + remote_stream_history: RemoteStreamHistory::new(config.local_parity.remote()), + }, + } + } + + pub fn open_stream( + &mut self, + route_id: RouteId, + sink: E, + ) -> Result, NoSessionError> + where + E: EventSink, + { + self.ensure_session_open()?; + let stream_id = self + .config + .local_parity + .make_stream_id(self.state.next_stream_ordinal); + self.state.next_stream_ordinal = self.state.next_stream_ordinal.saturating_add(1); + self.state.streams.insert( + stream_id, + StreamState::new( + StreamRole::Initiator, + Some(route_id), + self.config.stream_receive_buffer_size, + self.config.initial_peer_stream_receive_window, + ), + ); + let stream_index = self.state.streams.len() - 1; + Ok(StreamOps::new(self, stream_id, stream_index, sink)) + } + + pub fn stream( + &mut self, + stream_id: StreamId, + sink: E, + ) -> Result, StreamError> + where + E: EventSink, + { + self.ensure_session_open()?; + let Some(stream_index) = self.state.streams.get_index_of(&stream_id) else { + return Err(StreamError::MissingStream); + }; + + Ok(StreamOps::new(self, stream_id, stream_index, sink)) + } + + pub fn queue_ping(&mut self) -> Result<(), NoSessionError> { + self.ensure_session_open()?; + self.state.pending_ping = true; + Ok(()) + } + + pub fn close(&mut self, code: SessionCloseCode, sink: &mut impl EventSink) { + if self.state.phase != SessionPhase::Open { + return; + } + + self.begin_termination(TerminalFrame::Close(SessionClose { code }), sink); + } + + pub fn unpair(&mut self, sink: &mut impl EventSink) { + if self.state.phase != SessionPhase::Open { + return; + } + + self.begin_termination(TerminalFrame::Unpair, sink); + } + + pub fn is_closed(&self) -> bool { + self.state.phase == SessionPhase::Closed + } + + pub fn receive(&mut self, now: Instant, seq: RecordSeq, frames: I, sink: &mut impl EventSink) + where + I: IntoIterator, WireError>>, + { + if self.state.phase != SessionPhase::Open { + return; + } + + self.state.last_activity_at = now; + self.state.last_inbound_at = now; + self.collect_timeouts(now); + + match self.state.ack_tracker.insert(seq) { + ReceiveOutcome::TooOld => return, + ReceiveOutcome::Duplicate => { + self.schedule_ack(now, true); + return; + } + ReceiveOutcome::New => {} + } + + let mut ack_eliciting = false; + + for frame in frames { + let Ok(frame) = frame else { + self.close(SessionCloseCode::PROTOCOL, sink); + return; + }; + ack_eliciting |= !matches!(frame, SessionFrame::Ack(_)); + match frame { + SessionFrame::Ping => {} + SessionFrame::Unpair => { + self.unpair(sink); + return; + } + SessionFrame::Ack(ack) => self.process_record_ack(&ack, sink), + SessionFrame::StreamData(frame) => { + if self.handle_stream_data(frame, sink).is_err() { + self.close(SessionCloseCode::PROTOCOL, sink); + return; + } + } + SessionFrame::StreamWindow(frame) => self.handle_stream_window(&frame, sink), + SessionFrame::StreamClose(frame) => { + if self.handle_stream_close(&frame, sink).is_err() { + self.close(SessionCloseCode::PROTOCOL, sink); + return; + } + } + SessionFrame::Close(close) => { + self.close(close.code, sink); + return; + } + } + } + + if ack_eliciting { + self.schedule_ack(now, false); + } + } + + pub fn complete_write(&mut self, now: Instant, write_id: u64, success: bool) { + if !self.state.phase.is_open() { + return; + } + if success { + let Some(record) = self.state.tracked_records.get_mut(&write_id) else { + return; + }; + if record.sent_at.is_some() { + return; + } + self.state.last_activity_at = now; + record.sent_at = Some(now); + } else { + if self + .state + .tracked_records + .get(&write_id) + .is_some_and(|record| record.sent_at.is_some()) + { + return; + } + let Some(record) = self.state.tracked_records.shift_remove(&write_id) else { + return; + }; + restore_tracked_record( + now, + &mut self.state.ack_tracker, + &mut self.state.pending_ping, + &mut self.state.streams, + record, + ); + } + } + + pub fn on_timer(&mut self, now: Instant, sink: &mut impl EventSink) { + if !self.state.phase.is_open() { + return; + } + self.collect_timeouts(now); + if !self.config.peer_timeout.is_zero() + && self.state.last_inbound_at + self.config.peer_timeout <= now + { + self.close(SessionCloseCode::TIMEOUT, sink); + return; + } + if self.state.phase == SessionPhase::Open + && !self.config.keepalive_interval.is_zero() + && self.state.last_activity_at + self.config.keepalive_interval <= now + { + self.state.pending_ping = true; + } + } + + pub fn next_deadline(&self) -> Option { + if !self.state.phase.is_open() { + return None; + } + let ack_deadline = self.state.ack_tracker.ack_deadline(); + let retransmit_deadline = self + .state + .tracked_records + .values() + .filter_map(|record| { + record + .sent_at + .map(|sent_at| sent_at + self.config.retransmit_timeout) + }) + .min(); + let is_open = self.state.phase.is_open(); + let keepalive_deadline = + (is_open && !self.config.keepalive_interval.is_zero() && !self.state.pending_ping) + .then_some(self.state.last_activity_at + self.config.keepalive_interval); + let peer_timeout_deadline = (is_open && !self.config.peer_timeout.is_zero()) + .then_some(self.state.last_inbound_at + self.config.peer_timeout); + [ + ack_deadline, + retransmit_deadline, + keepalive_deadline, + peer_timeout_deadline, + ] + .into_iter() + .flatten() + .min() + } + + pub fn has_shutdown_work(&self) -> bool { + matches!(self.state.phase, SessionPhase::Terminating(_)) + || self.state.ack_tracker.ack_deadline().is_some() + || !self.state.tracked_records.is_empty() + } + + pub fn take_next_write(&mut self, now: Instant) -> Option<(Option, SessionRecordBuilder)> { + match &self.state.phase { + SessionPhase::Terminating(frame) => { + let seq = self.state.next_record_seq; + next_seq(&mut self.state.next_record_seq); + let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); + match frame { + TerminalFrame::Close(close) => { + assert!(builder.push_close(close), "builder has capacity"); + } + TerminalFrame::Unpair => { + assert!(builder.push_unpair(), "builder has capacity"); + } + } + self.state.phase = SessionPhase::Closed; + return Some((None, builder)); + } + SessionPhase::Closed => { + return None; + } + SessionPhase::Open => {} + } + self.collect_timeouts(now); + + let (builder, outbound) = self.build_next_record(now)?; + + let should_track = outbound.ping_included + || !outbound.window_updates.is_empty() + || !outbound.frames.is_empty(); + let write_id = should_track.then(|| { + let write_id = self.state.next_write_id; + self.state.next_write_id = self.state.next_write_id.wrapping_add(1); + self.state.tracked_records.insert(write_id, outbound); + write_id + }); + + Some((write_id, builder)) + } + + fn build_next_record(&mut self, now: Instant) -> Option<(SessionRecordBuilder, TrackedRecord)> { + let seq = self.state.next_record_seq; + let mut builder = SessionRecordBuilder::new(seq, self.config.record_max_size); + let mut outbound = TrackedRecord { + seq, + frames: Vec::new(), + ack: None, + ping_included: false, + window_updates: Vec::new(), + sent_at: None, + }; + + self.push_next_pending_stream_close(&mut builder, &mut outbound); + + if self.state.pending_ping && builder.push_ping() { + self.state.pending_ping = false; + outbound.ping_included = true; + } + + self.push_next_pending_stream_window(&mut builder, &mut outbound); + + self.push_next_stream_data(&mut builder, &mut outbound); + + if let Some(pending_ack) = self.pending_ack(builder.remaining_capacity()) { + if (!builder.is_empty() || pending_ack.due_at <= now) + && builder.push_ack(&pending_ack.ack) + { + self.state.ack_tracker.on_ack_emitted(&pending_ack); + outbound.ack = Some(pending_ack.ack); + } + } + + if builder.is_empty() { + return None; + } + + next_seq(&mut self.state.next_record_seq); + Some((builder, outbound)) + } + + fn begin_termination(&mut self, frame: TerminalFrame, sink: &mut impl EventSink) { + match &frame { + TerminalFrame::Close(close) => sink.emit(SessionEvent::SessionClosed(close.clone())), + TerminalFrame::Unpair => sink.emit(SessionEvent::Unpaired), + } + + self.state.phase = SessionPhase::Terminating(frame); + self.state.tracked_records.clear(); + self.state.ack_tracker.clear_ack_state(); + self.clear_streams(); + } + + fn push_next_pending_stream_close( + &mut self, + builder: &mut SessionRecordBuilder, + outbound: &mut TrackedRecord, + ) { + let len = self.state.streams.len(); + if len == 0 { + return; + } + + let start = self.state.next_stream_index % len; + for offset in 0..len { + let index = (start + offset) % len; + let stream = self.state.streams.get_index_mut(index).unwrap().1; + let Some(close) = stream.pending_close.as_ref() else { + continue; + }; + if !builder.push_stream_close(close) { + break; + } + + outbound.frames.push(TrackedFrame::StreamClose( + stream.pending_close.take().unwrap(), + )); + } + } + + fn push_next_pending_stream_window( + &mut self, + builder: &mut SessionRecordBuilder, + outbound: &mut TrackedRecord, + ) { + let len = self.state.streams.len(); + if len == 0 { + return; + } + + let start = self.state.next_stream_index % len; + for offset in 0..len { + let index = (start + offset) % len; + let (&stream_id, stream) = self.state.streams.get_index_mut(index).unwrap(); + if !stream.pending_window { + continue; + } + let frame = StreamWindow { + stream_id, + maximum_offset: VarInt::from_u64(stream.recv_limit()).unwrap(), + }; + if !builder.push_stream_window(&frame) { + break; + } + + stream.pending_window = false; + stream.advertised_max_offset = frame.maximum_offset.into_inner(); + outbound + .window_updates + .push((stream_id, frame.maximum_offset.into_inner())); + } + } + + fn push_next_stream_data( + &mut self, + builder: &mut SessionRecordBuilder, + outbound: &mut TrackedRecord, + ) { + const OVERHEAD: usize = 1 + StreamData::>::MIN_WIRE_SIZE; + + let len = self.state.streams.len(); + if len == 0 { + return; + } + + let start = self.state.next_stream_index % len; + let mut next_index = start; + + for offset in 0..len { + let Some(max_payload) = builder.remaining_capacity().checked_sub(OVERHEAD) else { + break; + }; + + let index = (start + offset) % len; + let (&stream_id, stream) = self.state.streams.get_index_mut(index).unwrap(); + if matches!(stream.outbound_state, OutboundState::Closed) { + continue; + } + let Some(candidate) = stream.tx.poll_transmit(max_payload, stream.peer_max_offset) + else { + continue; + }; + let offset = + VarInt::from_u64(candidate.offset).expect("stream offsets must fit ql-wire varint"); + let frame = StreamData { + stream_id, + offset, + header: if matches!(stream.role, StreamRole::Initiator) && candidate.offset == 0 { + stream.route_id.map(|route_id| StreamHeader { route_id }) + } else { + None + }, + fin: candidate.fin, + bytes: stream.tx.ranged_bytes(candidate), + }; + let res = builder.push_stream_data(&frame); + assert!(res, "builder has capacity"); + + if candidate.fin { + stream.outbound_state = OutboundState::Finished; + } + outbound + .frames + .push(TrackedFrame::StreamData(TrackedStreamData { + stream_id, + offset: candidate.offset, + len: candidate.len, + fin: candidate.fin, + })); + next_index = (index + 1) % len; + } + + self.state.next_stream_index = next_index; + } + + fn ensure_session_open(&self) -> Result<(), NoSessionError> { + if self.state.phase == SessionPhase::Open { + Ok(()) + } else { + Err(NoSessionError) + } + } + + fn process_record_ack(&mut self, ack: &RecordAck, sink: &mut impl EventSink) { + let stream_send_buffer_size = self.config.stream_send_buffer_size; + let acked_records = self + .state + .tracked_records + .extract_if(.., |_, record| { + record.sent_at.is_some() && ack.contains(record.seq.into_inner()) + }) + .map(|(_, record)| record) + .collect::>(); + + for record in acked_records { + for frame in &record.frames { + acknowledge_tracked_frame( + &mut self.state.streams, + stream_send_buffer_size, + frame, + sink, + ); + } + } + self.reap_reapable_streams(); + } + + fn schedule_ack(&mut self, now: Instant, immediate: bool) { + self.state.ack_tracker.schedule_ack(if immediate { + now + } else { + now + self.config.ack_delay + }); + } + + fn pending_ack(&self, remaining_capacity: usize) -> Option { + let max_ack_wire_size = remaining_capacity.checked_sub(1)?; + self.state.ack_tracker.pending_ack(max_ack_wire_size) + } + + fn collect_timeouts(&mut self, now: Instant) { + let retransmit_timeout = self.config.retransmit_timeout; + for (_, record) in self.state.tracked_records.extract_if(.., |_, record| { + record + .sent_at + .is_some_and(|sent_at| sent_at + retransmit_timeout <= now) + }) { + restore_tracked_record( + now, + &mut self.state.ack_tracker, + &mut self.state.pending_ping, + &mut self.state.streams, + record, + ); + } + } + + fn handle_stream_data( + &mut self, + frame: StreamData, + sink: &mut impl EventSink, + ) -> Result<(), ()> { + let StreamData { + stream_id, + offset, + header, + fin, + bytes, + } = frame; + let stream = match self.state.streams.get_mut(&stream_id) { + Some(stream) => stream, + None => match self.create_remote_stream(stream_id)? { + Some(stream) => stream, + None => return Ok(()), + }, + }; + + let frame_offset = offset.into_inner(); + let Some(frame_end) = frame_offset.checked_add(bytes.len() as u64) else { + return Err(()); + }; + let readable_before = stream.readable_bytes(); + let was_finished = matches!(stream.inbound_state, InboundState::Finished); + + let opened_route = match (stream.role, stream.route_id, header, frame_offset) { + (StreamRole::Responder, None, Some(header), 0) => { + stream.route_id = Some(header.route_id); + Some(header.route_id) + } + (StreamRole::Initiator, _, Some(_), _) + | (StreamRole::Responder, None, Some(_), _) + | (StreamRole::Responder, None, None, 0) => return Err(()), + _ => None, + }; + + match stream.inbound_state { + InboundState::Open => {} + InboundState::Discarding | InboundState::Closed(_) => return Ok(()), + InboundState::Finished => { + // finished stream should always have a final offset + let Some(final_offset) = stream.rx.final_offset() else { + debug_assert!(false, "finished stream must retain final offset"); + return Ok(()); + }; + + // retransmitted data for an already-finished stream is fine as long as it stays + // within the finalized byte range and any repeated FIN lands on that same offset. + if (!frame.fin || frame_end == final_offset) && frame_end <= final_offset { + if let Some(route_id) = opened_route { + sink.emit(SessionEvent::Opened { + stream_id, + route_id, + }); + if readable_before > 0 { + sink.emit(SessionEvent::Readable(stream_id)); + } else { + sink.emit(SessionEvent::Finished(stream_id)); + } + } + return Ok(()); + } + + return Err(()); + } + } + + let outcome = stream.rx.insert(frame_offset, fin, bytes).map_err(|_| ())?; + + if outcome.became_complete { + stream.inbound_state = InboundState::Finished; + } + + if let Some(route_id) = opened_route { + sink.emit(SessionEvent::Opened { + stream_id, + route_id, + }); + } + + if stream.route_id.is_some() && readable_before == 0 && stream.readable_bytes() > 0 { + sink.emit(SessionEvent::Readable(stream_id)); + } + + if stream.route_id.is_some() + && !was_finished + && matches!(stream.inbound_state, InboundState::Finished) + && stream.readable_bytes() == 0 + { + sink.emit(SessionEvent::Finished(stream_id)); + } + + self.try_reap_stream(stream_id); + Ok(()) + } + + fn handle_stream_window(&mut self, frame: &StreamWindow, sink: &mut impl EventSink) { + let Some(stream) = self.state.streams.get_mut(&frame.stream_id) else { + return; + }; + + let was_full = stream.send_capacity(self.config.stream_send_buffer_size) == 0; + let maximum_offset = frame.maximum_offset.into_inner(); + if maximum_offset > stream.peer_max_offset { + stream.peer_max_offset = maximum_offset; + } + if was_full && stream.send_capacity(self.config.stream_send_buffer_size) > 0 { + sink.emit(SessionEvent::Writable(frame.stream_id)); + } + } + + fn handle_stream_close( + &mut self, + frame: &StreamClose, + sink: &mut impl EventSink, + ) -> Result<(), ()> { + let stream_id = frame.stream_id; + let stream = match self.state.streams.get_mut(&stream_id) { + Some(stream) => stream, + None => match self.create_remote_stream(stream_id)? { + Some(stream) => stream, + None => return Ok(()), + }, + }; + + if Self::target_affects_inbound(stream.role, frame.target) + && !matches!( + stream.inbound_state, + InboundState::Closed(_) | InboundState::Discarding + ) + { + stream.inbound_state = InboundState::Closed(frame.clone()); + stream.reset_recv(); + sink.emit(SessionEvent::Closed(frame.clone())); + } + if Self::target_affects_outbound(stream.role, frame.target) + && !matches!(stream.outbound_state, OutboundState::Closed) + { + stream.outbound_state = OutboundState::Closed; + stream.tx.clear(); + stream.pending_close = None; + sink.emit(SessionEvent::WritableClosed(frame.clone())); + } + self.try_reap_stream(frame.stream_id); + Ok(()) + } + + fn apply_local_close_to_stream(stream: &mut StreamState, target: CloseTarget) { + if Self::target_affects_inbound(stream.role, target) { + stream.inbound_state = InboundState::Discarding; + stream.reset_recv(); + } + if Self::target_affects_outbound(stream.role, target) { + stream.outbound_state = OutboundState::Closed; + stream.tx.clear(); + } + } + + fn target_affects_inbound(role: StreamRole, target: CloseTarget) -> bool { + matches!(target, CloseTarget::Both) || role.inbound_target() == target + } + + fn target_affects_outbound(role: StreamRole, target: CloseTarget) -> bool { + matches!(target, CloseTarget::Both) || role.outbound_target() == target + } + + fn stream_is_reapable(&self, stream_id: StreamId, stream: &StreamState) -> bool { + let tracked_refs_stream = self.state.tracked_records.values().any(|record| { + record.window_updates.iter().any(|(id, _)| *id == stream_id) + || record.frames.iter().any(|frame| match frame { + TrackedFrame::StreamData(frame) => frame.stream_id == stream_id, + TrackedFrame::StreamClose(frame) => frame.stream_id == stream_id, + }) + }); + if tracked_refs_stream { + return false; + } + + if !stream.tx.is_empty() + || stream.pending_close.is_some() + || stream.pending_window + || stream.readable_bytes() > 0 + || stream.rx.buffered_end_offset() > stream.rx.start_offset() + { + return false; + } + + matches!( + stream.inbound_state, + InboundState::Finished | InboundState::Closed(_) | InboundState::Discarding + ) && matches!( + stream.outbound_state, + OutboundState::Finished | OutboundState::Closed + ) + } + + fn reap_reapable_streams(&mut self) { + let mut index = 0usize; + while index < self.state.streams.len() { + let stream_id = *self.state.streams.get_index(index).unwrap().0; + let len_before = self.state.streams.len(); + self.try_reap_stream(stream_id); + if self.state.streams.len() == len_before { + index += 1; + } + } + } + + fn try_reap_stream(&mut self, stream_id: StreamId) { + let Some(index) = self.state.streams.get_index_of(&stream_id) else { + return; + }; + self.try_reap_stream_at(stream_id, index); + } + + fn try_reap_stream_at(&mut self, stream_id: StreamId, index: usize) { + let Some((indexed_stream_id, stream)) = self.state.streams.get_index(index) else { + return; + }; + debug_assert_eq!(*indexed_stream_id, stream_id); + if !self.stream_is_reapable(stream_id, stream) { + return; + } + self.reap_stream_at(index); + } + + fn reap_stream_at(&mut self, index: usize) { + self.state.streams.shift_remove_index(index); + + if self.state.streams.is_empty() { + self.state.next_stream_index = 0; + return; + } + if index < self.state.next_stream_index { + self.state.next_stream_index -= 1; + } + if self.state.next_stream_index >= self.state.streams.len() { + self.state.next_stream_index %= self.state.streams.len(); + } + } + + fn clear_streams(&mut self) { + self.state.next_stream_index = 0; + self.state.streams.clear(); + } + + fn create_remote_stream( + &mut self, + stream_id: StreamId, + ) -> Result, ()> { + match classify_missing_stream( + self.config.local_parity, + self.state.next_stream_ordinal, + stream_id, + &mut self.state.remote_stream_history, + ) { + MissingStreamAction::Create => {} + MissingStreamAction::Ignore => return Ok(None), + MissingStreamAction::FailProtocol => { + return Err(()); + } + } + + let stream = self + .state + .streams + .entry(stream_id) + .insert_entry(StreamState::new( + StreamRole::Responder, + None, + self.config.stream_receive_buffer_size, + self.config.initial_peer_stream_receive_window, + )); + + Ok(Some(stream.into_mut())) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum MissingStreamAction { + Create, + Ignore, + FailProtocol, +} + +fn classify_missing_stream( + local_parity: StreamParity, + next_stream_ordinal: u32, + stream_id: StreamId, + remote_stream_history: &mut RemoteStreamHistory, +) -> MissingStreamAction { + if !local_parity.remote().matches(stream_id) { + return if local_stream_was_opened(local_parity, next_stream_ordinal, stream_id) { + MissingStreamAction::Ignore + } else { + MissingStreamAction::FailProtocol + }; + } + + if remote_stream_history.observe(stream_id) { + MissingStreamAction::Ignore + } else { + MissingStreamAction::Create + } +} + +fn local_stream_was_opened( + local_parity: StreamParity, + next_stream_ordinal: u32, + stream_id: StreamId, +) -> bool { + local_parity.matches(stream_id) + && stream_id.into_inner() + < local_parity + .make_stream_id(next_stream_ordinal) + .into_inner() +} + +fn restore_tracked_record( + now: Instant, + ack_tracker: &mut AckTracker, + pending_ping: &mut bool, + streams: &mut IndexMap, + record: TrackedRecord, +) { + if let Some(ack) = &record.ack { + ack_tracker.restore_acked_ranges(ack, now); + } + if record.ping_included { + *pending_ping = true; + } + for (stream_id, maximum_offset) in record.window_updates { + if let Some(stream) = streams.get_mut(&stream_id) { + if stream.recv_limit() >= maximum_offset { + stream.pending_window = true; + } + } + } + for frame in record.frames { + requeue_tracked_frame(streams, frame); + } +} + +fn requeue_tracked_frame(streams: &mut IndexMap, frame: TrackedFrame) { + match frame { + TrackedFrame::StreamClose(close) => restore_stream_close(streams, close), + TrackedFrame::StreamData(frame) => restore_stream_data(streams, frame), + } +} + +fn restore_stream_close(streams: &mut IndexMap, close: StreamClose) { + if let Some(stream) = streams.get_mut(&close.stream_id) { + stream.pending_close = Some(close); + } +} + +fn restore_stream_data(streams: &mut IndexMap, frame: TrackedStreamData) { + if let Some(stream) = streams.get_mut(&frame.stream_id) { + if matches!(stream.outbound_state, OutboundState::Closed) { + return; + } + stream.tx.retransmit(stream_tx::StreamTxRange { + offset: frame.offset, + len: frame.len, + fin: frame.fin, + }); + if frame.fin && matches!(stream.outbound_state, OutboundState::Finished) { + stream.outbound_state = OutboundState::FinQueued; + } + } +} + +fn acknowledge_tracked_frame( + streams: &mut IndexMap, + stream_send_buffer_size: usize, + frame: &TrackedFrame, + sink: &mut impl EventSink, +) { + match frame { + TrackedFrame::StreamClose(_) => {} + TrackedFrame::StreamData(frame) => { + let stream_id = frame.stream_id; + if let Some(stream) = streams.get_mut(&stream_id) { + let was_full = stream.send_capacity(stream_send_buffer_size) == 0; + let had_unacked_fin = frame.fin && stream.tx.has_unacked_fin(); + stream.tx.ack(StreamTxRange { + offset: frame.offset, + len: frame.len, + fin: frame.fin, + }); + if was_full && stream.send_capacity(stream_send_buffer_size) > 0 { + sink.emit(SessionEvent::Writable(stream_id)); + } + if had_unacked_fin && !stream.tx.has_unacked_fin() { + sink.emit(SessionEvent::OutboundFinished(stream_id)); + } + } + } + } +} + +#[inline] +#[track_caller] +fn next_seq(seq: &mut RecordSeq) { + *seq = seq + .into_inner() + .checked_add(1) + .and_then(|next| RecordSeq::from_u64(next).ok()) + .expect("record sequence overflow"); +} diff --git a/ql-fsm/src/session/range_set.rs b/ql-fsm/src/session/range_set.rs new file mode 100644 index 0000000..53d6626 --- /dev/null +++ b/ql-fsm/src/session/range_set.rs @@ -0,0 +1,221 @@ +use std::{ + cmp, + collections::BTreeMap, + ops::{ + Bound::{Excluded, Included}, + Range, + }, +}; + +/// A set of `u64` values optimized for long runs and random insert/delete. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct RangeSet(BTreeMap); + +impl RangeSet { + pub fn new() -> Self { + Self::default() + } + + pub fn insert(&mut self, mut x: Range) -> bool { + if x.is_empty() { + return false; + } + + if let Some((start, end)) = self.before(x.start) { + if end >= x.end { + return false; + } else if end >= x.start { + self.0.remove(&start); + x.start = start; + } + } + + while let Some((next_start, next_end)) = self.after(x.start) { + if next_start > x.end { + break; + } + self.0.remove(&next_start); + x.end = cmp::max(next_end, x.end); + } + + self.0.insert(x.start, x.end); + true + } + + pub fn remove(&mut self, x: Range) -> bool { + if x.is_empty() { + return false; + } + + let before = match self.before(x.start) { + Some((start, end)) if end > x.start => { + self.0.remove(&start); + if start < x.start { + self.0.insert(start, x.start); + } + if end > x.end { + self.0.insert(x.end, end); + } + if end >= x.end { + return true; + } + true + } + Some(_) | None => false, + }; + + let mut after = false; + while let Some((start, end)) = self.after(x.start) { + if start >= x.end { + break; + } + after = true; + self.0.remove(&start); + if end > x.end { + self.0.insert(x.end, end); + break; + } + } + + before || after + } + + pub fn min(&self) -> Option { + self.0.first_key_value().map(|(&start, _)| start) + } + + pub fn max(&self) -> Option { + self.0 + .last_key_value() + .map(|(_, &end)| end.checked_sub(1).unwrap()) + } + + pub fn contains(&self, x: u64) -> bool { + self.before(x).is_some_and(|(_, end)| end > x) + } + + pub fn range_count(&self) -> usize { + self.0.len() + } + + pub fn iter(&self) -> Iter<'_> { + Iter(self.0.iter()) + } + + pub fn iter_rev(&self) -> RevIter<'_> { + RevIter(self.0.iter().rev()) + } + + pub fn peek_min(&self) -> Option> { + let (&start, &end) = self.0.iter().next()?; + Some(start..end) + } + + pub fn pop_min(&mut self) -> Option> { + let result = self.peek_min()?; + self.0.remove(&result.start); + Some(result) + } + + #[cfg(test)] + pub fn peek_max(&self) -> Option> { + let (&start, &end) = self.0.iter().next_back()?; + Some(start..end) + } + + #[cfg(test)] + pub fn pop_max(&mut self) -> Option> { + let result = self.peek_max()?; + self.0.remove(&result.start); + Some(result) + } + + /// find closest range to `x` that begins at or before it + fn before(&self, x: u64) -> Option<(u64, u64)> { + self.0 + .range((Included(0), Included(x))) + .next_back() + .map(|(&start, &end)| (start, end)) + } + + /// find the closest range to `x` that begins after it + fn after(&self, x: u64) -> Option<(u64, u64)> { + self.0 + .range((Excluded(x), Included(u64::MAX))) + .next() + .map(|(&start, &end)| (start, end)) + } +} + +pub struct Iter<'a>(std::collections::btree_map::Iter<'a, u64, u64>); + +impl Iterator for Iter<'_> { + type Item = Range; + + fn next(&mut self) -> Option { + self.0.next().map(|(&start, &end)| start..end) + } +} + +pub struct RevIter<'a>(std::iter::Rev>); + +impl Iterator for RevIter<'_> { + type Item = Range; + + fn next(&mut self) -> Option { + self.0.next().map(|(&start, &end)| start..end) + } +} + +#[cfg(test)] +mod tests { + use super::RangeSet; + + #[test] + fn insert_merges_overlaps() { + let mut set = RangeSet::new(); + assert!(set.insert(10..20)); + assert!(set.insert(30..40)); + assert!(set.insert(15..35)); + assert_eq!(set.iter().collect::>(), vec![10..40]); + } + + #[test] + fn remove_splits_ranges() { + let mut set = RangeSet::new(); + set.insert(10..40); + assert!(set.remove(20..30)); + assert_eq!(set.iter().collect::>(), vec![10..20, 30..40]); + } + + #[test] + fn reverse_iteration_visits_highest_range_first() { + let mut set = RangeSet::new(); + set.insert(10..20); + set.insert(30..40); + set.insert(50..60); + + assert_eq!( + set.iter_rev().collect::>(), + vec![50..60, 30..40, 10..20] + ); + assert_eq!(set.peek_max(), Some(50..60)); + assert_eq!(set.pop_max(), Some(50..60)); + assert_eq!(set.iter().collect::>(), vec![10..20, 30..40]); + } + + #[test] + fn contains_and_max_reflect_current_membership() { + let mut set = RangeSet::new(); + set.insert(10..20); + set.insert(30..31); + + assert!(!set.contains(9)); + assert!(set.contains(10)); + assert!(set.contains(19)); + assert!(!set.contains(20)); + assert_eq!(set.min(), Some(10)); + assert_eq!(set.max(), Some(30)); + assert_eq!(set.range_count(), 2); + } +} diff --git a/ql-fsm/src/session/remote_stream_history.rs b/ql-fsm/src/session/remote_stream_history.rs new file mode 100644 index 0000000..76c1e8b --- /dev/null +++ b/ql-fsm/src/session/remote_stream_history.rs @@ -0,0 +1,60 @@ +use ql_wire::StreamId; + +use super::{range_set::RangeSet, stream_parity::StreamParity}; + +#[derive(Debug)] +pub struct RemoteStreamHistory { + parity: StreamParity, + seen: RangeSet, +} + +impl RemoteStreamHistory { + pub fn new(parity: StreamParity) -> Self { + Self { + parity, + seen: RangeSet::new(), + } + } + + /// returns true when this remote stream id was already observed before + /// panics if `stream_id` is wrong stream parity + #[allow(clippy::range_plus_one)] + pub fn observe(&mut self, stream_id: StreamId) -> bool { + let ordinal = self + .stream_ordinal(stream_id) + .expect("remote stream history used with wrong stream parity"); + !self.seen.insert(ordinal..ordinal + 1) + } + + fn stream_ordinal(&self, stream_id: StreamId) -> Option { + let delta = stream_id + .into_inner() + .checked_sub(u64::from(self.parity.first_stream_id()))?; + if delta % 2 != 0 { + return None; + } + Some(delta / 2) + } +} + +#[cfg(test)] +mod tests { + use super::RemoteStreamHistory; + use crate::session::stream_parity::StreamParity; + + #[test] + fn observe() { + let parity = StreamParity::Even; + let mut history = RemoteStreamHistory::new(parity); + + assert!(!history.observe(parity.make_stream_id(2))); + assert!(!history.observe(parity.make_stream_id(5))); + assert!(!history.observe(parity.make_stream_id(0))); + assert!(!history.observe(parity.make_stream_id(4))); + assert!(history.observe(parity.make_stream_id(2))); + assert!(!history.observe(parity.make_stream_id(1))); + assert!(history.observe(parity.make_stream_id(5))); + assert!(!history.observe(parity.make_stream_id(3))); + assert!(history.observe(parity.make_stream_id(0))); + } +} diff --git a/ql-fsm/src/session/state.rs b/ql-fsm/src/session/state.rs new file mode 100644 index 0000000..b63140a --- /dev/null +++ b/ql-fsm/src/session/state.rs @@ -0,0 +1,140 @@ +use std::time::Instant; + +use indexmap::IndexMap; +use ql_wire::{CloseTarget, RecordSeq, RouteId, SessionClose, StreamClose, StreamId}; + +use super::{ + ack_tracker::AckTracker, remote_stream_history::RemoteStreamHistory, stream_rx::StreamRx, + stream_tx::StreamTx, tracked::TrackedRecord, +}; + +pub struct SessionState { + pub last_activity_at: Instant, + pub last_inbound_at: Instant, + pub phase: SessionPhase, + pub next_stream_ordinal: u32, + pub next_record_seq: RecordSeq, + pub next_write_id: u64, + pub tracked_records: IndexMap, + pub ack_tracker: AckTracker, + pub pending_ping: bool, + pub streams: IndexMap, + pub next_stream_index: usize, + pub remote_stream_history: RemoteStreamHistory, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionPhase { + Open, + Terminating(TerminalFrame), + Closed, +} + +impl SessionPhase { + pub fn is_open(&self) -> bool { + self == &Self::Open + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TerminalFrame { + Close(SessionClose), + Unpair, +} + +#[derive(Debug)] +pub struct StreamState { + pub role: StreamRole, + pub route_id: Option, + pub rx: StreamRx, + pub tx: StreamTx, + pub pending_close: Option, + pub peer_max_offset: u64, + pub outbound_state: OutboundState, + pub inbound_state: InboundState, + pub advertised_max_offset: u64, + pub pending_window: bool, +} + +impl StreamState { + pub fn new( + role: StreamRole, + route_id: Option, + receive_buffer_size: u32, + initial_peer_stream_receive_window: u32, + ) -> Self { + let receive_buffer_size = receive_buffer_size as usize; + Self { + role, + route_id, + tx: StreamTx::new(), + pending_close: None, + peer_max_offset: u64::from(initial_peer_stream_receive_window), + outbound_state: OutboundState::Open, + inbound_state: InboundState::Open, + rx: StreamRx::new(receive_buffer_size), + advertised_max_offset: receive_buffer_size as u64, + pending_window: false, + } + } + + pub fn is_writable(&self) -> bool { + matches!(self.outbound_state, OutboundState::Open) + } + + pub fn send_capacity(&self, send_buffer_size: usize) -> usize { + send_buffer_size.saturating_sub(self.tx.buffered_len()) + } + + pub fn readable_bytes(&self) -> usize { + self.rx.readable_len() + } + + pub fn recv_limit(&self) -> u64 { + self.rx + .start_offset() + .saturating_add(self.rx.max_buffered() as u64) + } + + pub fn reset_recv(&mut self) { + self.rx = StreamRx::with_start_offset(self.rx.start_offset(), self.rx.max_buffered()); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamRole { + Initiator, + Responder, +} + +impl StreamRole { + pub fn outbound_target(self) -> CloseTarget { + match self { + Self::Initiator => CloseTarget::Origin, + Self::Responder => CloseTarget::Return, + } + } + + pub fn inbound_target(self) -> CloseTarget { + match self { + Self::Initiator => CloseTarget::Return, + Self::Responder => CloseTarget::Origin, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum OutboundState { + Open, + FinQueued, + Finished, + Closed, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum InboundState { + Open, + Finished, + Closed(StreamClose), + Discarding, +} diff --git a/ql-fsm/src/session/stream_ops.rs b/ql-fsm/src/session/stream_ops.rs new file mode 100644 index 0000000..548189b --- /dev/null +++ b/ql-fsm/src/session/stream_ops.rs @@ -0,0 +1,147 @@ +use ql_wire::{CloseTarget, StreamClose, StreamCloseCode, StreamId}; + +use super::{ + state::{InboundState, StreamState}, + stream_rx::StreamReadIter, + EventSink, SessionEvent, SessionFsm, +}; +use crate::CommitReadError; + +pub struct StreamOps<'a, E> { + session: &'a mut SessionFsm, + emit: E, + stream_id: StreamId, + stream_index: usize, + reap_on_drop: bool, +} + +impl<'a, E: EventSink> StreamOps<'a, E> { + pub(super) fn new( + session: &'a mut SessionFsm, + stream_id: StreamId, + stream_index: usize, + emit: E, + ) -> Self { + Self { + session, + emit, + stream_id, + stream_index, + reap_on_drop: false, + } + } + + /// returns this stream's identifier + pub fn stream_id(&self) -> StreamId { + self.stream_id + } + + /// returns the readable stream bytes as owned `Bytes` views without consuming them + pub fn read(&self) -> StreamReadIter<'_> { + self.stream().rx.bytes() + } + + /// returns how many bytes can be read from the stream + pub fn readable_bytes(&self) -> usize { + self.stream().readable_bytes() + } + + /// marks previously read bytes as consumed + pub fn commit_read(&mut self, len: usize) -> Result<(), CommitReadError> { + let stream_id = self.stream_id; + let emit_finished = { + let stream = self.stream_mut(); + if len > stream.readable_bytes() { + return Err(CommitReadError); + } + stream.rx.consume(len); + if stream.recv_limit() > stream.advertised_max_offset { + stream.pending_window = true; + } + stream.route_id.is_some() + && matches!(stream.inbound_state, InboundState::Finished) + && stream.readable_bytes() == 0 + }; + if emit_finished { + self.emit.emit(SessionEvent::Finished(stream_id)); + } + self.reap_on_drop = true; + Ok(()) + } + + /// returns a writer if the local write side is still open + pub fn writer(&mut self) -> Option> { + let send_buffer_size = self.session.config.stream_send_buffer_size; + let stream = self.stream_mut(); + if !stream.is_writable() { + return None; + } + Some(StreamWriter::new(stream, send_buffer_size)) + } + + /// closes the origin lane, return lane, or both lanes of the stream + pub fn close(&mut self, target: CloseTarget, code: StreamCloseCode) { + let stream_id = self.stream_id; + let stream = self.stream_mut(); + SessionFsm::apply_local_close_to_stream(stream, target); + stream.pending_close = Some(StreamClose { + stream_id, + target, + code, + }); + self.reap_on_drop = true; + } + + fn stream(&self) -> &StreamState { + &self.session.state.streams[self.stream_index] + } + + fn stream_mut(&mut self) -> &mut StreamState { + &mut self.session.state.streams[self.stream_index] + } +} + +impl Drop for StreamOps<'_, E> { + fn drop(&mut self) { + if !self.reap_on_drop { + return; + } + + self.session + .try_reap_stream_at(self.stream_id, self.stream_index); + } +} + +pub struct StreamWriter<'a> { + stream: &'a mut StreamState, + send_buffer_size: usize, +} + +impl<'a> StreamWriter<'a> { + pub(super) fn new(stream: &'a mut StreamState, send_buffer_size: usize) -> Self { + Self { + stream, + send_buffer_size, + } + } + + /// returns how many bytes can still be buffered for local writes + pub fn capacity(&self) -> usize { + self.stream.send_capacity(self.send_buffer_size) + } + + /// appends as many bytes as possible and returns the accepted count + pub fn write(&mut self, bytes: &mut bytes::Bytes) -> usize { + let accepted = bytes.len().min(self.capacity()); + if accepted > 0 { + self.stream.tx.append(bytes.split_to(accepted)); + } + accepted + } + + /// marks the local write side as finished + pub fn finish(self) { + self.stream.tx.queue_fin(); + self.stream.outbound_state = super::state::OutboundState::FinQueued; + } +} diff --git a/ql-fsm/src/session/stream_parity.rs b/ql-fsm/src/session/stream_parity.rs new file mode 100644 index 0000000..70f6077 --- /dev/null +++ b/ql-fsm/src/session/stream_parity.rs @@ -0,0 +1,44 @@ +use ql_wire::{StreamId, QID}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamParity { + Even, + Odd, +} + +impl StreamParity { + pub fn for_local(local: QID, peer: QID) -> Self { + match local.0.cmp(&peer.0) { + std::cmp::Ordering::Less | std::cmp::Ordering::Equal => Self::Even, + std::cmp::Ordering::Greater => Self::Odd, + } + } + + pub const fn first_stream_id(self) -> u32 { + match self { + Self::Even => 0, + Self::Odd => 1, + } + } + + pub const fn matches(self, stream_id: StreamId) -> bool { + match self { + Self::Even => stream_id.into_inner() % 2 == 0, + Self::Odd => stream_id.into_inner() % 2 == 1, + } + } + + pub const fn remote(self) -> Self { + match self { + Self::Even => Self::Odd, + Self::Odd => Self::Even, + } + } + + pub fn make_stream_id(self, ordinal: u32) -> StreamId { + StreamId(ql_wire::VarInt::from_u32( + self.first_stream_id() + .saturating_add(ordinal.saturating_mul(2)), + )) + } +} diff --git a/ql-fsm/src/session/stream_rx.rs b/ql-fsm/src/session/stream_rx.rs new file mode 100644 index 0000000..0f5a8ea --- /dev/null +++ b/ql-fsm/src/session/stream_rx.rs @@ -0,0 +1,428 @@ +use std::collections::{btree_map, BTreeMap}; + +use bytes::{Buf, Bytes}; + +/// reassembles one stream direction from out-of-order byte ranges. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamRx { + start_offset: u64, + chunks: BTreeMap, + final_offset: Option, + max_buffered: usize, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct InsertOutcome { + pub newly_readable_bytes: usize, + pub became_complete: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StreamRxError { + OffsetOverflow, + OutOfWindow, + InconsistentFinalOffset, + FinalOffsetBeforeBufferedData, + BeyondFinalOffset, +} + +impl StreamRx { + pub fn new(max_buffered: usize) -> Self { + Self::with_start_offset(0, max_buffered) + } + + pub fn with_start_offset(start_offset: u64, max_buffered: usize) -> Self { + Self { + start_offset, + chunks: BTreeMap::new(), + final_offset: None, + max_buffered, + } + } + + pub fn start_offset(&self) -> u64 { + self.start_offset + } + + pub fn buffered_end_offset(&self) -> u64 { + self.chunks + .last_key_value() + .map_or(self.start_offset, |(&offset, bytes)| { + offset + bytes.len() as u64 + }) + } + + pub fn final_offset(&self) -> Option { + self.final_offset + } + + pub fn max_buffered(&self) -> usize { + self.max_buffered + } + + pub fn readable_len(&self) -> usize { + let mut cursor = self.start_offset; + for (&offset, bytes) in self.chunks.range(self.start_offset..) { + if offset > cursor { + break; + } + + let end = offset + bytes.len() as u64; + if end > cursor { + cursor = end; + } + } + + usize::try_from(cursor - self.start_offset).expect("readable prefix exceeds usize") + } + + pub fn bytes(&self) -> StreamReadIter<'_> { + StreamReadIter { + inner: self.chunks.range(self.start_offset..), + cursor: self.start_offset, + remaining: self.readable_len(), + } + } + + pub fn is_complete(&self) -> bool { + matches!(self.final_offset, Some(final_offset) + if final_offset == self.buffered_end_offset() + && final_offset == self.start_offset + self.readable_len() as u64) + } + + pub fn insert( + &mut self, + offset: u64, + fin: bool, + mut bytes: Bytes, + ) -> Result { + let end = offset + .checked_add(bytes.len() as u64) + .ok_or(StreamRxError::OffsetOverflow)?; + + let was_complete = self.is_complete(); + let old_readable = self.readable_len(); + + if fin { + self.set_or_validate_final_offset(end)?; + } + if let Some(final_offset) = self.final_offset { + if end > final_offset { + return Err(StreamRxError::BeyondFinalOffset); + } + } + + if bytes.is_empty() || end <= self.start_offset { + return Ok(self.insert_outcome(was_complete, old_readable)); + } + + let effective_offset = offset.max(self.start_offset); + let trim_front = + usize::try_from(effective_offset - offset).expect("front trim exceeds usize"); + bytes.advance(trim_front); + if bytes.is_empty() { + return Ok(self.insert_outcome(was_complete, old_readable)); + } + + let effective_end = effective_offset + bytes.len() as u64; + self.ensure_within_window(effective_end)?; + self.insert_chunk(effective_offset, bytes); + + Ok(self.insert_outcome(was_complete, old_readable)) + } + + pub fn consume(&mut self, len: usize) { + let readable = self.readable_len(); + debug_assert!(len <= readable, "consume beyond readable bytes"); + if len > readable { + return; + } + + let new_start = self.start_offset.saturating_add(len as u64); + while let Some((&offset, bytes)) = self.chunks.first_key_value() { + let end = offset + bytes.len() as u64; + if end <= new_start { + self.chunks.pop_first(); + continue; + } + if offset < new_start { + let (offset, mut bytes) = self.chunks.pop_first().unwrap(); + bytes.advance(usize::try_from(new_start - offset).expect("trim exceeds usize")); + self.chunks.insert(new_start, bytes); + } + break; + } + + self.start_offset = new_start; + } + + fn insert_outcome(&self, was_complete: bool, old_readable: usize) -> InsertOutcome { + InsertOutcome { + newly_readable_bytes: self.readable_len().saturating_sub(old_readable), + became_complete: !was_complete && self.is_complete(), + } + } + + fn set_or_validate_final_offset(&mut self, final_offset: u64) -> Result<(), StreamRxError> { + if let Some(existing) = self.final_offset { + return if existing == final_offset { + Ok(()) + } else { + Err(StreamRxError::InconsistentFinalOffset) + }; + } + + let buffered_end = self.buffered_end_offset(); + if final_offset < buffered_end { + return Err(StreamRxError::FinalOffsetBeforeBufferedData); + } + + self.final_offset = Some(final_offset); + Ok(()) + } + + fn ensure_within_window(&self, end: u64) -> Result<(), StreamRxError> { + let attempted = end.saturating_sub(self.start_offset); + if attempted > self.max_buffered as u64 { + return Err(StreamRxError::OutOfWindow); + } + Ok(()) + } + + fn insert_chunk(&mut self, mut offset: u64, mut bytes: Bytes) { + if bytes.is_empty() { + return; + } + + if let Some((&existing_offset, existing)) = self.chunks.range(..offset).next_back() { + let existing_end = existing_offset + existing.len() as u64; + if existing_end > offset { + let overlap = + usize::try_from((existing_end - offset).min(bytes.len() as u64)).unwrap(); + bytes.advance(overlap); + offset += overlap as u64; + } + } + + if bytes.is_empty() { + return; + } + + let end = offset + bytes.len() as u64; + let overlapping = self + .chunks + .range(offset..end) + .map(|(&chunk_offset, _)| chunk_offset) + .collect::>(); + + for chunk_offset in overlapping { + let chunk_end = chunk_offset + self.chunks[&chunk_offset].len() as u64; + + if chunk_offset > offset { + let len = usize::try_from(chunk_offset - offset).expect("gap exceeds usize"); + self.chunks.insert(offset, bytes.slice(..len)); + bytes.advance(len); + offset = chunk_offset; + } + + let overlap = usize::try_from((chunk_end - offset).min(bytes.len() as u64)).unwrap(); + bytes.advance(overlap); + offset += overlap as u64; + + if bytes.is_empty() { + return; + } + } + + self.chunks.insert(offset, bytes); + } +} + +#[derive(Debug, Clone)] +pub struct StreamReadIter<'a> { + inner: btree_map::Range<'a, u64, Bytes>, + cursor: u64, + remaining: usize, +} + +impl Iterator for StreamReadIter<'_> { + type Item = Bytes; + + fn next(&mut self) -> Option { + while self.remaining > 0 { + let (&offset, bytes) = self.inner.next()?; + if offset > self.cursor { + self.remaining = 0; + return None; + } + + let skip = usize::try_from(self.cursor.saturating_sub(offset)) + .expect("read cursor exceeds usize"); + if skip >= bytes.len() { + continue; + } + + let len = (bytes.len() - skip).min(self.remaining); + self.remaining -= len; + self.cursor += len as u64; + return Some(bytes.slice(skip..skip + len)); + } + + None + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::{InsertOutcome, StreamRx, StreamRxError}; + + pub fn copy_readable(rx: &StreamRx) -> Vec { + let readable = rx.readable_len(); + let mut out = Vec::with_capacity(readable); + for chunk in rx.bytes() { + out.extend_from_slice(&chunk); + } + out + } + + fn bytes(bytes: &'static [u8]) -> Bytes { + Bytes::from_static(bytes) + } + + #[test] + fn contiguous_insert_becomes_readable_and_complete() { + let mut rx = StreamRx::new(64); + + let outcome = rx.insert(0, true, bytes(b"hello")).unwrap(); + + assert_eq!( + outcome, + InsertOutcome { + newly_readable_bytes: 5, + became_complete: true, + } + ); + assert_eq!(rx.readable_len(), 5); + assert_eq!(copy_readable(&rx), b"hello"); + assert_eq!(rx.final_offset, Some(5)); + assert!(rx.is_complete()); + } + + #[test] + fn out_of_order_insert_tracks_gap_until_prefix_is_filled() { + let mut rx = StreamRx::new(64); + + let first = rx.insert(5, true, bytes(b" world")).unwrap(); + assert_eq!( + first, + InsertOutcome { + newly_readable_bytes: 0, + became_complete: false, + } + ); + assert_eq!(rx.readable_len(), 0); + + let second = rx.insert(0, false, bytes(b"hello")).unwrap(); + assert_eq!( + second, + InsertOutcome { + newly_readable_bytes: 11, + became_complete: true, + } + ); + assert_eq!(copy_readable(&rx), b"hello world"); + assert!(rx.is_complete()); + } + + #[test] + fn duplicate_insert_is_ignored_if_bytes_match() { + let mut rx = StreamRx::new(64); + + rx.insert(0, false, bytes(b"hello")).unwrap(); + let duplicate = rx.insert(0, false, bytes(b"hello")).unwrap(); + + assert_eq!( + duplicate, + InsertOutcome { + newly_readable_bytes: 0, + became_complete: false, + } + ); + assert_eq!(copy_readable(&rx), b"hello"); + } + + #[test] + fn consume_advances_start_offset_and_trims_old_prefix() { + let mut rx = StreamRx::new(64); + + rx.insert(0, false, bytes(b"abcd")).unwrap(); + rx.consume(2); + assert_eq!(rx.start_offset(), 2); + assert_eq!(copy_readable(&rx), b"cd"); + + let outcome = rx.insert(1, true, bytes(b"bcde")).unwrap(); + assert_eq!( + outcome, + InsertOutcome { + newly_readable_bytes: 1, + became_complete: true, + } + ); + assert_eq!(copy_readable(&rx), b"cde"); + assert_eq!(rx.final_offset, Some(5)); + assert!(rx.is_complete()); + } + + #[test] + fn insert_can_fill_multiple_gaps_without_rebuilding_state() { + let mut rx = StreamRx::new(64); + + rx.insert(0, false, bytes(b"ab")).unwrap(); + rx.insert(4, false, bytes(b"ef")).unwrap(); + rx.insert(8, true, bytes(b"ij")).unwrap(); + + let outcome = rx.insert(2, false, bytes(b"cdefgh")).unwrap(); + + assert_eq!( + outcome, + InsertOutcome { + newly_readable_bytes: 8, + became_complete: true, + } + ); + + assert_eq!(copy_readable(&rx), b"abcdefghij"); + assert!(rx.is_complete()); + } + + #[test] + fn heavily_fragmented_inserts_stay_valid() { + let mut rx = StreamRx::new(64); + + rx.insert(1, false, bytes(b"b")).unwrap(); + rx.insert(3, false, bytes(b"d")).unwrap(); + rx.insert(5, false, bytes(b"f")).unwrap(); + rx.insert(7, false, bytes(b"h")).unwrap(); + rx.insert(9, true, bytes(b"j")).unwrap(); + + let outcome = rx.insert(0, false, bytes(b"abcdefghi")).unwrap(); + assert_eq!( + outcome, + InsertOutcome { + newly_readable_bytes: 10, + became_complete: true, + } + ); + assert_eq!(copy_readable(&rx), b"abcdefghij"); + assert!(rx.is_complete()); + } + + #[test] + fn out_of_window_insert_is_rejected() { + let mut rx = StreamRx::new(4); + let error = rx.insert(5, false, bytes(b"a")).unwrap_err(); + assert_eq!(error, StreamRxError::OutOfWindow); + } +} diff --git a/ql-fsm/src/session/stream_tx.rs b/ql-fsm/src/session/stream_tx.rs new file mode 100644 index 0000000..1553392 --- /dev/null +++ b/ql-fsm/src/session/stream_tx.rs @@ -0,0 +1,579 @@ +use std::{collections::VecDeque, ops::Range}; + +use bytes::{Buf, Bytes}; +use ql_wire::BufView; + +use super::range_set::RangeSet; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StreamTx { + chunks: VecDeque, + buffered_len: usize, + base_offset: u64, + unsent: u64, + acked: RangeSet, + retransmits: RangeSet, + final_offset: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct TrackedFinalOffset { + offset: u64, + state: SendState, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SendState { + Unsent, + Sent, + Lost, + Acked, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct StreamTxRange { + pub offset: u64, + pub len: usize, + pub fin: bool, +} + +#[derive(Debug, Clone, Copy)] +pub struct StreamTxBytes<'a> { + inner: &'a VecDeque, + offset: usize, + len: usize, +} + +pub struct StreamTxBuf<'a> { + inner: std::collections::vec_deque::Iter<'a, Bytes>, + skip: usize, + remaining: usize, + current: &'a [u8], +} + +impl BufView for StreamTxBytes<'_> { + type Buf<'a> + = StreamTxBuf<'a> + where + Self: 'a; + + fn buf(&self) -> Self::Buf<'_> { + let mut buf = StreamTxBuf { + inner: self.inner.iter(), + skip: self.offset, + remaining: self.len, + current: &[], + }; + buf.refill(); + buf + } +} + +impl StreamTxBuf<'_> { + fn refill(&mut self) { + if self.remaining == 0 { + self.current = &[]; + return; + } + + for chunk in self.inner.by_ref() { + if self.skip >= chunk.len() { + self.skip -= chunk.len(); + continue; + } + + let chunk = &chunk[self.skip..]; + self.skip = 0; + if chunk.is_empty() { + continue; + } + + let len = chunk.len().min(self.remaining); + self.current = &chunk[..len]; + return; + } + + self.current = &[]; + } +} + +impl Buf for StreamTxBuf<'_> { + fn remaining(&self) -> usize { + self.remaining + } + + fn chunk(&self) -> &[u8] { + self.current + } + + fn advance(&mut self, cnt: usize) { + let remaining = self.remaining; + assert!( + cnt <= remaining, + "cannot advance past remaining bytes: {cnt} > {remaining}", + ); + + self.remaining -= cnt; + let mut cnt = cnt; + while cnt > 0 { + if cnt < self.current.len() { + self.current = &self.current[cnt..]; + return; + } + + cnt -= self.current.len(); + self.refill(); + } + + if self.remaining == 0 { + self.current = &[]; + } + } +} + +impl StreamTx { + pub fn new() -> Self { + Self { + chunks: VecDeque::new(), + buffered_len: 0, + base_offset: 0, + unsent: 0, + acked: RangeSet::new(), + retransmits: RangeSet::new(), + final_offset: None, + } + } + + pub fn buffered_len(&self) -> usize { + self.buffered_len + } + + pub fn end_offset(&self) -> u64 { + self.base_offset + self.buffered_len as u64 + } + + pub fn is_empty(&self) -> bool { + self.buffered_len == 0 && self.final_offset.is_none() + } + + pub fn append(&mut self, bytes: Bytes) { + if bytes.is_empty() { + return; + } + + self.buffered_len += bytes.len(); + self.chunks.push_back(bytes); + } + + pub fn queue_fin(&mut self) { + self.final_offset = Some(TrackedFinalOffset { + offset: self.end_offset(), + state: SendState::Unsent, + }); + } + + pub fn has_unacked_fin(&self) -> bool { + self.final_offset + .is_some_and(|final_offset| final_offset.state != SendState::Acked) + } + + pub fn poll_transmit( + &mut self, + max_payload: usize, + peer_max_offset: u64, + ) -> Option { + let budget_end = |start: u64| { + start + .saturating_add(max_payload as u64) + .min(peer_max_offset) + }; + + // prefer the lowest lost bytes before sending new bytes + if let Some(range) = self.retransmits.peek_min() { + let mut end = range.end.min(budget_end(range.start)); + + // extend only when lost bytes end where unsent bytes begin + if end == range.end && range.end == self.unsent { + end = self.end_offset().min(budget_end(range.start)); + } + + if end > range.start { + let range = self.retransmits.pop_min().unwrap(); + if end < range.end { + self.retransmits.insert(end..range.end); + } + + // mark any new bytes in this frame as sent + self.unsent = self.unsent.max(end); + return Some(StreamTxRange { + offset: range.start, + len: usize::try_from(end - range.start).unwrap(), + fin: self.poll_fin(end), + }); + } + } + + // send bytes that have not been sent yet + if self.unsent < self.end_offset() { + let end = self.end_offset().min(budget_end(self.unsent)); + if end > self.unsent { + let start = self.unsent; + self.unsent = end; + return Some(StreamTxRange { + offset: start, + len: usize::try_from(end - start).unwrap(), + fin: self.poll_fin(end), + }); + } + } + + // send a fin after all data has been sent + let final_offset = + self.final_offset + .as_mut() + .filter(|TrackedFinalOffset { offset, state }| { + (*state == SendState::Lost || *state == SendState::Unsent) + && *offset <= peer_max_offset + })?; + final_offset.state = SendState::Sent; + Some(StreamTxRange { + offset: final_offset.offset, + len: 0, + fin: true, + }) + } + + pub fn ranged_bytes(&self, range: StreamTxRange) -> StreamTxBytes<'_> { + let offset = usize::try_from(range.offset - self.base_offset).unwrap(); + let len = range.len.min(self.buffered_len.saturating_sub(offset)); + StreamTxBytes { + inner: &self.chunks, + offset, + len, + } + } + + pub fn retransmit(&mut self, range: StreamTxRange) { + if let Some(range) = self.clamp_sent_range(range.offset, range.len) { + Self::insert_not_acked(&self.acked, &mut self.retransmits, range); + } + if range.fin { + self.mark_fin_lost(); + } + } + + pub fn ack(&mut self, range: StreamTxRange) { + if let Some(range) = self.clamp_buffered_range(range.offset, range.len) { + self.acked.insert(range.clone()); + self.retransmits.remove(range); + self.trim_acked_prefix(); + } + if range.fin { + if let Some(final_offset) = self.final_offset.as_mut() { + final_offset.state = SendState::Acked; + } + } + self.trim_acked_fin(); + } + + pub fn clear(&mut self) { + self.chunks.clear(); + self.buffered_len = 0; + self.unsent = self.base_offset; + self.acked = RangeSet::new(); + self.retransmits = RangeSet::new(); + self.final_offset = None; + } + + fn clamp_buffered_range(&self, offset: u64, len: usize) -> Option> { + if len == 0 { + return None; + } + let start = offset.max(self.base_offset); + let end = offset.saturating_add(len as u64).min(self.end_offset()); + (start < end).then_some(start..end) + } + + fn clamp_sent_range(&self, offset: u64, len: usize) -> Option> { + if len == 0 { + return None; + } + let start = offset.max(self.base_offset); + let end = offset.saturating_add(len as u64).min(self.unsent); + (start < end).then_some(start..end) + } + + fn insert_not_acked(acked_set: &RangeSet, target: &mut RangeSet, range: Range) { + let mut cursor = range.start; + for acked in acked_set.iter() { + if acked.end <= cursor { + continue; + } + if acked.start >= range.end { + break; + } + if cursor < acked.start { + target.insert(cursor..acked.start.min(range.end)); + } + cursor = cursor.max(acked.end); + if cursor >= range.end { + break; + } + } + if cursor < range.end { + target.insert(cursor..range.end); + } + } + + fn poll_fin(&mut self, offset: u64) -> bool { + let Some(final_offset) = self.final_offset.as_mut() else { + return false; + }; + if matches!(final_offset.state, SendState::Lost | SendState::Unsent) + && final_offset.offset == offset + { + final_offset.state = SendState::Sent; + true + } else { + false + } + } + + fn mark_fin_lost(&mut self) { + if let Some(final_offset) = self.final_offset.as_mut() { + if final_offset.state != SendState::Acked { + final_offset.state = SendState::Lost; + } + } + } + + fn trim_acked_prefix(&mut self) { + while self.acked.min() == Some(self.base_offset) { + let prefix = self.acked.pop_min().unwrap(); + let mut to_advance = usize::try_from(prefix.end - prefix.start).unwrap(); + self.buffered_len -= to_advance; + while to_advance > 0 { + let front = self + .chunks + .front_mut() + .expect("expected buffered chunks for acked prefix"); + if front.len() <= to_advance { + to_advance -= front.len(); + self.chunks.pop_front(); + } else { + front.advance(to_advance); + to_advance = 0; + } + } + self.base_offset = prefix.end; + } + } + + fn trim_acked_fin(&mut self) { + if self.final_offset.is_some_and(|final_offset| { + final_offset.state == SendState::Acked + && final_offset.offset == self.base_offset + && self.buffered_len == 0 + }) { + self.final_offset = None; + } + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::{StreamTx, StreamTxRange}; + + #[test] + fn append_tracks_unsent_bytes() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abc")); + tx.append(Bytes::from_static(b"de")); + + assert_eq!( + tx.poll_transmit(8, u64::MAX), + Some(StreamTxRange { + offset: 0, + len: 5, + fin: false, + }) + ); + } + + #[test] + fn lost_range_is_selected_before_unsent_bytes() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abcdef")); + + let first = tx.poll_transmit(3, u64::MAX).unwrap(); + tx.retransmit(first); + + assert_eq!( + tx.poll_transmit(3, u64::MAX), + Some(StreamTxRange { + offset: 0, + len: 3, + fin: false, + }) + ); + } + + #[test] + fn lost_range_coalesces_contiguous_unsent_bytes() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abc")); + + let first = tx.poll_transmit(3, u64::MAX).unwrap(); + tx.retransmit(first); + tx.append(Bytes::from_static(b"def")); + + assert_eq!( + tx.poll_transmit(6, u64::MAX), + Some(StreamTxRange { + offset: 0, + len: 6, + fin: false, + }) + ); + assert_eq!(tx.poll_transmit(6, u64::MAX), None); + } + + #[test] + fn lost_range_coalesces_only_new_bytes_that_fit() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abc")); + + let first = tx.poll_transmit(3, u64::MAX).unwrap(); + tx.retransmit(first); + tx.append(Bytes::from_static(b"def")); + + assert_eq!( + tx.poll_transmit(5, u64::MAX), + Some(StreamTxRange { + offset: 0, + len: 5, + fin: false, + }) + ); + assert_eq!( + tx.poll_transmit(6, u64::MAX), + Some(StreamTxRange { + offset: 5, + len: 1, + fin: false, + }) + ); + } + + #[test] + fn non_contiguous_lost_range_does_not_coalesce_unsent_bytes() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abcdef")); + + let first = tx.poll_transmit(3, u64::MAX).unwrap(); + let _second = tx.poll_transmit(3, u64::MAX).unwrap(); + tx.retransmit(first); + tx.append(Bytes::from_static(b"ghi")); + + assert_eq!( + tx.poll_transmit(6, u64::MAX), + Some(StreamTxRange { + offset: 0, + len: 3, + fin: false, + }) + ); + assert_eq!( + tx.poll_transmit(6, u64::MAX), + Some(StreamTxRange { + offset: 6, + len: 3, + fin: false, + }) + ); + } + + #[test] + fn acked_prefix_is_trimmed() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abcdef")); + + let first = tx.poll_transmit(3, u64::MAX).unwrap(); + tx.ack(first); + + assert_eq!( + tx.poll_transmit(3, u64::MAX), + Some(StreamTxRange { + offset: 3, + len: 3, + fin: false, + }) + ); + } + + #[test] + fn empty_fin_is_tracked_separately() { + let mut tx = StreamTx::new(); + tx.queue_fin(); + + let range = tx.poll_transmit(16, u64::MAX).unwrap(); + assert_eq!( + range, + StreamTxRange { + offset: 0, + len: 0, + fin: true, + } + ); + + tx.ack(range); + assert!(tx.is_empty()); + } + + #[test] + fn subrange_updates_split_merged_in_flight_segments() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abcdefghijkl")); + + let _first = tx.poll_transmit(4, u64::MAX).unwrap(); + let second = tx.poll_transmit(4, u64::MAX).unwrap(); + let _third = tx.poll_transmit(4, u64::MAX).unwrap(); + + tx.retransmit(second); + + assert_eq!( + tx.poll_transmit(4, u64::MAX), + Some(StreamTxRange { + offset: 4, + len: 4, + fin: false, + }) + ); + } + + #[test] + fn acked_subrange_is_not_reopened_by_stale_timeout() { + let mut tx = StreamTx::new(); + tx.append(Bytes::from_static(b"abcdefghijklmnop")); + + let _first = tx.poll_transmit(4, u64::MAX).unwrap(); + let second = tx.poll_transmit(4, u64::MAX).unwrap(); + let third = tx.poll_transmit(4, u64::MAX).unwrap(); + let _fourth = tx.poll_transmit(4, u64::MAX).unwrap(); + + tx.ack(second); + tx.retransmit(second); + tx.retransmit(third); + + assert_eq!( + tx.poll_transmit(4, u64::MAX), + Some(StreamTxRange { + offset: 8, + len: 4, + fin: false, + }) + ); + } +} diff --git a/ql-fsm/src/session/tests.rs b/ql-fsm/src/session/tests.rs new file mode 100644 index 0000000..f1f2987 --- /dev/null +++ b/ql-fsm/src/session/tests.rs @@ -0,0 +1,869 @@ +use std::time::{Duration, Instant}; + +use bytes::Bytes; +use ql_wire::{ + decode_session_frames, parse_session_frames, CloseTarget, RecordAck, RecordSeq, RouteId, + SessionFrame, SessionRecordBuilder, StreamClose, StreamCloseCode, StreamData, StreamHeader, + StreamId, VarInt, QID, +}; + +use super::{SessionConfig, SessionEvent, SessionFsm}; +use crate::session::stream_parity::StreamParity; + +fn seq(value: u64) -> RecordSeq { + RecordSeq::from_u64(value).unwrap() +} + +fn stream_id(value: u64) -> StreamId { + StreamId(VarInt::from_u64(value).unwrap()) +} + +fn offset(value: u64) -> VarInt { + VarInt::from_u64(value).unwrap() +} + +fn route_id(value: u64) -> RouteId { + RouteId::from_u64(value).unwrap() +} + +fn record_ack(seq: RecordSeq) -> RecordAck { + RecordAck::from_ranges([seq..=seq]).unwrap() +} + +const REFUSED: StreamCloseCode = StreamCloseCode(1); +const TIMEOUT: StreamCloseCode = StreamCloseCode(2); + +fn header(value: u64) -> StreamHeader { + StreamHeader { + route_id: route_id(value), + } +} + +fn opened(stream_id: StreamId) -> SessionEvent { + SessionEvent::Opened { + stream_id, + route_id: route_id(1), + } +} + +fn open_stream_id(fsm: &mut SessionFsm) -> StreamId { + fsm.open_stream(route_id(1), |_| {}).unwrap().stream_id() +} + +fn write_stream_bytes(fsm: &mut SessionFsm, stream_id: StreamId, bytes: &[u8]) -> usize { + let mut bytes = Bytes::copy_from_slice(bytes); + let mut stream = fsm.stream(stream_id, |_| {}).unwrap(); + let mut writer = stream.writer().unwrap(); + writer.write(&mut bytes) +} + +fn read_stream_all(fsm: &mut SessionFsm, stream_id: StreamId) -> Vec { + let mut stream = fsm.stream(stream_id, |_| {}).unwrap(); + let out = stream.read().flatten().collect::>(); + stream.commit_read(out.len()).unwrap(); + out +} + +fn read_stream_all_with_events( + fsm: &mut SessionFsm, + stream_id: StreamId, + events: &mut Vec, +) -> Vec { + let mut stream = fsm.stream(stream_id, |event| events.push(event)).unwrap(); + let out = stream.read().flatten().collect::>(); + stream.commit_read(out.len()).unwrap(); + out +} + +fn next_outbound( + fsm: &mut SessionFsm, + now: Instant, +) -> Option<(RecordSeq, Vec>>)> { + let (write_id, builder) = fsm.take_next_write(now)?; + if let Some(write_id) = write_id { + fsm.complete_write(now, write_id, true); + } + Some(( + builder.seq(), + decode_session_frames(builder.bytes()).unwrap(), + )) +} + +fn drain_outbound( + fsm: &mut SessionFsm, + now: Instant, + limit: usize, +) -> Vec<(RecordSeq, Vec>>)> { + let mut records = Vec::new(); + for _ in 0..limit { + let Some(record) = next_outbound(fsm, now) else { + return records; + }; + records.push(record); + } + + panic!("session did not quiesce within outbound limit"); +} + +fn receive_events( + fsm: &mut SessionFsm, + now: Instant, + seq: RecordSeq, + record: &[SessionFrame>], +) -> Vec { + let mut builder = SessionRecordBuilder::new(seq, usize::MAX); + for frame in record { + assert!(builder.push_frame(frame)); + } + let bytes = Bytes::from(builder.bytes().to_vec()); + let frames = parse_session_frames(bytes); + let mut events = Vec::new(); + let mut emit = |event| events.push(event); + fsm.receive(now, seq, frames, &mut emit); + events +} + +#[test] +fn outbound_record_seq_increments_monotonically() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = open_stream_id(&mut fsm); + + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"one"), 3); + let (first_seq, _) = next_outbound(&mut fsm, now).unwrap(); + + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"two"), 3); + let (second_seq, _) = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); + + assert_eq!(first_seq, seq(0)); + assert_eq!(second_seq, seq(1)); +} + +#[test] +fn retransmit_uses_new_record_seq() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = open_stream_id(&mut fsm); + + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"retry"), 5); + let (first_seq, first) = next_outbound(&mut fsm, now).unwrap(); + + let mut emit = |_| {}; + fsm.on_timer(now + Duration::from_millis(200), &mut emit); + let (retried_seq, retried) = next_outbound(&mut fsm, now + Duration::from_millis(200)).unwrap(); + + assert_ne!(first_seq, retried_seq); + assert_eq!(first, retried); +} + +#[test] +fn lost_record_on_one_stream_does_not_block_another_stream() { + let now = Instant::now(); + let mut fsm = SessionFsm::new( + SessionConfig { + record_max_size: 80 + SessionRecordBuilder::MIN_CAPACITY, + ..SessionConfig::default() + }, + now, + ); + let stream_id_a = open_stream_id(&mut fsm); + let stream_id_b = open_stream_id(&mut fsm); + let payload_a = vec![b'a'; 40]; + let payload_b = vec![b'b'; 40]; + + assert_eq!(write_stream_bytes(&mut fsm, stream_id_a, &payload_a), 40); + assert_eq!(write_stream_bytes(&mut fsm, stream_id_b, &payload_b), 40); + + let (first_seq, first) = next_outbound(&mut fsm, now).unwrap(); + let (second_seq, _second) = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); + assert_ne!(first_seq, second_seq); + assert!(first.iter().any( + |frame| matches!(frame, SessionFrame::StreamData(frame) if frame.stream_id == stream_id_a) + )); + + assert_eq!(write_stream_bytes(&mut fsm, stream_id_b, b"b-2"), 3); + let (_third_seq, third) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); + + let stream_ids: Vec<_> = third + .iter() + .filter_map(|frame| match frame { + SessionFrame::StreamData(frame) => Some(frame.stream_id), + _ => None, + }) + .collect(); + assert_eq!(stream_ids, vec![stream_id_b]); +} + +#[test] +fn ack_reopens_write_capacity() { + let now = Instant::now(); + let mut fsm = SessionFsm::new( + SessionConfig { + stream_send_buffer_size: 4, + ..SessionConfig::default() + }, + now, + ); + let stream_id = open_stream_id(&mut fsm); + + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"abcd"), 4); + let (record_seq, _record) = next_outbound(&mut fsm, now).unwrap(); + + let mut events = Vec::new(); + let mut emit = |event| events.push(event); + fsm.receive( + now + Duration::from_millis(1), + seq(9), + std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), + &mut emit, + ); + + assert!(events.contains(&SessionEvent::Writable(stream_id))); + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"z"), 1); +} + +#[test] +fn ack_of_fin_emits_outbound_finished_once() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = open_stream_id(&mut fsm); + + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"done"), 4); + fsm.stream(stream_id, |_| {}) + .unwrap() + .writer() + .unwrap() + .finish(); + + let (record_seq, record) = next_outbound(&mut fsm, now).unwrap(); + assert!(matches!( + record.as_slice(), + [SessionFrame::StreamData(StreamData { + stream_id: id, + fin: true, + .. + })] if *id == stream_id + )); + + let mut events = Vec::new(); + { + let mut emit = |event| events.push(event); + fsm.receive( + now + Duration::from_millis(1), + seq(9), + std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), + &mut emit, + ); + } + assert_eq!(events, vec![SessionEvent::OutboundFinished(stream_id)]); + + { + let mut emit = |event| events.push(event); + fsm.receive( + now + Duration::from_millis(2), + seq(10), + std::iter::once(Ok(SessionFrame::Ack(record_ack(record_seq)))), + &mut emit, + ); + } + assert_eq!(events, vec![SessionEvent::OutboundFinished(stream_id)]); +} + +#[test] +fn commit_stream_read_is_what_advances_stream_window() { + let now = Instant::now(); + let mut fsm = SessionFsm::new( + SessionConfig { + local_parity: StreamParity::Even, + ack_delay: Duration::ZERO, + ..SessionConfig::default() + }, + now, + ); + let stream_id = stream_id(1); + let data = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + header: Some(header(1)), + fin: false, + bytes: b"hi".to_vec(), + })]; + let events = receive_events(&mut fsm, now, seq(7), &data); + assert_eq!( + events, + vec![opened(stream_id), SessionEvent::Readable(stream_id)] + ); + + let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); + let first = decode_session_frames(builder.bytes()).unwrap(); + assert!(write_id.is_none()); + assert!(matches!(first.as_slice(), [SessionFrame::Ack(_)])); + + let read = fsm + .stream(stream_id, |_| {}) + .unwrap() + .read() + .map(|chunk| chunk.len()) + .sum::(); + assert_eq!(read, 2); + + assert!(next_outbound(&mut fsm, now + Duration::from_millis(2)).is_none()); + + fsm.stream(stream_id, |_| {}) + .unwrap() + .commit_read(2) + .unwrap(); + let (_second_seq, second) = next_outbound(&mut fsm, now + Duration::from_millis(3)).unwrap(); + assert!(matches!( + second.as_slice(), + [SessionFrame::StreamWindow(window)] if window.stream_id == stream_id + )); +} + +#[test] +fn pure_ack_only_records_are_fire_and_forget() { + let now = Instant::now(); + let config = SessionConfig { + ack_delay: Duration::ZERO, + ..SessionConfig::default() + }; + let retransmit_timeout = config.retransmit_timeout; + let mut fsm = SessionFsm::new(config, now); + let stream_id = stream_id(1); + let record = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + header: Some(header(1)), + fin: false, + bytes: b"hi".to_vec(), + })]; + + let _ = receive_events(&mut fsm, now, seq(7), &record); + + let (write_id, builder) = fsm.take_next_write(now + Duration::from_millis(1)).unwrap(); + let ack = decode_session_frames(builder.bytes()).unwrap(); + assert!(write_id.is_none()); + assert!(matches!(ack.as_slice(), [SessionFrame::Ack(_)])); + + let mut emit = |_| {}; + fsm.on_timer( + now + retransmit_timeout + Duration::from_millis(1), + &mut emit, + ); + assert!(fsm + .take_next_write(now + retransmit_timeout + Duration::from_millis(1)) + .is_none()); +} + +#[test] +fn inbound_stream_data_emits_opened_and_readable() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = stream_id(1); + let record = vec![SessionFrame::StreamData(ql_wire::StreamData { + stream_id, + offset: offset(0), + header: Some(header(1)), + fin: true, + bytes: b"hello".to_vec(), + })]; + + let events = receive_events(&mut fsm, now, seq(0), &record); + assert_eq!( + events, + vec![opened(stream_id), SessionEvent::Readable(stream_id)] + ); + let mut events = Vec::new(); + assert_eq!( + read_stream_all_with_events(&mut fsm, stream_id, &mut events), + b"hello".to_vec() + ); + assert_eq!(events, vec![SessionEvent::Finished(stream_id)]); +} + +#[test] +fn inbound_empty_fin_emits_finished_immediately() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = stream_id(1); + let record = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + header: Some(header(1)), + fin: true, + bytes: Vec::new(), + })]; + + let events = receive_events(&mut fsm, now, seq(0), &record); + assert_eq!( + events, + vec![opened(stream_id), SessionEvent::Finished(stream_id)] + ); +} + +#[test] +fn remote_stream_close_is_reliable_and_retried() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = open_stream_id(&mut fsm); + + fsm.stream(stream_id, |_| {}) + .unwrap() + .close(CloseTarget::Both, StreamCloseCode::CANCELLED); + + let (write_id, builder) = fsm.take_next_write(now).unwrap(); + fsm.complete_write(now, write_id.expect("stream close should be tracked"), true); + let first = decode_session_frames(builder.bytes()).unwrap(); + assert!(matches!( + first.as_slice(), + [SessionFrame::StreamClose(StreamClose { stream_id: id, .. })] if *id == stream_id + )); + + let mut emit = |_| {}; + fsm.on_timer(now + Duration::from_millis(200), &mut emit); + let (_retried_seq, retried) = + next_outbound(&mut fsm, now + Duration::from_millis(200)).unwrap(); + assert_eq!(first, retried); +} + +#[test] +fn stream_ids_follow_even_odd_xid_ordering() { + let now = Instant::now(); + let even = StreamParity::for_local(QID([1; QID::SIZE]), QID([2; QID::SIZE])); + let odd = StreamParity::for_local(QID([2; QID::SIZE]), QID([1; QID::SIZE])); + + let even_id = SessionFsm::new( + SessionConfig { + local_parity: even, + ..SessionConfig::default() + }, + now, + ) + .open_stream(route_id(1), |_| {}) + .unwrap() + .stream_id(); + let odd_id = SessionFsm::new( + SessionConfig { + local_parity: odd, + ..SessionConfig::default() + }, + now, + ) + .open_stream(route_id(1), |_| {}) + .unwrap() + .stream_id(); + + assert_eq!(even_id.into_inner() % 2, 0); + assert_eq!(odd_id.into_inner() % 2, 1); +} + +#[test] +fn duplicate_stream_data_is_not_redelivered() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = stream_id(1); + let record = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + header: Some(header(1)), + fin: false, + bytes: b"hi".to_vec(), + })]; + let _ = receive_events(&mut fsm, now, seq(1), &record); + let _ = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &record); + + assert_eq!(read_stream_all(&mut fsm, stream_id), b"hi".to_vec()); +} + +#[test] +fn duplicate_remote_close_after_reap_is_ignored() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let close = StreamClose { + stream_id: stream_id(1), + target: CloseTarget::Both, + code: StreamCloseCode(9), + }; + let record = vec![SessionFrame::StreamClose(close.clone())]; + + let first = receive_events(&mut fsm, now, seq(1), &record); + assert_eq!( + first, + vec![ + SessionEvent::Closed(close.clone()), + SessionEvent::WritableClosed(close), + ] + ); + + let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &record); + assert!(second.is_empty()); +} + +#[test] +fn late_remote_stream_data_after_close_is_ignored() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = stream_id(1); + let close = vec![SessionFrame::StreamClose(StreamClose { + stream_id, + target: CloseTarget::Both, + code: StreamCloseCode(9), + })]; + let data = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + header: Some(header(1)), + fin: false, + bytes: b"hello".to_vec(), + })]; + + let first = receive_events(&mut fsm, now, seq(1), &close); + assert_eq!( + first, + vec![ + SessionEvent::Closed(StreamClose { + stream_id, + target: CloseTarget::Both, + code: StreamCloseCode(9), + }), + SessionEvent::WritableClosed(StreamClose { + stream_id, + target: CloseTarget::Both, + code: StreamCloseCode(9), + }), + ] + ); + + let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &data); + assert!(second.is_empty()); +} + +#[test] +fn duplicate_finished_remote_data_after_reap_is_ignored() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = stream_id(1); + let record = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + header: Some(header(1)), + fin: true, + bytes: b"hello".to_vec(), + })]; + + let first = receive_events(&mut fsm, now, seq(1), &record); + assert_eq!( + first, + vec![opened(stream_id), SessionEvent::Readable(stream_id)] + ); + let mut events = Vec::new(); + assert_eq!( + read_stream_all_with_events(&mut fsm, stream_id, &mut events), + b"hello".to_vec() + ); + assert_eq!(events, vec![SessionEvent::Finished(stream_id)]); + + let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &record); + assert!(second.is_empty()); +} + +#[test] +fn duplicate_finished_remote_data_before_read_is_ignored() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let stream_id = stream_id(1); + let record = vec![SessionFrame::StreamData(StreamData { + stream_id, + offset: offset(0), + header: Some(header(1)), + fin: true, + bytes: b"hello".to_vec(), + })]; + + let first = receive_events(&mut fsm, now, seq(1), &record); + assert_eq!( + first, + vec![opened(stream_id), SessionEvent::Readable(stream_id)] + ); + + let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &record); + assert!(second.is_empty()); + let mut events = Vec::new(); + assert_eq!( + read_stream_all_with_events(&mut fsm, stream_id, &mut events), + b"hello".to_vec() + ); + assert_eq!(events, vec![SessionEvent::Finished(stream_id)]); +} + +#[test] +fn out_of_order_remote_stream_first_observations_still_open_once_each() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + let close3 = vec![SessionFrame::StreamClose(StreamClose { + stream_id: stream_id(3), + target: CloseTarget::Both, + code: REFUSED, + })]; + let close1 = vec![SessionFrame::StreamClose(StreamClose { + stream_id: stream_id(1), + target: CloseTarget::Both, + code: TIMEOUT, + })]; + + let first = receive_events(&mut fsm, now, seq(1), &close3); + assert_eq!( + first, + vec![ + SessionEvent::Closed(StreamClose { + stream_id: stream_id(3), + target: CloseTarget::Both, + code: REFUSED, + }), + SessionEvent::WritableClosed(StreamClose { + stream_id: stream_id(3), + target: CloseTarget::Both, + code: REFUSED, + }), + ] + ); + + let second = receive_events(&mut fsm, now + Duration::from_millis(1), seq(2), &close1); + assert_eq!( + second, + vec![ + SessionEvent::Closed(StreamClose { + stream_id: stream_id(1), + target: CloseTarget::Both, + code: TIMEOUT, + }), + SessionEvent::WritableClosed(StreamClose { + stream_id: stream_id(1), + target: CloseTarget::Both, + code: TIMEOUT, + }), + ] + ); + + let third = receive_events(&mut fsm, now + Duration::from_millis(2), seq(3), &close3); + assert!(third.is_empty()); +} + +#[test] +fn invalid_remote_stream_close_closes_session() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + + let invalid = vec![SessionFrame::StreamClose(StreamClose { + stream_id: stream_id(0), + target: CloseTarget::Both, + code: StreamCloseCode(9), + })]; + let events = receive_events(&mut fsm, now, seq(1), &invalid); + + assert_eq!( + events, + vec![SessionEvent::SessionClosed(ql_wire::SessionClose { + code: ql_wire::SessionCloseCode::PROTOCOL, + })] + ); +} + +#[test] +fn close_does_not_ack_rejected_record_seq() { + let now = Instant::now(); + let mut fsm = SessionFsm::new( + SessionConfig { + ack_delay: Duration::ZERO, + ..SessionConfig::default() + }, + now, + ); + + let invalid = vec![SessionFrame::StreamData(StreamData { + stream_id: stream_id(0), + offset: offset(0), + header: Some(header(1)), + fin: false, + bytes: b"bad".to_vec(), + })]; + let events = receive_events(&mut fsm, now, seq(7), &invalid); + assert_eq!( + events, + vec![SessionEvent::SessionClosed(ql_wire::SessionClose { + code: ql_wire::SessionCloseCode::PROTOCOL, + })] + ); + + let valid_after_close = vec![SessionFrame::Ping]; + let events = receive_events( + &mut fsm, + now + Duration::from_millis(1), + seq(8), + &valid_after_close, + ); + assert!(events.is_empty()); + + let (_seq, outbound) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); + assert!(matches!(outbound.as_slice(), [SessionFrame::Close(_)])); +} + +#[test] +fn inbound_unpair_emits_final_unpair_frame() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + + let events = receive_events(&mut fsm, now, seq(1), &[SessionFrame::Unpair]); + assert_eq!(events, vec![SessionEvent::Unpaired]); + assert!(!fsm.is_closed()); + + let (_seq, outbound) = next_outbound(&mut fsm, now + Duration::from_millis(1)).unwrap(); + assert!(matches!(outbound.as_slice(), [SessionFrame::Unpair])); + assert!(fsm.is_closed()); +} + +#[test] +fn terminating_session_ignores_inbound_frames() { + let now = Instant::now(); + let mut fsm = SessionFsm::new(SessionConfig::default(), now); + + let mut events = Vec::new(); + fsm.unpair(&mut |event| events.push(event)); + assert_eq!(events, vec![SessionEvent::Unpaired]); + + let ignored = receive_events( + &mut fsm, + now + Duration::from_millis(1), + seq(1), + &[SessionFrame::Ping], + ); + assert!(ignored.is_empty()); + + let (_seq, outbound) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); + assert!(matches!(outbound.as_slice(), [SessionFrame::Unpair])); + assert!(fsm.is_closed()); +} + +#[test] +fn initial_peer_stream_receive_window_limits_first_send() { + let now = Instant::now(); + let mut fsm = SessionFsm::new( + SessionConfig { + initial_peer_stream_receive_window: 3, + ..SessionConfig::default() + }, + now, + ); + let stream_id = open_stream_id(&mut fsm); + + assert_eq!(write_stream_bytes(&mut fsm, stream_id, b"hello"), 5); + let (_first_seq, first) = next_outbound(&mut fsm, now).unwrap(); + assert!(matches!( + first.as_slice(), + [SessionFrame::StreamData(frame)] if frame.stream_id == stream_id && frame.bytes.as_slice() == b"hel" + )); + + let events = receive_events( + &mut fsm, + now + Duration::from_millis(1), + seq(9), + &[SessionFrame::StreamWindow(ql_wire::StreamWindow { + stream_id, + maximum_offset: offset(5), + })], + ); + assert!(events.is_empty()); + + let (_second_seq, second) = next_outbound(&mut fsm, now + Duration::from_millis(2)).unwrap(); + assert!(second.iter().any(|frame| { + matches!( + frame, + SessionFrame::StreamData(frame) + if frame.stream_id == stream_id + && frame.offset == offset(3) + && frame.bytes.as_slice() == b"lo" + ) + })); +} + +#[test] +fn sparse_out_of_order_ack_ranges_page_and_quiesce() { + let now = Instant::now(); + let sender_config = SessionConfig { + local_parity: StreamParity::Even, + record_max_size: SessionRecordBuilder::MIN_CAPACITY + 40, + ack_delay: Duration::from_millis(5), + retransmit_timeout: Duration::from_millis(25), + stream_send_buffer_size: 8 * 1024, + initial_peer_stream_receive_window: 8 * 1024, + ..SessionConfig::default() + }; + let receiver_config = SessionConfig { + local_parity: StreamParity::Odd, + record_max_size: SessionRecordBuilder::MIN_CAPACITY + 10, + ack_delay: Duration::from_millis(1), + retransmit_timeout: Duration::from_millis(25), + pending_ack_range_limit: 512, + initial_peer_stream_receive_window: 8 * 1024, + ..SessionConfig::default() + }; + let mut sender = SessionFsm::new(sender_config, now); + let mut receiver = SessionFsm::new(receiver_config, now); + + let stream_id = open_stream_id(&mut sender); + let payload = vec![b'x'; 2048]; + assert_eq!( + write_stream_bytes(&mut sender, stream_id, &payload), + payload.len() + ); + + let originals = drain_outbound(&mut sender, now, 4096); + assert!(originals.len() >= 64); + + for (seq, record) in originals + .iter() + .filter(|(seq, _)| seq.into_inner() % 2 == 1) + { + let _ = receive_events(&mut receiver, now, *seq, record); + } + + let first_ack_time = now + receiver_config.ack_delay; + let first_acks = drain_outbound(&mut receiver, first_ack_time, originals.len()); + assert!(first_acks.len() > 1); + assert!(first_acks + .iter() + .all(|(_, frames)| matches!(frames.as_slice(), [SessionFrame::Ack(_)]))); + + for (seq, record) in &first_acks { + let _ = receive_events(&mut sender, first_ack_time, *seq, record); + } + + let retransmit_time = now + sender_config.retransmit_timeout + Duration::from_millis(1); + let mut emit = |_| {}; + sender.on_timer(retransmit_time, &mut emit); + let retransmits = drain_outbound(&mut sender, retransmit_time, originals.len()); + assert!(!retransmits.is_empty()); + + for (seq, record) in &retransmits { + let _ = receive_events(&mut receiver, retransmit_time, *seq, record); + } + + let second_ack_time = retransmit_time + receiver_config.ack_delay; + let second_acks = drain_outbound(&mut receiver, second_ack_time, retransmits.len() + 16); + assert!(!second_acks.is_empty()); + assert!(second_acks + .iter() + .all(|(_, frames)| matches!(frames.as_slice(), [SessionFrame::Ack(_)]))); + + for (seq, record) in &second_acks { + let _ = receive_events(&mut sender, second_ack_time, *seq, record); + } + + let final_now = second_ack_time + sender_config.retransmit_timeout + Duration::from_millis(1); + let mut sender_emit = |_| {}; + sender.on_timer(final_now, &mut sender_emit); + let mut receiver_emit = |_| {}; + receiver.on_timer(final_now, &mut receiver_emit); + assert!(next_outbound(&mut sender, final_now).is_none()); + assert!(next_outbound(&mut receiver, final_now).is_none()); +} diff --git a/ql-fsm/src/session/tracked.rs b/ql-fsm/src/session/tracked.rs new file mode 100644 index 0000000..8431795 --- /dev/null +++ b/ql-fsm/src/session/tracked.rs @@ -0,0 +1,29 @@ +//! outbound record tracking state for ack and retransmit handling + +use std::time::Instant; + +use ql_wire::{RecordAck, RecordSeq, StreamClose, StreamId}; + +#[derive(Debug, Clone)] +pub struct TrackedRecord { + pub seq: RecordSeq, + pub frames: Vec, + pub ack: Option, + pub ping_included: bool, + pub window_updates: Vec<(StreamId, u64)>, + pub sent_at: Option, +} + +#[derive(Debug, Clone)] +pub enum TrackedFrame { + StreamData(TrackedStreamData), + StreamClose(StreamClose), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TrackedStreamData { + pub stream_id: StreamId, + pub offset: u64, + pub len: usize, + pub fin: bool, +} diff --git a/ql-fsm/src/state.rs b/ql-fsm/src/state.rs new file mode 100644 index 0000000..8268bc1 --- /dev/null +++ b/ql-fsm/src/state.rs @@ -0,0 +1,139 @@ +use std::time::Instant; + +use ql_wire::{ + ConnectionId, EphemeralPublicKey, HandshakeId, HandshakeMeta, IkHandshake, KkHandshake, + PairingToken, PeerBundle, QlHandshakeRecord, SessionKey, TransportParams, XxHandshake, +}; + +use crate::{session::SessionFsm, NoSessionError, PeerStatus}; + +pub struct QlFsmState { + pub next_control_id: u32, + pub peer: Option, + pub armed_pairing_token: Option, + pub handshake: Option, + pub link: LinkState, + pub now: Instant, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionTransport { + pub tx_key: SessionKey, + pub rx_key: SessionKey, + pub tx_connection_id: ConnectionId, + pub rx_connection_id: ConnectionId, + pub remote_transport_params: TransportParams, +} + +impl SessionTransport { + pub fn from_finalized(finalized: ql_wire::FinalizedHandshake) -> (Self, PeerBundle) { + ( + Self { + tx_key: finalized.tx_key, + rx_key: finalized.rx_key, + tx_connection_id: finalized.tx_connection_id, + rx_connection_id: finalized.rx_connection_id, + remote_transport_params: finalized.remote_transport_params, + }, + finalized.remote_bundle, + ) + } +} + +#[allow(clippy::large_enum_variant)] +pub enum LinkState { + Idle, + IkInitiator(IkInitiatorState), + KkInitiator(KkInitiatorState), + XxInitiator(XxInitiatorState), + XxResponder(XxResponderState), + Connected(ConnectedState), +} + +pub struct ConnectedState { + pub transport: SessionTransport, + pub session: SessionFsm, +} + +#[derive(Debug, Clone)] +pub struct IkInitiatorState { + pub handshake: IkHandshake, + pub handshake_id: HandshakeId, + pub deadline: Instant, + pub initial_ephemeral: EphemeralPublicKey, +} + +#[derive(Debug, Clone)] +pub struct KkInitiatorState { + pub handshake: KkHandshake, + pub handshake_id: HandshakeId, + pub deadline: Instant, + pub initial_ephemeral: EphemeralPublicKey, +} + +#[derive(Debug, Clone)] +pub struct XxInitiatorState { + pub handshake: XxHandshake, + pub handshake_id: HandshakeId, + pub deadline: Instant, + pub initial_ephemeral: EphemeralPublicKey, +} + +#[derive(Debug, Clone)] +pub struct XxResponderState { + pub handshake: XxHandshake, + pub handshake_meta: HandshakeMeta, + pub deadline: Instant, +} + +impl LinkState { + pub fn take(&mut self) -> Self { + std::mem::replace(self, Self::Idle) + } + + pub fn status(&self) -> PeerStatus { + match self { + Self::Idle | Self::XxResponder(_) => PeerStatus::Disconnected, + Self::IkInitiator(_) | Self::KkInitiator(_) | Self::XxInitiator(_) => { + PeerStatus::Initiator + } + Self::Connected(_) => PeerStatus::Connected, + } + } + + #[inline] + pub fn connected(&self) -> Option<&ConnectedState> { + match self { + Self::Connected(state) => Some(state), + _ => None, + } + } + + #[inline] + pub fn connected_mut(&mut self) -> Option<&mut ConnectedState> { + match self { + Self::Connected(state) => Some(state), + _ => None, + } + } + + #[inline] + pub fn connected_mut_or_err(&mut self) -> Result<&mut ConnectedState, NoSessionError> { + self.connected_mut().ok_or(NoSessionError) + } + + pub fn handshake_deadline(&self) -> Option { + match self { + Self::Idle | Self::Connected(_) => None, + Self::IkInitiator(state) => Some(state.deadline), + Self::KkInitiator(state) => Some(state.deadline), + Self::XxInitiator(state) => Some(state.deadline), + Self::XxResponder(state) => Some(state.deadline), + } + } + + #[cfg(test)] + pub fn transport(&self) -> Option<&SessionTransport> { + self.connected().map(|state| &state.transport) + } +} diff --git a/ql-fsm/src/tests/handshake.rs b/ql-fsm/src/tests/handshake.rs new file mode 100644 index 0000000..4de4f06 --- /dev/null +++ b/ql-fsm/src/tests/handshake.rs @@ -0,0 +1,388 @@ +use std::time::Duration; + +use ql_wire::QlHandshakeRecord; + +use super::*; +use crate::{state::LinkState, Event, NoPeerError, PeerStatus, ReceiveError}; + +#[test] +fn ik_connect_round_trip_establishes_transport() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_ik(Side::A).unwrap(); + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn kk_connect_round_trip_establishes_transport() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_kk(Side::A).unwrap(); + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn xx_connect_round_trip_establishes_transport_when_armed() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, false); + let token = pairing_token(1); + + harness.b.fsm.arm_pairing(token); + harness.connect_xx(Side::A, token); + + let xx1 = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, xx1); + let xx2 = harness.next_outbound(Side::B).unwrap(); + harness.deliver(Side::A, xx2); + let xx3 = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, xx3); + + let xx4 = harness.next_outbound(Side::B).unwrap(); + harness.deliver(Side::A, xx4); + + assert_eq!(harness.a.fsm.peer(), Some(&harness.b.fsm.identity.bundle())); + assert_eq!(harness.b.fsm.peer(), Some(&harness.a.fsm.identity.bundle())); + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn ik_connect_learns_remote_initial_stream_receive_window() { + let mut harness = Harness::paired_known_with_configs( + QlFsmConfig { + session_stream_receive_buffer_size: 9, + ..QlFsmConfig::default() + }, + QlFsmConfig { + session_stream_receive_buffer_size: 3, + ..QlFsmConfig::default() + }, + ); + + harness.connect_ik(Side::A).unwrap(); + harness.pump(); + + assert_eq!( + harness + .a + .fsm + .state + .link + .transport() + .unwrap() + .remote_transport_params + .initial_stream_receive_window, + 3 + ); + assert_eq!( + harness + .b + .fsm + .state + .link + .transport() + .unwrap() + .remote_transport_params + .initial_stream_receive_window, + 9 + ); +} + +#[test] +fn connect_methods_require_bound_peer() { + let time = Harness::paired_known(QlFsmConfig::default()).time(); + let identity = generate_identity(&SoftwareCrypto, "identity").unwrap(); + let mut fsm = QlFsm::new(QlFsmConfig::default(), identity, time); + let crypto = SoftwareCrypto; + + assert_eq!(fsm.connect_ik(time, &crypto), Err(NoPeerError)); + assert_eq!(fsm.connect_kk(time, &crypto), Err(NoPeerError)); + + fsm.connect_xx( + time, + PairingInvite { + qid: ql_wire::QID([2; ql_wire::QID::SIZE]), + token: pairing_token(2), + }, + &crypto, + ); +} + +#[test] +fn connect_ik_emits_initiator_status() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_ik(Side::A).unwrap(); + + assert_eq!( + harness.drain_events(Side::A), + vec![Event::PeerStatusChanged(PeerStatus::Initiator)] + ); +} + +#[test] +fn inbound_xx1_rejects_when_not_in_pairing_mode() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, false); + let token = pairing_token(3); + + harness.connect_xx(Side::A, token); + let xx1 = harness.next_outbound(Side::A).unwrap(); + let time = harness.time(); + let Node { fsm, crypto } = &mut harness.b; + let err = fsm.receive(time, xx1, crypto); + + assert_eq!(err, Err(ReceiveError::NotPairingMode)); + assert!(matches!(harness.b.fsm.state.link, LinkState::Idle)); + assert!(harness.drain_events(Side::B).is_empty()); + assert!(harness.next_outbound(Side::B).is_none()); +} + +#[test] +fn inbound_xx1_rejects_mismatched_pairing_id_with_expected_and_actual() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, false); + let expected = pairing_token(4); + let actual = pairing_token(7); + + harness.b.fsm.arm_pairing(expected); + harness.connect_xx(Side::A, actual); + let xx1 = harness.next_outbound(Side::A).unwrap(); + + let time = harness.time(); + let Node { fsm, crypto } = &mut harness.b; + let err = fsm.receive(time, xx1, crypto); + + assert_eq!( + err, + Err(ReceiveError::InvalidPairingId { + expected: expected.id(&SoftwareCrypto), + actual: actual.id(&SoftwareCrypto), + }) + ); +} + +#[test] +fn disarm_pairing_rejects_inflight_inbound_xx_responder() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, false); + let token = pairing_token(5); + + harness.b.fsm.arm_pairing(token); + harness.connect_xx(Side::A, token); + let xx1 = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, xx1); + let xx2 = harness.next_outbound(Side::B).unwrap(); + harness.deliver(Side::A, xx2); + let xx3 = harness.next_outbound(Side::A).unwrap(); + harness.b.fsm.disarm_pairing(); + harness.deliver(Side::B, xx3); + + assert!(matches!(harness.b.fsm.state.link, LinkState::Idle)); + assert!(harness.next_outbound(Side::B).is_none()); +} + +#[test] +fn simultaneous_xx_connect_converges() { + let mut harness = Harness::paired(QlFsmConfig::default(), false, false); + let token = pairing_token(6); + + harness.a.fsm.arm_pairing(token); + harness.b.fsm.arm_pairing(token); + harness.connect_xx(Side::A, token); + harness.connect_xx(Side::B, token); + + for _ in 0..2 { + if let Some(record) = harness.next_outbound(Side::A) { + harness.deliver(Side::B, record); + } + if let Some(record) = harness.next_outbound(Side::B) { + harness.deliver(Side::A, record); + } + } + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn connect_ik_replaces_in_flight_attempt_and_ignores_stale_reply() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_ik(Side::A).unwrap(); + harness.drain_events(Side::A); + let first = harness.next_outbound(Side::A).unwrap(); + let first_id = handshake_id(&first); + + harness.connect_ik(Side::A).unwrap(); + let second = harness.next_outbound(Side::A).unwrap(); + let second_id = handshake_id(&second); + + assert_ne!(first_id, second_id); + + harness.deliver(Side::B, first); + let stale_reply = harness.next_outbound(Side::B).unwrap(); + assert_eq!(handshake_id(&stale_reply), first_id); + + harness.deliver(Side::A, stale_reply); + assert!(matches!( + harness.a.fsm.state.link, + LinkState::IkInitiator(_) + )); + + harness.deliver(Side::B, second); + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn connect_kk_replaces_in_flight_attempt_and_ignores_stale_reply() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_kk(Side::A).unwrap(); + let first = harness.next_outbound(Side::A).unwrap(); + let first_id = handshake_id(&first); + + harness.connect_kk(Side::A).unwrap(); + let second = harness.next_outbound(Side::A).unwrap(); + let second_id = handshake_id(&second); + + assert_ne!(first_id, second_id); + + harness.deliver(Side::B, first); + let stale_reply = harness.next_outbound(Side::B).unwrap(); + assert_eq!(handshake_id(&stale_reply), first_id); + + harness.deliver(Side::A, stale_reply); + assert!(matches!( + harness.a.fsm.state.link, + LinkState::KkInitiator(_) + )); + + harness.deliver(Side::B, second); + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn inbound_ik1_auto_binds_unbound_responder() { + let mut harness = Harness::paired(QlFsmConfig::default(), true, false); + + harness.connect_ik(Side::A).unwrap(); + harness.pump(); + + let expected_peer = harness.a.fsm.identity.bundle(); + assert_eq!(harness.b.fsm.peer(), Some(&expected_peer)); + assert_eq!( + harness.drain_events(Side::B), + vec![ + Event::NewPeer, + Event::PeerStatusChanged(PeerStatus::Connected), + ] + ); + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn handshake_timeout_drops_single_ik_attempt_without_resend() { + let config = QlFsmConfig { + handshake_timeout: Duration::from_millis(60), + ..QlFsmConfig::default() + }; + let mut harness = Harness::paired_known(config); + + harness.connect_ik(Side::A).unwrap(); + harness.drain_events(Side::A); + let first = harness.next_outbound(Side::A).unwrap(); + let (_, first) = ql_wire::decode_record::(first.as_slice()).unwrap(); + assert!(matches!(first, ql_wire::QlHandshakeRecord::Ik1(_))); + assert!(harness.next_outbound(Side::A).is_none()); + + harness.advance(config.handshake_timeout); + harness.on_timer(Side::A); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); + assert_eq!( + harness.take_event(Side::A), + Some(Event::PeerStatusChanged(PeerStatus::Disconnected)) + ); + assert!(harness.next_outbound(Side::A).is_none()); +} + +#[test] +fn handshake_timeout_clears_queued_kk_output() { + let config = QlFsmConfig { + handshake_timeout: Duration::from_millis(60), + ..QlFsmConfig::default() + }; + let mut harness = Harness::paired_known(config); + + harness.connect_kk(Side::A).unwrap(); + + harness.advance(config.handshake_timeout); + harness.on_timer(Side::A); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); + assert!(harness.next_outbound(Side::A).is_none()); +} + +#[test] +fn bind_peer_clears_queued_handshake_output() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_ik(Side::A).unwrap(); + harness.drain_events(Side::A); + harness + .a + .fsm + .bind_peer(generate_identity(&SoftwareCrypto, "peer").unwrap().bundle()); + + assert!(harness.drain_events(Side::A).is_empty()); + assert!(harness.next_outbound(Side::A).is_none()); +} + +#[test] +fn simultaneous_ik_connect_converges() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_ik(Side::A).unwrap(); + harness.connect_ik(Side::B).unwrap(); + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +#[test] +fn simultaneous_ik_and_kk_connect_prefers_ik() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.connect_ik(Side::A).unwrap(); + harness.connect_kk(Side::B).unwrap(); + harness.pump(); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!(harness.b.fsm.state.link, LinkState::Connected(_))); +} + +fn handshake_id(record: &[u8]) -> ql_wire::HandshakeId { + let (_, record) = ql_wire::decode_record(record).unwrap(); + match record { + ql_wire::QlHandshakeRecord::Ik1(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Ik2(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Kk1(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Kk2(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Xx1(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Xx2(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Xx3(message) => message.meta.handshake_id, + ql_wire::QlHandshakeRecord::Xx4(message) => message.meta.handshake_id, + } +} diff --git a/ql-fsm/src/tests/mod.rs b/ql-fsm/src/tests/mod.rs new file mode 100644 index 0000000..c4d005e --- /dev/null +++ b/ql-fsm/src/tests/mod.rs @@ -0,0 +1,351 @@ +mod handshake; +mod proptest; +mod session; + +use std::time::{Duration, Instant}; + +use ql_wire::{ + self, generate_identity, test_identities, ConnectionId, PairingToken, QlCrypto, SessionKey, + SoftwareCrypto, TransportParams, QID, +}; + +use crate::{ + session::{SessionConfig, SessionFsm, StreamParity}, + state::{ConnectedState, LinkState, SessionTransport}, + Event, NoPeerError, OutboundWrite, PairingInvite, QlFsm, QlFsmConfig, WriteId, +}; + +type TestCrypto = SoftwareCrypto; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum Side { + A, + B, +} + +impl Side { + fn idx(self) -> usize { + match self { + Self::A => 0, + Self::B => 1, + } + } +} + +struct Node { + fsm: QlFsm, + crypto: TestCrypto, +} + +struct Harness { + now: Instant, + a: Node, + b: Node, +} + +struct DecodedSessionWrite { + record: Vec, + write_id: Option, + header: ql_wire::SessionHeader, + frames: Vec>>, +} + +impl Harness { + fn paired_known(config: QlFsmConfig) -> Self { + Self::paired_with_configs(config, config, true, true) + } + + fn paired(config: QlFsmConfig, know_a: bool, know_b: bool) -> Self { + Self::paired_with_configs(config, config, know_a, know_b) + } + + fn paired_known_with_configs(config_a: QlFsmConfig, config_b: QlFsmConfig) -> Self { + Self::paired_with_configs(config_a, config_b, true, true) + } + + fn paired_with_configs( + config_a: QlFsmConfig, + config_b: QlFsmConfig, + know_a: bool, + know_b: bool, + ) -> Self { + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + let now = Instant::now(); + + let mut harness = Self { + now, + a: Node { + fsm: QlFsm::new(config_a, identity_a.clone(), now), + crypto: SoftwareCrypto, + }, + b: Node { + fsm: QlFsm::new(config_b, identity_b.clone(), now), + crypto: SoftwareCrypto, + }, + }; + + if know_a { + harness.a.fsm.bind_peer(identity_b.bundle()); + } + if know_b { + harness.b.fsm.bind_peer(identity_a.bundle()); + } + + harness + } + + fn connected(config: QlFsmConfig) -> Self { + let mut harness = Self::paired_known(config); + let a_to_b_key = SessionKey::from_data([7; SessionKey::SIZE]); + let b_to_a_key = SessionKey::from_data([9; SessionKey::SIZE]); + let a_to_b_conn = ConnectionId::from_data([0xA1; ConnectionId::SIZE]); + let b_to_a_conn = ConnectionId::from_data([0xB2; ConnectionId::SIZE]); + + harness.a.fsm.state.link = LinkState::Connected(ConnectedState { + transport: SessionTransport { + tx_key: a_to_b_key.clone(), + rx_key: b_to_a_key.clone(), + tx_connection_id: a_to_b_conn, + rx_connection_id: b_to_a_conn, + remote_transport_params: TransportParams { + initial_stream_receive_window: harness + .b + .fsm + .config + .session_stream_receive_buffer_size, + }, + }, + session: SessionFsm::new(session_config(&harness, true), harness.now), + }); + harness.b.fsm.state.link = LinkState::Connected(ConnectedState { + transport: SessionTransport { + tx_key: b_to_a_key, + rx_key: a_to_b_key, + tx_connection_id: b_to_a_conn, + rx_connection_id: a_to_b_conn, + remote_transport_params: TransportParams { + initial_stream_receive_window: harness + .a + .fsm + .config + .session_stream_receive_buffer_size, + }, + }, + session: SessionFsm::new(session_config(&harness, false), harness.now), + }); + harness + } + + fn time(&self) -> Instant { + self.now + } + + fn advance(&mut self, duration: Duration) { + self.now += duration; + } + + fn node(&self, side: Side) -> &Node { + match side { + Side::A => &self.a, + Side::B => &self.b, + } + } + + fn node_mut(&mut self, side: Side) -> &mut Node { + match side { + Side::A => &mut self.a, + Side::B => &mut self.b, + } + } + + fn next_outbound(&mut self, side: Side) -> Option> { + let write = self.next_write(side)?; + if let Some(id) = write.write_id { + self.confirm_write(side, id); + } + Some(write.record) + } + + fn next_write(&mut self, side: Side) -> Option { + let time = self.time(); + let Node { fsm, crypto } = self.node_mut(side); + fsm.take_next_write(time, crypto) + } + + fn next_decoded_outbound(&mut self, side: Side) -> Option { + let write = self.next_write(side)?; + if let Some(id) = write.write_id { + self.confirm_write(side, id); + } + Some(self.decode_session_write(write, side)) + } + + fn next_decoded_write(&mut self, side: Side) -> Option { + let write = self.next_write(side)?; + Some(self.decode_session_write(write, side)) + } + + fn connect_ik(&mut self, side: Side) -> Result<(), NoPeerError> { + let time = self.time(); + let Node { fsm, crypto } = self.node_mut(side); + fsm.connect_ik(time, crypto) + } + + fn connect_kk(&mut self, side: Side) -> Result<(), NoPeerError> { + let time = self.time(); + let Node { fsm, crypto } = self.node_mut(side); + fsm.connect_kk(time, crypto) + } + + fn connect_xx(&mut self, side: Side, token: PairingToken) { + let time = self.time(); + let remote_qid = self.remote_qid(side); + let Node { fsm, crypto } = self.node_mut(side); + fsm.connect_xx( + time, + PairingInvite { + qid: remote_qid, + token, + }, + crypto, + ); + } + + fn remote_qid(&self, side: Side) -> QID { + match side { + Side::A => self.b.fsm.identity.qid, + Side::B => self.a.fsm.identity.qid, + } + } + + fn deliver(&mut self, side: Side, record: Vec) { + let time = self.time(); + let Node { fsm, crypto } = self.node_mut(side); + fsm.receive(time, record, crypto).unwrap(); + } + + fn confirm_write(&mut self, side: Side, write_id: WriteId) { + let time = self.time(); + self.node_mut(side).fsm.complete_write(time, write_id, true); + } + + fn reject_write(&mut self, side: Side, write_id: WriteId) { + let time = self.time(); + self.node_mut(side) + .fsm + .complete_write(time, write_id, false); + } + + fn decode_session_write(&self, write: OutboundWrite, side: Side) -> DecodedSessionWrite { + let peer = self.node(match side { + Side::A => Side::B, + Side::B => Side::A, + }); + let crypto = &peer.crypto; + let session_key = &peer.fsm.state.link.transport().unwrap().rx_key; + let (header, frames) = decrypt_record(crypto, &write.record, session_key); + DecodedSessionWrite { + record: write.record, + write_id: write.write_id, + header, + frames, + } + } + + fn on_timer(&mut self, side: Side) { + let time = self.time(); + self.node_mut(side).fsm.on_timer(time); + } + + fn take_event(&mut self, side: Side) -> Option { + self.node_mut(side).fsm.poll_event() + } + + fn drain_events(&mut self, side: Side) -> Vec { + let mut events = Vec::new(); + while let Some(event) = self.take_event(side) { + events.push(event); + } + events + } + + fn pump(&mut self) { + for _ in 0..128 { + let mut progressed = false; + + while let Some(record) = self.next_outbound(Side::A) { + progressed = true; + self.deliver(Side::B, record); + } + + while let Some(record) = self.next_outbound(Side::B) { + progressed = true; + self.deliver(Side::A, record); + } + + if !progressed { + return; + } + } + + panic!("pump did not quiesce"); + } +} + +fn pairing_token(byte: u8) -> PairingToken { + PairingToken([byte; PairingToken::SIZE]) +} + +fn session_config(harness: &Harness, a: bool) -> SessionConfig { + let (local, peer, config) = if a { + ( + harness.a.fsm.identity.qid, + harness.a.fsm.state.peer.as_ref().unwrap().qid, + harness.a.fsm.config, + ) + } else { + ( + harness.b.fsm.identity.qid, + harness.b.fsm.state.peer.as_ref().unwrap().qid, + harness.b.fsm.config, + ) + }; + + SessionConfig { + local_parity: StreamParity::for_local(local, peer), + record_max_size: config.session_record_max_size, + ack_delay: config.session_record_ack_delay, + retransmit_timeout: config.session_record_retransmit_timeout, + keepalive_interval: config.session_keepalive_interval, + peer_timeout: config.session_peer_timeout, + stream_send_buffer_size: config.session_stream_send_buffer_size, + stream_receive_buffer_size: config.session_stream_receive_buffer_size, + accepted_record_window: config.session_accepted_record_window, + pending_ack_range_limit: config.session_pending_ack_range_limit, + initial_peer_stream_receive_window: if a { + harness.b.fsm.config.session_stream_receive_buffer_size + } else { + harness.a.fsm.config.session_stream_receive_buffer_size + }, + } +} + +fn decrypt_record( + crypto: &impl QlCrypto, + record: &[u8], + session_key: &SessionKey, +) -> (ql_wire::SessionHeader, Vec>>) { + let (_header, record) = + ql_wire::decode_record::, _>(record).unwrap(); + let plaintext = ql_wire::decrypt_record( + crypto, + &record.header, + record.payload.into_owned(), + session_key, + ) + .unwrap(); + ( + record.header, + ql_wire::decode_session_frames(&plaintext).unwrap(), + ) +} diff --git a/ql-fsm/src/tests/proptest.rs b/ql-fsm/src/tests/proptest.rs new file mode 100644 index 0000000..bc97ca7 --- /dev/null +++ b/ql-fsm/src/tests/proptest.rs @@ -0,0 +1,1001 @@ +use std::{ + collections::{BTreeMap, BTreeSet}, + time::Duration, +}; + +extern crate proptest as proptest_crate; + +use bytes::Bytes; +use proptest_crate::{collection::vec, prelude::*, test_runner::TestCaseResult}; +use ql_wire::{CloseTarget, StreamCloseCode, StreamId, WireError}; + +use super::*; + +fn test_route_id() -> ql_wire::RouteId { + ql_wire::RouteId::from_u32(1) +} +use crate::{state::LinkState, Event, PeerStatus, ReceiveError, WriteId}; + +const SLOT_COUNT: usize = 4; + +#[derive(Clone, Debug)] +enum Action { + ConnectIk(Side), + ConnectKk(Side), + AdvanceMs(u8), + OnTimer(Side), + OnTimerBoth, + Pump, + TakeNext(Side), + ConfirmTaken { + side: Side, + index: usize, + }, + RejectTaken { + side: Side, + index: usize, + }, + CaptureNext(Side), + DeliverNext(Side), + DropNext(Side), + DeliverQueued { + side: Side, + index: usize, + }, + DuplicateQueued { + side: Side, + index: usize, + }, + DropQueued { + side: Side, + index: usize, + }, + OpenStream { + side: Side, + slot: usize, + }, + Write { + side: Side, + slot: usize, + bytes: Vec, + }, + Finish { + side: Side, + slot: usize, + }, + Close { + side: Side, + slot: usize, + }, +} + +impl Action { + fn confirm_taken(side: Side, index: usize) -> Self { + Self::ConfirmTaken { side, index } + } + + fn reject_taken(side: Side, index: usize) -> Self { + Self::RejectTaken { side, index } + } + + fn deliver_queued(side: Side, index: usize) -> Self { + Self::DeliverQueued { side, index } + } + + fn duplicate_queued(side: Side, index: usize) -> Self { + Self::DuplicateQueued { side, index } + } + + fn drop_queued(side: Side, index: usize) -> Self { + Self::DropQueued { side, index } + } + + fn open_stream(side: Side, slot: usize) -> Self { + Self::OpenStream { side, slot } + } + + fn write(side: Side, slot: usize, bytes: Vec) -> Self { + Self::Write { side, slot, bytes } + } + + fn finish(side: Side, slot: usize) -> Self { + Self::Finish { side, slot } + } + + fn close(side: Side, slot: usize) -> Self { + Self::Close { side, slot } + } +} + +#[derive(Clone, Debug)] +struct TakenWrite { + record: Vec, + write_id: Option, +} + +#[derive(Default)] +struct SideEventState { + opened: BTreeSet, + finished: BTreeSet, + outbound_finished: BTreeSet, + writable_closed: BTreeSet, + closed: BTreeSet, + peer_statuses: Vec, + last_peer_status: Option, + session_epoch: usize, + session_closed_epoch: Option, +} + +impl SideEventState { + fn note_peer_status(&mut self, status: PeerStatus) { + if status == PeerStatus::Connected && self.last_peer_status != Some(PeerStatus::Connected) { + self.session_epoch = self.session_epoch.saturating_add(1); + } + self.peer_statuses.push(status); + self.last_peer_status = Some(status); + } +} + +struct Runner { + harness: Harness, + slots: [[Option; SLOT_COUNT]; 2], + taken: [Vec; 2], + pending: [Vec>; 2], + receive_errors: Vec<(Side, ReceiveError)>, + events: [SideEventState; 2], + known_streams: BTreeSet, + expected: [BTreeMap>; 2], + received: [BTreeMap>; 2], + finished_by: [BTreeSet; 2], + closed_by: [BTreeSet; 2], +} + +impl Runner { + fn handshake() -> Self { + let config = QlFsmConfig { + handshake_timeout: Duration::from_millis(60), + session_record_ack_delay: Duration::from_millis(5), + session_record_retransmit_timeout: Duration::from_millis(15), + session_peer_timeout: Duration::from_millis(80), + ..QlFsmConfig::default() + }; + + Self { + harness: Harness::paired_known(config), + slots: [[None; SLOT_COUNT]; 2], + taken: [Vec::new(), Vec::new()], + pending: [Vec::new(), Vec::new()], + receive_errors: Vec::new(), + events: [SideEventState::default(), SideEventState::default()], + known_streams: BTreeSet::new(), + expected: [BTreeMap::new(), BTreeMap::new()], + received: [BTreeMap::new(), BTreeMap::new()], + finished_by: [BTreeSet::new(), BTreeSet::new()], + closed_by: [BTreeSet::new(), BTreeSet::new()], + } + } + + fn connected() -> Self { + let config = QlFsmConfig { + session_record_ack_delay: Duration::from_millis(5), + session_record_retransmit_timeout: Duration::from_millis(15), + session_peer_timeout: Duration::from_secs(5), + ..QlFsmConfig::default() + }; + Self::connected_with_config(config) + } + + fn connected_with_config(config: QlFsmConfig) -> Self { + let connected_events = || SideEventState { + last_peer_status: Some(PeerStatus::Connected), + session_epoch: 1, + ..SideEventState::default() + }; + + Self { + harness: Harness::connected(config), + slots: [[None; SLOT_COUNT]; 2], + taken: [Vec::new(), Vec::new()], + pending: [Vec::new(), Vec::new()], + receive_errors: Vec::new(), + events: [connected_events(), connected_events()], + known_streams: BTreeSet::new(), + expected: [BTreeMap::new(), BTreeMap::new()], + received: [BTreeMap::new(), BTreeMap::new()], + finished_by: [BTreeSet::new(), BTreeSet::new()], + closed_by: [BTreeSet::new(), BTreeSet::new()], + } + } + + fn run(&mut self, actions: &[Action]) -> TestCaseResult { + for action in actions { + self.apply(action); + self.observe_and_assert()?; + } + + self.cleanup()?; + self.observe_and_assert()?; + self.assert_terminal_semantics()?; + self.assert_quiesced() + } + + #[allow(clippy::cognitive_complexity, clippy::too_many_lines)] + fn apply(&mut self, action: &Action) { + match action { + Action::ConnectIk(side) => { + let _ = self.harness.connect_ik(*side); + } + Action::ConnectKk(side) => { + let _ = self.harness.connect_kk(*side); + } + Action::AdvanceMs(ms) => { + self.harness + .advance(Duration::from_millis(u64::from(*ms) + 1)); + } + Action::OnTimer(side) => self.harness.on_timer(*side), + Action::OnTimerBoth => { + self.harness.on_timer(Side::A); + self.harness.on_timer(Side::B); + } + Action::Pump => self.capture_all_outbound(), + Action::TakeNext(side) => { + if let Some(write) = take_unconfirmed_outbound(&mut self.harness, *side) { + self.taken[side.idx()].push(write); + } + } + Action::ConfirmTaken { side, index } => { + if let Some(write) = take_taken(&mut self.taken[side.idx()], *index) { + confirm_taken(&mut self.harness, *side, &write); + self.pending[side.idx()].push(write.record); + } + } + Action::RejectTaken { side, index } => { + if let Some(write) = take_taken(&mut self.taken[side.idx()], *index) { + reject_taken(&mut self.harness, *side, &write); + } + } + Action::CaptureNext(side) => { + if let Some(record) = take_confirmed_outbound(&mut self.harness, *side) { + self.pending[side.idx()].push(record); + } + } + Action::DeliverNext(side) => { + if let Some(record) = take_confirmed_outbound(&mut self.harness, *side) { + self.deliver_to(opposite(*side), record); + } + } + Action::DropNext(side) => { + let _ = take_confirmed_outbound(&mut self.harness, *side); + } + Action::DeliverQueued { side, index } => { + if let Some(record) = take_pending(&mut self.pending[side.idx()], *index) { + self.deliver_to(opposite(*side), record); + } + } + Action::DuplicateQueued { side, index } => { + if let Some(record) = peek_pending(&self.pending[side.idx()], *index) { + self.deliver_to(opposite(*side), record); + } + } + Action::DropQueued { side, index } => { + let _ = take_pending(&mut self.pending[side.idx()], *index); + } + Action::OpenStream { side, slot } => { + let stream_id = self + .harness + .node_mut(*side) + .fsm + .open_stream(test_route_id()) + .ok() + .map(|stream| stream.stream_id()); + if let Some(stream_id) = stream_id { + self.slots[side.idx()][*slot] = Some(stream_id); + self.known_streams.insert(stream_id); + } + } + Action::Write { side, slot, bytes } => { + if let Some(stream_id) = self.slots[side.idx()][*slot] { + let mut chunk = Bytes::copy_from_slice(bytes); + let accepted = self.harness.node_mut(*side).fsm.stream(stream_id).map_or( + 0, + |mut stream| { + stream + .writer() + .map_or(0, |mut writer| writer.write(&mut chunk)) + }, + ); + if accepted != 0 { + self.expected[opposite(*side).idx()] + .entry(stream_id) + .or_default() + .extend_from_slice(&bytes[..accepted]); + } + } + } + Action::Finish { side, slot } => { + if let Some(stream_id) = self.slots[side.idx()][*slot] { + let finished = self + .harness + .node_mut(*side) + .fsm + .stream(stream_id) + .is_ok_and(|mut stream| { + stream.writer().is_some_and(|writer| { + writer.finish(); + true + }) + }); + if finished { + self.finished_by[side.idx()].insert(stream_id); + } + } + } + Action::Close { side, slot } => { + if let Some(stream_id) = self.slots[side.idx()][*slot] { + let closed = self + .harness + .node_mut(*side) + .fsm + .stream(stream_id) + .is_ok_and(|mut stream| { + stream.close(CloseTarget::Both, StreamCloseCode::CANCELLED); + true + }); + if closed { + self.closed_by[side.idx()].insert(stream_id); + self.slots[side.idx()][*slot] = None; + } + } + } + } + } + + fn observe_and_assert(&mut self) -> TestCaseResult { + self.drain_reads(Side::A); + self.drain_reads(Side::B); + let events_a = self.harness.drain_events(Side::A); + let events_b = self.harness.drain_events(Side::B); + self.process_events(Side::A, events_a)?; + self.process_events(Side::B, events_b)?; + self.assert_prefix_invariants()?; + self.assert_legal_link_state()?; + self.assert_receive_errors() + } + + fn cleanup(&mut self) -> TestCaseResult { + let tick = self + .harness + .a + .fsm + .config + .session_record_retransmit_timeout + .max(self.harness.a.fsm.config.session_record_ack_delay) + + Duration::from_millis(1); + + self.reject_all_taken(); + + for _ in 0..12 { + self.capture_all_outbound(); + self.flush_pending_in_order(); + self.capture_all_outbound(); + self.flush_pending_in_order(); + self.observe_and_assert()?; + self.harness.advance(tick); + self.harness.on_timer(Side::A); + self.harness.on_timer(Side::B); + self.capture_all_outbound(); + self.flush_pending_in_order(); + self.observe_and_assert()?; + self.reject_all_taken(); + } + + Ok(()) + } + + fn drain_reads(&mut self, side: Side) { + for stream_id in self.known_streams.clone() { + let appended = drain_stream(&mut self.harness.node_mut(side).fsm, stream_id); + if !appended.is_empty() { + self.received[side.idx()] + .entry(stream_id) + .or_default() + .extend_from_slice(&appended); + } + } + } + + fn process_events(&mut self, side: Side, events: Vec) -> TestCaseResult { + for event in events { + match event { + Event::NewPeer => {} + Event::PeerStatusChanged(status) => { + if status == PeerStatus::Unpaired { + let state = &mut self.events[side.idx()]; + prop_assert!( + state.session_epoch > 0, + "side {side:?} emitted Unpaired without a connected session" + ); + prop_assert!( + state.session_closed_epoch != Some(state.session_epoch), + "side {side:?} emitted duplicate terminal event in session epoch {}", + state.session_epoch + ); + state.session_closed_epoch = Some(state.session_epoch); + } + self.events[side.idx()].note_peer_status(status); + } + Event::Opened { stream_id, .. } => { + prop_assert!( + self.known_streams.contains(&stream_id), + "side {side:?} emitted Opened for unknown stream {stream_id:?}" + ); + prop_assert!( + self.events[side.idx()].opened.insert(stream_id), + "side {side:?} emitted duplicate Opened for {stream_id:?}" + ); + } + Event::Readable(stream_id) | Event::Writable(stream_id) => { + prop_assert!( + self.known_streams.contains(&stream_id), + "side {side:?} emitted readiness for unknown stream {stream_id:?}" + ); + } + Event::Finished(stream_id) => { + prop_assert!( + self.known_streams.contains(&stream_id), + "side {side:?} emitted Finished for unknown stream {stream_id:?}" + ); + prop_assert!( + self.events[side.idx()].finished.insert(stream_id), + "side {side:?} emitted duplicate Finished for {stream_id:?}" + ); + prop_assert!( + !self.events[side.idx()].closed.contains(&stream_id), + "side {side:?} emitted Finished after Closed for {stream_id:?}" + ); + } + Event::OutboundFinished(stream_id) => { + prop_assert!( + self.known_streams.contains(&stream_id), + "side {side:?} emitted OutboundFinished for unknown stream {stream_id:?}" + ); + prop_assert!( + self.events[side.idx()].outbound_finished.insert(stream_id), + "side {side:?} emitted duplicate OutboundFinished for {stream_id:?}" + ); + } + Event::Closed(frame) => { + prop_assert!( + self.known_streams.contains(&frame.stream_id), + "side {side:?} emitted Closed for unknown stream {:?}", + frame.stream_id + ); + prop_assert!( + self.events[side.idx()].closed.insert(frame.stream_id), + "side {side:?} emitted duplicate Closed for {:?}", + frame.stream_id + ); + } + Event::WritableClosed(frame) => { + let stream_id = frame.stream_id; + prop_assert!( + self.known_streams.contains(&stream_id), + "side {side:?} emitted WritableClosed for unknown stream {stream_id:?}" + ); + prop_assert!( + self.events[side.idx()].writable_closed.insert(stream_id), + "side {side:?} emitted duplicate WritableClosed for {stream_id:?}" + ); + } + Event::SessionClosed(_) => { + let state = &mut self.events[side.idx()]; + prop_assert!( + state.session_epoch > 0, + "side {side:?} emitted SessionClosed without a connected session" + ); + prop_assert!( + state.session_closed_epoch != Some(state.session_epoch), + "side {side:?} emitted duplicate SessionClosed in session epoch {}", + state.session_epoch + ); + state.session_closed_epoch = Some(state.session_epoch); + } + } + } + + Ok(()) + } + + fn assert_prefix_invariants(&self) -> TestCaseResult { + for side in [Side::A, Side::B] { + for (stream_id, received) in &self.received[side.idx()] { + let expected = self.expected[side.idx()] + .get(stream_id) + .map_or(&[][..], Vec::as_slice); + prop_assert!( + expected.starts_with(received), + "side {side:?} observed non-prefix bytes on {stream_id:?}: received={received:?} expected={expected:?}" + ); + } + } + + Ok(()) + } + + fn assert_legal_link_state(&self) -> TestCaseResult { + let a_connected = matches!(self.harness.a.fsm.state.link, LinkState::Connected(_)); + let b_connected = matches!(self.harness.b.fsm.state.link, LinkState::Connected(_)); + + prop_assert!( + !a_connected || self.harness.a.fsm.peer().is_some(), + "side A reached Connected without a bound peer" + ); + prop_assert!( + !b_connected || self.harness.b.fsm.peer().is_some(), + "side B reached Connected without a bound peer" + ); + + Ok(()) + } + + fn assert_receive_errors(&self) -> TestCaseResult { + for (side, error) in &self.receive_errors { + prop_assert!( + matches!( + error, + ReceiveError::NoSession + | ReceiveError::NoPeer + | ReceiveError::InvalidRemoteBundle + | ReceiveError::InvalidSessionPayload(WireError::InvalidPayload) + | ReceiveError::InvalidSessionPayload(WireError::DecryptFailed) + | ReceiveError::InvalidIkHandshake(WireError::InvalidPayload) + | ReceiveError::InvalidIkHandshake(WireError::InvalidState) + | ReceiveError::InvalidKkHandshake(WireError::InvalidPayload) + | ReceiveError::InvalidKkHandshake(WireError::InvalidState) + | ReceiveError::InvalidXxHandshake(WireError::InvalidPayload) + | ReceiveError::InvalidXxHandshake(WireError::InvalidState) + | ReceiveError::InvalidXxHandshake(WireError::DecryptFailed) + ), + "unexpected receive error on side {side:?}: {error:?}" + ); + } + + Ok(()) + } + + fn assert_terminal_semantics(&self) -> TestCaseResult { + let a_connected = matches!(self.harness.a.fsm.state.link, LinkState::Connected(_)); + let b_connected = matches!(self.harness.b.fsm.state.link, LinkState::Connected(_)); + let connected = [a_connected, b_connected]; + + for side in [Side::A, Side::B] { + for stream_id in &self.events[side.idx()].finished { + if self.inbound_aborted(side, *stream_id) { + continue; + } + let expected = self.expected[side.idx()] + .get(stream_id) + .map_or(&[][..], Vec::as_slice); + let received = self.received[side.idx()] + .get(stream_id) + .map_or(&[][..], Vec::as_slice); + prop_assert_eq!( + received, + expected, + "side {:?} finished {:?} without receiving all expected bytes", + side, + stream_id + ); + } + + for stream_id in &self.finished_by[side.idx()] { + prop_assert!( + self.events[opposite(side).idx()].finished.contains(stream_id) + || self.events[opposite(side).idx()].closed.contains(stream_id) + || !connected[opposite(side).idx()], + "side {side:?} finished {stream_id:?} but side {:?} saw neither Finished nor Closed", + opposite(side) + ); + } + + for stream_id in &self.closed_by[side.idx()] { + prop_assert!( + self.events[opposite(side).idx()].closed.contains(stream_id) + || !connected[opposite(side).idx()], + "side {side:?} closed {stream_id:?} but side {:?} saw no Closed event", + opposite(side) + ); + } + } + + Ok(()) + } + + fn assert_expected_delivered(&self, side: Side) -> TestCaseResult { + for (stream_id, expected) in &self.expected[side.idx()] { + let received = self.received[side.idx()] + .get(stream_id) + .map_or(&[][..], Vec::as_slice); + prop_assert_eq!( + received, + expected, + "side {:?} did not receive full payload for {:?}", + side, + stream_id + ); + } + + Ok(()) + } + + fn assert_no_stream_events(&self) -> TestCaseResult { + prop_assert!( + self.known_streams.is_empty() + && self.events.iter().all(|events| { + events.opened.is_empty() + && events.finished.is_empty() + && events.outbound_finished.is_empty() + && events.closed.is_empty() + && events.writable_closed.is_empty() + }), + "handshake-only property observed stream activity" + ); + Ok(()) + } + + fn assert_no_taken_writes(&self) -> TestCaseResult { + prop_assert!( + self.taken.iter().all(Vec::is_empty), + "cleanup left taken writes queued" + ); + Ok(()) + } + + fn assert_quiesced(&mut self) -> TestCaseResult { + self.reject_all_taken(); + + for _ in 0..8 { + self.capture_all_outbound(); + if self.pending.iter().all(Vec::is_empty) { + break; + } + self.flush_pending_in_order(); + self.observe_and_assert()?; + } + + self.capture_all_outbound(); + prop_assert!( + self.pending.iter().all(Vec::is_empty) && self.taken.iter().all(Vec::is_empty), + "cleanup did not quiesce: taken_a={} taken_b={} pending_a={} pending_b={}", + self.taken[Side::A.idx()].len(), + self.taken[Side::B.idx()].len(), + self.pending[Side::A.idx()].len(), + self.pending[Side::B.idx()].len() + ); + + Ok(()) + } + + fn capture_all_outbound(&mut self) { + for side in [Side::A, Side::B] { + while let Some(record) = take_confirmed_outbound(&mut self.harness, side) { + self.pending[side.idx()].push(record); + } + } + } + + fn flush_pending_in_order(&mut self) { + for side in [Side::A, Side::B] { + while let Some(record) = pop_front_pending(&mut self.pending[side.idx()]) { + self.deliver_to(opposite(side), record); + } + } + } + + fn reject_all_taken(&mut self) { + for side in [Side::A, Side::B] { + while let Some(write) = self.taken[side.idx()].pop() { + reject_taken(&mut self.harness, side, &write); + } + } + } + + fn deliver_to(&mut self, side: Side, record: Vec) { + if let Err(error) = deliver_to(&mut self.harness, side, record) { + self.receive_errors.push((side, error)); + } + } + + fn inbound_aborted(&self, side: Side, stream_id: StreamId) -> bool { + self.events[side.idx()].closed.contains(&stream_id) + || self.closed_by[side.idx()].contains(&stream_id) + } +} + +fn take_unconfirmed_outbound(harness: &mut Harness, side: Side) -> Option { + let write = harness.next_write(side)?; + Some(TakenWrite { + record: write.record, + write_id: write.write_id, + }) +} + +fn take_confirmed_outbound(harness: &mut Harness, side: Side) -> Option> { + let write = take_unconfirmed_outbound(harness, side)?; + confirm_taken(harness, side, &write); + Some(write.record) +} + +fn confirm_taken(harness: &mut Harness, side: Side, write: &TakenWrite) { + if let Some(write_id) = write.write_id { + harness.confirm_write(side, write_id); + } +} + +fn reject_taken(harness: &mut Harness, side: Side, write: &TakenWrite) { + if let Some(write_id) = write.write_id { + harness.reject_write(side, write_id); + } +} + +fn deliver_to(harness: &mut Harness, side: Side, record: Vec) -> Result<(), ReceiveError> { + let time = harness.time(); + let Node { fsm, crypto } = harness.node_mut(side); + fsm.receive(time, record, crypto) +} + +fn take_pending(pending: &mut Vec>, index: usize) -> Option> { + if pending.is_empty() { + return None; + } + + Some(pending.remove(index % pending.len())) +} + +fn peek_pending(pending: &[Vec], index: usize) -> Option> { + if pending.is_empty() { + return None; + } + + Some(pending[index % pending.len()].clone()) +} + +fn pop_front_pending(pending: &mut Vec>) -> Option> { + if pending.is_empty() { + None + } else { + Some(pending.remove(0)) + } +} + +fn take_taken(taken: &mut Vec, index: usize) -> Option { + if taken.is_empty() { + return None; + } + + Some(taken.remove(index % taken.len())) +} + +fn drain_stream(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { + let mut out = Vec::new(); + let Ok(mut stream) = fsm.stream(stream_id) else { + return out; + }; + + loop { + let mut read = 0usize; + for chunk in stream.read() { + out.extend_from_slice(&chunk); + read += chunk.len(); + } + + if read == 0 { + break; + } + + stream.commit_read(read).unwrap(); + } + + out +} + +fn opposite(side: Side) -> Side { + match side { + Side::A => Side::B, + Side::B => Side::A, + } +} + +fn side_strategy() -> impl Strategy { + prop_oneof![Just(Side::A), Just(Side::B)] +} + +fn side_action(f: fn(Side) -> Action) -> impl Strategy { + side_strategy().prop_map(f) +} + +fn side_usize_action( + values: impl Strategy, + f: fn(Side, usize) -> Action, +) -> impl Strategy { + (side_strategy(), values).prop_map(move |(side, value)| f(side, value)) +} + +fn side_usize_vec_action( + values: impl Strategy, + bytes: impl Strategy>, + f: fn(Side, usize, Vec) -> Action, +) -> impl Strategy { + (side_strategy(), values, bytes).prop_map(move |(side, value, bytes)| f(side, value, bytes)) +} + +fn handshake_action_strategy() -> impl Strategy { + let queue_index = 0usize..6; + prop_oneof![ + side_action(Action::ConnectIk), + side_action(Action::ConnectKk), + (0u8..40).prop_map(Action::AdvanceMs), + side_action(Action::OnTimer), + Just(Action::OnTimerBoth), + Just(Action::Pump), + side_action(Action::TakeNext), + side_usize_action(queue_index.clone(), Action::confirm_taken), + side_usize_action(queue_index.clone(), Action::reject_taken), + side_action(Action::CaptureNext), + side_action(Action::DeliverNext), + side_action(Action::DropNext), + side_usize_action(queue_index.clone(), Action::deliver_queued), + side_usize_action(queue_index.clone(), Action::duplicate_queued), + side_usize_action(queue_index, Action::drop_queued), + ] +} + +fn connected_action_strategy() -> impl Strategy { + let bytes = vec(any::(), 0..24); + let slot = 0usize..SLOT_COUNT; + let queue_index = 0usize..6; + prop_oneof![ + (0u8..30).prop_map(Action::AdvanceMs), + side_action(Action::OnTimer), + Just(Action::OnTimerBoth), + Just(Action::Pump), + side_action(Action::TakeNext), + side_usize_action(queue_index.clone(), Action::confirm_taken), + side_usize_action(queue_index.clone(), Action::reject_taken), + side_action(Action::CaptureNext), + side_action(Action::DeliverNext), + side_action(Action::DropNext), + side_usize_action(queue_index.clone(), Action::deliver_queued), + side_usize_action(queue_index.clone(), Action::duplicate_queued), + side_usize_action(queue_index, Action::drop_queued), + side_usize_action(slot.clone(), Action::open_stream), + side_usize_vec_action(slot.clone(), bytes, Action::write), + side_usize_action(slot.clone(), Action::finish), + side_usize_action(slot, Action::close), + ] +} + +fn write_tracking_action_strategy() -> impl Strategy { + let bytes = vec(any::(), 0..16); + let slot = 0usize..SLOT_COUNT; + let queue_index = 0usize..6; + prop_oneof![ + side_usize_action(slot.clone(), Action::open_stream), + side_usize_vec_action(slot, bytes, Action::write), + side_action(Action::TakeNext), + side_usize_action(queue_index.clone(), Action::confirm_taken), + side_usize_action(queue_index.clone(), Action::reject_taken), + side_usize_action(queue_index.clone(), Action::deliver_queued), + side_usize_action(queue_index.clone(), Action::duplicate_queued), + side_usize_action(queue_index, Action::drop_queued), + Just(Action::Pump), + side_action(Action::OnTimer), + Just(Action::OnTimerBoth), + (0u8..20).prop_map(Action::AdvanceMs), + ] +} + +fn packet_loss_recovery_action_strategy() -> impl Strategy { + let queue_index = 0usize..16; + prop_oneof![ + (0u8..20).prop_map(Action::AdvanceMs), + side_action(Action::OnTimer), + Just(Action::OnTimerBoth), + Just(Action::Pump), + side_usize_action(queue_index.clone(), Action::deliver_queued), + side_usize_action(queue_index.clone(), Action::duplicate_queued), + side_usize_action(queue_index, Action::drop_queued), + ] +} + +fn terminal_action_strategy() -> impl Strategy { + let bytes = vec(any::(), 0..16); + let slot = 0usize..SLOT_COUNT; + let queue_index = 0usize..6; + prop_oneof![ + side_usize_action(slot.clone(), Action::open_stream), + side_usize_vec_action(slot.clone(), bytes, Action::write), + side_usize_action(slot.clone(), Action::finish), + side_usize_action(slot, Action::close), + side_action(Action::TakeNext), + side_usize_action(queue_index.clone(), Action::confirm_taken), + side_usize_action(queue_index.clone(), Action::reject_taken), + side_usize_action(queue_index.clone(), Action::deliver_queued), + side_usize_action(queue_index.clone(), Action::duplicate_queued), + side_usize_action(queue_index, Action::drop_queued), + Just(Action::Pump), + side_action(Action::OnTimer), + Just(Action::OnTimerBoth), + (0u8..20).prop_map(Action::AdvanceMs), + ] +} + +proptest_crate::proptest! { + #![proptest_config(ProptestConfig { + cases: 24, + max_shrink_iters: 10_000, + .. ProptestConfig::default() + })] + + #[test] + fn randomized_handshake_actions_quiesce(actions in vec(handshake_action_strategy(), 1..64)) { + let mut runner = Runner::handshake(); + runner.run(&actions)?; + runner.assert_no_stream_events()?; + } + + #[test] + fn randomized_stream_actions_preserve_integrity(actions in vec(connected_action_strategy(), 1..80)) { + let mut runner = Runner::connected(); + runner.run(&actions)?; + } + + #[test] + fn randomized_write_tracking_actions_quiesce(actions in vec(write_tracking_action_strategy(), 1..80)) { + let mut runner = Runner::connected(); + runner.run(&actions)?; + runner.assert_no_taken_writes()?; + } + + #[test] + fn randomized_session_packet_loss_recovers( + payload in vec(any::(), 512..2048), + actions in vec(packet_loss_recovery_action_strategy(), 1..96), + ) { + let config = QlFsmConfig { + session_record_ack_delay: Duration::from_millis(1), + session_record_retransmit_timeout: Duration::from_millis(10), + session_record_max_size: 96, + session_pending_ack_range_limit: 512, + ..QlFsmConfig::default() + }; + let mut runner = Runner::connected_with_config(config); + + runner.apply(&Action::open_stream(Side::A, 0)); + runner.observe_and_assert()?; + + runner.apply(&Action::write(Side::A, 0, payload)); + runner.observe_and_assert()?; + + runner.apply(&Action::finish(Side::A, 0)); + runner.observe_and_assert()?; + + for action in &actions { + runner.apply(action); + runner.observe_and_assert()?; + } + + runner.cleanup()?; + runner.observe_and_assert()?; + runner.assert_expected_delivered(Side::B)?; + runner.assert_terminal_semantics()?; + runner.assert_quiesced()?; + } + + #[test] + fn randomized_terminal_actions_preserve_terminal_semantics(actions in vec(terminal_action_strategy(), 1..80)) { + let mut runner = Runner::connected(); + runner.run(&actions)?; + runner.assert_terminal_semantics()?; + } +} diff --git a/ql-fsm/src/tests/session.rs b/ql-fsm/src/tests/session.rs new file mode 100644 index 0000000..c55e51c --- /dev/null +++ b/ql-fsm/src/tests/session.rs @@ -0,0 +1,532 @@ +use std::time::Duration; + +use bytes::Bytes; +use ql_wire::{RouteId, SessionClose, StreamId, VarInt}; + +use super::*; +use crate::{state::LinkState, CommitReadError, Event, NoSessionError, PeerStatus, StreamError}; + +fn stream_id(value: u32) -> StreamId { + StreamId(VarInt::from_u32(value)) +} + +fn route_id(value: u32) -> RouteId { + RouteId::from_u32(value) +} + +fn opened(stream_id: StreamId) -> Event { + Event::Opened { + stream_id, + route_id: route_id(1), + } +} + +fn open_stream_id(fsm: &mut QlFsm) -> StreamId { + fsm.open_stream(route_id(1)).unwrap().stream_id() +} + +fn write_stream_bytes( + fsm: &mut QlFsm, + stream_id: StreamId, + bytes: &[u8], +) -> Result { + let mut bytes = Bytes::copy_from_slice(bytes); + let mut stream = fsm.stream(stream_id)?; + let Some(mut writer) = stream.writer() else { + return Err(StreamError::NotWritable); + }; + Ok(writer.write(&mut bytes)) +} + +fn read_stream_all(fsm: &mut QlFsm, stream_id: StreamId) -> Vec { + let mut out = Vec::new(); + let Ok(mut stream) = fsm.stream(stream_id) else { + return out; + }; + loop { + let mut read = 0; + for chunk in stream.read() { + out.extend_from_slice(&chunk); + read += chunk.len(); + } + if read == 0 { + break; + } + stream.commit_read(read).unwrap(); + } + out +} + +#[test] +fn connected_fsms_deliver_stream_data() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"hello").unwrap(), + 5 + ); + harness + .a + .fsm + .stream(stream_id) + .unwrap() + .writer() + .unwrap() + .finish(); + + harness.pump(); + + assert_eq!(harness.take_event(Side::B), Some(opened(stream_id))); + assert_eq!( + harness.take_event(Side::B), + Some(Event::Readable(stream_id)) + ); + assert_eq!( + read_stream_all(&mut harness.b.fsm, stream_id), + b"hello".to_vec() + ); + assert_eq!( + harness.take_event(Side::B), + Some(Event::Finished(stream_id)) + ); + harness.advance(QlFsmConfig::default().session_record_ack_delay); + harness.on_timer(Side::B); + harness.pump(); + assert_eq!( + harness.take_event(Side::A), + Some(Event::OutboundFinished(stream_id)) + ); +} + +#[test] +fn session_retransmit_uses_new_record_seq() { + let config = QlFsmConfig::default(); + let mut harness = Harness::connected(config); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), + 5 + ); + + let first = harness.next_decoded_outbound(Side::A).unwrap(); + + harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); + harness.on_timer(Side::A); + + let retried = harness.next_decoded_outbound(Side::A).unwrap(); + + assert_ne!(retried.header.seq, first.header.seq); + assert_eq!(retried.frames, first.frames); + + harness.deliver(Side::B, retried.record); + harness.advance(config.session_record_ack_delay); + harness.on_timer(Side::A); + harness.on_timer(Side::B); + harness.pump(); + + assert_eq!(harness.take_event(Side::B), Some(opened(stream_id))); + assert_eq!( + harness.take_event(Side::B), + Some(Event::Readable(stream_id)) + ); + assert_eq!( + read_stream_all(&mut harness.b.fsm, stream_id), + b"retry".to_vec() + ); + + harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); + harness.on_timer(Side::A); + assert!(harness.next_outbound(Side::A).is_none()); +} + +#[test] +fn simultaneous_opens_use_even_and_odd_stream_ids() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + let stream_id_a = open_stream_id(&mut harness.a.fsm); + let stream_id_b = open_stream_id(&mut harness.b.fsm); + + assert_ne!(stream_id_a, stream_id_b); + assert!( + StreamParity::for_local(harness.a.fsm.identity.qid, harness.b.fsm.identity.qid) + .matches(stream_id_a) + ); + assert!( + StreamParity::for_local(harness.b.fsm.identity.qid, harness.a.fsm.identity.qid) + .matches(stream_id_b) + ); + + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id_a, b"from-a").unwrap(), + 6 + ); + assert_eq!( + write_stream_bytes(&mut harness.b.fsm, stream_id_b, b"from-b").unwrap(), + 6 + ); + + harness.pump(); + + assert_eq!(harness.take_event(Side::A), Some(opened(stream_id_b))); + assert_eq!( + harness.take_event(Side::A), + Some(Event::Readable(stream_id_b)) + ); + assert_eq!( + read_stream_all(&mut harness.a.fsm, stream_id_b), + b"from-b".to_vec() + ); + assert_eq!(harness.take_event(Side::B), Some(opened(stream_id_a))); + assert_eq!( + harness.take_event(Side::B), + Some(Event::Readable(stream_id_a)) + ); + assert_eq!( + read_stream_all(&mut harness.b.fsm, stream_id_a), + b"from-a".to_vec() + ); +} + +#[test] +fn disconnected_stream_operations_fail_with_no_session() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + let missing = stream_id(0); + + assert!(matches!( + harness.a.fsm.open_stream(route_id(1)), + Err(NoSessionError) + )); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, missing, b"queued"), + Err(StreamError::NoSession) + ); + assert_eq!( + harness + .a + .fsm + .stream(missing) + .map(|mut stream| stream.writer().unwrap().finish()), + Err(StreamError::NoSession) + ); + assert_eq!( + harness.a.fsm.stream(missing).map(|mut stream| { + stream.close( + ql_wire::CloseTarget::Both, + ql_wire::StreamCloseCode::CANCELLED, + ); + }), + Err(StreamError::NoSession) + ); + assert_eq!(harness.a.fsm.queue_ping(), Err(NoSessionError)); + assert!(matches!( + harness.a.fsm.stream(missing), + Err(StreamError::NoSession) + )); +} + +#[test] +fn disconnected_stream_read_accessors_return_none() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + let missing = stream_id(0); + + assert!(matches!( + harness.a.fsm.stream(missing), + Err(StreamError::NoSession) + )); +} + +#[test] +fn commit_read_rejects_lengths_past_readable_prefix() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"hi").unwrap(), + 2 + ); + harness.pump(); + + let mut stream = harness.b.fsm.stream(stream_id).unwrap(); + assert_eq!(stream.commit_read(3), Err(CommitReadError)); +} + +#[test] +fn returned_session_write_is_reissued_with_new_record_seq() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), + 5 + ); + + let first = harness.next_decoded_write(Side::A).unwrap(); + let id = first.write_id.expect("expected session write"); + + harness.reject_write(Side::A, id); + + let reissued = harness.next_decoded_write(Side::A).unwrap(); + let reissued_id = reissued.write_id.expect("expected reissued write"); + + assert_ne!(reissued_id, id); + assert_ne!(reissued.header.seq, first.header.seq); + assert_eq!(reissued.frames, first.frames); + + harness.confirm_write(Side::A, reissued_id); + harness.deliver(Side::B, reissued.record); + harness.pump(); + + assert_eq!(harness.take_event(Side::B), Some(opened(stream_id))); + assert_eq!( + harness.take_event(Side::B), + Some(Event::Readable(stream_id)) + ); + assert_eq!( + read_stream_all(&mut harness.b.fsm, stream_id), + b"retry".to_vec() + ); +} + +#[test] +fn unconfirmed_session_write_does_not_start_retransmit_timer() { + let config = QlFsmConfig::default(); + let mut harness = Harness::connected(config); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"retry").unwrap(), + 5 + ); + + let first = harness.next_decoded_write(Side::A).unwrap(); + let id = first.write_id.expect("expected session write"); + + harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); + harness.on_timer(Side::A); + assert!(harness.next_write(Side::A).is_none()); + + harness.confirm_write(Side::A, id); + harness.advance(config.session_record_retransmit_timeout + Duration::from_millis(1)); + harness.on_timer(Side::A); + + let retried = harness.next_decoded_write(Side::A).unwrap(); + + assert_ne!(retried.header.seq, first.header.seq); + assert_eq!(retried.frames, first.frames); +} + +#[test] +fn ack_frame_releases_stream_capacity_and_emits_writable() { + let config = QlFsmConfig { + session_stream_send_buffer_size: 4, + ..QlFsmConfig::default() + }; + let mut harness = Harness::connected(config); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"abcd").unwrap(), + 4 + ); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"z").unwrap(), + 0 + ); + + let record = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, record); + harness.advance(config.session_record_ack_delay); + harness.on_timer(Side::A); + harness.on_timer(Side::B); + harness.pump(); + + assert_eq!( + harness.take_event(Side::A), + Some(Event::Writable(stream_id)) + ); +} + +#[test] +fn close_session_disconnects_locally() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + harness + .a + .fsm + .close_session(ql_wire::SessionCloseCode::CANCELLED); + + assert!(matches!( + harness.take_event(Side::A), + Some(Event::SessionClosed(SessionClose { + code: ql_wire::SessionCloseCode::CANCELLED, + })) + )); + assert!(matches!(harness.a.fsm.state.link, LinkState::Connected(_))); + assert!(matches!( + harness.a.fsm.open_stream(route_id(1)), + Err(NoSessionError) + )); + assert_eq!(harness.a.fsm.queue_ping(), Err(NoSessionError)); + + let close = harness.next_decoded_outbound(Side::A).unwrap(); + assert!(matches!( + close.frames.as_slice(), + [ql_wire::SessionFrame::Close(_)] + )); + + assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); + assert_eq!( + harness.take_event(Side::A), + Some(Event::PeerStatusChanged(PeerStatus::Disconnected)) + ); +} + +#[test] +fn unpair_clears_bound_peer_and_emits_unpair_frame() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + harness.a.fsm.unpair(); + + assert_eq!( + harness.take_event(Side::A), + Some(Event::PeerStatusChanged(PeerStatus::Unpaired)) + ); + assert!(harness.a.fsm.peer().is_none()); + assert!(matches!( + harness.a.fsm.open_stream(route_id(1)), + Err(NoSessionError) + )); + assert_eq!(harness.a.fsm.queue_ping(), Err(NoSessionError)); + + let unpair = harness.next_decoded_outbound(Side::A).unwrap(); + assert!(matches!( + unpair.frames.as_slice(), + [ql_wire::SessionFrame::Unpair] + )); + assert!(matches!(harness.a.fsm.state.link, LinkState::Idle)); +} + +#[test] +fn inbound_unpair_clears_remote_peer_binding() { + let mut harness = Harness::connected(QlFsmConfig::default()); + + harness.a.fsm.unpair(); + let unpair = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, unpair); + + assert_eq!( + harness.take_event(Side::B), + Some(Event::PeerStatusChanged(PeerStatus::Unpaired)) + ); + assert!(harness.b.fsm.peer().is_none()); + assert!(matches!( + harness.b.fsm.open_stream(route_id(1)), + Err(NoSessionError) + )); + assert!(matches!(harness.connect_ik(Side::B), Err(NoPeerError))); + + let reply_key = harness.b.fsm.state.link.transport().unwrap().tx_key.clone(); + let reply = harness.next_outbound(Side::B).unwrap(); + let (_header, frames) = decrypt_record(&harness.b.crypto, &reply, &reply_key); + assert!(matches!(frames.as_slice(), [ql_wire::SessionFrame::Unpair])); + assert!(matches!(harness.b.fsm.state.link, LinkState::Idle)); +} + +#[test] +fn local_unpair_without_session_emits_unpaired_immediately() { + let mut harness = Harness::paired_known(QlFsmConfig::default()); + + harness.a.fsm.unpair(); + + assert_eq!( + harness.take_event(Side::A), + Some(Event::PeerStatusChanged(PeerStatus::Unpaired)) + ); + assert!(harness.a.fsm.peer().is_none()); + assert_eq!(harness.take_event(Side::A), None); +} + +#[test] +fn session_records_contain_ack_frames_after_delivery() { + let config = QlFsmConfig::default(); + let mut harness = Harness::connected(config); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"x").unwrap(), + 1 + ); + + let data = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, data); + harness.advance(config.session_record_ack_delay); + harness.on_timer(Side::B); + + let ack = harness.next_decoded_outbound(Side::B).unwrap(); + assert!(matches!( + ack.frames.as_slice(), + [ql_wire::SessionFrame::Ack(_)] + )); +} + +#[test] +fn first_stream_data_uses_negotiated_initial_peer_credit() { + let mut harness = Harness::paired_known_with_configs( + QlFsmConfig { + session_stream_receive_buffer_size: 8, + ..QlFsmConfig::default() + }, + QlFsmConfig { + session_stream_receive_buffer_size: 3, + ..QlFsmConfig::default() + }, + ); + + harness.connect_ik(Side::A).unwrap(); + let ik1 = harness.next_outbound(Side::A).unwrap(); + harness.deliver(Side::B, ik1); + let ik2 = harness.next_outbound(Side::B).unwrap(); + harness.deliver(Side::A, ik2); + + let stream_id = open_stream_id(&mut harness.a.fsm); + assert_eq!( + write_stream_bytes(&mut harness.a.fsm, stream_id, b"hello").unwrap(), + 5 + ); + + assert!(matches!( + harness.next_decoded_outbound(Side::A).unwrap().frames.as_slice(), + [ql_wire::SessionFrame::StreamData(frame)] if frame.stream_id == stream_id && frame.bytes.as_slice() == b"hel" + )); +} + +#[test] +fn session_timeout_emits_close_before_disconnect() { + let config = QlFsmConfig { + session_peer_timeout: Duration::from_millis(30), + ..QlFsmConfig::default() + }; + let mut harness = Harness::connected(config); + + harness.advance(config.session_peer_timeout); + harness.on_timer(Side::A); + + assert_eq!( + harness.drain_events(Side::A), + vec![Event::SessionClosed(SessionClose { + code: ql_wire::SessionCloseCode::TIMEOUT, + })] + ); + + let close = harness.next_decoded_outbound(Side::A).unwrap(); + assert!(matches!( + close.frames.as_slice(), + [ql_wire::SessionFrame::Close(_)] + )); + assert_eq!( + harness.take_event(Side::A), + Some(Event::PeerStatusChanged(PeerStatus::Disconnected)) + ); +} From ba41308051cb164ce14bca4849182fdaf800c152 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 4 Jun 2026 09:25:14 -0400 Subject: [PATCH 4/6] ql-rpc: add RPC modality layer --- Cargo.lock | 19 ++ Cargo.toml | 2 + ql-rpc/Cargo.toml | 10 + ql-rpc/src/chunk_queue.rs | 252 ++++++++++++++++++ ql-rpc/src/codec.rs | 83 ++++++ ql-rpc/src/error.rs | 112 ++++++++ ql-rpc/src/framed_value.rs | 127 +++++++++ ql-rpc/src/lib.rs | 44 ++++ ql-rpc/src/route_id.rs | 19 ++ ql-rpc/src/router/builder.rs | 354 ++++++++++++++++++++++++++ ql-rpc/src/router/config.rs | 12 + ql-rpc/src/router/mod.rs | 89 +++++++ ql-rpc/src/router/mode.rs | 21 ++ ql-rpc/src/rpc/download/client.rs | 246 ++++++++++++++++++ ql-rpc/src/rpc/download/mod.rs | 31 +++ ql-rpc/src/rpc/download/server.rs | 221 ++++++++++++++++ ql-rpc/src/rpc/duplex/client.rs | 164 ++++++++++++ ql-rpc/src/rpc/duplex/codec.rs | 86 +++++++ ql-rpc/src/rpc/duplex/mod.rs | 24 ++ ql-rpc/src/rpc/duplex/server.rs | 47 ++++ ql-rpc/src/rpc/mod.rs | 32 +++ ql-rpc/src/rpc/notification/client.rs | 10 + ql-rpc/src/rpc/notification/mod.rs | 19 ++ ql-rpc/src/rpc/notification/server.rs | 48 ++++ ql-rpc/src/rpc/parts.rs | 283 ++++++++++++++++++++ ql-rpc/src/rpc/progress/client.rs | 152 +++++++++++ ql-rpc/src/rpc/progress/codec.rs | 145 +++++++++++ ql-rpc/src/rpc/progress/mod.rs | 27 ++ ql-rpc/src/rpc/progress/server.rs | 106 ++++++++ ql-rpc/src/rpc/request/client.rs | 34 +++ ql-rpc/src/rpc/request/mod.rs | 23 ++ ql-rpc/src/rpc/request/server.rs | 96 +++++++ ql-rpc/src/rpc/subscription/client.rs | 99 +++++++ ql-rpc/src/rpc/subscription/codec.rs | 58 +++++ ql-rpc/src/rpc/subscription/mod.rs | 23 ++ ql-rpc/src/rpc/subscription/server.rs | 105 ++++++++ ql-rpc/src/rpc/upload/client.rs | 146 +++++++++++ ql-rpc/src/rpc/upload/mod.rs | 26 ++ ql-rpc/src/rpc/upload/server.rs | 243 ++++++++++++++++++ ql-rpc/src/rpc/utils.rs | 120 +++++++++ ql-rpc/src/stream.rs | 89 +++++++ 41 files changed, 3847 insertions(+) create mode 100644 ql-rpc/Cargo.toml create mode 100644 ql-rpc/src/chunk_queue.rs create mode 100644 ql-rpc/src/codec.rs create mode 100644 ql-rpc/src/error.rs create mode 100644 ql-rpc/src/framed_value.rs create mode 100644 ql-rpc/src/lib.rs create mode 100644 ql-rpc/src/route_id.rs create mode 100644 ql-rpc/src/router/builder.rs create mode 100644 ql-rpc/src/router/config.rs create mode 100644 ql-rpc/src/router/mod.rs create mode 100644 ql-rpc/src/router/mode.rs create mode 100644 ql-rpc/src/rpc/download/client.rs create mode 100644 ql-rpc/src/rpc/download/mod.rs create mode 100644 ql-rpc/src/rpc/download/server.rs create mode 100644 ql-rpc/src/rpc/duplex/client.rs create mode 100644 ql-rpc/src/rpc/duplex/codec.rs create mode 100644 ql-rpc/src/rpc/duplex/mod.rs create mode 100644 ql-rpc/src/rpc/duplex/server.rs create mode 100644 ql-rpc/src/rpc/mod.rs create mode 100644 ql-rpc/src/rpc/notification/client.rs create mode 100644 ql-rpc/src/rpc/notification/mod.rs create mode 100644 ql-rpc/src/rpc/notification/server.rs create mode 100644 ql-rpc/src/rpc/parts.rs create mode 100644 ql-rpc/src/rpc/progress/client.rs create mode 100644 ql-rpc/src/rpc/progress/codec.rs create mode 100644 ql-rpc/src/rpc/progress/mod.rs create mode 100644 ql-rpc/src/rpc/progress/server.rs create mode 100644 ql-rpc/src/rpc/request/client.rs create mode 100644 ql-rpc/src/rpc/request/mod.rs create mode 100644 ql-rpc/src/rpc/request/server.rs create mode 100644 ql-rpc/src/rpc/subscription/client.rs create mode 100644 ql-rpc/src/rpc/subscription/codec.rs create mode 100644 ql-rpc/src/rpc/subscription/mod.rs create mode 100644 ql-rpc/src/rpc/subscription/server.rs create mode 100644 ql-rpc/src/rpc/upload/client.rs create mode 100644 ql-rpc/src/rpc/upload/mod.rs create mode 100644 ql-rpc/src/rpc/upload/server.rs create mode 100644 ql-rpc/src/rpc/utils.rs create mode 100644 ql-rpc/src/stream.rs diff --git a/Cargo.lock b/Cargo.lock index 016c88c..071d36b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2099,6 +2099,14 @@ dependencies = [ "ql-wire", ] +[[package]] +name = "ql-rpc" +version = "0.1.0" +dependencies = [ + "bytes", + "trait-variant", +] + [[package]] name = "ql-wire" version = "0.1.0" @@ -2771,6 +2779,17 @@ dependencies = [ "slab", ] +[[package]] +name = "trait-variant" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70977707304198400eb4835a78f6a9f928bf41bba420deb8fdb175cd965d77a7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "typenum" version = "1.18.0" diff --git a/Cargo.toml b/Cargo.toml index 84dde3e..de8c80b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "backup-shard", "btp", "ql-fsm", + "ql-rpc", "ql-wire", "quantum-link-macros", ] @@ -33,6 +34,7 @@ btp = { path = "btp" } foundation-api = { path = "api" } quantum-link-macros = { path = "quantum-link-macros" } ql-fsm = { path = "ql-fsm" } +ql-rpc = { path = "ql-rpc" } ql-wire = { path = "ql-wire" } [patch.crates-io] diff --git a/ql-rpc/Cargo.toml b/ql-rpc/Cargo.toml new file mode 100644 index 0000000..836e9d2 --- /dev/null +++ b/ql-rpc/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "ql-rpc" +version = "0.1.0" +edition = "2021" +description = "QuantumLink RPC protocol" +license = "Proprietary" + +[dependencies] +bytes = { workspace = true } +trait-variant = { version = "0.1" } diff --git a/ql-rpc/src/chunk_queue.rs b/ql-rpc/src/chunk_queue.rs new file mode 100644 index 0000000..33f6299 --- /dev/null +++ b/ql-rpc/src/chunk_queue.rs @@ -0,0 +1,252 @@ +use std::collections::VecDeque; + +use bytes::{Buf, Bytes}; + +use crate::{CodecError, Error}; + +const LENGTH_SIZE: usize = 8; + +#[derive(Debug, Default)] +pub struct ChunkQueue { + chunks: VecDeque, + remaining: usize, +} + +impl ChunkQueue { + pub fn push(&mut self, chunk: Bytes) { + if chunk.is_empty() { + return; + } + self.remaining += chunk.len(); + self.chunks.push_back(chunk); + } + + pub fn remaining(&self) -> usize { + self.remaining + } + + pub fn expect_empty(&self) -> Result<(), CodecError> { + if self.remaining > 0 { + Err(CodecError::Rpc(Error::TrailingBytes)) + } else { + Ok(()) + } + } + + pub fn pop_front(&mut self, max_len: usize) -> Option { + let front = self.chunks.front_mut()?; + let chunk = if max_len >= front.len() { + self.chunks.pop_front().expect("buffered chunk is present") + } else { + front.split_to(max_len) + }; + self.remaining -= chunk.len(); + Some(chunk) + } + + pub fn pop_front_chunk(&mut self) -> Option { + self.pop_front(usize::MAX) + } + + pub fn try_take_part(&mut self) -> Result>, Error> { + let Some(len) = self.peek_next_part_len()? else { + return Ok(None); + }; + self.advance(LENGTH_SIZE); + Ok(Some(DrainBuf::new(self, len))) + } + + pub fn try_take_tagged_part(&mut self) -> Result)>, Error> { + let mut bytes = self.peek(); + let Ok(kind) = bytes.try_get_u8() else { + return Ok(None); + }; + let Some(len) = read_next_part_len(&mut bytes)? else { + return Ok(None); + }; + + self.advance(1 + LENGTH_SIZE); + Ok(Some((kind, DrainBuf::new(self, len)))) + } + + pub fn try_take_tagged_part_header(&mut self) -> Result, Error> { + let mut bytes = self.peek(); + let Ok(kind) = bytes.try_get_u8() else { + return Ok(None); + }; + let Some(len) = read_part_len_header(&mut bytes)? else { + return Ok(None); + }; + + self.advance(1 + LENGTH_SIZE); + Ok(Some((kind, len))) + } + + pub fn try_take_body(&mut self, len: usize) -> Option> { + if self.remaining < len { + return None; + } + + Some(DrainBuf::new(self, len)) + } + + fn peek_next_part_len(&self) -> Result, Error> { + let mut bytes = self.peek(); + read_next_part_len(&mut bytes) + } + + fn peek(&self) -> ChunkQueuePeek<'_> { + ChunkQueuePeek { + chunks: &self.chunks, + chunk_index: 0, + chunk_offset: 0, + remaining: self.remaining, + } + } + + fn front_chunk(&self, limit: usize) -> &[u8] { + let Some(chunk) = self.chunks.front() else { + return &[]; + }; + &chunk[..chunk.len().min(limit)] + } + + fn advance_inner(&mut self, mut cnt: usize) { + assert!(cnt <= self.remaining, "advanced past buffered data"); + self.remaining -= cnt; + while cnt > 0 { + let front = self.chunks.front_mut().expect("buffered data present"); + let consumed = cnt.min(front.len()); + front.advance(consumed); + cnt -= consumed; + if front.is_empty() { + self.chunks.pop_front(); + } + } + } +} + +struct ChunkQueuePeek<'a> { + chunks: &'a VecDeque, + chunk_index: usize, + chunk_offset: usize, + remaining: usize, +} + +impl Buf for ChunkQueuePeek<'_> { + fn remaining(&self) -> usize { + self.remaining + } + + fn chunk(&self) -> &[u8] { + if self.remaining == 0 { + return &[]; + } + + let Some(chunk) = self.chunks.get(self.chunk_index) else { + return &[]; + }; + &chunk[self.chunk_offset..] + } + + fn advance(&mut self, mut cnt: usize) { + assert!(cnt <= self.remaining, "advanced past buffered data"); + self.remaining -= cnt; + + while cnt > 0 { + let chunk = self + .chunks + .get(self.chunk_index) + .expect("buffered data present"); + let available = chunk.len() - self.chunk_offset; + let step = cnt.min(available); + self.chunk_offset += step; + cnt -= step; + if self.chunk_offset == chunk.len() { + self.chunk_index += 1; + self.chunk_offset = 0; + } + } + } +} + +impl Buf for ChunkQueue { + fn remaining(&self) -> usize { + self.remaining + } + + fn chunk(&self) -> &[u8] { + self.front_chunk(self.remaining) + } + + fn advance(&mut self, cnt: usize) { + assert!(cnt <= self.remaining, "advanced past buffered data"); + self.advance_inner(cnt); + } +} + +pub struct DrainBuf<'a> { + bytes: &'a mut ChunkQueue, + remaining: usize, +} + +impl<'a> DrainBuf<'a> { + pub fn new(bytes: &'a mut ChunkQueue, len: usize) -> Self { + debug_assert!(bytes.remaining() >= len); + Self { + bytes, + remaining: len, + } + } + + pub fn expect_empty(&self) -> Result<(), CodecError> { + if self.remaining > 0 { + Err(CodecError::Rpc(Error::TrailingBytes)) + } else { + Ok(()) + } + } +} + +impl Buf for DrainBuf<'_> { + fn remaining(&self) -> usize { + self.remaining + } + + fn chunk(&self) -> &[u8] { + self.bytes.front_chunk(self.remaining) + } + + fn advance(&mut self, cnt: usize) { + assert!(cnt <= self.remaining(), "advanced past payload boundary"); + self.bytes.advance_inner(cnt); + self.remaining -= cnt; + } +} + +impl Drop for DrainBuf<'_> { + fn drop(&mut self) { + if self.remaining > 0 { + self.bytes.advance_inner(self.remaining); + self.remaining = 0; + } + } +} + +fn read_next_part_len(bytes: &mut B) -> Result, Error> { + let Some(len) = read_part_len_header(bytes)? else { + return Ok(None); + }; + if bytes.remaining() < len { + return Ok(None); + } + Ok(Some(len)) +} + +fn read_part_len_header(bytes: &mut B) -> Result, Error> { + let Ok(len) = bytes.try_get_u64_le() else { + return Ok(None); + }; + let len: usize = len.try_into().map_err(|_| Error::LengthOverflow)?; + Ok(Some(len)) +} diff --git a/ql-rpc/src/codec.rs b/ql-rpc/src/codec.rs new file mode 100644 index 0000000..51da527 --- /dev/null +++ b/ql-rpc/src/codec.rs @@ -0,0 +1,83 @@ +use std::{convert::Infallible, str::Utf8Error}; + +use bytes::{Buf, BufMut, Bytes}; + +pub use crate::chunk_queue::ChunkQueue; + +pub trait RpcCodec: Sized { + type Error; + + fn encode_value(&self, out: &mut B); + fn decode_value(bytes: &mut B) -> Result; +} + +impl RpcCodec for String { + type Error = Utf8Error; + + fn encode_value(&self, out: &mut B) { + out.put_slice(self.as_bytes()); + } + + fn decode_value(bytes: &mut B) -> Result { + let len = bytes.remaining(); + if bytes.chunk().len() == len { + let s = std::str::from_utf8(bytes.chunk())?.to_owned(); + bytes.advance(len); + Ok(s) + } else { + let mut buf = vec![0; len]; + bytes.copy_to_slice(&mut buf); + String::from_utf8(buf).map_err(|err| err.utf8_error()) + } + } +} + +impl RpcCodec for Vec { + type Error = Infallible; + + fn encode_value(&self, out: &mut B) { + out.put_slice(self.as_slice()); + } + + fn decode_value(bytes: &mut B) -> Result { + let len = bytes.remaining(); + let mut buf = vec![0; len]; + bytes.copy_to_slice(&mut buf); + Ok(buf) + } +} + +impl RpcCodec for Bytes { + type Error = Infallible; + + fn encode_value(&self, out: &mut B) { + out.put_slice(self.as_ref()); + } + + fn decode_value(bytes: &mut B) -> Result { + Ok(bytes.copy_to_bytes(bytes.remaining())) + } +} + +const LENGTH_SIZE: usize = 8; + +pub fn encode_value_part>(value: &T, out: &mut B) { + let payload_start = reserve_length(out); + value.encode_value(out); + backpatch_length(out, payload_start); +} + +/// reads one length-delimited rpc value from buffered byte chunks +pub fn reserve_length>(out: &mut B) -> usize { + let start = out.as_mut().len(); + out.put_bytes(0, LENGTH_SIZE); + start +} + +pub fn backpatch_length + ?Sized>(out: &mut B, start: usize) { + let out = out.as_mut(); + let payload_start = start + LENGTH_SIZE; + let payload_len = out.len() - payload_start; + let payload_len = u64::try_from(payload_len).expect("rpc payload exceeds u64 length framing"); + out[start..payload_start].copy_from_slice(&payload_len.to_le_bytes()); +} diff --git a/ql-rpc/src/error.rs b/ql-rpc/src/error.rs new file mode 100644 index 0000000..7404a22 --- /dev/null +++ b/ql-rpc/src/error.rs @@ -0,0 +1,112 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Error { + Truncated, + LengthOverflow, + UnexpectedFrameKind(u8), + MissingResponse, + TrailingBytes, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Truncated => f.write_str("truncated rpc payload"), + Self::LengthOverflow => f.write_str("rpc payload length overflow"), + Self::UnexpectedFrameKind(kind) => write!(f, "unexpected rpc frame kind {kind}"), + Self::MissingResponse => f.write_str("missing terminal rpc response"), + Self::TrailingBytes => f.write_str("trailing rpc bytes"), + } + } +} + +impl std::error::Error for Error {} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CodecError { + Rpc(Error), + Codec(E), +} + +impl std::error::Error for CodecError +where + E: std::error::Error + 'static, +{ + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + CodecError::Rpc(e) => Some(e), + CodecError::Codec(e) => Some(e), + } + } + + fn cause(&self) -> Option<&dyn std::error::Error> { + self.source() + } +} + +impl std::fmt::Display for CodecError +where + E: std::fmt::Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CodecError::Rpc(e) => write!(f, "{e}"), + CodecError::Codec(e) => write!(f, "{e}"), + } + } +} + +impl From for CodecError { + fn from(error: Error) -> Self { + Self::Rpc(error) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CallError { + Protocol(Error), + Codec(C), + Transport(T), +} + +impl std::fmt::Display for CallError +where + C: std::fmt::Display, + T: std::fmt::Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Protocol(error) => write!(f, "{error}"), + Self::Codec(error) => write!(f, "{error}"), + Self::Transport(error) => write!(f, "{error}"), + } + } +} + +impl std::error::Error for CallError +where + C: std::error::Error + 'static, + T: std::error::Error + 'static, +{ + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + CallError::Protocol(error) => Some(error), + CallError::Codec(error) => Some(error), + CallError::Transport(error) => Some(error), + } + } +} + +impl From for CallError { + fn from(error: Error) -> Self { + Self::Protocol(error) + } +} + +impl From> for CallError { + fn from(error: CodecError) -> Self { + match error { + CodecError::Rpc(error) => Self::Protocol(error), + CodecError::Codec(error) => Self::Codec(error), + } + } +} diff --git a/ql-rpc/src/framed_value.rs b/ql-rpc/src/framed_value.rs new file mode 100644 index 0000000..600357d --- /dev/null +++ b/ql-rpc/src/framed_value.rs @@ -0,0 +1,127 @@ +use std::marker::PhantomData; + +use bytes::Bytes; + +use crate::{chunk_queue::ChunkQueue, CodecError, RpcCodec}; + +/// reads one length-delimited rpc value from buffered byte chunks +pub struct FramedReader { + bytes: ChunkQueue, + marker: PhantomData T>, +} + +pub enum FramedReadStep { + NeedMore(FramedReader), + Value(T), +} + +pub enum FramedPrefixStep { + NeedMore(FramedReader), + Value { value: T, bytes: ChunkQueue }, +} + +impl Default for FramedReader { + fn default() -> Self { + Self { + bytes: ChunkQueue::default(), + marker: PhantomData, + } + } +} + +impl FramedReader { + pub fn push(mut self, chunk: Bytes) -> Self { + self.bytes.push(chunk); + self + } + + pub fn advance(self) -> Result, CodecError> { + match self.advance_prefix()? { + FramedPrefixStep::NeedMore(next) => Ok(FramedReadStep::NeedMore(next)), + FramedPrefixStep::Value { value, bytes } => { + bytes.expect_empty()?; + Ok(FramedReadStep::Value(value)) + } + } + } + + pub fn advance_prefix(self) -> Result, CodecError> { + let mut this = self; + let Some(mut body) = this.bytes.try_take_part()? else { + return Ok(FramedPrefixStep::NeedMore(this)); + }; + + let value = T::decode_value(&mut body).map_err(CodecError::Codec)?; + drop(body); + Ok(FramedPrefixStep::Value { + value, + bytes: this.bytes, + }) + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::{FramedPrefixStep, FramedReadStep, FramedReader}; + use crate::codec::encode_value_part; + + #[test] + fn value_reader_round_trips_framed_values() { + let mut encoded = Vec::new(); + encode_value_part(&b"hello".to_vec(), &mut encoded); + + match FramedReader::>::default() + .push(Bytes::from(encoded)) + .advance() + .unwrap() + { + FramedReadStep::Value(value) => assert_eq!(value, b"hello".to_vec()), + _ => unreachable!(), + } + } + + #[test] + fn value_reader_waits_for_complete_frame() { + let mut encoded = Vec::new(); + encode_value_part(&b"hello".to_vec(), &mut encoded); + let encoded = Bytes::from(encoded); + + let reader = match FramedReader::>::default() + .push(encoded.slice(..4)) + .advance() + .unwrap() + { + FramedReadStep::NeedMore(next) => next, + _ => unreachable!(), + }; + + match reader.push(encoded.slice(4..)).advance().unwrap() { + FramedReadStep::Value(value) => assert_eq!(value, b"hello".to_vec()), + _ => unreachable!(), + } + } + + #[test] + fn value_reader_returns_prefix_remainder() { + let mut encoded = Vec::new(); + encode_value_part(&b"hello".to_vec(), &mut encoded); + encoded.extend_from_slice(b"tail"); + + match FramedReader::>::default() + .push(Bytes::from(encoded)) + .advance_prefix() + .unwrap() + { + FramedPrefixStep::Value { value, mut bytes } => { + assert_eq!(value, b"hello".to_vec()); + assert_eq!( + bytes.pop_front(usize::MAX), + Some(Bytes::from_static(b"tail")) + ); + } + _ => unreachable!(), + } + } +} diff --git a/ql-rpc/src/lib.rs b/ql-rpc/src/lib.rs new file mode 100644 index 0000000..efea025 --- /dev/null +++ b/ql-rpc/src/lib.rs @@ -0,0 +1,44 @@ +#![allow(clippy::type_complexity)] + +//! QuantumLink RPC protocol + +mod chunk_queue; +pub(crate) mod codec; +mod error; +mod framed_value; +mod route_id; +mod router; +mod rpc; +mod stream; + +pub use chunk_queue::ChunkQueue; +pub use codec::RpcCodec; +pub use error::*; +use framed_value::*; +pub use route_id::RouteId; +pub use router::*; +pub use rpc::*; +pub use stream::*; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct StreamCloseCode(pub u16); + +impl StreamCloseCode { + /// operation was cancelled + pub const CANCELLED: Self = Self(0); + /// local internal error + pub const INTERNAL: Self = Self(1); + /// request was refused + pub const REFUSED: Self = Self(2); + /// operation timed out + pub const TIMEOUT: Self = Self(3); + /// configured limit was exceeded + pub const LIMIT: Self = Self(4); + /// route identifier was unknown + pub const UNKNOWN_ROUTE: Self = Self(5); + + pub const fn into_inner(self) -> u16 { + self.0 + } +} diff --git a/ql-rpc/src/route_id.rs b/ql-rpc/src/route_id.rs new file mode 100644 index 0000000..1b054e7 --- /dev/null +++ b/ql-rpc/src/route_id.rs @@ -0,0 +1,19 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct RouteId(pub u32); + +impl RouteId { + pub const fn from_u32(value: u32) -> Self { + Self(value) + } + + pub const fn into_inner(self) -> u32 { + self.0 + } +} + +impl From for RouteId { + fn from(value: u32) -> Self { + Self::from_u32(value) + } +} diff --git a/ql-rpc/src/router/builder.rs b/ql-rpc/src/router/builder.rs new file mode 100644 index 0000000..b59a84e --- /dev/null +++ b/ql-rpc/src/router/builder.rs @@ -0,0 +1,354 @@ +use std::marker::PhantomData; + +use super::{ + LocalSpawner, RouteEntry, RouteFn, Router, RouterConfig, RpcStream, SendSpawner, Spawner, +}; +use crate::{ + download::{server::*, Download as DownloadRpc}, + duplex::{server::*, Duplex as DuplexRpc}, + notification::{server::*, Notification as NotificationRpc}, + progress::{server::*, Progress as ProgressRpc}, + request::{server::*, Request as RequestRpc}, + subscription::{server::*, Subscription as SubscriptionRpc}, + upload::{server::*, Upload as UploadRpc}, +}; + +pub struct LocalRoutes; +pub struct SendRoutes; + +pub struct RouterBuilder +where + Sp: Spawner, +{ + config: RouterConfig, + spawner: Sp, + routes: Vec>, + marker: PhantomData Mode>, +} + +impl RouterBuilder +where + Sp: Spawner, +{ + pub(crate) fn new(spawner: Sp) -> Self { + Self { + config: RouterConfig::default(), + spawner, + routes: Vec::new(), + marker: PhantomData, + } + } + + pub fn config(mut self, config: RouterConfig) -> Self { + self.config = config; + self + } + + pub fn max_request_bytes(mut self, max_request_bytes: usize) -> Self { + self.config.max_request_bytes = max_request_bytes; + self + } + + pub fn build(mut self, state: S) -> Router { + self.routes.sort_by_key(|entry| entry.route_id); + self.routes.shrink_to_fit(); + Router { + config: self.config, + state, + spawner: self.spawner, + routes: self.routes, + } + } + + fn add_route(mut self, route_id: crate::RouteId, route: RouteFn) -> Self { + if self.routes.iter().any(|entry| entry.route_id == route_id) { + panic!("duplicate rpc route {}", route_id.into_inner()); + } + self.routes.push(RouteEntry::new(route_id, route)); + self + } +} + +impl RouterBuilder +where + Sp: LocalSpawner, + St: RpcStream + 'static, +{ + pub fn request(self) -> Self + where + M: RequestRpc + 'static, + S: RequestHandlerLocal + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_request_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn notification(self) -> Self + where + M: NotificationRpc + 'static, + S: NotificationHandlerLocal + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_notification_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn duplex(self) -> Self + where + M: DuplexRpc + 'static, + S: DuplexHandlerLocal + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_duplex_inner::( + state, + config, + reader, + writer, + S::handle, + )) + }) + } + + pub fn download(self) -> Self + where + M: DownloadRpc + 'static, + S: DownloadHandlerLocal + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_download_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn subscription(self) -> Self + where + M: SubscriptionRpc + 'static, + S: SubscriptionHandlerLocal + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_subscription_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn progress(self) -> Self + where + M: ProgressRpc + 'static, + S: ProgressHandlerLocal + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_progress_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn upload(self) -> Self + where + M: UploadRpc + 'static, + S: UploadHandlerLocal + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_upload_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } +} + +impl RouterBuilder +where + Sp: SendSpawner + Send, + St: RpcStream + 'static, +{ + pub fn request(self) -> Self + where + M: RequestRpc + 'static, + M::Request: Send + 'static, + S: RequestHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_request_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn notification(self) -> Self + where + M: NotificationRpc + 'static, + M::Payload: Send + 'static, + S: NotificationHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_notification_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn duplex(self) -> Self + where + M: DuplexRpc + 'static, + M::InitiatorEvent: Send + 'static, + M::ResponderEvent: Send + 'static, + S: DuplexHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_duplex_inner::( + state, + config, + reader, + writer, + S::handle, + )) + }) + } + + pub fn download(self) -> Self + where + M: DownloadRpc + 'static, + M::Request: Send + 'static, + S: DownloadHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_download_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn subscription(self) -> Self + where + M: SubscriptionRpc + 'static, + M::Request: Send + 'static, + S: SubscriptionHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_subscription_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn progress(self) -> Self + where + M: ProgressRpc + 'static, + M::Request: Send + 'static, + S: ProgressHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_progress_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } + + pub fn upload(self) -> Self + where + M: UploadRpc + 'static, + M::Request: Send + 'static, + S: UploadHandler + Send + 'static, + St::Reader: Send + 'static, + St::Writer: Send + 'static, + { + self.add_route(M::ROUTE, |spawner, state, config, stream| { + let (reader, writer) = stream.split(); + spawner.spawn(handle_upload_inner::( + state, + config, + reader, + writer, + S::handle, + S::handle_transport_error, + )) + }) + } +} diff --git a/ql-rpc/src/router/config.rs b/ql-rpc/src/router/config.rs new file mode 100644 index 0000000..d6fb048 --- /dev/null +++ b/ql-rpc/src/router/config.rs @@ -0,0 +1,12 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RouterConfig { + pub max_request_bytes: usize, +} + +impl Default for RouterConfig { + fn default() -> Self { + Self { + max_request_bytes: usize::MAX, + } + } +} diff --git a/ql-rpc/src/router/mod.rs b/ql-rpc/src/router/mod.rs new file mode 100644 index 0000000..31e973a --- /dev/null +++ b/ql-rpc/src/router/mod.rs @@ -0,0 +1,89 @@ +use crate::{RouteId, StreamCloseCode}; + +mod builder; +mod config; +mod mode; + +pub use self::{ + builder::{LocalRoutes, RouterBuilder, SendRoutes}, + config::RouterConfig, + mode::*, +}; +use crate::{close_stream, RpcStream}; +pub use crate::{ + download::{DownloadHandler, DownloadHandlerLocal, DownloadStart, DownloadWriter}, + duplex::{DuplexHandler, DuplexHandlerLocal, DuplexPeer}, + notification::{NotificationHandler, NotificationHandlerLocal}, + progress::{ProgressHandler, ProgressHandlerLocal, ProgressResponder}, + request::{RequestHandler, RequestHandlerLocal, Response}, + subscription::{SubscriptionHandler, SubscriptionHandlerLocal, SubscriptionResponder}, + upload::{UploadHandler, UploadHandlerLocal, UploadReader, UploadResponder}, +}; + +pub struct Router +where + Sp: Spawner, +{ + config: RouterConfig, + state: S, + spawner: Sp, + routes: Vec>, +} + +struct RouteEntry +where + Sp: Spawner, +{ + route_id: RouteId, + route: RouteFn, +} + +impl RouteEntry +where + Sp: Spawner, +{ + fn new(route_id: RouteId, route: RouteFn) -> Self { + Self { route_id, route } + } +} + +impl Router +where + S: Clone + 'static, + St: RpcStream, + Sp: Spawner, +{ + pub fn builder_local(spawner: Sp) -> RouterBuilder + where + Sp: LocalSpawner, + { + RouterBuilder::::new(spawner) + } + + pub fn builder_send(spawner: Sp) -> RouterBuilder + where + Sp: SendSpawner, + { + RouterBuilder::::new(spawner) + } + + pub fn handle(&self, stream: St) -> Option<(RouteId, Sp::Handle)> { + let route_id = stream.route_id()?; + let Ok(index) = self + .routes + .binary_search_by_key(&route_id, |entry| entry.route_id) + else { + close_stream(stream, StreamCloseCode::UNKNOWN_ROUTE); + return None; + }; + let route = self.routes[index].route; + Some(( + route_id, + route(&self.spawner, self.state.clone(), self.config, stream), + )) + } + + pub fn route_ids(&self) -> impl ExactSizeIterator + '_ { + self.routes.iter().map(|entry| entry.route_id) + } +} diff --git a/ql-rpc/src/router/mode.rs b/ql-rpc/src/router/mode.rs new file mode 100644 index 0000000..33b6c06 --- /dev/null +++ b/ql-rpc/src/router/mode.rs @@ -0,0 +1,21 @@ +use std::future::Future; + +use crate::RouterConfig; + +pub type RouteFn = fn(&Sp, S, RouterConfig, St) -> ::Handle; + +pub trait Spawner: Clone + 'static { + type Handle; +} + +pub trait LocalSpawner: Spawner { + fn spawn(&self, fut: F) -> Self::Handle + where + F: Future + 'static; +} + +pub trait SendSpawner: Spawner { + fn spawn(&self, fut: F) -> Self::Handle + where + F: Future + Send + 'static; +} diff --git a/ql-rpc/src/rpc/download/client.rs b/ql-rpc/src/rpc/download/client.rs new file mode 100644 index 0000000..9a64818 --- /dev/null +++ b/ql-rpc/src/rpc/download/client.rs @@ -0,0 +1,246 @@ +use std::future::poll_fn; + +use bytes::{BufMut, Bytes}; + +use crate::{ + download::{Download, PartReadStep}, + rpc::parts::FrameKind, + CallError, FramedPrefixStep, FramedReader, RpcCodec, RpcRead, StreamCloseCode, +}; + +pub struct DownloadCall +where + M: Download, + R: RpcRead, +{ + stream: Option, + reader: Option>, +} + +pub struct DownloadPart<'a, M, R> +where + M: Download, + R: RpcRead, +{ + parent: &'a mut DownloadReader, + finished: bool, +} + +pub struct DownloadReader +where + M: Download, + R: RpcRead, +{ + stream: Option, + reader: crate::download::PartFrameReader, +} + +impl DownloadCall +where + M: Download, + R: RpcRead, +{ + pub fn new(stream: R) -> Self { + Self { + stream: Some(stream), + reader: Some(FramedReader::default()), + } + } + + pub async fn start( + mut self, + ) -> Result<(M::ResponseHeader, DownloadReader), CallError> { + loop { + let reader = self.reader.take().unwrap(); + let reader = match reader.advance_prefix() { + Ok(FramedPrefixStep::Value { value, bytes }) => { + let stream = self.stream.take().unwrap(); + return Ok(( + value, + DownloadReader { + stream: Some(stream), + reader: crate::download::PartFrameReader::::new(bytes), + }, + )); + } + Ok(FramedPrefixStep::NeedMore(next)) => next, + Err(error) => return Err(error.into()), + }; + + let stream = self.stream.as_mut().unwrap(); + match poll_fn(|cx| stream.poll_read(usize::MAX, cx)).await { + Ok(Some(chunk)) => { + self.reader = Some(reader.push(chunk)); + } + Ok(None) => return Err(crate::Error::Truncated.into()), + Err(error) => return Err(CallError::Transport(error)), + } + } + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for DownloadCall +where + M: Download, + R: RpcRead, +{ + fn drop(&mut self) { + self.close_inner(StreamCloseCode::CANCELLED); + } +} + +impl DownloadReader +where + M: Download, + R: RpcRead, +{ + pub async fn next_part( + &mut self, + ) -> Result)>, CallError> + { + if self.stream.is_none() { + return Ok(None); + } + + match self.read_frame().await? { + PartReadStep::PartHeader(value) => Ok(Some(( + value, + DownloadPart { + parent: self, + finished: false, + }, + ))), + PartReadStep::Finish => { + self.stream.take(); + Ok(None) + } + PartReadStep::BodyBytes(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::BodyChunk.tag()).into()) + } + PartReadStep::EndPart => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::EndPart.tag()).into()) + } + PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), + } + } + + pub async fn complete(mut self) -> Result<(), CallError> { + match self.read_frame().await? { + PartReadStep::Finish => { + self.stream.take(); + Ok(()) + } + PartReadStep::PartHeader(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::PartHeader.tag()).into()) + } + PartReadStep::BodyBytes(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::BodyChunk.tag()).into()) + } + PartReadStep::EndPart => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::EndPart.tag()).into()) + } + PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), + } + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + async fn read_frame( + &mut self, + ) -> Result, CallError> { + loop { + match self.reader.advance() { + Ok(PartReadStep::NeedMore) => {} + Ok(step) => return Ok(step), + Err(error) => return Err(error.into()), + } + + let stream = self.stream.as_mut().unwrap(); + match poll_fn(|cx| stream.poll_read(usize::MAX, cx)).await { + Ok(Some(chunk)) => { + self.reader.push(chunk); + } + Ok(None) => return Err(crate::Error::Truncated.into()), + Err(error) => return Err(CallError::Transport(error)), + } + } + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for DownloadReader +where + M: Download, + R: RpcRead, +{ + fn drop(&mut self) { + if self.stream.is_some() { + self.close_inner(StreamCloseCode::CANCELLED); + } + } +} + +impl DownloadPart<'_, M, R> +where + M: Download, + R: RpcRead, +{ + pub async fn read_chunk(&mut self) -> Result, CallError> { + if self.finished { + return Ok(None); + } + + match self.parent.read_frame().await? { + PartReadStep::BodyBytes(bytes) => Ok(Some(bytes)), + PartReadStep::EndPart => { + self.finished = true; + Ok(None) + } + PartReadStep::PartHeader(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::PartHeader.tag()).into()) + } + PartReadStep::Finish => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::Finish.tag()).into()) + } + PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), + } + } + + pub fn close(mut self, code: StreamCloseCode) { + self.parent.close_inner(code); + self.finished = true; + } +} + +impl Drop for DownloadPart<'_, M, R> +where + M: Download, + R: RpcRead, +{ + fn drop(&mut self) { + if !self.finished { + self.parent.close_inner(StreamCloseCode::CANCELLED); + } + } +} + +pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { + request.encode_value(out) +} diff --git a/ql-rpc/src/rpc/download/mod.rs b/ql-rpc/src/rpc/download/mod.rs new file mode 100644 index 0000000..5ed34ae --- /dev/null +++ b/ql-rpc/src/rpc/download/mod.rs @@ -0,0 +1,31 @@ +use super::Route; +use crate::RpcCodec; + +pub(crate) mod client; +pub(crate) mod server; + +pub use client::{encode_request, DownloadCall, DownloadPart, DownloadReader}; +pub use server::{ + DownloadHandler, DownloadHandlerLocal, DownloadPartWriter, DownloadStart, DownloadWriter, +}; + +pub use crate::rpc::parts::{ + encode_body_chunk, encode_end_part, encode_finish, encode_part_header, PartFrameReader, + PartReadStep, +}; + +/// rpc where the responder returns metadata first and then zero or more byte parts +/// +/// the typed portion of the response ends at [`Self::ResponseHeader`] +/// after the header is decoded, the rest of the stream is exposed as typed +/// part headers followed by raw byte chunks through [`DownloadReader`] +pub trait Download: Route { + /// codec error shared by request and response header values + type Error; + /// typed input needed to start the download + type Request: RpcCodec; + /// typed metadata available before parts arrive + type ResponseHeader: RpcCodec; + /// typed metadata available before each byte part arrives + type PartHeader: RpcCodec; +} diff --git a/ql-rpc/src/rpc/download/server.rs b/ql-rpc/src/rpc/download/server.rs new file mode 100644 index 0000000..fcdcb04 --- /dev/null +++ b/ql-rpc/src/rpc/download/server.rs @@ -0,0 +1,221 @@ +use std::{future::Future, marker::PhantomData}; + +use bytes::Bytes; + +use crate::{ + codec, + download::Download as DownloadRpc, + finish_bytes, + rpc::{ + parts::{encode_body_chunk, encode_end_part, encode_finish, encode_part_header}, + read_eof_request, + }, + write_bytes, RouterConfig, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, +}; + +#[trait_variant::make(DownloadHandler: Send)] +pub trait DownloadHandlerLocal +where + M: DownloadRpc, + St: RpcStream, +{ + async fn handle(self, message: M::Request, download: DownloadStart); + + fn handle_transport_error(&self, _error: &St::Error) {} +} + +pub struct DownloadStart +where + M: DownloadRpc, + W: RpcWrite, +{ + writer: Option, + marker: PhantomData M>, +} + +pub struct DownloadWriter +where + M: DownloadRpc, + W: RpcWrite, +{ + writer: Option, + marker: PhantomData M>, +} + +pub struct DownloadPartWriter<'a, M, W> +where + M: DownloadRpc, + W: RpcWrite, +{ + parent: &'a mut DownloadWriter, + finished: bool, +} + +impl DownloadStart +where + M: DownloadRpc, + W: RpcWrite, +{ + pub(crate) fn new(writer: W) -> Self { + Self { + writer: Some(writer), + marker: PhantomData, + } + } + + /// send the response header and begin streaming parts + pub async fn start( + mut self, + response_header: M::ResponseHeader, + ) -> Result, W::Error> { + let mut writer = self.writer.take().unwrap(); + let mut encoded = Vec::new(); + codec::encode_value_part(&response_header, &mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)).await?; + Ok(DownloadWriter { + writer: Some(writer), + marker: PhantomData, + }) + } + + /// send a header-only response and finish the stream + pub async fn complete(mut self, response_header: M::ResponseHeader) -> Result<(), W::Error> { + let mut writer = self.writer.take().unwrap(); + let mut encoded = Vec::new(); + codec::encode_value_part(&response_header, &mut encoded); + encode_finish(&mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)).await?; + finish_bytes(&mut writer).await + } + + /// close the stream with a transport code + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for DownloadStart +where + M: DownloadRpc, + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +impl DownloadWriter +where + M: DownloadRpc, + W: RpcWrite, +{ + pub async fn start_part( + &mut self, + part_header: M::PartHeader, + ) -> Result, W::Error> { + let writer = self.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + encode_part_header(&part_header, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await?; + Ok(DownloadPartWriter { + parent: self, + finished: false, + }) + } + + pub async fn finish(mut self) -> Result<(), W::Error> { + let mut writer = self.writer.take().unwrap(); + let mut encoded = Vec::new(); + encode_finish(&mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)).await?; + finish_bytes(&mut writer).await + } + + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for DownloadWriter +where + M: DownloadRpc, + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +impl DownloadPartWriter<'_, M, W> +where + M: DownloadRpc, + W: RpcWrite, +{ + pub async fn send(&mut self, bytes: Bytes) -> Result<(), W::Error> { + let writer = self.parent.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + encode_body_chunk(&bytes, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await + } + + pub async fn finish(mut self) -> Result<(), W::Error> { + let writer = self.parent.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + encode_end_part(&mut encoded); + write_bytes(writer, Bytes::from(encoded)).await?; + self.finished = true; + Ok(()) + } +} + +impl Drop for DownloadPartWriter<'_, M, W> +where + M: DownloadRpc, + W: RpcWrite, +{ + fn drop(&mut self) { + if !self.finished { + if let Some(writer) = self.parent.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } + } +} + +pub(crate) async fn handle_download_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, + handle: H, + handle_transport_error: E, +) where + M: DownloadRpc + 'static, + St: RpcStream + 'static, + H: FnOnce(S, M::Request, DownloadStart) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), +{ + let request = match read_eof_request::(&mut reader, config).await { + Ok(request) => request, + Err(error) => { + let code = error.close_code(); + handle_transport_error(&state, &error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; + } + }; + + handle(state, request, DownloadStart::new(writer)).await; +} diff --git a/ql-rpc/src/rpc/duplex/client.rs b/ql-rpc/src/rpc/duplex/client.rs new file mode 100644 index 0000000..e76050a --- /dev/null +++ b/ql-rpc/src/rpc/duplex/client.rs @@ -0,0 +1,164 @@ +use std::{ + future::poll_fn, + marker::PhantomData, + task::{Context, Poll}, +}; + +use bytes::Bytes; + +use crate::{ + duplex::{codec, Duplex, EventReader, ReadStep}, + finish_bytes, write_bytes, CallError, RpcCodec, RpcRead, RpcWrite, StreamCloseCode, +}; + +pub struct DuplexCall +where + M: Duplex, + W: RpcWrite, + R: RpcRead, +{ + pub sender: DuplexSender, + pub receiver: DuplexReceiver, +} + +pub struct DuplexSender +where + T: RpcCodec, + W: RpcWrite, +{ + writer: Option, + marker: PhantomData T>, +} + +pub struct DuplexReceiver +where + T: RpcCodec, + R: RpcRead, +{ + stream: Option, + reader: EventReader, +} + +impl DuplexSender +where + T: RpcCodec, + W: RpcWrite, +{ + pub fn new(writer: W) -> Self { + Self { + writer: Some(writer), + marker: PhantomData, + } + } + + pub async fn send(&mut self, event: &T) -> Result<(), W::Error> { + let writer = self.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + codec::encode_event(event, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await + } + + pub async fn finish(mut self) -> Result<(), W::Error> { + let mut writer = self.writer.take().unwrap(); + finish_bytes(&mut writer).await + } + + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for DuplexSender +where + T: RpcCodec, + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +impl DuplexReceiver +where + T: RpcCodec, + R: RpcRead, +{ + pub fn new(stream: R) -> Self { + Self { + stream: Some(stream), + reader: EventReader::default(), + } + } + + pub async fn next_event(&mut self) -> Option>> { + poll_fn(|cx| self.poll_next_event(cx)).await + } + + pub fn poll_next_event( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + if self.stream.is_none() { + return Poll::Ready(None); + } + + loop { + match self.reader.advance() { + Ok(ReadStep::Event(value)) => return Poll::Ready(Some(Ok(value))), + Ok(ReadStep::NeedMore) => {} + Err(error) => { + self.stream.take(); + return Poll::Ready(Some(Err(error.into()))); + } + } + + let stream = self.stream.as_mut().unwrap(); + match stream.poll_read(usize::MAX, cx) { + Poll::Ready(Ok(Some(chunk))) => { + self.reader.push(chunk); + } + Poll::Ready(Ok(None)) => { + if self.reader.is_empty() { + self.stream.take(); + return Poll::Ready(None); + } + self.stream.take(); + return Poll::Ready(Some(Err(crate::Error::Truncated.into()))); + } + Poll::Ready(Err(error)) => { + self.stream.take(); + return Poll::Ready(Some(Err(CallError::Transport(error)))); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for DuplexReceiver +where + T: RpcCodec, + R: RpcRead, +{ + fn drop(&mut self) { + if self.stream.is_some() { + self.close_inner(StreamCloseCode::CANCELLED); + } + } +} diff --git a/ql-rpc/src/rpc/duplex/codec.rs b/ql-rpc/src/rpc/duplex/codec.rs new file mode 100644 index 0000000..68bc87c --- /dev/null +++ b/ql-rpc/src/rpc/duplex/codec.rs @@ -0,0 +1,86 @@ +use std::marker::PhantomData; + +use bytes::{BufMut, Bytes}; + +use crate::{codec, CodecError, RpcCodec}; + +pub fn encode_event(event: &T, out: &mut (impl BufMut + AsMut<[u8]>)) +where + T: RpcCodec, +{ + codec::encode_value_part(event, out) +} + +pub enum ReadStep { + NeedMore, + Event(T), +} + +pub struct EventReader { + bytes: codec::ChunkQueue, + marker: PhantomData T>, +} + +impl Default for EventReader { + fn default() -> Self { + Self { + bytes: codec::ChunkQueue::default(), + marker: PhantomData, + } + } +} + +impl EventReader { + pub fn push(&mut self, chunk: Bytes) { + self.bytes.push(chunk); + } + + pub fn is_empty(&self) -> bool { + self.bytes.remaining() == 0 + } + + pub fn advance(&mut self) -> Result, CodecError> { + let Some(mut body) = self.bytes.try_take_part()? else { + return Ok(ReadStep::NeedMore); + }; + + let value = { + let value = T::decode_value(&mut body).map_err(CodecError::Codec)?; + drop(body); + value + }; + Ok(ReadStep::Event(value)) + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::{encode_event, EventReader, ReadStep}; + + #[test] + fn event_reader_emits_multiple_events() { + let mut encoded = Vec::new(); + encode_event(&b"one".to_vec(), &mut encoded); + encode_event(&b"two".to_vec(), &mut encoded); + + let mut reader = EventReader::>::default(); + reader.push(Bytes::from(encoded)); + + match reader.advance().unwrap() { + ReadStep::Event(value) => { + assert_eq!(value, b"one".to_vec()); + } + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + ReadStep::Event(value) => { + assert_eq!(value, b"two".to_vec()); + assert!(reader.is_empty()); + } + _ => unreachable!(), + } + } +} diff --git a/ql-rpc/src/rpc/duplex/mod.rs b/ql-rpc/src/rpc/duplex/mod.rs new file mode 100644 index 0000000..a962202 --- /dev/null +++ b/ql-rpc/src/rpc/duplex/mod.rs @@ -0,0 +1,24 @@ +use super::Route; +use crate::RpcCodec; + +pub(crate) mod client; +pub(crate) mod codec; +pub(crate) mod server; + +pub use client::{DuplexCall, DuplexReceiver, DuplexSender}; +pub use codec::{encode_event, EventReader, ReadStep}; +pub use server::{DuplexHandler, DuplexHandlerLocal, DuplexPeer}; + +/// rpc where both sides exchange typed events on the same stream +/// +/// The initiator opens the routed stream. After that, either side may send any +/// number of events of its directional event type until it finishes or closes +/// its write side. +pub trait Duplex: Route { + /// codec error shared by both directional event values + type Error; + /// typed event sent by the side that opened the stream + type InitiatorEvent: RpcCodec; + /// typed event sent by the side handling the route + type ResponderEvent: RpcCodec; +} diff --git a/ql-rpc/src/rpc/duplex/server.rs b/ql-rpc/src/rpc/duplex/server.rs new file mode 100644 index 0000000..bf02433 --- /dev/null +++ b/ql-rpc/src/rpc/duplex/server.rs @@ -0,0 +1,47 @@ +use std::future::Future; + +use crate::{ + duplex::{Duplex, DuplexReceiver, DuplexSender}, + RpcRead, RpcStream, RpcWrite, +}; + +#[trait_variant::make(DuplexHandler: Send)] +pub trait DuplexHandlerLocal +where + M: Duplex, + St: RpcStream, +{ + async fn handle(self, peer: DuplexPeer); +} + +pub struct DuplexPeer +where + M: Duplex, + W: RpcWrite, + R: RpcRead, +{ + pub sender: DuplexSender, + pub receiver: DuplexReceiver, +} + +pub(crate) async fn handle_duplex_inner( + state: S, + _config: crate::RouterConfig, + reader: St::Reader, + writer: St::Writer, + handle: H, +) where + M: Duplex + 'static, + St: RpcStream + 'static, + H: FnOnce(S, DuplexPeer) -> HF, + HF: Future, +{ + handle( + state, + DuplexPeer { + sender: DuplexSender::new(writer), + receiver: DuplexReceiver::new(reader), + }, + ) + .await; +} diff --git a/ql-rpc/src/rpc/mod.rs b/ql-rpc/src/rpc/mod.rs new file mode 100644 index 0000000..2d84f05 --- /dev/null +++ b/ql-rpc/src/rpc/mod.rs @@ -0,0 +1,32 @@ +//! rpc protocol families built on top of one stream per call +//! +//! each trait in this module names one rpc shape and the typed values that +//! travel on that stream +//! route dispatch uses [`crate::RouteId`] and the submodules provide the matching +//! client and server helpers for encoding, decoding, and handler glue + +use crate::RouteId; + +pub mod download; +pub mod duplex; +pub mod notification; +pub(crate) mod parts; +pub mod progress; +pub mod request; +pub mod subscription; +pub mod upload; +mod utils; + +pub trait Route { + /// route used to dispatch this rpc family + const ROUTE: RouteId; +} + +pub use download::Download; +pub use duplex::Duplex; +pub use notification::Notification; +pub use progress::Progress; +pub use request::Request; +pub use subscription::Subscription; +pub use upload::Upload; +use utils::*; diff --git a/ql-rpc/src/rpc/notification/client.rs b/ql-rpc/src/rpc/notification/client.rs new file mode 100644 index 0000000..72b6900 --- /dev/null +++ b/ql-rpc/src/rpc/notification/client.rs @@ -0,0 +1,10 @@ +use bytes::BufMut; + +use crate::{notification::Notification, RpcCodec}; + +pub fn encode_notification( + payload: &M::Payload, + out: &mut (impl BufMut + AsMut<[u8]>), +) { + payload.encode_value(out) +} diff --git a/ql-rpc/src/rpc/notification/mod.rs b/ql-rpc/src/rpc/notification/mod.rs new file mode 100644 index 0000000..4740a64 --- /dev/null +++ b/ql-rpc/src/rpc/notification/mod.rs @@ -0,0 +1,19 @@ +use super::Route; +use crate::RpcCodec; + +pub(crate) mod client; +pub(crate) mod server; + +pub use client::encode_notification; +pub use server::{NotificationHandler, NotificationHandlerLocal}; + +/// one-way rpc that carries a single typed payload and no typed response +/// +/// the server reads [`Self::Payload`] to eof and then closes the response side +/// of the stream +pub trait Notification: Route { + /// codec error for the notification payload + type Error; + /// typed payload emitted by the caller + type Payload: RpcCodec; +} diff --git a/ql-rpc/src/rpc/notification/server.rs b/ql-rpc/src/rpc/notification/server.rs new file mode 100644 index 0000000..c9a4fdb --- /dev/null +++ b/ql-rpc/src/rpc/notification/server.rs @@ -0,0 +1,48 @@ +use std::future::Future; + +use crate::{ + notification::Notification as NotificationRpc, rpc::read_eof_request, RouterConfig, RpcRead, + RpcStream, RpcWrite, StreamCloseCode, StreamError, +}; + +#[trait_variant::make(NotificationHandler: Send)] +pub trait NotificationHandlerLocal +where + M: NotificationRpc, + St: RpcStream, +{ + async fn handle(self, message: M::Payload); + + fn handle_transport_error(&self, _error: &St::Error) {} +} + +pub(crate) async fn handle_notification_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, + handle: H, + handle_transport_error: E, +) where + M: NotificationRpc + 'static, + St: RpcStream + 'static, + H: FnOnce(S, M::Payload) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), +{ + let notification = match read_eof_request::(&mut reader, config).await { + Ok(notification) => notification, + Err(error) => { + let code = error.close_code(); + handle_transport_error(&state, &error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; + } + }; + + writer.close(StreamCloseCode::CANCELLED); + handle(state, notification).await; +} diff --git a/ql-rpc/src/rpc/parts.rs b/ql-rpc/src/rpc/parts.rs new file mode 100644 index 0000000..47ff1e8 --- /dev/null +++ b/ql-rpc/src/rpc/parts.rs @@ -0,0 +1,283 @@ +use std::marker::PhantomData; + +use bytes::{BufMut, Bytes}; + +use crate::{codec, ChunkQueue, CodecError, RpcCodec}; + +pub enum PartReadStep { + NeedMore, + PartHeader(H), + BodyBytes(Bytes), + EndPart, + Finish, +} + +pub struct PartFrameReader { + bytes: codec::ChunkQueue, + pending_frame: PendingFrame, + marker: PhantomData H>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PendingFrame { + None, + Control { kind: FrameKind, len: usize }, + Body { remaining: usize }, +} + +impl PendingFrame { + fn take(&mut self) -> Self { + std::mem::replace(self, Self::None) + } +} + +impl PartFrameReader { + pub fn new(bytes: ChunkQueue) -> Self { + Self { + bytes, + pending_frame: PendingFrame::None, + marker: PhantomData, + } + } + + pub fn push(&mut self, chunk: Bytes) { + self.bytes.push(chunk); + } + + pub fn advance(&mut self) -> Result, CodecError> { + loop { + match self.pending_frame.take() { + PendingFrame::Body { remaining } => { + if remaining == 0 { + continue; + } + + let Some(bytes) = self.bytes.pop_front(remaining) else { + self.pending_frame = PendingFrame::Body { remaining }; + return Ok(PartReadStep::NeedMore); + }; + + let remaining = remaining - bytes.len(); + self.pending_frame = if remaining == 0 { + PendingFrame::None + } else { + PendingFrame::Body { remaining } + }; + return Ok(PartReadStep::BodyBytes(bytes)); + } + PendingFrame::Control { kind, len } => { + let Some(mut body) = self.bytes.try_take_body(len) else { + self.pending_frame = PendingFrame::Control { kind, len }; + return Ok(PartReadStep::NeedMore); + }; + + match kind { + FrameKind::PartHeader => { + let value = H::decode_value(&mut body).map_err(CodecError::Codec)?; + return Ok(PartReadStep::PartHeader(value)); + } + FrameKind::BodyChunk => unreachable!("body chunk is not a control frame"), + FrameKind::EndPart => { + body.expect_empty()?; + return Ok(PartReadStep::EndPart); + } + FrameKind::Finish => { + body.expect_empty()?; + drop(body); + self.bytes.expect_empty()?; + return Ok(PartReadStep::Finish); + } + } + } + PendingFrame::None => { + let Some((kind, len)) = self + .bytes + .try_take_tagged_part_header() + .map_err(CodecError::Rpc)? + else { + return Ok(PartReadStep::NeedMore); + }; + + let kind = FrameKind::try_from(kind).map_err(CodecError::Rpc)?; + self.pending_frame = if kind == FrameKind::BodyChunk { + PendingFrame::Body { remaining: len } + } else { + PendingFrame::Control { kind, len } + }; + } + } + } + } +} + +pub fn encode_part_header(part_header: &H, out: &mut (impl BufMut + AsMut<[u8]>)) { + encode_tagged_value_part(FrameKind::PartHeader, part_header, out) +} + +pub fn encode_body_chunk(bytes: &Bytes, out: &mut (impl BufMut + AsMut<[u8]>)) { + encode_tagged_value_part(FrameKind::BodyChunk, bytes, out) +} + +pub fn encode_end_part(out: &mut (impl BufMut + AsMut<[u8]>)) { + encode_tagged_empty_part(FrameKind::EndPart, out) +} + +pub fn encode_finish(out: &mut (impl BufMut + AsMut<[u8]>)) { + encode_tagged_empty_part(FrameKind::Finish, out) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub(super) enum FrameKind { + PartHeader = 1, + BodyChunk = 2, + EndPart = 3, + Finish = 4, +} + +impl FrameKind { + pub fn tag(self) -> u8 { + self as u8 + } +} + +impl TryFrom for FrameKind { + type Error = crate::Error; + + fn try_from(value: u8) -> Result { + match value { + x if x == Self::PartHeader.tag() => Ok(Self::PartHeader), + x if x == Self::BodyChunk.tag() => Ok(Self::BodyChunk), + x if x == Self::EndPart.tag() => Ok(Self::EndPart), + x if x == Self::Finish.tag() => Ok(Self::Finish), + other => Err(crate::Error::UnexpectedFrameKind(other)), + } + } +} + +fn encode_tagged_value_part>( + kind: FrameKind, + value: &T, + out: &mut B, +) { + out.put_u8(kind.tag()); + let payload_start = codec::reserve_length(out); + value.encode_value(out); + codec::backpatch_length(out, payload_start); +} + +fn encode_tagged_empty_part>(kind: FrameKind, out: &mut B) { + out.put_u8(kind.tag()); + let payload_start = codec::reserve_length(out); + codec::backpatch_length(out, payload_start); +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::{ + encode_body_chunk, encode_end_part, encode_finish, encode_part_header, PartFrameReader, + PartReadStep, + }; + + #[test] + fn part_reader_emits_multipart_sequence() { + let mut encoded = Vec::new(); + encode_part_header(&b"a.txt".to_vec(), &mut encoded); + encode_body_chunk(&Bytes::from_static(b"hel"), &mut encoded); + encode_body_chunk(&Bytes::from_static(b"lo"), &mut encoded); + encode_end_part(&mut encoded); + encode_part_header(&b"b.txt".to_vec(), &mut encoded); + encode_end_part(&mut encoded); + encode_finish(&mut encoded); + + let mut reader = PartFrameReader::>::new(Default::default()); + reader.push(Bytes::from(encoded)); + + match reader.advance().unwrap() { + PartReadStep::PartHeader(value) => { + assert_eq!(value, b"a.txt".to_vec()); + } + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::BodyBytes(bytes) => assert_eq!(bytes, Bytes::from_static(b"hel")), + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::BodyBytes(bytes) => assert_eq!(bytes, Bytes::from_static(b"lo")), + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::EndPart => {} + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::PartHeader(value) => { + assert_eq!(value, b"b.txt".to_vec()); + } + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::EndPart => {} + _ => unreachable!(), + }; + + match reader.advance().unwrap() { + PartReadStep::Finish => {} + _ => unreachable!(), + } + } + + #[test] + fn part_reader_waits_for_complete_header_frame() { + let mut encoded = Vec::new(); + encode_part_header(&b"a.txt".to_vec(), &mut encoded); + let encoded = Bytes::from(encoded); + + let mut reader = PartFrameReader::>::new(Default::default()); + reader.push(encoded.slice(..4)); + match reader.advance().unwrap() { + PartReadStep::NeedMore => {} + _ => unreachable!(), + }; + + reader.push(encoded.slice(4..)); + match reader.advance().unwrap() { + PartReadStep::PartHeader(value) => assert_eq!(value, b"a.txt".to_vec()), + _ => unreachable!(), + } + } + + #[test] + fn body_chunk_frame_streams_after_header() { + let mut encoded = Vec::new(); + encode_body_chunk(&Bytes::from_static(b"hello"), &mut encoded); + let encoded = Bytes::from(encoded); + + let mut reader = PartFrameReader::>::new(Default::default()); + reader.push(encoded.slice(..9)); + match reader.advance().unwrap() { + PartReadStep::NeedMore => {} + _ => unreachable!(), + }; + + reader.push(encoded.slice(9..11)); + match reader.advance().unwrap() { + PartReadStep::BodyBytes(bytes) => assert_eq!(bytes, Bytes::from_static(b"he")), + _ => unreachable!(), + }; + + reader.push(encoded.slice(11..)); + match reader.advance().unwrap() { + PartReadStep::BodyBytes(bytes) => assert_eq!(bytes, Bytes::from_static(b"llo")), + _ => unreachable!(), + }; + } +} diff --git a/ql-rpc/src/rpc/progress/client.rs b/ql-rpc/src/rpc/progress/client.rs new file mode 100644 index 0000000..c2218c9 --- /dev/null +++ b/ql-rpc/src/rpc/progress/client.rs @@ -0,0 +1,152 @@ +use std::{ + future::{poll_fn, Future}, + pin::Pin, + task::{Context, Poll}, +}; + +use crate::{ + progress::{Progress, ReadStep, ResponseReader}, + CallError, Error, RpcRead, StreamCloseCode, +}; + +pub struct ProgressCall +where + M: Progress, + R: RpcRead, +{ + stream: Option, + state: State, +} + +enum State +where + M: Progress, +{ + Invalid, + Reading(ResponseReader), + Terminal(Result>), + Done, +} + +impl Unpin for ProgressCall +where + M: Progress, + R: RpcRead, +{ +} + +impl ProgressCall +where + M: Progress, + R: RpcRead, +{ + pub fn new(stream: R) -> Self { + Self { + stream: Some(stream), + state: State::Reading(ResponseReader::default()), + } + } + + pub async fn next_progress(&mut self) -> Option { + poll_fn(|cx| self.poll_next_progress(cx)).await + } + + fn poll_step(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + let reader = match &mut self.state { + State::Reading(reader) => reader, + State::Terminal(_) | State::Done => return Poll::Ready(None), + State::Invalid => panic!("invalid state"), + }; + + match reader.advance() { + Ok(ReadStep::Progress(value)) => return Poll::Ready(Some(value)), + Ok(ReadStep::Response(response)) => { + self.state = State::Terminal(Ok(response)); + return Poll::Ready(None); + } + Ok(ReadStep::NeedMore) => {} + Err(error) => { + self.state = State::Terminal(Err(error.into())); + return Poll::Ready(None); + } + } + + let stream = self.stream.as_mut().unwrap(); + match stream.poll_read(usize::MAX, cx) { + Poll::Ready(Ok(Some(chunk))) => { + let State::Reading(reader) = &mut self.state else { + panic!("invalid state"); + }; + reader.push(chunk); + } + Poll::Ready(Ok(None)) => { + self.state = State::Terminal(Err(Error::MissingResponse.into())); + return Poll::Ready(None); + } + Poll::Ready(Err(error)) => { + self.state = State::Terminal(Err(CallError::Transport(error))); + return Poll::Ready(None); + } + Poll::Pending => return Poll::Pending, + } + } + } + + pub fn poll_next_progress(&mut self, cx: &mut Context<'_>) -> Poll> { + self.poll_step(cx) + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + self.state = State::Done; + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for ProgressCall +where + M: Progress, + R: RpcRead, +{ + fn drop(&mut self) { + if matches!(self.state, State::Reading(_)) { + self.close_inner(StreamCloseCode::CANCELLED); + } + } +} + +impl Future for ProgressCall +where + M: Progress, + R: RpcRead, +{ + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + loop { + match this.poll_step(cx) { + Poll::Ready(Some(_)) => {} + Poll::Ready(None) => match std::mem::replace(&mut this.state, State::Invalid) { + State::Terminal(result) => { + this.state = State::Done; + return Poll::Ready(result); + } + State::Done => panic!("polled after completion"), + State::Invalid => panic!("polled during state transition"), + State::Reading(_) => { + panic!("progress call reached terminal step without result") + } + }, + Poll::Pending => return Poll::Pending, + } + } + } +} diff --git a/ql-rpc/src/rpc/progress/codec.rs b/ql-rpc/src/rpc/progress/codec.rs new file mode 100644 index 0000000..a0dc1b8 --- /dev/null +++ b/ql-rpc/src/rpc/progress/codec.rs @@ -0,0 +1,145 @@ +use std::marker::PhantomData; + +use bytes::{BufMut, Bytes}; + +use crate::{codec, progress::Progress, CodecError, Error, RpcCodec}; + +pub enum ReadStep { + NeedMore, + Progress(M::Progress), + Response(M::Response), +} + +pub struct ResponseReader { + bytes: codec::ChunkQueue, + marker: PhantomData M>, +} + +impl Default for ResponseReader { + fn default() -> Self { + Self { + bytes: codec::ChunkQueue::default(), + marker: PhantomData, + } + } +} + +impl ResponseReader { + pub fn push(&mut self, chunk: Bytes) { + self.bytes.push(chunk); + } + + pub fn advance(&mut self) -> Result, CodecError> { + let Some((kind, mut body)) = self.bytes.try_take_tagged_part().map_err(CodecError::Rpc)? + else { + return Ok(ReadStep::NeedMore); + }; + + match kind { + x if x == FrameKind::Progress as u8 => { + let value = { + let value = M::Progress::decode_value(&mut body).map_err(CodecError::Codec)?; + drop(body); + value + }; + Ok(ReadStep::Progress(value)) + } + x if x == FrameKind::Response as u8 => { + let response = M::Response::decode_value(&mut body).map_err(CodecError::Codec)?; + drop(body); + if self.bytes.remaining() > 0 { + Err(CodecError::Rpc(Error::TrailingBytes)) + } else { + Ok(ReadStep::Response(response)) + } + } + other => Err(CodecError::Rpc(Error::UnexpectedFrameKind(other))), + } + } +} + +pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { + codec::encode_value_part(request, out) +} + +pub fn encode_progress(progress: &M::Progress, out: &mut (impl BufMut + AsMut<[u8]>)) { + encode_tagged_value_part(FrameKind::Progress, progress, out) +} + +pub fn encode_response(response: &M::Response, out: &mut (impl BufMut + AsMut<[u8]>)) { + encode_tagged_value_part(FrameKind::Response, response, out) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +enum FrameKind { + Progress = 1, + Response = 2, +} + +fn encode_tagged_value_part>( + kind: FrameKind, + value: &T, + out: &mut B, +) { + out.put_u8(kind as u8); + let payload_start = codec::reserve_length(out); + value.encode_value(out); + codec::backpatch_length(out, payload_start); +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::{encode_progress, encode_response, ReadStep, ResponseReader}; + use crate::{progress::Progress, Route, RouteId}; + + struct Watch; + + impl Route for Watch { + const ROUTE: RouteId = RouteId::from_u32(11); + } + + impl Progress for Watch { + type Error = core::convert::Infallible; + type Request = Vec; + type Progress = Vec; + type Response = Vec; + } + + #[test] + fn response_reader_emits_progress_then_response() { + let mut encoded = Vec::new(); + encode_progress::(&b"10%".to_vec(), &mut encoded); + encode_response::(&b"done".to_vec(), &mut encoded); + + let mut reader = ResponseReader::::default(); + reader.push(Bytes::from(encoded)); + + match reader.advance().unwrap() { + ReadStep::Progress(value) => { + assert_eq!(value, b"10%".to_vec()); + } + _ => unreachable!(), + }; + match reader.advance().unwrap() { + ReadStep::Response(value) => assert_eq!(value, b"done".to_vec()), + _ => unreachable!(), + } + } + + #[test] + fn response_reader_handles_response_only() { + let mut encoded = Vec::new(); + encode_response::(&b"done".to_vec(), &mut encoded); + + let mut reader = ResponseReader::::default(); + reader.push(Bytes::from(encoded)); + + match reader.advance().unwrap() { + ReadStep::Response(value) => assert_eq!(value, b"done".to_vec()), + _ => unreachable!(), + } + } +} diff --git a/ql-rpc/src/rpc/progress/mod.rs b/ql-rpc/src/rpc/progress/mod.rs new file mode 100644 index 0000000..b21c826 --- /dev/null +++ b/ql-rpc/src/rpc/progress/mod.rs @@ -0,0 +1,27 @@ +use super::Route; +use crate::RpcCodec; + +pub(crate) mod client; +pub(crate) mod codec; +pub(crate) mod server; + +pub use client::ProgressCall; +pub use codec::{encode_progress, encode_request, encode_response, ReadStep, ResponseReader}; +pub use server::{ProgressHandler, ProgressHandlerLocal, ProgressResponder}; + +/// rpc where the responder streams progress values before a final response +/// +/// the request is length-delimited +/// response frames are tagged so the client can distinguish +/// [`Self::Progress`] items from the final [`Self::Response`] +/// reaching eof before the final response is an error +pub trait Progress: Route { + /// codec error shared by request, progress, and response values + type Error; + /// typed input sent by the caller + type Request: RpcCodec; + /// typed progress item emitted before completion + type Progress: RpcCodec; + /// typed terminal response that completes the call + type Response: RpcCodec; +} diff --git a/ql-rpc/src/rpc/progress/server.rs b/ql-rpc/src/rpc/progress/server.rs new file mode 100644 index 0000000..b94421c --- /dev/null +++ b/ql-rpc/src/rpc/progress/server.rs @@ -0,0 +1,106 @@ +use std::{future::Future, marker::PhantomData}; + +use bytes::Bytes; + +use crate::{ + finish_bytes, + progress::{encode_progress, encode_response, Progress}, + rpc::read_framed_request, + write_bytes, RouterConfig, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, +}; + +#[trait_variant::make(ProgressHandler: Send)] +pub trait ProgressHandlerLocal +where + M: Progress, + St: RpcStream, +{ + async fn handle(self, request: M::Request, responder: ProgressResponder); + + fn handle_transport_error(&self, _error: &St::Error) {} +} + +pub struct ProgressResponder +where + M: Progress, + W: RpcWrite, +{ + writer: Option, + marker: PhantomData M>, +} + +impl ProgressResponder +where + M: Progress, + W: RpcWrite, +{ + pub(crate) fn new(writer: W) -> Self { + Self { + writer: Some(writer), + marker: PhantomData, + } + } + + pub async fn send(&mut self, progress: M::Progress) -> Result<(), W::Error> { + let writer = self.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + encode_progress::(&progress, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await + } + + pub async fn finish(mut self, response: M::Response) -> Result<(), W::Error> { + let mut writer = self.writer.take().unwrap(); + let mut encoded = Vec::new(); + encode_response::(&response, &mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)).await?; + finish_bytes(&mut writer).await + } + + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for ProgressResponder +where + M: Progress, + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +pub(crate) async fn handle_progress_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, + handle: H, + handle_transport_error: E, +) where + M: Progress + 'static, + St: RpcStream + 'static, + H: FnOnce(S, M::Request, ProgressResponder) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), +{ + let request = match read_framed_request::(&mut reader, config).await { + Ok(request) => request, + Err(error) => { + let code = error.close_code(); + handle_transport_error(&state, &error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; + } + }; + + handle(state, request, ProgressResponder::new(writer)).await; +} diff --git a/ql-rpc/src/rpc/request/client.rs b/ql-rpc/src/rpc/request/client.rs new file mode 100644 index 0000000..e7ffb84 --- /dev/null +++ b/ql-rpc/src/rpc/request/client.rs @@ -0,0 +1,34 @@ +use bytes::BufMut; + +use crate::{read_bytes, request::Request, CallError, ChunkQueue, RpcCodec, RpcRead}; + +pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { + request.encode_value(out) +} + +pub fn encode_response(response: &M::Response, out: &mut (impl BufMut + AsMut<[u8]>)) { + response.encode_value(out) +} + +pub async fn read_response( + mut reader: R, +) -> Result> +where + M: Request, + R: RpcRead, +{ + let mut bytes = ChunkQueue::default(); + + while let Some(chunk) = read_bytes(&mut reader, usize::MAX) + .await + .map_err(CallError::Transport)? + { + bytes.push(chunk); + } + + let value = M::Response::decode_value(&mut bytes).map_err(CallError::Codec)?; + if bytes.remaining() > 0 { + return Err(crate::Error::TrailingBytes.into()); + } + Ok(value) +} diff --git a/ql-rpc/src/rpc/request/mod.rs b/ql-rpc/src/rpc/request/mod.rs new file mode 100644 index 0000000..adf3259 --- /dev/null +++ b/ql-rpc/src/rpc/request/mod.rs @@ -0,0 +1,23 @@ +use super::Route; +use crate::RpcCodec; + +pub(crate) mod client; +pub(crate) mod server; + +pub use client::{encode_request, encode_response, read_response}; +pub use server::{RequestHandler, RequestHandlerLocal, Response}; + +/// request-response rpc with exactly one typed value in each direction +/// +/// the request is read to eof on the server side, so callers must finish the +/// request stream after encoding [`Self::Request`] +/// the response is also read to eof and rejects trailing bytes after +/// [`Self::Response`] +pub trait Request: Route { + /// codec error shared by request and response values + type Error; + /// typed input sent by the caller + type Request: RpcCodec; + /// typed output returned by the responder + type Response: RpcCodec; +} diff --git a/ql-rpc/src/rpc/request/server.rs b/ql-rpc/src/rpc/request/server.rs new file mode 100644 index 0000000..5211cce --- /dev/null +++ b/ql-rpc/src/rpc/request/server.rs @@ -0,0 +1,96 @@ +use std::{future::Future, marker::PhantomData}; + +use bytes::Bytes; + +use crate::{ + finish_bytes, request::Request as RequestRpc, rpc::read_eof_request, write_bytes, RouterConfig, + RpcCodec, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, +}; + +#[trait_variant::make(RequestHandler: Send)] +pub trait RequestHandlerLocal +where + M: RequestRpc, + St: RpcStream, +{ + async fn handle(self, message: M::Request, responder: Response); + + fn handle_transport_error(&self, _error: &St::Error) {} +} + +pub struct Response +where + W: RpcWrite, +{ + writer: Option, + marker: PhantomData T>, +} + +impl Response +where + T: RpcCodec, + W: RpcWrite, +{ + pub(crate) fn new(writer: W) -> Self { + Self { + writer: Some(writer), + marker: PhantomData, + } + } + + pub async fn respond(mut self, response: T) -> Result<(), W::Error> { + let mut writer = self.writer.take().unwrap(); + let mut encoded = Vec::new(); + response.encode_value(&mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)).await?; + finish_bytes(&mut writer).await?; + Ok(()) + } + + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for Response +where + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +pub(crate) async fn handle_request_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, + handle: H, + handle_transport_error: E, +) where + M: RequestRpc + 'static, + St: RpcStream + 'static, + H: FnOnce(S, M::Request, Response) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), +{ + let request = match read_eof_request::(&mut reader, config).await { + Ok(request) => request, + Err(error) => { + let code = error.close_code(); + handle_transport_error(&state, &error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; + } + }; + + handle(state, request, Response::new(writer)).await; +} diff --git a/ql-rpc/src/rpc/subscription/client.rs b/ql-rpc/src/rpc/subscription/client.rs new file mode 100644 index 0000000..fe6aa5b --- /dev/null +++ b/ql-rpc/src/rpc/subscription/client.rs @@ -0,0 +1,99 @@ +use std::{ + future::poll_fn, + task::{Context, Poll}, +}; + +use crate::{ + subscription::{ReadStep, ResponseReader, Subscription}, + CallError, RpcRead, StreamCloseCode, +}; + +pub struct SubscriptionCall +where + M: Subscription, + R: RpcRead, +{ + stream: Option, + reader: ResponseReader, +} + +impl SubscriptionCall +where + M: Subscription, + R: RpcRead, +{ + pub fn new(stream: R) -> Self { + Self { + stream: Some(stream), + reader: ResponseReader::default(), + } + } + + pub async fn next_event(&mut self) -> Option>> { + poll_fn(|cx| self.poll_next_event(cx)).await + } + + pub fn poll_next_event( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + if self.stream.is_none() { + return Poll::Ready(None); + } + + loop { + match self.reader.advance() { + Ok(ReadStep::Item(value)) => return Poll::Ready(Some(Ok(value))), + Ok(ReadStep::NeedMore) => {} + Err(error) => { + self.stream.take(); + return Poll::Ready(Some(Err(error.into()))); + } + } + + let stream = self.stream.as_mut().unwrap(); + match stream.poll_read(usize::MAX, cx) { + Poll::Ready(Ok(Some(chunk))) => { + self.reader.push(chunk); + } + Poll::Ready(Ok(None)) => { + if self.reader.is_empty() { + self.stream.take(); + return Poll::Ready(None); + } + self.stream.take(); + return Poll::Ready(Some(Err(crate::Error::Truncated.into()))); + } + Poll::Ready(Err(error)) => { + self.stream.take(); + return Poll::Ready(Some(Err(CallError::Transport(error)))); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for SubscriptionCall +where + M: Subscription, + R: RpcRead, +{ + fn drop(&mut self) { + if self.stream.is_some() { + self.close_inner(StreamCloseCode::CANCELLED); + } + } +} diff --git a/ql-rpc/src/rpc/subscription/codec.rs b/ql-rpc/src/rpc/subscription/codec.rs new file mode 100644 index 0000000..bdd1620 --- /dev/null +++ b/ql-rpc/src/rpc/subscription/codec.rs @@ -0,0 +1,58 @@ +use std::marker::PhantomData; + +use bytes::{BufMut, Bytes}; + +use crate::{codec, subscription::Subscription, CodecError, RpcCodec}; + +pub fn encode_request( + request: &M::Request, + out: &mut (impl BufMut + AsMut<[u8]>), +) { + request.encode_value(out) +} + +pub fn encode_item(item: &M::Event, out: &mut (impl BufMut + AsMut<[u8]>)) { + codec::encode_value_part(item, out) +} + +pub enum ReadStep { + NeedMore, + Item(M::Event), +} + +pub struct ResponseReader { + bytes: codec::ChunkQueue, + marker: PhantomData M>, +} + +impl Default for ResponseReader { + fn default() -> Self { + Self { + bytes: codec::ChunkQueue::default(), + marker: PhantomData, + } + } +} + +impl ResponseReader { + pub fn push(&mut self, chunk: Bytes) { + self.bytes.push(chunk); + } + + pub fn is_empty(&self) -> bool { + self.bytes.remaining() == 0 + } + + pub fn advance(&mut self) -> Result, CodecError> { + let Some(mut body) = self.bytes.try_take_part()? else { + return Ok(ReadStep::NeedMore); + }; + + let item = { + let item = M::Event::decode_value(&mut body).map_err(CodecError::Codec)?; + drop(body); + item + }; + Ok(ReadStep::Item(item)) + } +} diff --git a/ql-rpc/src/rpc/subscription/mod.rs b/ql-rpc/src/rpc/subscription/mod.rs new file mode 100644 index 0000000..672eb9b --- /dev/null +++ b/ql-rpc/src/rpc/subscription/mod.rs @@ -0,0 +1,23 @@ +use super::Route; +use crate::RpcCodec; + +pub(crate) mod client; +pub(crate) mod codec; +pub(crate) mod server; + +pub use client::SubscriptionCall; +pub use codec::{encode_item, encode_request, ReadStep, ResponseReader}; +pub use server::{SubscriptionHandler, SubscriptionHandlerLocal, SubscriptionResponder}; + +/// rpc where one request opens a stream of typed events +/// +/// event frames are length-delimited and the stream ends cleanly at eof +/// any partial trailing frame is reported as truncation on the client side +pub trait Subscription: Route { + /// codec error shared by request and event values + type Error; + /// typed input that starts the subscription + type Request: RpcCodec; + /// typed event yielded by the responder + type Event: RpcCodec; +} diff --git a/ql-rpc/src/rpc/subscription/server.rs b/ql-rpc/src/rpc/subscription/server.rs new file mode 100644 index 0000000..6dfdd4b --- /dev/null +++ b/ql-rpc/src/rpc/subscription/server.rs @@ -0,0 +1,105 @@ +use std::{future::Future, marker::PhantomData}; + +use bytes::Bytes; + +use crate::{ + codec, finish_bytes, rpc::read_eof_request, subscription::Subscription as SubscriptionRpc, + write_bytes, RouterConfig, RpcCodec, RpcRead, RpcStream, RpcWrite, StreamCloseCode, + StreamError, +}; + +#[trait_variant::make(SubscriptionHandler: Send)] +pub trait SubscriptionHandlerLocal +where + M: SubscriptionRpc, + St: RpcStream, +{ + async fn handle( + self, + message: M::Request, + responder: SubscriptionResponder, + ); + + fn handle_transport_error(&self, _error: &St::Error) {} +} + +pub struct SubscriptionResponder +where + W: RpcWrite, +{ + writer: Option, + marker: PhantomData T>, +} + +impl SubscriptionResponder +where + T: RpcCodec, + W: RpcWrite, +{ + pub(crate) fn new(writer: W) -> Self { + Self { + writer: Some(writer), + marker: PhantomData, + } + } + + pub async fn send(&mut self, event: T) -> Result<(), W::Error> { + let writer = self.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + codec::encode_value_part(&event, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await?; + Ok(()) + } + + pub async fn finish(mut self) -> Result<(), W::Error> { + let mut writer = self.writer.take().unwrap(); + finish_bytes(&mut writer).await + } + + pub fn close(mut self, code: StreamCloseCode) { + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for SubscriptionResponder +where + W: RpcWrite, +{ + fn drop(&mut self) { + if let Some(writer) = self.writer.take() { + writer.close(StreamCloseCode::CANCELLED); + } + } +} + +pub(crate) async fn handle_subscription_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, + handle: H, + handle_transport_error: E, +) where + M: SubscriptionRpc + 'static, + St: RpcStream + 'static, + H: FnOnce(S, M::Request, SubscriptionResponder) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), +{ + let request = match read_eof_request::(&mut reader, config).await { + Ok(request) => request, + Err(error) => { + let code = error.close_code(); + handle_transport_error(&state, &error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; + } + }; + + handle(state, request, SubscriptionResponder::new(writer)).await; +} diff --git a/ql-rpc/src/rpc/upload/client.rs b/ql-rpc/src/rpc/upload/client.rs new file mode 100644 index 0000000..b41dedc --- /dev/null +++ b/ql-rpc/src/rpc/upload/client.rs @@ -0,0 +1,146 @@ +use bytes::{BufMut, Bytes}; + +use crate::{ + finish_bytes, read_bytes, + rpc::parts::{encode_body_chunk, encode_end_part, encode_finish, encode_part_header}, + upload::Upload, + write_bytes, CallError, ChunkQueue, RpcCodec, RpcRead, RpcWrite, StreamCloseCode, +}; + +pub struct UploadCall +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + writer: Option, + reader: Option, + marker: std::marker::PhantomData M>, +} + +pub struct UploadPartWriter<'a, M, W, R> +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + parent: &'a mut UploadCall, + finished: bool, +} + +impl UploadCall +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + pub fn new(writer: W, reader: R) -> Self { + Self { + writer: Some(writer), + reader: Some(reader), + marker: std::marker::PhantomData, + } + } + + pub async fn start_part( + &mut self, + part_header: M::PartHeader, + ) -> Result, W::Error> { + let writer = self.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + encode_part_header(&part_header, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await?; + Ok(UploadPartWriter { + parent: self, + finished: false, + }) + } + + pub async fn finish(mut self) -> Result> { + let mut writer = self.writer.take().unwrap(); + let mut encoded = Vec::new(); + encode_finish(&mut encoded); + write_bytes(&mut writer, Bytes::from(encoded)) + .await + .map_err(CallError::Transport)?; + finish_bytes(&mut writer) + .await + .map_err(CallError::Transport)?; + + let mut reader = self.reader.take().unwrap(); + let mut bytes = ChunkQueue::default(); + + while let Some(chunk) = read_bytes(&mut reader, usize::MAX) + .await + .map_err(CallError::Transport)? + { + bytes.push(chunk); + } + + let value = M::Response::decode_value(&mut bytes).map_err(CallError::Codec)?; + if bytes.remaining() > 0 { + return Err(crate::Error::TrailingBytes.into()); + } + Ok(value) + } + + fn close(&mut self, code: StreamCloseCode) { + if let Some(reader) = self.reader.take() { + reader.close(code); + } + if let Some(writer) = self.writer.take() { + writer.close(code); + } + } +} + +impl Drop for UploadCall +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + fn drop(&mut self) { + self.close(StreamCloseCode::CANCELLED); + } +} + +impl UploadPartWriter<'_, M, W, R> +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + pub async fn send(&mut self, bytes: Bytes) -> Result<(), W::Error> { + let writer = self.parent.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + encode_body_chunk(&bytes, &mut encoded); + write_bytes(writer, Bytes::from(encoded)).await + } + + pub async fn finish(mut self) -> Result<(), W::Error> { + let writer = self.parent.writer.as_mut().unwrap(); + let mut encoded = Vec::new(); + encode_end_part(&mut encoded); + write_bytes(writer, Bytes::from(encoded)).await?; + self.finished = true; + Ok(()) + } +} + +impl Drop for UploadPartWriter<'_, M, W, R> +where + M: Upload, + W: RpcWrite, + R: RpcRead, +{ + fn drop(&mut self) { + if !self.finished { + self.parent.close(StreamCloseCode::CANCELLED); + } + } +} + +pub fn encode_request(request: &M::Request, out: &mut (impl BufMut + AsMut<[u8]>)) { + crate::codec::encode_value_part(request, out) +} diff --git a/ql-rpc/src/rpc/upload/mod.rs b/ql-rpc/src/rpc/upload/mod.rs new file mode 100644 index 0000000..9f96a82 --- /dev/null +++ b/ql-rpc/src/rpc/upload/mod.rs @@ -0,0 +1,26 @@ +use super::Route; +use crate::RpcCodec; + +pub(crate) mod client; +pub(crate) mod server; + +pub use client::{encode_request, UploadCall, UploadPartWriter}; +pub use server::{UploadHandler, UploadHandlerLocal, UploadPart, UploadReader, UploadResponder}; + +/// rpc where the caller uploads zero or more byte parts after a typed request +/// +/// the typed request usually describes how the responder should interpret the +/// following parts +/// the request is length-delimited so raw upload bytes can follow immediately +/// once the upload reaches eof, the responder returns one typed +/// [`Self::Response`] +pub trait Upload: Route { + /// codec error shared by request and response values + type Error; + /// typed input needed before request body bytes arrive + type Request: RpcCodec; + /// typed metadata available before each byte part arrives + type PartHeader: RpcCodec; + /// typed terminal result after the upload body is fully read + type Response: RpcCodec; +} diff --git a/ql-rpc/src/rpc/upload/server.rs b/ql-rpc/src/rpc/upload/server.rs new file mode 100644 index 0000000..d2e6765 --- /dev/null +++ b/ql-rpc/src/rpc/upload/server.rs @@ -0,0 +1,243 @@ +use std::future::{poll_fn, Future}; + +use bytes::Bytes; + +use crate::{ + request::Response, + rpc::{ + parts::{FrameKind, PartFrameReader, PartReadStep}, + read_framed_request_prefix, + }, + RouterConfig, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError, Upload, +}; + +#[trait_variant::make(UploadHandler: Send)] +pub trait UploadHandlerLocal +where + M: Upload, + St: RpcStream, +{ + async fn handle( + self, + request: M::Request, + upload: UploadReader, + responder: UploadResponder, + ); + + fn handle_transport_error(&self, _error: &St::Error) {} +} + +pub struct UploadReader +where + M: Upload, + R: RpcRead, +{ + stream: Option, + reader: PartFrameReader, +} + +pub struct UploadPart<'a, M, R> +where + M: Upload, + R: RpcRead, +{ + parent: &'a mut UploadReader, + finished: bool, +} + +pub struct UploadResponder +where + W: RpcWrite, +{ + inner: Response, +} + +impl UploadReader +where + M: Upload, + R: RpcRead, +{ + pub async fn next_part( + &mut self, + ) -> Result)>, crate::CallError> + { + if self.stream.is_none() { + return Ok(None); + } + + match self.read_frame().await? { + PartReadStep::PartHeader(value) => Ok(Some(( + value, + UploadPart { + parent: self, + finished: false, + }, + ))), + PartReadStep::Finish => { + self.stream.take(); + Ok(None) + } + PartReadStep::BodyBytes(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::BodyChunk.tag()).into()) + } + PartReadStep::EndPart => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::EndPart.tag()).into()) + } + PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), + } + } + + async fn read_frame( + &mut self, + ) -> Result, crate::CallError> { + loop { + match self.reader.advance() { + Ok(PartReadStep::NeedMore) => {} + Ok(step) => return Ok(step), + Err(error) => return Err(error.into()), + } + + let stream = self.stream.as_mut().unwrap(); + match poll_fn(|cx| stream.poll_read(usize::MAX, cx)).await { + Ok(Some(chunk)) => { + self.reader.push(chunk); + } + Ok(None) => return Err(crate::Error::Truncated.into()), + Err(error) => return Err(crate::CallError::Transport(error)), + } + } + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if let Some(stream) = self.stream.take() { + stream.close(code); + } + } +} + +impl Drop for UploadReader +where + M: Upload, + R: RpcRead, +{ + fn drop(&mut self) { + if self.stream.is_some() { + self.close_inner(StreamCloseCode::CANCELLED); + } + } +} + +impl UploadPart<'_, M, R> +where + M: Upload, + R: RpcRead, +{ + pub async fn read_chunk( + &mut self, + ) -> Result, crate::CallError> { + if self.finished { + return Ok(None); + } + + match self.parent.read_frame().await? { + PartReadStep::BodyBytes(bytes) => Ok(Some(bytes)), + PartReadStep::EndPart => { + self.finished = true; + Ok(None) + } + PartReadStep::PartHeader(_) => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::PartHeader.tag()).into()) + } + PartReadStep::Finish => { + Err(crate::Error::UnexpectedFrameKind(FrameKind::Finish.tag()).into()) + } + PartReadStep::NeedMore => unreachable!("read_frame waits for a complete frame"), + } + } + + pub fn close(mut self, code: StreamCloseCode) { + self.parent.close_inner(code); + self.finished = true; + } +} + +impl Drop for UploadPart<'_, M, R> +where + M: Upload, + R: RpcRead, +{ + fn drop(&mut self) { + if !self.finished { + self.parent.close_inner(StreamCloseCode::CANCELLED); + } + } +} + +impl UploadResponder +where + T: crate::RpcCodec, + W: RpcWrite, +{ + pub(crate) fn new(writer: W) -> Self { + Self { + inner: Response::new(writer), + } + } + + pub async fn respond(self, response: T) -> Result<(), W::Error> { + self.inner.respond(response).await + } + + pub fn close(self, code: StreamCloseCode) { + self.inner.close(code); + } +} + +pub(crate) async fn handle_upload_inner( + state: S, + config: RouterConfig, + mut reader: St::Reader, + writer: St::Writer, + handle: H, + handle_transport_error: E, +) where + M: Upload + 'static, + St: RpcStream + 'static, + H: FnOnce( + S, + M::Request, + UploadReader, + UploadResponder, + ) -> HF, + HF: Future, + E: FnOnce(&S, &St::Error), +{ + let (request, buffered) = + match read_framed_request_prefix::(&mut reader, config).await { + Ok(value) => value, + Err(error) => { + let code = error.close_code(); + handle_transport_error(&state, &error); + if let Some(code) = code { + reader.close(code); + writer.close(code); + } + return; + } + }; + + handle( + state, + request, + UploadReader { + stream: Some(reader), + reader: PartFrameReader::new(buffered), + }, + UploadResponder::new(writer), + ) + .await; +} diff --git a/ql-rpc/src/rpc/utils.rs b/ql-rpc/src/rpc/utils.rs new file mode 100644 index 0000000..bf5f49e --- /dev/null +++ b/ql-rpc/src/rpc/utils.rs @@ -0,0 +1,120 @@ +use crate::{ + read_bytes, ChunkQueue, CodecError, FramedPrefixStep, FramedReadStep, FramedReader, + RouterConfig, RpcCodec, RpcRead, StreamCloseCode, +}; + +/// reads one length-delimited value and rejects trailing bytes +pub(crate) async fn read_framed_request( + reader: &mut R, + config: RouterConfig, +) -> Result +where + T: RpcCodec, + R: RpcRead, +{ + let mut value_reader = FramedReader::::default(); + let mut total_read = 0usize; + + let value = loop { + match value_reader.advance() { + Ok(FramedReadStep::Value(value)) => break value, + Ok(FramedReadStep::NeedMore(next)) => value_reader = next, + Err(CodecError::Rpc(_error)) => return Err(StreamCloseCode::REFUSED.into()), + Err(CodecError::Codec(_error)) => return Err(StreamCloseCode::REFUSED.into()), + } + + let remaining = config.max_request_bytes.saturating_sub(total_read); + if remaining == 0 { + return Err(StreamCloseCode::LIMIT.into()); + } + + match read_bytes(reader, remaining).await { + Ok(Some(chunk)) => { + total_read += chunk.len(); + value_reader = value_reader.push(chunk); + } + Ok(None) => return Err(StreamCloseCode::REFUSED.into()), + Err(error) => return Err(error), + } + }; + + let remaining = config.max_request_bytes.saturating_sub(total_read); + let probe = remaining.max(1); + match read_bytes(reader, probe).await { + Ok(None) => Ok(value), + Ok(Some(_)) if remaining == 0 => Err(StreamCloseCode::LIMIT.into()), + Ok(Some(_)) => Err(StreamCloseCode::REFUSED.into()), + Err(error) => Err(error), + } +} + +/// reads one length-delimited value and returns any bytes already buffered +pub(crate) async fn read_framed_request_prefix( + reader: &mut R, + config: RouterConfig, +) -> Result<(T, ChunkQueue), R::Error> +where + T: RpcCodec, + R: RpcRead, +{ + let mut value_reader = FramedReader::::default(); + let mut total_read = 0usize; + + loop { + match value_reader.advance_prefix() { + Ok(FramedPrefixStep::Value { value, bytes }) => return Ok((value, bytes)), + Ok(FramedPrefixStep::NeedMore(next)) => value_reader = next, + Err(CodecError::Rpc(_error)) => return Err(StreamCloseCode::REFUSED.into()), + Err(CodecError::Codec(_error)) => return Err(StreamCloseCode::REFUSED.into()), + } + + let remaining = config.max_request_bytes.saturating_sub(total_read); + if remaining == 0 { + return Err(StreamCloseCode::LIMIT.into()); + } + + match read_bytes(reader, remaining).await { + Ok(Some(chunk)) => { + total_read += chunk.len(); + value_reader = value_reader.push(chunk); + } + Ok(None) => return Err(StreamCloseCode::REFUSED.into()), + Err(error) => return Err(error), + } + } +} + +/// reads one eof-delimited value up to the configured request limit +pub(crate) async fn read_eof_request( + reader: &mut R, + config: RouterConfig, +) -> Result +where + T: RpcCodec, + R: RpcRead, +{ + let mut bytes = ChunkQueue::default(); + let mut total_read = 0usize; + + loop { + let remaining = config.max_request_bytes.saturating_sub(total_read); + let probe = remaining.max(1); + match read_bytes(reader, probe).await { + Ok(Some(chunk)) => { + if chunk.len() > remaining { + return Err(StreamCloseCode::LIMIT.into()); + } + total_read += chunk.len(); + bytes.push(chunk); + } + Ok(None) => break, + Err(error) => return Err(error), + } + } + + let value = T::decode_value(&mut bytes).map_err(|_error| StreamCloseCode::REFUSED)?; + if bytes.remaining() > 0 { + return Err(StreamCloseCode::REFUSED.into()); + } + Ok(value) +} diff --git a/ql-rpc/src/stream.rs b/ql-rpc/src/stream.rs new file mode 100644 index 0000000..f6174ef --- /dev/null +++ b/ql-rpc/src/stream.rs @@ -0,0 +1,89 @@ +use std::{ + future::poll_fn, + task::{Context, Poll}, +}; + +use bytes::Bytes; + +use crate::{RouteId, StreamCloseCode}; + +pub trait RpcStream { + type Error: StreamError; + type Reader: RpcRead; + type Writer: RpcWrite; + + fn route_id(&self) -> Option; + fn split(self) -> (Self::Reader, Self::Writer); +} + +pub trait RpcRead { + type Error: StreamError; + + /// reads inbound bytes until eof or error + fn poll_read( + &mut self, + max_len: usize, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>; + + /// aborts the read side + fn close(self, code: StreamCloseCode); +} + +pub trait RpcWrite { + type Error: StreamError; + + /// writes outbound bytes before finish or close + fn poll_write( + &mut self, + bytes: &mut Bytes, + cx: &mut Context<'_>, + ) -> Poll>; + + /// completes the write side and must be polled until ready without further write or close calls + fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll>; + + /// aborts the write side before finish + fn close(self, code: StreamCloseCode); +} + +pub trait StreamError: From { + fn close_code(&self) -> Option; +} + +impl StreamError for StreamCloseCode { + fn close_code(&self) -> Option { + Some(*self) + } +} + +pub async fn read_bytes(reader: &mut R, max_len: usize) -> Result, R::Error> +where + R: RpcRead, +{ + poll_fn(|cx| reader.poll_read(max_len, cx)).await +} + +pub async fn write_bytes(writer: &mut W, bytes: Bytes) -> Result<(), W::Error> +where + W: RpcWrite, +{ + let mut bytes = bytes; + poll_fn(|cx| writer.poll_write(&mut bytes, cx)).await +} + +pub async fn finish_bytes(writer: &mut W) -> Result<(), W::Error> +where + W: RpcWrite, +{ + poll_fn(|cx| writer.poll_finish(cx)).await +} + +pub fn close_stream(stream: St, code: StreamCloseCode) +where + St: RpcStream, +{ + let (reader, writer) = stream.split(); + reader.close(code); + writer.close(code); +} From 154a4d9ccb4a5e74bbbbbfd5667e70bdcccf33a0 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 4 Jun 2026 09:26:25 -0400 Subject: [PATCH 5/6] ql-runtime: add async runtime --- Cargo.lock | 447 ++++++++++++++++-- Cargo.toml | 1 + ql-runtime/Cargo.toml | 35 ++ ql-runtime/src/command.rs | 57 +++ ql-runtime/src/driver/mod.rs | 609 +++++++++++++++++++++++++ ql-runtime/src/driver/state.rs | 165 +++++++ ql-runtime/src/driver/test.rs | 207 +++++++++ ql-runtime/src/error.rs | 18 + ql-runtime/src/handle/mod.rs | 96 ++++ ql-runtime/src/io/inner.rs | 643 ++++++++++++++++++++++++++ ql-runtime/src/io/mod.rs | 59 +++ ql-runtime/src/io/reader.rs | 236 ++++++++++ ql-runtime/src/io/slot.rs | 175 +++++++ ql-runtime/src/io/sync.rs | 89 ++++ ql-runtime/src/io/writer.rs | 294 ++++++++++++ ql-runtime/src/lib.rs | 63 +++ ql-runtime/src/log.rs | 54 +++ ql-runtime/src/platform.rs | 43 ++ ql-runtime/src/rpc/adapter.rs | 83 ++++ ql-runtime/src/rpc/download.rs | 67 +++ ql-runtime/src/rpc/duplex.rs | 59 +++ ql-runtime/src/rpc/error.rs | 79 ++++ ql-runtime/src/rpc/mod.rs | 154 +++++++ ql-runtime/src/rpc/progress.rs | 50 ++ ql-runtime/src/rpc/subscription.rs | 43 ++ ql-runtime/src/rpc/upload.rs | 44 ++ ql-runtime/src/tests/handshake.rs | 178 ++++++++ ql-runtime/src/tests/mod.rs | 710 +++++++++++++++++++++++++++++ ql-runtime/src/tests/rpc.rs | 677 +++++++++++++++++++++++++++ ql-runtime/src/tests/session.rs | 213 +++++++++ ql-runtime/src/tests/stream.rs | 673 +++++++++++++++++++++++++++ 31 files changed, 6295 insertions(+), 26 deletions(-) create mode 100644 ql-runtime/Cargo.toml create mode 100644 ql-runtime/src/command.rs create mode 100644 ql-runtime/src/driver/mod.rs create mode 100644 ql-runtime/src/driver/state.rs create mode 100644 ql-runtime/src/driver/test.rs create mode 100644 ql-runtime/src/error.rs create mode 100644 ql-runtime/src/handle/mod.rs create mode 100644 ql-runtime/src/io/inner.rs create mode 100644 ql-runtime/src/io/mod.rs create mode 100644 ql-runtime/src/io/reader.rs create mode 100644 ql-runtime/src/io/slot.rs create mode 100644 ql-runtime/src/io/sync.rs create mode 100644 ql-runtime/src/io/writer.rs create mode 100644 ql-runtime/src/lib.rs create mode 100644 ql-runtime/src/log.rs create mode 100644 ql-runtime/src/platform.rs create mode 100644 ql-runtime/src/rpc/adapter.rs create mode 100644 ql-runtime/src/rpc/download.rs create mode 100644 ql-runtime/src/rpc/duplex.rs create mode 100644 ql-runtime/src/rpc/error.rs create mode 100644 ql-runtime/src/rpc/mod.rs create mode 100644 ql-runtime/src/rpc/progress.rs create mode 100644 ql-runtime/src/rpc/subscription.rs create mode 100644 ql-runtime/src/rpc/upload.rs create mode 100644 ql-runtime/src/tests/handshake.rs create mode 100644 ql-runtime/src/tests/mod.rs create mode 100644 ql-runtime/src/tests/rpc.rs create mode 100644 ql-runtime/src/tests/session.rs create mode 100644 ql-runtime/src/tests/stream.rs diff --git a/Cargo.lock b/Cargo.lock index 071d36b..123d0e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -72,7 +72,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbb4e440d04be07da1f1bf44fb4495ebd58669372fe0cffa6e48595ac5bd88a3" dependencies = [ "android_log-sys", - "env_filter", + "env_filter 0.1.3", "log", ] @@ -85,6 +85,56 @@ dependencies = [ "libc", ] +[[package]] +name = "anstream" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" + +[[package]] +name = "anstyle-parse" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys 0.61.2", +] + [[package]] name = "anyhow" version = "1.0.99" @@ -109,6 +159,18 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "async-channel" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "924ed96dd52d1b75e9c1a3e6275715fd320f5f9439fb5a4a11fa51f4221158d2" +dependencies = [ + "concurrent-queue", + "event-listener-strategy", + "futures-core", + "pin-project-lite", +] + [[package]] name = "atomic" version = "0.5.3" @@ -341,9 +403,9 @@ dependencies = [ [[package]] name = "bitflags" -version = "2.12.1" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84d7ced0ae9557296835c32bf1b1e02b44c746701f898460fb000d7eaa84f00a" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" [[package]] name = "blake2" @@ -494,7 +556,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -508,6 +570,22 @@ dependencies = [ "zeroize", ] +[[package]] +name = "colorchoice" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" + +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", + "loom", +] + [[package]] name = "console" version = "0.15.11" @@ -517,7 +595,7 @@ dependencies = [ "encode_unicode", "libc", "once_cell", - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -591,6 +669,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crunchy" version = "0.2.4" @@ -704,6 +788,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "diatomic-waker" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab03c107fafeb3ee9f5925686dbb7a73bc76e3932abb0d2b365cb64b169cf04c" + [[package]] name = "digest" version = "0.10.7" @@ -829,6 +919,29 @@ dependencies = [ "regex", ] +[[package]] +name = "env_filter" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a" +dependencies = [ + "anstream", + "anstyle", + "env_filter 1.0.1", + "jiff", + "log", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -842,14 +955,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "event-listener" +version = "5.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" +dependencies = [ + "concurrent-queue", + "loom", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" +dependencies = [ + "event-listener", + "pin-project-lite", ] [[package]] name = "fastrand" -version = "2.4.1" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "ff" @@ -989,6 +1124,19 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-lite" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad" +dependencies = [ + "fastrand", + "futures-core", + "futures-io", + "parking", + "pin-project-lite", +] + [[package]] name = "futures-macro" version = "0.3.31" @@ -1030,6 +1178,21 @@ dependencies = [ "slab", ] +[[package]] +name = "generator" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52f04ae4152da20c76fe800fa48659201d5cf627c5149ca0b707b69d7eef6cf9" +dependencies = [ + "cc", + "cfg-if", + "libc", + "log", + "rustversion", + "windows-link 0.2.1", + "windows-result", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -1380,6 +1543,12 @@ dependencies = [ "libc", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + [[package]] name = "itertools" version = "0.11.0" @@ -1395,6 +1564,30 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +[[package]] +name = "jiff" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a3546dc96b6d42c5f24902af9e2538e82e39ad350b0c766eb3fbf2d8f3d8359" +dependencies = [ + "jiff-static", + "log", + "portable-atomic", + "portable-atomic-util", + "serde_core", +] + +[[package]] +name = "jiff-static" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a8c8b344124222efd714b73bb41f8b5120b27a7cc1c75593a6ff768d9d05aa4" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "jobserver" version = "0.1.33" @@ -1549,9 +1742,31 @@ dependencies = [ [[package]] name = "log" -version = "0.4.27" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "loom" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] [[package]] name = "md-5" @@ -1615,7 +1830,7 @@ checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" dependencies = [ "libc", "wasi 0.11.1+wasi-snapshot-preview1", - "windows-sys", + "windows-sys 0.59.0", ] [[package]] @@ -1638,6 +1853,15 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -1719,6 +1943,18 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + +[[package]] +name = "oneshot" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea" + [[package]] name = "opaque-debug" version = "0.3.1" @@ -1774,6 +2010,15 @@ dependencies = [ "sha2", ] +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" +dependencies = [ + "loom", +] + [[package]] name = "parking_lot_core" version = "0.9.11" @@ -1806,9 +2051,9 @@ checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] name = "pastey" -version = "0.2.3" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ee67f1008b1ba2321834326597b8e186293b049a023cdef258527550b9935b4" +checksum = "b867cad97c0791bbd3aaa6472142568c6c9e8f71937e98379f584cfb0cf35bec" [[package]] name = "pbkdf2" @@ -1927,6 +2172,15 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +[[package]] +name = "portable-atomic-util" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a106d1259c23fac8e543272398ae0e3c0b8d33c88ed73d0cc71b0f1d902618" +dependencies = [ + "portable-atomic", +] + [[package]] name = "potential_utf" version = "0.1.2" @@ -2107,6 +2361,25 @@ dependencies = [ "trait-variant", ] +[[package]] +name = "ql-runtime" +version = "0.1.0" +dependencies = [ + "async-channel", + "bytes", + "diatomic-waker", + "env_logger", + "event-listener", + "futures-lite", + "log", + "loom", + "oneshot", + "ql-fsm", + "ql-rpc", + "ql-wire", + "tokio", +] + [[package]] name = "ql-wire" version = "0.1.0" @@ -2245,9 +2518,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.11.1" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" dependencies = [ "aho-corasick", "memchr", @@ -2257,9 +2530,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.9" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" dependencies = [ "aho-corasick", "memchr", @@ -2367,7 +2640,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -2403,6 +2676,12 @@ dependencies = [ "cipher", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" @@ -2462,18 +2741,28 @@ checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -2514,6 +2803,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shlex" version = "1.3.0" @@ -2687,7 +2985,7 @@ dependencies = [ "getrandom 0.3.3", "once_cell", "rustix", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -2710,6 +3008,15 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + [[package]] name = "threadpool" version = "1.8.1" @@ -2777,6 +3084,67 @@ dependencies = [ "mio", "pin-project-lite", "slab", + "tokio-macros", +] + +[[package]] +name = "tokio-macros" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", ] [[package]] @@ -2857,6 +3225,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "uuid" version = "1.18.1" @@ -2868,6 +3242,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "version_check" version = "0.9.5" @@ -2987,7 +3367,7 @@ checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" dependencies = [ "windows-implement", "windows-interface", - "windows-link", + "windows-link 0.1.3", "windows-result", "windows-strings", ] @@ -3020,13 +3400,19 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + [[package]] name = "windows-result" version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -3035,7 +3421,7 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" dependencies = [ - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -3047,6 +3433,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link 0.2.1", +] + [[package]] name = "windows-targets" version = "0.52.6" diff --git a/Cargo.toml b/Cargo.toml index de8c80b..b2492c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "btp", "ql-fsm", "ql-rpc", + "ql-runtime", "ql-wire", "quantum-link-macros", ] diff --git a/ql-runtime/Cargo.toml b/ql-runtime/Cargo.toml new file mode 100644 index 0000000..564cee1 --- /dev/null +++ b/ql-runtime/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "ql-runtime" +version = "0.1.0" +edition = "2021" +description = "QuantumLink async runtime" +license = "Proprietary" + +[features] +default = [] +log = ["dep:log"] +rpc = ["dep:ql-rpc"] + +[dependencies] +async-channel = { version = "2.5" } +bytes = { workspace = true } +diatomic-waker = { version = "0.2.3", default-features = false } +futures-lite = { version = "2.5" } +log = { version = "0.4", optional = true } +oneshot = { version = "0.1.11" } +ql-fsm = { workspace = true } +ql-rpc = { workspace = true, optional = true } +ql-wire = { workspace = true } + +[dev-dependencies] +env_logger = "0.11" +log = "0.4" +ql-wire = { workspace = true, features = ["test-utils"] } +tokio = { version = "1.44", features = ["macros", "rt", "time", "sync"] } + +[target.'cfg(loom)'.dev-dependencies] +event-listener = { version = "5.4", features = ["loom"] } +loom = "0.7" + +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(loom)'] } diff --git a/ql-runtime/src/command.rs b/ql-runtime/src/command.rs new file mode 100644 index 0000000..4a47a45 --- /dev/null +++ b/ql-runtime/src/command.rs @@ -0,0 +1,57 @@ +use ql_fsm::{NoSessionError, PairingInvite}; +use ql_wire::{ + CloseTarget, PairingToken, PeerBundle, RouteId, SessionCloseCode, StreamCloseCode, StreamId, +}; + +use crate::{StreamReader, StreamWriter}; + +pub enum Command { + BindPeer { + peer: PeerBundle, + }, + Connect, + ArmPairing { + token: PairingToken, + }, + DisarmPairing, + StartPairing { + invite: PairingInvite, + }, + OpenStream { + route_id: RouteId, + start: oneshot::Sender>, + }, + PollInbound { + stream_id: StreamId, + }, + PollStream { + stream_id: StreamId, + }, + CloseSession { + code: SessionCloseCode, + }, + Unpair, + CloseStream { + stream_id: StreamId, + target: CloseTarget, + code: StreamCloseCode, + }, +} + +impl Command { + pub fn kind(&self) -> &'static str { + match self { + Self::BindPeer { .. } => "BindPeer", + Self::Connect => "Connect", + Self::ArmPairing { .. } => "ArmPairing", + Self::DisarmPairing => "DisarmPairing", + Self::StartPairing { .. } => "StartPairing", + Self::OpenStream { .. } => "OpenStream", + Self::PollInbound { .. } => "PollInbound", + Self::PollStream { .. } => "PollStream", + Self::CloseSession { .. } => "CloseSession", + Self::Unpair => "Unpair", + Self::CloseStream { .. } => "CloseStream", + } + } +} diff --git a/ql-runtime/src/driver/mod.rs b/ql-runtime/src/driver/mod.rs new file mode 100644 index 0000000..35de1bf --- /dev/null +++ b/ql-runtime/src/driver/mod.rs @@ -0,0 +1,609 @@ +mod state; +#[cfg(test)] +mod test; + +use std::{ + collections::{ + hash_map::{Entry, OccupiedEntry}, + HashMap, + }, + future::Future, + pin::{pin, Pin}, + task::{Context, Poll}, + time::Instant, +}; + +use async_channel::Recv; +use futures_lite::future::{poll_fn, yield_now}; +use ql_fsm::{Event, QlFsm, WriteId}; +use ql_wire::{CloseTarget, StreamCloseCode, StreamId}; + +use self::state::{DriverState, DriverStreamIo, InboundIo, InboundWriteResult, OutboundIo}; +use crate::{ + command::Command, + handle::QlStream, + io, log, + platform::{QlInbound, QlPlatform, QlTimer}, + QlStreamError, Runtime, RuntimeHandle, +}; + +impl Runtime

{ + #[allow(clippy::future_not_send)] + pub async fn run(self) { + let Self { + identity, + mut platform, + config, + rx, + tx, + } = self; + + let mut fsm = QlFsm::new(config.fsm, identity, Instant::now()); + + let mut state = DriverState { + streams: HashMap::new(), + runtime_tx: tx, + max_concurrent_message_writes: config.max_concurrent_message_writes, + }; + + let mut in_flight = Vec::new(); + let timer = platform.timer(); + let mut timer = pin!(timer); + let inbound = platform.inbound(); + let mut inbound = pin!(inbound); + let recv_future = rx.recv(); + let mut recv_future = Some(pin!(recv_future)); + let mut poll_cursor = 0usize; + + loop { + state.drain_fsm_events(&mut fsm, &platform); + if state.fill_write_slots(&mut fsm, &platform, &mut in_flight) { + state.drain_fsm_events(&mut fsm, &platform); + } + timer.as_mut().set_deadline(fsm.next_deadline()); + + let step = poll_fn(|cx| { + next_step( + cx, + recv_future.as_mut().map(|future| future.as_mut()), + inbound.as_mut(), + timer.as_mut(), + &mut in_flight, + poll_cursor, + ) + }) + .await; + poll_cursor = (poll_cursor + 1) % STEP_COUNT; + + match step { + DriverStep::Command(command) => { + log::trace!("processing command: kind={}", command.kind()); + state.drive_command(&mut fsm, command, &platform); + } + DriverStep::Inbound(bytes) => { + log::trace!("received transport frame: len={}", bytes.len()); + if let Err(e) = fsm.receive(Instant::now(), bytes, &platform) { + log::info!("receive rejected frame: error={e:?}"); + platform.handle_recv_error(e); + } + } + DriverStep::WriteCompleted { index, success } => { + let write = in_flight.swap_remove(index); + let write_id = write.write_id; + log::trace!( + "write completed: success={success} index={index} write_id={write_id:?}", + ); + DriverState::drive_write_completed(&mut fsm, write_id, success); + yield_now().await; + } + DriverStep::TimerExpired => { + log::trace!("timer expired"); + fsm.on_timer(Instant::now()); + } + DriverStep::Closed => { + log::debug!( + "command channel closed: in_flight_writes={}", + in_flight.len() + ); + recv_future = None; + if in_flight.is_empty() && !fsm.has_shutdown_work() { + break; + } + } + } + } + log::info!("runtime stopped"); + } +} + +struct InFlightWrite { + write_id: Option, + future: F, +} + +enum DriverStep { + Command(Command), + Inbound(Vec), + WriteCompleted { index: usize, success: bool }, + TimerExpired, + Closed, +} + +const STEP_COUNT: usize = 4; + +fn next_step( + cx: &mut Context<'_>, + mut recv_future: Option>>, + mut inbound: Pin<&mut I>, + mut timer: Pin<&mut T>, + in_flight: &mut [InFlightWrite], + start: usize, +) -> Poll +where + T: QlTimer, + F: Future + Unpin, + I: QlInbound, +{ + for offset in 0..STEP_COUNT { + let step = (start + offset) % STEP_COUNT; + let poll = match step { + 0 => recv_future.as_mut().map_or(Poll::Pending, |recv_future| { + recv_future + .as_mut() + .poll(cx) + .map(|res| res.map_or(DriverStep::Closed, DriverStep::Command)) + }), + 1 => inbound.as_mut().poll_recv(cx).map(DriverStep::Inbound), + 2 => { + for (index, write) in in_flight.iter_mut().enumerate() { + if let Poll::Ready(success) = Pin::new(&mut write.future).poll(cx) { + return Poll::Ready(DriverStep::WriteCompleted { index, success }); + } + } + Poll::Pending + } + 3 => timer + .as_mut() + .poll_wait(cx) + .map(|()| DriverStep::TimerExpired), + _ => unreachable!(), + }; + if poll.is_ready() { + return poll; + } + } + + Poll::Pending +} + +impl DriverState { + #[allow(clippy::too_many_lines)] + fn drive_command(&mut self, fsm: &mut QlFsm, command: Command, platform: &P) { + match command { + Command::BindPeer { peer } => { + log::info!("binding peer"); + fsm.bind_peer(peer); + } + Command::Connect => { + log::info!("starting IK connect"); + if fsm.connect_ik(Instant::now(), platform).is_err() { + log::warn!("IK connect ignored: no bound peer"); + } + } + Command::ArmPairing { token } => { + log::info!("arming inbound pairing"); + fsm.arm_pairing(token); + } + Command::DisarmPairing => { + log::info!("disarming inbound pairing"); + fsm.disarm_pairing(); + } + Command::StartPairing { invite } => { + log::info!(" starting XX pairing"); + fsm.connect_xx(Instant::now(), invite, platform); + } + Command::CloseSession { code } => { + log::info!("closing session: code={code:?}"); + fsm.close_session(code); + } + Command::Unpair => { + log::info!("unpairing peer"); + fsm.unpair(); + } + Command::OpenStream { route_id, start } => { + log::info!("open stream requested: route_id={route_id}"); + let Some(runtime_tx) = self.runtime_tx.upgrade() else { + log::warn!("open stream aborted: runtime channel unavailable"); + let _ = start.send(Err(ql_fsm::NoSessionError)); + return; + }; + + let mut stream_ops = match fsm.open_stream(route_id) { + Ok(stream_ops) => stream_ops, + Err(error) => { + log::warn!("open stream failed: route_id={route_id}"); + let _ = start.send(Err(error)); + return; + } + }; + let stream_id = stream_ops.stream_id(); + log::info!("open stream allocated: route_id={route_id} stream_id={stream_id}"); + let (reader, writer, reader_io, writer_io) = io::new_stream( + stream_id, + CloseTarget::Return, + CloseTarget::Origin, + RuntimeHandle::new(runtime_tx), + ); + self.streams.insert( + stream_id, + DriverStreamIo::new( + true, + Some(OutboundIo::new(writer_io)), + Some(InboundIo::new(reader_io)), + ), + ); + if start.send(Ok((stream_id, reader, writer))).is_err() { + log::warn!("open stream cancelled before delivery: stream_id={stream_id}"); + if let Some(stream) = self.streams.get_mut(&stream_id) { + stream.inbound_close(); + stream.outbound_close(); + } + stream_ops.close(CloseTarget::Both, StreamCloseCode::CANCELLED); + drop(stream_ops); + return; + } + drop(stream_ops); + self.poll_stream(fsm, stream_id); + } + Command::PollInbound { stream_id } => { + log::trace!("poll inbound requested: stream_id={stream_id}"); + self.handle_inbound_readable(fsm, stream_id); + } + Command::PollStream { stream_id } => { + log::trace!("poll stream requested: stream_id={stream_id}"); + self.poll_stream(fsm, stream_id); + } + Command::CloseStream { + stream_id, + target, + code, + } => { + log::debug!( + "close stream command: stream_id={stream_id} target={target:?} code={code:?}" + ); + if let Entry::Occupied(mut entry) = self.streams.entry(stream_id) { + let stream = entry.get_mut(); + if target == CloseTarget::Both || target == stream.inbound_target() { + stream.inbound_close(); + } + if target == CloseTarget::Both || target == stream.outbound_target() { + stream.outbound_close(); + } + Self::try_reap_stream(entry); + } + if let Ok(mut stream) = fsm.stream(stream_id) { + stream.close(target, code); + } + } + } + } + + fn drive_write_completed(fsm: &mut QlFsm, session_write_id: Option, success: bool) { + if let Some(write_id) = session_write_id { + fsm.complete_write(Instant::now(), write_id, success); + } + } + + fn drain_fsm_events(&mut self, fsm: &mut QlFsm, platform: &P) { + while let Some(event) = fsm.poll_event() { + log::trace!("polled FSM event: event={event:?}"); + match event { + Event::NewPeer => { + log::info!("new ql peer"); + if let Some(peer) = fsm.peer().cloned() { + platform.persist_peer(peer); + } + } + Event::PeerStatusChanged(status) => { + let peer = fsm.peer().map(|peer| peer.qid); + log::info!("peer status changed: peer={peer:?} status={status:?}"); + if status == ql_fsm::PeerStatus::Unpaired { + for (_, mut stream) in self.streams.drain() { + stream.fail_all(); + } + } + platform.handle_peer_status(peer, status); + } + Event::Opened { + stream_id, + route_id, + } => { + log::info!("inbound stream opened: stream_id={stream_id} route_id={route_id}"); + self.handle_opened_stream(fsm, platform, stream_id, route_id); + } + Event::Readable(stream_id) => { + log::trace!("stream readable: stream_id={stream_id}"); + self.handle_inbound_readable(fsm, stream_id); + } + Event::Writable(stream_id) => { + log::trace!("stream writable: stream_id={stream_id}"); + self.poll_stream(fsm, stream_id); + } + Event::Finished(stream_id) => { + log::info!("peer finished stream writes: stream_id={stream_id}"); + self.handle_inbound_finished(stream_id); + } + Event::OutboundFinished(stream_id) => { + log::info!("outbound finish acknowledged: stream_id={stream_id}"); + self.handle_outbound_finished(stream_id); + } + Event::Closed(frame) => { + self.handle_closed_stream(&frame); + } + Event::WritableClosed(frame) => { + self.handle_writable_closed(&frame); + } + Event::SessionClosed(close) => { + log::info!("session closed: frame={close:?}"); + for (_, mut stream) in self.streams.drain() { + stream.fail_all(); + } + } + } + } + } + + fn handle_opened_stream( + &mut self, + fsm: &mut QlFsm, + platform: &P, + stream_id: StreamId, + route_id: ql_wire::RouteId, + ) { + let Some(runtime_tx) = self.runtime_tx.upgrade() else { + log::warn!( + "dropping inbound stream because handle channel is unavailable: stream_id={stream_id}" + ); + if let Ok(mut stream) = fsm.stream(stream_id) { + stream.close(CloseTarget::Both, StreamCloseCode::CANCELLED); + } + return; + }; + + let (reader, writer, reader_io, writer_io) = io::new_stream( + stream_id, + CloseTarget::Origin, + CloseTarget::Return, + RuntimeHandle::new(runtime_tx), + ); + + self.streams.insert( + stream_id, + DriverStreamIo::new( + false, + Some(OutboundIo::new(writer_io)), + Some(InboundIo::new(reader_io)), + ), + ); + + log::info!( + "delivering inbound stream to platform: stream_id={stream_id} route_id={route_id}" + ); + platform.handle_inbound(QlStream { + stream_id, + route_id, + writer, + reader, + }); + } + + fn handle_inbound_readable(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { + let Ok(mut stream_ops) = fsm.stream(stream_id) else { + log::info!("inbound readable for unknown stream: stream_id={stream_id}"); + return; + }; + let readable = stream_ops.readable_bytes(); + if readable == 0 { + return; + } + log::trace!("draining inbound bytes: stream_id={stream_id} readable={readable}"); + let mut accepted = 0usize; + let mut peer_closed = false; + let target; + { + let Some(stream) = self.streams.get_mut(&stream_id) else { + return; + }; + target = stream.inbound_target(); + for chunk in stream_ops.read() { + if chunk.is_empty() { + continue; + } + match stream.inbound_try_write(chunk) { + InboundWriteResult::Accepted(n) => { + accepted += n; + } + InboundWriteResult::Full => { + log::debug!( + "inbound backpressure: stream_id={stream_id} accepted={accepted}" + ); + break; + } + InboundWriteResult::Closed => { + log::warn!( + "inbound consumer closed; sending CANCELLED: stream_id={stream_id} target={target:?}" + ); + peer_closed = true; + break; + } + } + } + } + + if accepted > 0 { + log::trace!("committed inbound bytes: stream_id={stream_id:?} accepted={accepted}"); + stream_ops.commit_read(accepted).unwrap(); + } + if peer_closed { + stream_ops.close(target, StreamCloseCode::CANCELLED); + if let Entry::Occupied(entry) = self.streams.entry(stream_id) { + Self::try_reap_stream(entry); + } + } + + drop(stream_ops); + } + + fn handle_inbound_finished(&mut self, stream_id: StreamId) { + log::info!("inbound finished event: stream_id={stream_id}"); + let Entry::Occupied(mut entry) = self.streams.entry(stream_id) else { + return; + }; + log::info!("delivering clean inbound finish: stream_id={stream_id}"); + entry.get_mut().inbound_finish(); + Self::try_reap_stream(entry); + } + + fn handle_closed_stream(&mut self, frame: &ql_wire::StreamClose) { + log::info!( + "inbound close frame: stream_id={} target={:?} code={}", + frame.stream_id, + frame.target, + frame.code + ); + let Entry::Occupied(mut entry) = self.streams.entry(frame.stream_id) else { + return; + }; + let stream = entry.get_mut(); + + if frame.target == CloseTarget::Both || frame.target == stream.inbound_target() { + stream.inbound_fail(QlStreamError::StreamClosed { code: frame.code }); + } + if frame.target == CloseTarget::Both || frame.target == stream.outbound_target() { + stream.outbound_fail(QlStreamError::StreamClosed { code: frame.code }); + } + Self::try_reap_stream(entry); + } + + fn handle_writable_closed(&mut self, frame: &ql_wire::StreamClose) { + log::info!( + "writable close frame: stream_id={} target={:?} code={}", + frame.stream_id, + frame.target, + frame.code + ); + let Entry::Occupied(mut entry) = self.streams.entry(frame.stream_id) else { + return; + }; + let stream = entry.get_mut(); + stream.outbound_fail(QlStreamError::StreamClosed { code: frame.code }); + Self::try_reap_stream(entry); + } + + fn handle_outbound_finished(&mut self, stream_id: StreamId) { + log::info!("outbound finish acknowledged: stream_id={stream_id}"); + let Entry::Occupied(mut entry) = self.streams.entry(stream_id) else { + return; + }; + let stream = entry.get_mut(); + if !stream.outbound_finish_pending() { + return; + } + stream.outbound_finish(); + Self::try_reap_stream(entry); + } + + fn fill_write_slots<'a, P: QlPlatform + 'a>( + &self, + fsm: &mut QlFsm, + platform: &'a P, + in_flight: &mut Vec>>, + ) -> bool { + let mut filled = false; + while in_flight.len() < self.max_concurrent_message_writes { + let Some(write) = fsm.take_next_write(Instant::now(), platform) else { + break; + }; + filled = true; + log::trace!( + "queueing transport write: bytes={} write_id={:?}", + write.record.len(), + write.write_id + ); + in_flight.push(InFlightWrite { + write_id: write.write_id, + future: platform.write_message(write.record), + }); + } + filled + } + + fn poll_stream(&mut self, fsm: &mut QlFsm, stream_id: StreamId) { + let Entry::Occupied(mut entry) = self.streams.entry(stream_id) else { + return; + }; + let stream = entry.get_mut(); + let Some(writer_io) = stream.outbound_writer_mut() else { + log::trace!("poll stream skipped without outbound writer: stream_id={stream_id}"); + return; + }; + + if writer_io.is_finished() { + log::info!("observed outbound writer finished before write: stream_id={stream_id}"); + if let Ok(mut stream_ops) = fsm.stream(stream_id) { + if let Some(writer) = stream_ops.writer() { + writer.finish(); + } + } + stream.outbound_queue_finish(); + if stream.is_closed() { + entry.remove(); + } + return; + } + + let Ok(mut stream_ops) = fsm.stream(stream_id) else { + return; + }; + let Some(mut writer) = stream_ops.writer() else { + log::trace!("poll stream skipped without session writer: stream_id={stream_id}"); + return; + }; + + loop { + let capacity = writer.capacity(); + log::trace!("stream write capacity: stream_id={stream_id} capacity={capacity}"); + if capacity == 0 { + break; + } + + let Ok(mut bytes) = writer_io.try_read(capacity) else { + break; + }; + if bytes.is_empty() { + break; + } + + log::trace!( + "writing stream bytes: stream_id={stream_id} len={}", + bytes.len() + ); + let _ = writer.write(&mut bytes); + } + + if writer_io.is_finished() { + log::info!("observed outbound writer finished after write: stream_id={stream_id}"); + writer.finish(); + stream.outbound_queue_finish(); + if stream.is_closed() { + entry.remove(); + } + } + } + + fn try_reap_stream(entry: OccupiedEntry<'_, StreamId, DriverStreamIo>) { + if entry.get().is_closed() { + entry.remove(); + } + } +} diff --git a/ql-runtime/src/driver/state.rs b/ql-runtime/src/driver/state.rs new file mode 100644 index 0000000..0ff8eca --- /dev/null +++ b/ql-runtime/src/driver/state.rs @@ -0,0 +1,165 @@ +use std::collections::HashMap; + +use bytes::Bytes; +use ql_wire::{CloseTarget, StreamId}; + +use crate::{ + command::Command, + io::{PushError, Rx, Tx}, + QlStreamError, +}; + +pub struct DriverState { + pub streams: HashMap, + pub runtime_tx: async_channel::WeakSender, + pub max_concurrent_message_writes: usize, +} + +pub struct DriverStreamIo { + is_initiator: bool, + outbound: Option, + inbound: Option, +} + +impl DriverStreamIo { + pub fn new( + is_initiator: bool, + outbound: Option, + inbound: Option, + ) -> Self { + Self { + is_initiator, + outbound, + inbound, + } + } + + pub fn inbound_target(&self) -> CloseTarget { + if self.is_initiator { + CloseTarget::Return + } else { + CloseTarget::Origin + } + } + + pub fn outbound_target(&self) -> CloseTarget { + if self.is_initiator { + CloseTarget::Origin + } else { + CloseTarget::Return + } + } + + pub fn fail_all(&mut self) { + self.inbound_fail(QlStreamError::NoSession); + self.outbound_fail(QlStreamError::NoSession); + } + + pub fn is_closed(&self) -> bool { + self.outbound.is_none() && self.inbound.is_none() + } + + pub fn outbound_close(&mut self) { + self.outbound = None; + } + + pub fn outbound_finish(&mut self) { + if let Some(outbound) = self.outbound.take() { + outbound.tx.finish(); + } + } + + pub fn outbound_fail(&mut self, error: QlStreamError) { + if let Some(outbound) = self.outbound.take() { + let _ = outbound.tx.fail(error); + } + } + + pub fn outbound_writer_mut(&mut self) -> Option<&mut OutboundIo> { + self.outbound.as_mut() + } + + pub fn outbound_queue_finish(&mut self) { + if let Some(outbound) = self.outbound.as_mut() { + outbound.finish_pending = true; + } + } + + pub fn outbound_finish_pending(&self) -> bool { + self.outbound + .as_ref() + .is_some_and(|outbound| outbound.finish_pending) + } + + pub fn inbound_close(&mut self) { + self.inbound = None; + } + + pub fn inbound_try_write(&mut self, bytes: Bytes) -> InboundWriteResult { + let Some(inbound) = self.inbound.as_mut() else { + return InboundWriteResult::Closed; + }; + + let len = bytes.len(); + match inbound.rx.try_write(bytes) { + Ok(()) => InboundWriteResult::Accepted(len), + Err(PushError::Full(_)) => InboundWriteResult::Full, + Err(PushError::Closed(_)) => { + self.inbound = None; + InboundWriteResult::Closed + } + } + } + + pub fn inbound_finish(&mut self) { + if let Some(inbound) = self.inbound.take() { + inbound.rx.finish(); + } + } + + pub fn inbound_fail(&mut self, error: QlStreamError) { + if let Some(inbound) = self.inbound.take() { + inbound.rx.fail(error); + } + } +} + +pub struct OutboundIo { + tx: Tx, + pending: Bytes, + finish_pending: bool, +} + +impl OutboundIo { + pub fn new(tx: Tx) -> Self { + Self { + tx, + pending: Bytes::new(), + finish_pending: false, + } + } + + pub fn is_finished(&self) -> bool { + self.pending.is_empty() && self.tx.is_finished() + } + + pub fn try_read(&mut self, max_len: usize) -> Result { + self.tx.try_read(&mut self.pending, max_len) + } +} + +pub struct InboundIo { + rx: Rx, +} + +pub enum InboundWriteResult { + Accepted(usize), + Full, + Closed, +} + +impl InboundIo { + pub fn new(rx: Rx) -> Self { + Self { rx } + } +} diff --git a/ql-runtime/src/driver/test.rs b/ql-runtime/src/driver/test.rs new file mode 100644 index 0000000..af4ab63 --- /dev/null +++ b/ql-runtime/src/driver/test.rs @@ -0,0 +1,207 @@ +use ql_wire::{generate_identity, NoopCrypto, PeerBundle, SoftwareCrypto, StreamClose, QID}; + +use super::*; +use crate::{ + driver::state::{InboundIo, OutboundIo}, + io, + platform::QlInbound, +}; + +pub struct NoopTimer; +pub struct NoopInbound; + +impl crate::platform::QlTimer for NoopTimer { + fn set_deadline(self: Pin<&mut Self>, _deadline: Option) {} + + fn poll_wait(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> { + Poll::Pending + } +} + +impl QlPlatform for NoopCrypto { + type Timer = NoopTimer; + type WriteMessageFut<'a> = std::future::Ready; + type Inbound = NoopInbound; + + fn write_message(&self, _message: Vec) -> Self::WriteMessageFut<'_> { + std::future::ready(true) + } + + fn inbound(&mut self) -> Self::Inbound { + NoopInbound + } + + fn timer(&self) -> Self::Timer { + NoopTimer + } + + fn persist_peer(&self, _peer: PeerBundle) {} + + fn handle_peer_status(&self, _peer: Option, _status: ql_fsm::PeerStatus) {} + + fn handle_inbound(&self, _event: QlStream) {} +} + +impl QlInbound for NoopInbound { + fn poll_recv(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Pending + } +} + +fn new_driver_state() -> (DriverState, QlFsm) { + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + ( + DriverState { + streams: HashMap::new(), + runtime_tx: runtime_tx.downgrade(), + max_concurrent_message_writes: 1, + }, + QlFsm::new( + ql_fsm::QlFsmConfig::default(), + generate_identity(&SoftwareCrypto, "driver").unwrap(), + Instant::now(), + ), + ) +} + +fn new_inbound_io(capacity: usize) -> InboundIo { + let _ = capacity; + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + let stream = io::new_stream( + StreamId(99u32.into()), + CloseTarget::Origin, + CloseTarget::Return, + RuntimeHandle::new(runtime_tx), + ); + let (_, _, reader_io, _) = stream; + InboundIo::new(reader_io) +} + +fn new_outbound_io() -> OutboundIo { + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + let stream = io::new_stream( + StreamId(100u32.into()), + CloseTarget::Return, + CloseTarget::Origin, + RuntimeHandle::new(runtime_tx), + ); + let (_, _, _, writer_io) = stream; + OutboundIo::new(writer_io) +} + +#[test] +fn handle_inbound_finished_reaps_closed_initiator_stream() { + let (mut state, _fsm) = new_driver_state(); + let stream_id = StreamId(1u32.into()); + + state.streams.insert( + stream_id, + DriverStreamIo::new(true, None, Some(new_inbound_io(1))), + ); + + state.handle_inbound_finished(stream_id); + + assert!(!state.streams.contains_key(&stream_id)); +} + +#[test] +fn handle_closed_stream_reaps_when_both_halves_close() { + let (mut state, _fsm) = new_driver_state(); + let stream_id = StreamId(1u32.into()); + + state.streams.insert( + stream_id, + DriverStreamIo::new(false, Some(new_outbound_io()), Some(new_inbound_io(1))), + ); + + state.handle_closed_stream(&StreamClose { + stream_id, + target: CloseTarget::Both, + code: StreamCloseCode::CANCELLED, + }); + + assert!(!state.streams.contains_key(&stream_id)); +} + +#[test] +fn poll_stream_keeps_outbound_pending_after_local_finish_when_inbound_is_closed() { + let (mut state, mut fsm) = new_driver_state(); + let stream_id = StreamId(1u32.into()); + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + let (_, mut writer, _, writer_io) = io::new_stream( + stream_id, + CloseTarget::Return, + CloseTarget::Origin, + RuntimeHandle::new(runtime_tx), + ); + writer.queue_finish(); + state.streams.insert( + stream_id, + DriverStreamIo::new(true, Some(OutboundIo::new(writer_io)), None), + ); + + state.poll_stream(&mut fsm, stream_id); + + let stream = state.streams.get(&stream_id).unwrap(); + assert!(stream.outbound_finish_pending()); + assert!(!stream.is_closed()); +} + +#[test] +fn local_close_command_reaps_when_other_half_is_already_closed() { + let (mut state, mut fsm) = new_driver_state(); + let stream_id = StreamId(1u32.into()); + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + let (_, _, _, writer_io) = io::new_stream( + stream_id, + CloseTarget::Return, + CloseTarget::Origin, + RuntimeHandle::new(runtime_tx), + ); + + state.streams.insert( + stream_id, + DriverStreamIo::new(true, Some(OutboundIo::new(writer_io)), None), + ); + + state.drive_command( + &mut fsm, + Command::CloseStream { + stream_id, + target: CloseTarget::Origin, + code: StreamCloseCode::CANCELLED, + }, + &NoopCrypto, + ); + + assert!(!state.streams.contains_key(&stream_id)); +} + +#[test] +fn unpaired_status_fails_and_reaps_all_streams() { + let (mut state, mut fsm) = new_driver_state(); + let peer = generate_identity(&SoftwareCrypto, "peer").unwrap().bundle(); + let stream_id = StreamId(1u32.into()); + let (runtime_tx, _runtime_rx) = async_channel::unbounded(); + let (_, _, reader_io, writer_io) = io::new_stream( + stream_id, + CloseTarget::Origin, + CloseTarget::Return, + RuntimeHandle::new(runtime_tx), + ); + + state.streams.insert( + stream_id, + DriverStreamIo::new( + false, + Some(OutboundIo::new(writer_io)), + Some(InboundIo::new(reader_io)), + ), + ); + fsm.bind_peer(peer); + fsm.unpair(); + + state.drain_fsm_events(&mut fsm, &NoopCrypto); + + assert!(state.streams.is_empty()); +} diff --git a/ql-runtime/src/error.rs b/ql-runtime/src/error.rs new file mode 100644 index 0000000..5b74bcf --- /dev/null +++ b/ql-runtime/src/error.rs @@ -0,0 +1,18 @@ +use ql_wire::StreamCloseCode; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum QlStreamError { + StreamClosed { code: StreamCloseCode }, + NoSession, +} + +impl std::fmt::Display for QlStreamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::StreamClosed { code } => write!(f, "stream closed {code:?}"), + Self::NoSession => f.write_str("no session"), + } + } +} + +impl std::error::Error for QlStreamError {} diff --git a/ql-runtime/src/handle/mod.rs b/ql-runtime/src/handle/mod.rs new file mode 100644 index 0000000..1782c17 --- /dev/null +++ b/ql-runtime/src/handle/mod.rs @@ -0,0 +1,96 @@ +use ql_fsm::{NoSessionError, PairingInvite}; +use ql_wire::{PairingToken, PeerBundle, RouteId, SessionCloseCode, StreamId}; + +use crate::command::Command; +pub use crate::io::{StreamReader, StreamWriter}; + +#[derive(Debug)] +pub struct QlStream { + pub stream_id: StreamId, + pub route_id: RouteId, + pub writer: StreamWriter, + pub reader: StreamReader, +} + +#[derive(Clone)] +pub struct RuntimeHandle { + tx: async_channel::Sender, +} + +impl RuntimeHandle { + /// binds the remote peer + pub fn bind_peer(&self, peer: PeerBundle) { + self.send(Command::BindPeer { peer }); + } + + /// starts an IK handshake with the bound peer + pub fn connect(&self) { + self.send(Command::Connect); + } + + /// arms acceptance of inbound xx pairings for a single token + pub fn arm_pairing(&self, token: PairingToken) { + self.send(Command::ArmPairing { token }); + } + + /// disarms inbound xx pairing + pub fn disarm_pairing(&self) { + self.send(Command::DisarmPairing); + } + + /// starts an outbound xx handshake using an out-of-band pairing invite + pub fn start_pairing(&self, invite: PairingInvite) { + self.send(Command::StartPairing { invite }); + } + + /// closes the current encrypted session + pub fn close_session(&self, code: SessionCloseCode) { + self.send(Command::CloseSession { code }); + } + + /// forgets the currently bound peer and initiates session unpairing if connected + pub fn unpair(&self) { + self.send(Command::Unpair); + } + + /// opens a new stream on the active encrypted session + pub async fn open_stream(&self, route_id: RouteId) -> Result { + let (start_tx, start_rx) = oneshot::channel(); + + self.send(Command::OpenStream { + route_id, + start: start_tx, + }); + + // runtime cannot be shutdown while we have a handle + let (stream_id, reader, writer) = start_rx.await.unwrap()?; + + Ok(QlStream { + stream_id, + route_id, + writer, + reader, + }) + } + + #[cfg(feature = "rpc")] + pub fn rpc(&self) -> crate::rpc::RpcHandle { + crate::rpc::RpcHandle::new(self.clone()) + } +} + +impl RuntimeHandle { + pub(crate) fn new(tx: async_channel::Sender) -> Self { + Self { tx } + } + + #[inline] + #[track_caller] + pub(crate) fn send(&self, cmd: Command) { + self.tx.try_send(cmd).expect("runtime is alive"); + } + + pub(crate) fn try_send(&self, cmd: Command) -> bool { + self.tx.try_send(cmd).is_ok() + } +} diff --git a/ql-runtime/src/io/inner.rs b/ql-runtime/src/io/inner.rs new file mode 100644 index 0000000..64df6ce --- /dev/null +++ b/ql-runtime/src/io/inner.rs @@ -0,0 +1,643 @@ +//! per-stream shared io state +//! each lane has one slot and one waker +//! the low slot bits belong to `slot.rs` and the higher bits here carry lane-specific flags + +use std::task::Waker; + +use bytes::Bytes; +use diatomic_waker::DiatomicWaker; +use ql_wire::StreamId; + +use super::{ + slot::{PopError, PushError, Slot}, + sync::Arc, +}; +use crate::QlStreamError; + +pub(super) fn new(stream_id: StreamId) -> Arc { + Arc::new(Inner { + stream_id, + rx: RxInner::new(), + tx: TxInner::new(), + }) +} + +pub(super) struct Inner { + pub(super) stream_id: StreamId, + pub(super) rx: RxInner, + pub(super) tx: TxInner, +} + +pub enum Item { + Chunk(Bytes), + Error(QlStreamError), +} + +#[derive(Debug, PartialEq, Eq)] +pub struct ForcePushError(pub T); + +/// reader-lane shared state +pub struct RxInner { + slot: Slot, + changed: DiatomicWaker, +} + +impl RxInner { + const FINISHED: usize = 1 << 2; + + fn new() -> Self { + Self { + slot: Slot::new(), + changed: DiatomicWaker::new(), + } + } + + pub fn try_write(&self, bytes: Bytes) -> Result<(), PushError> { + try_write_chunk(&self.slot, &self.changed, bytes, Self::FINISHED) + } + + /// marks clean reader eof + pub fn finish(&self) { + if self.slot.fetch_or(Self::FINISHED) & Self::FINISHED == 0 { + self.changed.notify(); + } + } + + /// stores a terminal reader error + pub fn fail(&self, error: QlStreamError) -> Option { + let displaced = self.slot.force_push(Item::Error(error)); + self.changed.notify(); + displaced_bytes(displaced) + } + + pub fn load_state(&self) -> usize { + self.slot.load_state() + } + + pub fn is_finished(state: usize) -> bool { + state & Self::FINISHED != 0 + } + + pub fn pop(&self) -> Result { + pop_item(&self.slot, &self.changed) + } + + /// registers the sole reader-lane waiter + pub fn register_waiter(&self, waker: &Waker) { + // Safety: StreamReader is the only reader-lane registrar for this + // shared state, so register/unregister never run concurrently. + unsafe { self.changed.register(waker) }; + } + + /// unregisters the sole reader-lane waiter + pub fn unregister_waiter(&self) { + // Safety: StreamReader is the only reader-lane registrar for this + // shared state, so register/unregister never run concurrently. + unsafe { self.changed.unregister() }; + } +} + +/// writer-lane shared state +/// +/// finish and fail race to establish the terminal result +/// terminal errors are stored in the slot +pub struct TxInner { + slot: Slot, + changed: DiatomicWaker, +} + +impl TxInner { + const FINISH_REQUESTED: usize = 1 << 2; + const TERMINAL_READY: usize = 1 << 3; + const TERMINAL_OK: usize = 1 << 4; + + fn new() -> Self { + Self { + slot: Slot::new(), + changed: DiatomicWaker::new(), + } + } + + pub fn load_state(&self) -> usize { + self.slot.load_state() + } + + pub fn finish_requested(state: usize) -> bool { + state & Self::FINISH_REQUESTED != 0 + } + + pub fn terminal_ready(state: usize) -> bool { + state & Self::TERMINAL_READY != 0 + } + + pub fn terminal_ok(state: usize) -> bool { + state & Self::TERMINAL_OK != 0 + } + + pub fn try_write(&self, bytes: Bytes) -> Result<(), PushError> { + try_write_chunk( + &self.slot, + &self.changed, + bytes, + Self::FINISH_REQUESTED | Self::TERMINAL_READY, + ) + } + + /// prevents future chunk writes once observed + pub fn request_finish(&self) { + if self.slot.fetch_or(Self::FINISH_REQUESTED) & Self::FINISH_REQUESTED == 0 { + self.changed.notify(); + } + } + + /// commits a clean writer eof + pub fn finish(&self) { + let mut state = self.slot.load_state(); + loop { + if Self::terminal_ready(state) { + return; + } + + let new_state = state | Self::TERMINAL_READY | Self::TERMINAL_OK; + match self.slot.compare_exchange(state, new_state) { + Ok(()) => { + self.changed.notify(); + return; + } + Err(actual) => state = actual, + } + } + } + + /// stores a terminal writer error + /// futures calls will have no effect + pub fn fail( + &self, + error: QlStreamError, + ) -> Result, ForcePushError> { + let mut state = self.slot.load_state(); + loop { + if Self::terminal_ready(state) { + return Err(ForcePushError(error)); + } + + let new_state = state | Self::TERMINAL_READY; + match self.slot.compare_exchange(state, new_state) { + Ok(()) => break, + Err(actual) => state = actual, + } + } + + let displaced = self.slot.force_push(Item::Error(error)); + self.changed.notify(); + Ok(displaced_bytes(displaced)) + } + + pub fn pop(&self) -> Result { + pop_item(&self.slot, &self.changed) + } + + /// registers the sole writer-lane waiter + pub fn register_waiter(&self, waker: &Waker) { + // Safety: StreamWriter is the only writer-lane registrar for this + // shared state, so register/unregister never run concurrently. + unsafe { self.changed.register(waker) }; + } + + /// unregisters the sole writer-lane waiter + pub fn unregister_waiter(&self) { + // Safety: StreamWriter is the only writer-lane registrar for this + // shared state, so register/unregister never run concurrently. + unsafe { self.changed.unregister() }; + } + + /// returns true once finish was requested and buffered data is drained + pub fn is_finished(&self) -> bool { + let state = self.load_state(); + Self::finish_requested(state) && Slot::::is_empty_state(state) + } + + pub fn try_read(&self, pending: &mut Bytes, max_len: usize) -> Result { + if !pending.is_empty() { + return Ok(if pending.len() <= max_len { + std::mem::take(pending) + } else { + pending.split_to(max_len) + }); + } + + let state = self.load_state(); + if Self::terminal_ready(state) { + return Err(()); + } + + match self.pop() { + Ok(Item::Chunk(mut bytes)) => { + if bytes.len() <= max_len { + Ok(bytes) + } else { + let head = bytes.split_to(max_len); + *pending = bytes; + Ok(head) + } + } + Ok(Item::Error(_)) => Err(()), + Err(PopError) => Ok(Bytes::new()), + } + } +} + +#[inline] +fn try_write_chunk( + slot: &Slot, + changed: &DiatomicWaker, + bytes: Bytes, + closed_mask: usize, +) -> Result<(), PushError> { + match slot.try_push(Item::Chunk(bytes), closed_mask) { + Ok(()) => { + changed.notify(); + Ok(()) + } + Err(PushError::Closed(Item::Chunk(bytes))) => Err(PushError::Closed(bytes)), + Err(PushError::Full(Item::Chunk(bytes))) => Err(PushError::Full(bytes)), + Err(PushError::Closed(Item::Error(_)) | PushError::Full(Item::Error(_))) => { + unreachable!("chunk write cannot recover an error payload") + } + } +} + +#[inline] +fn displaced_bytes(displaced: Option) -> Option { + match displaced { + Some(Item::Chunk(bytes)) => Some(bytes), + Some(Item::Error(_)) | None => None, + } +} + +#[inline] +fn pop_item(slot: &Slot, changed: &DiatomicWaker) -> Result { + match slot.pop() { + item @ Ok(Item::Chunk(_)) => { + changed.notify(); + item + } + item @ (Ok(Item::Error(_)) | Err(_)) => item, + } +} + +#[cfg(all(test, loom))] +mod loom_tests { + use std::task::Waker; + + use bytes::Bytes; + use loom::thread; + use ql_wire::StreamCloseCode; + + use super::*; + use crate::{ + io::{sync::loom::*, Tx}, + QlStreamError, + }; + + #[test] + fn reader_waiter_registration_survives_finish() { + check_model(|| { + let shared = shared(); + shared.rx.register_waiter(Waker::noop()); + + let finisher = { + let shared = shared.clone(); + thread::spawn(move || { + shared.rx.finish(); + }) + }; + + finisher.join().unwrap(); + assert!(RxInner::is_finished(shared.rx.load_state())); + + shared.rx.unregister_waiter(); + }); + } + + #[test] + fn reader_chunk_remains_available_after_finish() { + check_model(|| { + let shared = shared(); + + let producer = { + let shared = shared.clone(); + thread::spawn(move || { + shared.rx.try_write(Bytes::from_static(b"abc")).unwrap(); + shared.rx.finish(); + }) + }; + + producer.join().unwrap(); + + match shared.rx.pop() { + Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), + _ => panic!("expected buffered reader chunk"), + } + assert!(RxInner::is_finished(shared.rx.load_state())); + assert!(matches!(shared.rx.pop(), Err(PopError))); + }); + } + + #[test] + fn reader_rejects_write_after_finish() { + check_model(|| { + let shared = shared(); + + shared.rx.finish(); + + assert_eq!( + shared.rx.try_write(Bytes::from_static(b"abc")), + Err(PushError::Closed(Bytes::from_static(b"abc"))) + ); + assert!(RxInner::is_finished(shared.rx.load_state())); + assert!(matches!(shared.rx.pop(), Err(PopError))); + }); + } + + #[test] + fn reader_write_races_with_finish_has_coherent_outcome() { + check_model(|| { + let shared = shared(); + + let writer = { + let shared = shared.clone(); + thread::spawn(move || shared.rx.try_write(Bytes::from_static(b"abc"))) + }; + let finisher = { + let shared = shared.clone(); + thread::spawn(move || shared.rx.finish()) + }; + + let write_result = writer.join().unwrap(); + finisher.join().unwrap(); + + assert!(RxInner::is_finished(shared.rx.load_state())); + match write_result { + Ok(()) => match shared.rx.pop() { + Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), + _ => panic!("expected buffered reader chunk"), + }, + Err(PushError::Closed(bytes)) => { + assert_eq!(bytes, Bytes::from_static(b"abc")); + assert!(matches!(shared.rx.pop(), Err(PopError))); + return; + } + Err(PushError::Full(_)) => panic!("empty reader slot must not report full"), + } + assert!(matches!(shared.rx.pop(), Err(PopError))); + }); + } + + #[test] + fn reader_fail_racing_with_pop_preserves_terminal_outcome() { + check_model(|| { + let shared = shared(); + shared.rx.try_write(Bytes::from_static(b"abc")).unwrap(); + + let popper = { + let shared = shared.clone(); + thread::spawn(move || shared.rx.pop()) + }; + let failer = { + let shared = shared.clone(); + thread::spawn(move || { + shared.rx.fail(QlStreamError::StreamClosed { + code: StreamCloseCode::CANCELLED, + }) + }) + }; + + let pop_result = popper.join().unwrap(); + let fail_result = failer.join().unwrap(); + + match (pop_result, fail_result) { + (Ok(Item::Chunk(bytes)), None) => { + assert_eq!(bytes, Bytes::from_static(b"abc")); + match shared.rx.pop() { + Ok(Item::Error(QlStreamError::StreamClosed { code })) => { + assert_eq!(code, StreamCloseCode::CANCELLED); + } + _ => panic!("expected terminal reader error"), + } + } + (Ok(Item::Error(QlStreamError::StreamClosed { code })), Some(bytes)) => { + assert_eq!(code, StreamCloseCode::CANCELLED); + assert_eq!(bytes, Bytes::from_static(b"abc")); + assert!(matches!(shared.rx.pop(), Err(PopError))); + } + _ => panic!("unexpected reader fail/pop race outcome"), + } + }); + } + + #[test] + fn writer_is_finished_only_after_drain() { + check_model(|| { + let shared = shared(); + let tx = Tx(shared.clone()); + let mut pending = Bytes::new(); + + shared.tx.try_write(Bytes::from_static(b"abc")).unwrap(); + shared.tx.request_finish(); + + assert!(!(pending.is_empty() && tx.is_finished())); + assert_eq!(tx.try_read(&mut pending, 2), Ok(Bytes::from_static(b"ab"))); + assert!(!(pending.is_empty() && tx.is_finished())); + assert_eq!(tx.try_read(&mut pending, 8), Ok(Bytes::from_static(b"c"))); + assert!(pending.is_empty() && tx.is_finished()); + }); + } + + #[test] + fn writer_write_races_with_request_finish() { + check_model(|| { + let shared = shared(); + let tx = Tx(shared.clone()); + let mut pending = Bytes::new(); + + let writer = { + let shared = shared.clone(); + thread::spawn(move || shared.tx.try_write(Bytes::from_static(b"abc"))) + }; + let finisher = { + let shared = shared.clone(); + thread::spawn(move || shared.tx.request_finish()) + }; + + let write_result = writer.join().unwrap(); + finisher.join().unwrap(); + + assert!(TxInner::finish_requested(shared.tx.load_state())); + match write_result { + Ok(()) => { + assert_eq!(tx.try_read(&mut pending, 8), Ok(Bytes::from_static(b"abc"))); + assert!(pending.is_empty() && tx.is_finished()); + } + Err(PushError::Closed(bytes)) => { + assert_eq!(bytes, Bytes::from_static(b"abc")); + assert!(pending.is_empty() && tx.is_finished()); + } + Err(PushError::Full(_)) => panic!("empty writer slot must not report full"), + } + }); + } + + #[test] + fn writer_fail_overwrites_buffered_chunk_and_keeps_terminal_state_observable() { + check_model(|| { + let shared = shared(); + shared.tx.try_write(Bytes::from_static(b"abc")).unwrap(); + shared.tx.register_waiter(Waker::noop()); + + let failer = { + let shared = shared.clone(); + thread::spawn(move || { + let displaced = shared.tx.fail(QlStreamError::StreamClosed { + code: StreamCloseCode::CANCELLED, + }); + assert_eq!(displaced.unwrap(), Some(Bytes::from_static(b"abc"))); + }) + }; + + failer.join().unwrap(); + + assert!(TxInner::terminal_ready(shared.tx.load_state())); + shared.tx.unregister_waiter(); + match shared.tx.pop() { + Ok(Item::Error(QlStreamError::StreamClosed { code })) => { + assert_eq!(code, StreamCloseCode::CANCELLED); + } + _ => panic!("expected terminal writer error"), + } + }); + } + + #[test] + fn reader_waiter_registration_can_be_reused_after_notification() { + check_model(|| { + let shared = shared(); + + shared.rx.register_waiter(Waker::noop()); + shared.rx.try_write(Bytes::from_static(b"abc")).unwrap(); + match shared.rx.pop() { + Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), + _ => panic!("expected buffered reader chunk"), + } + + shared.rx.register_waiter(Waker::noop()); + shared.rx.finish(); + assert!(RxInner::is_finished(shared.rx.load_state())); + shared.rx.unregister_waiter(); + }); + } + + #[test] + fn writer_waiter_registration_can_be_reused_after_notification() { + check_model(|| { + let shared = shared(); + + shared.tx.register_waiter(Waker::noop()); + shared.tx.try_write(Bytes::from_static(b"abc")).unwrap(); + match shared.tx.pop() { + Ok(Item::Chunk(bytes)) => assert_eq!(bytes, Bytes::from_static(b"abc")), + _ => panic!("expected buffered writer chunk"), + } + + shared.tx.register_waiter(Waker::noop()); + shared.tx.finish(); + assert!(TxInner::terminal_ready(shared.tx.load_state())); + shared.tx.unregister_waiter(); + }); + } + + #[test] + fn writer_write_races_with_fail() { + check_model(|| { + let shared = shared(); + + let writer = { + let shared = shared.clone(); + thread::spawn(move || shared.tx.try_write(Bytes::from_static(b"abc"))) + }; + let failer = { + let shared = shared.clone(); + thread::spawn(move || { + shared.tx.fail(QlStreamError::StreamClosed { + code: StreamCloseCode::CANCELLED, + }) + }) + }; + + let write_result = writer.join().unwrap(); + let fail_result = failer.join().unwrap(); + + assert!(TxInner::terminal_ready(shared.tx.load_state())); + match (&write_result, &fail_result) { + (Ok(()), Ok(Some(bytes))) => { + assert_eq!(Bytes::from_static(b"abc"), bytes.clone()); + } + (Err(PushError::Closed(bytes)), Ok(None)) => { + assert_eq!(Bytes::from_static(b"abc"), bytes.clone()); + } + (Err(PushError::Full(bytes)), Ok(None)) => { + assert_eq!(Bytes::from_static(b"abc"), bytes.clone()); + } + _ => panic!( + "unexpected writer fail/write race outcome: write={write_result:?} fail={fail_result:?}" + ), + } + + match shared.tx.pop() { + Ok(Item::Error(QlStreamError::StreamClosed { code })) => { + assert_eq!(code, StreamCloseCode::CANCELLED); + } + _ => panic!("expected terminal writer error"), + } + }); + } + + #[test] + fn writer_finish_races_with_fail_without_masking_error() { + check_model(|| { + let shared = shared(); + + let finisher = { + let shared = shared.clone(); + thread::spawn(move || shared.tx.finish()) + }; + let failer = { + let shared = shared.clone(); + thread::spawn(move || { + shared.tx.fail(QlStreamError::StreamClosed { + code: StreamCloseCode::CANCELLED, + }) + }) + }; + + finisher.join().unwrap(); + let fail_result = failer.join().unwrap(); + + assert!(TxInner::terminal_ready(shared.tx.load_state())); + match fail_result { + Err(_) => { + assert!(TxInner::terminal_ok(shared.tx.load_state())); + } + Ok(_) => { + assert!(!TxInner::terminal_ok(shared.tx.load_state())); + match shared.tx.pop() { + Ok(Item::Error(QlStreamError::StreamClosed { code })) => { + assert_eq!(code, StreamCloseCode::CANCELLED); + } + _ => panic!("expected terminal writer error"), + } + } + } + }); + } +} diff --git a/ql-runtime/src/io/mod.rs b/ql-runtime/src/io/mod.rs new file mode 100644 index 0000000..2eb7f0f --- /dev/null +++ b/ql-runtime/src/io/mod.rs @@ -0,0 +1,59 @@ +mod inner; +mod reader; +mod slot; +mod sync; +mod writer; + +use std::ops::Deref; + +use ql_wire::{CloseTarget, StreamId}; + +pub use self::{reader::StreamReader, slot::PushError, writer::StreamWriter}; +use crate::RuntimeHandle; + +pub struct Rx(sync::Arc); + +impl Deref for Rx { + type Target = inner::RxInner; + + fn deref(&self) -> &Self::Target { + &self.0.rx + } +} + +impl Rx { + pub fn stream_id(&self) -> StreamId { + self.0.stream_id + } +} + +pub struct Tx(sync::Arc); + +impl Deref for Tx { + type Target = inner::TxInner; + + fn deref(&self) -> &Self::Target { + &self.0.tx + } +} + +impl Tx { + pub fn stream_id(&self) -> StreamId { + self.0.stream_id + } +} + +pub fn new_stream( + stream_id: StreamId, + reader_target: CloseTarget, + writer_target: CloseTarget, + handle: RuntimeHandle, +) -> (StreamReader, StreamWriter, Rx, Tx) { + let shared = inner::new(stream_id); + ( + StreamReader::new(Rx(shared.clone()), reader_target, handle.clone()), + StreamWriter::new(Tx(shared.clone()), writer_target, handle), + Rx(shared.clone()), + Tx(shared), + ) +} diff --git a/ql-runtime/src/io/reader.rs b/ql-runtime/src/io/reader.rs new file mode 100644 index 0000000..8c40ccd --- /dev/null +++ b/ql-runtime/src/io/reader.rs @@ -0,0 +1,236 @@ +use std::{ + future::poll_fn, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use ql_wire::{CloseTarget, StreamCloseCode}; + +use super::{ + inner::{Item, RxInner}, + slot::PopError, + Rx, +}; +use crate::{command::Command, log, QlStreamError, RuntimeHandle}; + +pub struct StreamReader { + rx: Rx, + target: CloseTarget, + pending: Bytes, + terminal: ReaderTerminalState, + handle: RuntimeHandle, +} + +enum ReaderTerminalState { + Open, + Delivered, +} + +unsafe impl Sync for StreamReader {} + +impl std::fmt::Debug for StreamReader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StreamReader") + .field("stream_id", &self.rx.stream_id()) + .field("target", &self.target) + .field( + "terminal", + &matches!(self.terminal, ReaderTerminalState::Delivered), + ) + .finish_non_exhaustive() + } +} + +impl StreamReader { + pub(crate) fn new(shared: Rx, target: CloseTarget, handle: RuntimeHandle) -> Self { + Self { + rx: shared, + target, + pending: Bytes::new(), + terminal: ReaderTerminalState::Open, + handle, + } + } + + pub fn poll_read( + &mut self, + max_len: usize, + cx: &mut Context<'_>, + ) -> Poll, QlStreamError>> { + if matches!(self.terminal, ReaderTerminalState::Delivered) { + return Poll::Ready(Ok(None)); + } + + match self.try_read_ready(max_len) { + Poll::Ready(result) => return Poll::Ready(result), + Poll::Pending => {} + } + + self.rx.register_waiter(cx.waker()); + + match self.try_read_ready(max_len) { + Poll::Ready(result) => { + self.rx.unregister_waiter(); + Poll::Ready(result) + } + Poll::Pending => Poll::Pending, + } + } + + fn try_read_ready(&mut self, max_len: usize) -> Poll, QlStreamError>> { + if !self.pending.is_empty() { + let pending = &mut self.pending; + let bytes = if pending.len() <= max_len { + std::mem::take(pending) + } else { + pending.split_to(max_len) + }; + self.handle.try_send(Command::PollInbound { + stream_id: self.rx.stream_id(), + }); + return Poll::Ready(Ok(Some(bytes))); + } + + match self.rx.pop() { + Ok(Item::Chunk(mut bytes)) => { + log::trace!( + "byte reader received chunk: stream_id={} target={:?} len={}", + self.rx.stream_id(), + self.target, + bytes.len() + ); + self.handle.try_send(Command::PollInbound { + stream_id: self.rx.stream_id(), + }); + if bytes.len() <= max_len { + return Poll::Ready(Ok(Some(bytes))); + } + let head = bytes.split_to(max_len); + self.pending = bytes; + Poll::Ready(Ok(Some(head))) + } + Ok(Item::Error(error)) => { + log::debug!( + "byte reader delivered terminal error: stream_id={} target={:?} error={:?}", + self.rx.stream_id(), + self.target, + error + ); + self.terminal = ReaderTerminalState::Delivered; + Poll::Ready(Err(error)) + } + Err(PopError) => { + if RxInner::is_finished(self.rx.load_state()) { + log::debug!( + "byte reader delivered clean eof: stream_id={} target={:?}", + self.rx.stream_id(), + self.target + ); + self.terminal = ReaderTerminalState::Delivered; + return Poll::Ready(Ok(None)); + } + Poll::Pending + } + } + } + + pub fn poll_read_chunk( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, QlStreamError>> { + self.poll_read(usize::MAX, cx) + } + + pub async fn read(&mut self, max_len: usize) -> Result, QlStreamError> { + poll_fn(|cx| self.poll_read(max_len, cx)).await + } + + pub async fn read_chunk(&mut self) -> Result, QlStreamError> { + self.read(usize::MAX).await + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if matches!(self.terminal, ReaderTerminalState::Delivered) { + return; + } + log::debug!( + "byte reader explicit close: stream_id={:?} target={:?} code={:?}", + self.rx.stream_id(), + self.target, + code + ); + self.terminal = ReaderTerminalState::Delivered; + self.handle.try_send(Command::CloseStream { + stream_id: self.rx.stream_id(), + target: self.target, + code, + }); + } +} + +impl Drop for StreamReader { + fn drop(&mut self) { + if matches!(self.terminal, ReaderTerminalState::Delivered) { + return; + } + log::debug!( + "byte reader drop close: stream_id={:?} target={:?} code={:?}", + self.rx.stream_id(), + self.target, + StreamCloseCode::CANCELLED + ); + self.handle.try_send(Command::CloseStream { + stream_id: self.rx.stream_id(), + target: self.target, + code: StreamCloseCode::CANCELLED, + }); + } +} + +#[cfg(all(test, loom))] +mod loom_tests { + use std::task::{Context, Poll, Waker}; + + use bytes::Bytes; + use loom::thread; + use ql_wire::CloseTarget; + + use super::*; + use crate::io::sync::loom::*; + + #[test] + fn poll_read_observes_chunk_racing_with_registration() { + check_model(|| { + let inner = shared(); + let mut reader = StreamReader::new(Rx(inner.clone()), CloseTarget::Origin, handle()); + let mut cx = Context::from_waker(Waker::noop()); + + let producer = { + let inner = inner.clone(); + thread::spawn(move || { + inner.rx.try_write(Bytes::from_static(b"abc")).unwrap(); + }) + }; + + let first = reader.poll_read(usize::MAX, &mut cx); + producer.join().unwrap(); + + match first { + Poll::Ready(Ok(Some(bytes))) => { + assert_eq!(bytes, Bytes::from_static(b"abc")); + } + Poll::Pending => { + assert_eq!( + reader.poll_read(usize::MAX, &mut cx), + Poll::Ready(Ok(Some(Bytes::from_static(b"abc")))) + ); + } + other => panic!("unexpected first poll result: {other:?}"), + } + }); + } +} diff --git a/ql-runtime/src/io/slot.rs b/ql-runtime/src/io/slot.rs new file mode 100644 index 0000000..f71f1b0 --- /dev/null +++ b/ql-runtime/src/io/slot.rs @@ -0,0 +1,175 @@ +//! local single-slot queue for stream io +//! copied from `concurrent_queue::single::Single` in `concurrent-queue` + +use core::mem::MaybeUninit; + +#[allow(clippy::wildcard_imports)] +use super::sync::*; + +const LOCKED: usize = 1 << 0; +const PUSHED: usize = 1 << 1; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct PopError; + +#[derive(Debug, PartialEq, Eq)] +pub enum PushError { + Full(T), + Closed(T), +} + +/// A single-element queue. +pub struct Slot { + state: AtomicUsize, + value: UnsafeCell>, +} + +unsafe impl Send for Slot {} +unsafe impl Sync for Slot {} + +impl Slot { + /// Creates a new single-element queue. + pub fn new() -> Self { + Self { + state: AtomicUsize::new(0), + value: UnsafeCell::new(MaybeUninit::uninit()), + } + } + + #[inline] + pub fn load_state(&self) -> usize { + self.state.load(Ordering::Acquire) + } + + #[inline] + pub fn fetch_or(&self, bits: usize) -> usize { + self.state.fetch_or(bits, Ordering::Release) + } + + #[inline] + pub fn compare_exchange(&self, current: usize, new: usize) -> Result<(), usize> { + self.state + .compare_exchange(current, new, Ordering::AcqRel, Ordering::Acquire) + .map(|_| ()) + } + + /// Attempts to push an item into the queue. + pub fn try_push(&self, value: T, closed_mask: usize) -> Result<(), PushError> { + let mut state = self.load_state(); + loop { + if state & closed_mask != 0 { + return Err(PushError::Closed(value)); + } + if state & LOCKED != 0 { + busy_wait(); + state = self.load_state(); + continue; + } + if state & PUSHED != 0 { + return Err(PushError::Full(value)); + } + + // Lock and fill the slot. + let new_state = state | LOCKED | PUSHED; + match self.compare_exchange(state, new_state) { + Ok(()) => { + // Write the value and unlock. + self.value.with_mut(|slot| unsafe { + slot.write(MaybeUninit::new(value)); + }); + self.state.fetch_and(!LOCKED, Ordering::Release); + return Ok(()); + } + Err(actual) => state = actual, + } + } + } + + /// Attempts to push an item into the queue, displacing another if necessary. + pub fn force_push(&self, value: T) -> Option { + // Attempt to lock the slot. + let mut state = self.load_state(); + + loop { + if state & LOCKED != 0 { + busy_wait(); + state = self.load_state(); + continue; + } + + // Lock the slot. + let new_state = state | LOCKED | PUSHED; + match self.compare_exchange(state, new_state) { + Ok(()) => { + // If the value was pushed, swap out the value. + let displaced = if state & PUSHED == 0 { + // SAFETY: write is safe because we have locked the state. + self.value.with_mut(|slot| unsafe { + slot.write(MaybeUninit::new(value)); + }); + None + } else { + // SAFETY: replace is safe because we have locked the state, and + // assume_init is safe because we have checked that the value was pushed. + self.value.with_mut(move |slot| unsafe { + Some(std::ptr::replace(slot, MaybeUninit::new(value)).assume_init()) + }) + }; + + // We can unlock the slot now. + self.state.fetch_and(!LOCKED, Ordering::Release); + return displaced; + } + Err(actual) => state = actual, + } + } + } + + /// Attempts to pop an item from the queue. + pub fn pop(&self) -> Result { + let mut state = PUSHED; + loop { + if state & LOCKED != 0 { + busy_wait(); + state = self.load_state(); + continue; + } + if state & PUSHED == 0 { + return Err(PopError); + } + + // Lock and empty the slot. + let new_state = (state | LOCKED) & !PUSHED; + match self.compare_exchange(state, new_state) { + Ok(()) => { + // Read the value and unlock. + let value = self + .value + .with_mut(|slot| unsafe { slot.read().assume_init() }); + self.state.fetch_and(!LOCKED, Ordering::Release); + return Ok(value); + } + Err(actual) => state = actual, + } + } + } + + #[inline] + pub fn is_empty_state(state: usize) -> bool { + state & PUSHED == 0 + } +} + +impl Drop for Slot { + fn drop(&mut self) { + // Drop the value in the slot. + self.state.with_mut(|state| { + if *state & PUSHED != 0 { + self.value.with_mut(|slot| unsafe { + let value = &mut *slot; + value.as_mut_ptr().drop_in_place(); + }); + } + }); + } +} diff --git a/ql-runtime/src/io/sync.rs b/ql-runtime/src/io/sync.rs new file mode 100644 index 0000000..c503407 --- /dev/null +++ b/ql-runtime/src/io/sync.rs @@ -0,0 +1,89 @@ +#[cfg(not(all(test, loom)))] +mod inner { + pub use std::{ + cell::UnsafeCell, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + }; + + pub fn busy_wait() { + std::thread::yield_now(); + } + + pub trait UnsafeCellExt { + type Value; + + fn with_mut(&self, f: F) -> R + where + F: FnOnce(*mut Self::Value) -> R; + } + + impl UnsafeCellExt for UnsafeCell { + type Value = T; + + fn with_mut(&self, f: F) -> R + where + F: FnOnce(*mut Self::Value) -> R, + { + f(self.get()) + } + } + + pub trait AtomicExt { + type Value; + + fn with_mut(&mut self, f: F) -> R + where + F: FnOnce(&mut Self::Value) -> R; + } + + impl AtomicExt for AtomicUsize { + type Value = usize; + + fn with_mut(&mut self, f: F) -> R + where + F: FnOnce(&mut Self::Value) -> R, + { + f(self.get_mut()) + } + } +} + +#[cfg(all(test, loom))] +mod inner { + pub use loom::{ + cell::UnsafeCell, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + thread::yield_now as busy_wait, + }; +} + +pub use inner::*; + +#[cfg(all(test, loom))] +pub(crate) mod loom { + use loom::model; + use ql_wire::StreamId; + + use super::Arc; + use crate::{io::inner::Inner, RuntimeHandle}; + + pub(crate) fn check_model(f: impl Fn() + Sync + Send + 'static) { + let builder = model::Builder::new(); + builder.check(f); + } + + pub(crate) fn shared() -> Arc { + crate::io::inner::new(StreamId(1u32.into())) + } + + pub(crate) fn handle() -> RuntimeHandle { + let (tx, _rx) = async_channel::unbounded(); + RuntimeHandle::new(tx) + } +} diff --git a/ql-runtime/src/io/writer.rs b/ql-runtime/src/io/writer.rs new file mode 100644 index 0000000..cfad319 --- /dev/null +++ b/ql-runtime/src/io/writer.rs @@ -0,0 +1,294 @@ +use std::{ + future::poll_fn, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use ql_wire::{CloseTarget, StreamCloseCode}; + +use super::{ + inner::{Item, TxInner}, + slot::PopError, + PushError, Tx, +}; +use crate::{command::Command, log, QlStreamError, RuntimeHandle}; + +pub struct StreamWriter { + tx: Tx, + target: CloseTarget, + open: bool, + terminal: WriterTerminalState, + handle: RuntimeHandle, +} + +enum WriterTerminalState { + Pending, + Terminal(Result<(), QlStreamError>), +} + +unsafe impl Sync for StreamWriter {} + +impl std::fmt::Debug for StreamWriter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StreamWriter") + .field("stream_id", &self.tx.stream_id()) + .field("target", &self.target) + .field("closed", &!self.open) + .finish_non_exhaustive() + } +} + +impl StreamWriter { + pub(crate) fn new(shared: Tx, target: CloseTarget, handle: RuntimeHandle) -> Self { + Self { + tx: shared, + target, + open: true, + terminal: WriterTerminalState::Pending, + handle, + } + } + + pub fn poll_write( + &mut self, + bytes: &mut Bytes, + cx: &mut Context<'_>, + ) -> Poll> { + if bytes.is_empty() { + return Poll::Ready(Ok(())); + } + + if !self.open { + return self.poll_terminal(cx); + } + + match self.tx.try_write(std::mem::take(bytes)) { + Ok(()) => { + log::trace!( + "byte writer accepted chunk: stream_id={} target={:?}", + self.tx.stream_id(), + self.target + ); + self.poll_runtime(); + return Poll::Ready(Ok(())); + } + Err(PushError::Closed(chunk)) => { + *bytes = chunk; + self.open = false; + return self.poll_terminal(cx); + } + Err(PushError::Full(chunk)) => { + *bytes = chunk; + } + } + + self.tx.register_waiter(cx.waker()); + + match self.tx.try_write(std::mem::take(bytes)) { + Ok(()) => { + self.tx.unregister_waiter(); + log::trace!( + "byte writer accepted chunk: stream_id={} target={:?}", + self.tx.stream_id(), + self.target + ); + self.poll_runtime(); + Poll::Ready(Ok(())) + } + Err(PushError::Closed(chunk)) => { + self.tx.unregister_waiter(); + *bytes = chunk; + self.open = false; + self.poll_terminal(cx) + } + Err(PushError::Full(chunk)) => { + *bytes = chunk; + Poll::Pending + } + } + } + + pub async fn write(&mut self, bytes: Bytes) -> Result<(), QlStreamError> { + let mut bytes = bytes; + poll_fn(|cx| self.poll_write(&mut bytes, cx)).await + } + + pub fn queue_finish(&mut self) { + if !self.open { + return; + } + log::debug!( + "byte writer finish: stream_id={} target={:?}", + self.tx.stream_id(), + self.target + ); + self.open = false; + self.tx.request_finish(); + self.poll_runtime(); + } + + pub async fn finish(mut self) -> Result<(), QlStreamError> { + self.queue_finish(); + poll_fn(|cx| self.poll_terminal(cx)).await + } + + pub fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.open { + self.queue_finish(); + } + self.poll_terminal(cx) + } + + pub fn close(mut self, code: StreamCloseCode) { + self.close_inner(code); + } + + fn poll_runtime(&self) { + self.handle.try_send(Command::PollStream { + stream_id: self.tx.stream_id(), + }); + } + + fn poll_terminal(&mut self, cx: &Context<'_>) -> Poll> { + match &self.terminal { + WriterTerminalState::Terminal(result) => return Poll::Ready(result.clone()), + WriterTerminalState::Pending => {} + } + + match self.try_poll_terminal_ready() { + Poll::Ready(result) => return Poll::Ready(result), + Poll::Pending => {} + } + + self.tx.register_waiter(cx.waker()); + + match self.try_poll_terminal_ready() { + Poll::Ready(result) => { + self.tx.unregister_waiter(); + Poll::Ready(result) + } + Poll::Pending => Poll::Pending, + } + } + + fn try_poll_terminal_ready(&mut self) -> Poll> { + let state = self.tx.load_state(); + if TxInner::terminal_ready(state) { + if TxInner::terminal_ok(state) { + self.terminal = WriterTerminalState::Terminal(Ok(())); + return Poll::Ready(Ok(())); + } + + match self.tx.pop() { + Ok(Item::Error(error)) => { + self.terminal = WriterTerminalState::Terminal(Err(error.clone())); + return Poll::Ready(Err(error)); + } + Ok(Item::Chunk(_)) => { + panic!("writer terminal phase contained chunk data") + } + Err(PopError) => {} + } + } + + Poll::Pending + } + + fn close_inner(&mut self, code: StreamCloseCode) { + if !self.open { + return; + } + self.open = false; + log::debug!( + "byte writer close: stream_id={:?} target={:?} code={:?}", + self.tx.stream_id(), + self.target, + code + ); + self.handle.try_send(Command::CloseStream { + stream_id: self.tx.stream_id(), + target: self.target, + code, + }); + } +} + +impl Drop for StreamWriter { + fn drop(&mut self) { + self.close_inner(StreamCloseCode::CANCELLED); + } +} + +#[cfg(all(test, loom))] +mod loom_tests { + use std::task::{Context, Poll, Waker}; + + use bytes::Bytes; + use loom::thread; + use ql_wire::CloseTarget; + + use super::*; + use crate::io::sync::loom::*; + + #[test] + fn poll_write_observes_capacity_racing_with_registration() { + check_model(|| { + let inner = shared(); + inner.tx.try_write(Bytes::from_static(b"abc")).unwrap(); + + let mut writer = StreamWriter::new(Tx(inner.clone()), CloseTarget::Origin, handle()); + let mut bytes = Bytes::from_static(b"xyz"); + let mut cx = Context::from_waker(Waker::noop()); + + let drainer = { + let inner = inner.clone(); + thread::spawn(move || { + assert!(matches!(inner.tx.pop(), Ok(Item::Chunk(_)))); + }) + }; + + let first = writer.poll_write(&mut bytes, &mut cx); + drainer.join().unwrap(); + + match first { + Poll::Ready(Ok(())) => { + assert!(bytes.is_empty()); + } + Poll::Pending => { + assert_eq!(writer.poll_write(&mut bytes, &mut cx), Poll::Ready(Ok(()))); + assert!(bytes.is_empty()); + } + other => panic!("unexpected first poll result: {other:?}"), + } + }); + } + + #[test] + fn poll_finish_observes_terminal_racing_with_registration() { + check_model(|| { + let inner = shared(); + let mut writer = StreamWriter::new(Tx(inner.clone()), CloseTarget::Origin, handle()); + let mut cx = Context::from_waker(Waker::noop()); + + writer.queue_finish(); + + let finisher = { + let inner = inner.clone(); + thread::spawn(move || { + inner.tx.finish(); + }) + }; + + let first = writer.poll_finish(&mut cx); + finisher.join().unwrap(); + + match first { + Poll::Ready(Ok(())) => {} + Poll::Pending => { + assert_eq!(writer.poll_finish(&mut cx), Poll::Ready(Ok(()))); + } + other => panic!("unexpected first poll result: {other:?}"), + } + }); + } +} diff --git a/ql-runtime/src/lib.rs b/ql-runtime/src/lib.rs new file mode 100644 index 0000000..3378345 --- /dev/null +++ b/ql-runtime/src/lib.rs @@ -0,0 +1,63 @@ +pub use ql_fsm::{NoSessionError, PairingInvite}; + +pub use self::{error::QlStreamError, handle::*, platform::*}; + +pub(crate) mod command; +pub(crate) mod driver; +mod error; +pub mod handle; +pub(crate) mod io; +pub mod log; +pub mod platform; +#[cfg(feature = "rpc")] +pub mod rpc; + +#[cfg(test)] +mod tests; + +use ql_fsm::QlFsmConfig; +use ql_wire::QlIdentity; + +#[derive(Debug, Clone, Copy)] +pub struct RuntimeConfig { + pub fsm: QlFsmConfig, + pub max_concurrent_message_writes: usize, +} + +impl Default for RuntimeConfig { + fn default() -> Self { + Self { + fsm: QlFsmConfig::default(), + max_concurrent_message_writes: 4, + } + } +} + +pub struct Runtime

{ + identity: QlIdentity, + platform: P, + config: RuntimeConfig, + rx: async_channel::Receiver, + tx: async_channel::WeakSender, +} + +pub fn new_runtime

( + identity: QlIdentity, + platform: P, + config: RuntimeConfig, +) -> (Runtime

, RuntimeHandle) +where + P: QlPlatform, +{ + let (tx, rx) = async_channel::unbounded(); + ( + Runtime { + identity, + platform, + config, + rx, + tx: tx.downgrade(), + }, + RuntimeHandle::new(tx), + ) +} diff --git a/ql-runtime/src/log.rs b/ql-runtime/src/log.rs new file mode 100644 index 0000000..a0908f7 --- /dev/null +++ b/ql-runtime/src/log.rs @@ -0,0 +1,54 @@ +#![allow(unused_imports, unused_macros)] + +#[cfg(any(feature = "log", test))] +macro_rules! log { + ($level:ident, $($arg:tt)*) => { + ::log::log!(::log::Level::$level, $($arg)*) + }; +} + +#[cfg(not(any(feature = "log", test)))] +macro_rules! log { + ($level:ident, $($arg:tt)*) => { + if false { + let _ = format_args!($($arg)*); + } + }; +} + +macro_rules! trace { + ($($arg:tt)*) => { + $crate::log::log!(Trace, $($arg)*) + }; +} + +macro_rules! debug { + ($($arg:tt)*) => { + $crate::log::log!(Debug, $($arg)*) + }; +} + +macro_rules! info { + ($($arg:tt)*) => { + $crate::log::log!(Info, $($arg)*) + }; +} + +macro_rules! warn_ { + ($($arg:tt)*) => { + $crate::log::log!(Warn, $($arg)*) + }; +} + +macro_rules! error { + ($($arg:tt)*) => { + $crate::log::log!(Error, $($arg)*) + }; +} + +pub(crate) use debug; +pub(crate) use error; +pub(crate) use info; +pub(crate) use log; +pub(crate) use trace; +pub(crate) use warn_ as warn; diff --git a/ql-runtime/src/platform.rs b/ql-runtime/src/platform.rs new file mode 100644 index 0000000..331bfe7 --- /dev/null +++ b/ql-runtime/src/platform.rs @@ -0,0 +1,43 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Instant, +}; + +use ql_fsm::{PeerStatus, ReceiveError}; +use ql_wire::{PeerBundle, QlCrypto, QID}; + +use crate::QlStream; + +pub trait QlTimer { + fn set_deadline(self: Pin<&mut Self>, deadline: Option); + fn poll_wait(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()>; +} + +pub trait QlInbound { + fn poll_recv(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; +} + +pub trait QlPlatform: QlCrypto { + type Timer: QlTimer; + type WriteMessageFut<'a>: Future + Unpin + 'a + where + Self: 'a; + type Inbound: QlInbound; + + fn write_message(&self, message: Vec) -> Self::WriteMessageFut<'_>; + /// Returns the platform's inbound transport poller. + /// + /// The runtime calls this once while starting the driver loop and retains the returned + /// poller for the lifetime of the runtime. Platform implementations may panic if this is + /// called more than once. + fn inbound(&mut self) -> Self::Inbound; + fn timer(&self) -> Self::Timer; + + fn persist_peer(&self, peer: PeerBundle); + + fn handle_peer_status(&self, peer: Option, status: PeerStatus); + fn handle_inbound(&self, event: QlStream); + fn handle_recv_error(&self, _error: ReceiveError) {} +} diff --git a/ql-runtime/src/rpc/adapter.rs b/ql-runtime/src/rpc/adapter.rs new file mode 100644 index 0000000..a734760 --- /dev/null +++ b/ql-runtime/src/rpc/adapter.rs @@ -0,0 +1,83 @@ +use std::task::{Context, Poll}; + +use bytes::Bytes; +use ql_rpc::{RouteId, RpcRead, RpcStream, RpcWrite, StreamCloseCode, StreamError}; +use ql_wire::{RouteId as WireRouteId, StreamCloseCode as WireStreamCloseCode}; + +use crate::{QlStream, QlStreamError, StreamReader, StreamWriter}; + +impl RpcStream for QlStream { + type Error = QlStreamError; + type Reader = StreamReader; + type Writer = StreamWriter; + + fn route_id(&self) -> Option { + let route_id = u32::try_from(self.route_id.into_inner()).ok()?; + Some(RouteId::from_u32(route_id)) + } + + fn split(self) -> (Self::Reader, Self::Writer) { + (self.reader, self.writer) + } +} + +impl RpcRead for StreamReader { + type Error = QlStreamError; + + fn poll_read( + &mut self, + max_len: usize, + cx: &mut Context<'_>, + ) -> Poll, QlStreamError>> { + StreamReader::poll_read(self, max_len, cx) + } + + fn close(self, code: StreamCloseCode) { + StreamReader::close(self, to_wire_close_code(code)); + } +} + +impl RpcWrite for StreamWriter { + type Error = QlStreamError; + + fn poll_write( + &mut self, + bytes: &mut Bytes, + cx: &mut Context<'_>, + ) -> Poll> { + StreamWriter::poll_write(self, bytes, cx) + } + + fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll> { + StreamWriter::poll_finish(self, cx) + } + + fn close(self, code: StreamCloseCode) { + StreamWriter::close(self, to_wire_close_code(code)); + } +} + +pub(super) fn to_wire_route_id(route_id: RouteId) -> WireRouteId { + WireRouteId::from_u32(route_id.into_inner()) +} + +pub(super) fn to_wire_close_code(code: StreamCloseCode) -> WireStreamCloseCode { + WireStreamCloseCode(code.into_inner()) +} + +impl From for QlStreamError { + fn from(code: StreamCloseCode) -> Self { + Self::StreamClosed { + code: WireStreamCloseCode(code.into_inner()), + } + } +} + +impl StreamError for QlStreamError { + fn close_code(&self) -> Option { + match self { + QlStreamError::StreamClosed { code } => Some(StreamCloseCode(code.0)), + QlStreamError::NoSession => None, + } + } +} diff --git a/ql-runtime/src/rpc/download.rs b/ql-runtime/src/rpc/download.rs new file mode 100644 index 0000000..d3b6358 --- /dev/null +++ b/ql-runtime/src/rpc/download.rs @@ -0,0 +1,67 @@ +use bytes::Bytes; +use ql_rpc::download::Download as DownloadRpc; + +use super::RpcError; +use crate::StreamReader; + +pub struct DownloadCall { + pub(super) inner: ql_rpc::download::DownloadCall, +} + +pub struct DownloadReader { + pub(super) inner: ql_rpc::download::DownloadReader, +} + +pub struct DownloadPart<'a, M: DownloadRpc> { + inner: ql_rpc::download::DownloadPart<'a, M, StreamReader>, +} + +impl DownloadCall +where + M: DownloadRpc, +{ + pub async fn start(self) -> Result<(M::ResponseHeader, DownloadReader), RpcError> { + let (header, inner) = self.inner.start().await?; + Ok((header, DownloadReader { inner })) + } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} + +impl DownloadReader +where + M: DownloadRpc, +{ + pub async fn next_part( + &mut self, + ) -> Result)>, RpcError> { + Ok(self + .inner + .next_part() + .await? + .map(|(header, inner)| (header, DownloadPart { inner }))) + } + + pub async fn complete(self) -> Result<(), RpcError> { + self.inner.complete().await.map_err(RpcError::from) + } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} + +impl DownloadPart<'_, M> +where + M: DownloadRpc, +{ + pub async fn read_chunk(&mut self) -> Result, RpcError> { + Ok(self.inner.read_chunk().await?) + } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} diff --git a/ql-runtime/src/rpc/duplex.rs b/ql-runtime/src/rpc/duplex.rs new file mode 100644 index 0000000..cdad667 --- /dev/null +++ b/ql-runtime/src/rpc/duplex.rs @@ -0,0 +1,59 @@ +use futures_lite::future::poll_fn; +use ql_rpc::duplex::Duplex as DuplexRpc; + +use super::RpcError; +use crate::{QlStreamError, StreamReader, StreamWriter}; + +pub struct DuplexCall { + pub sender: DuplexSender, + pub receiver: DuplexReceiver, +} + +pub struct DuplexSender +where + T: ql_rpc::RpcCodec, +{ + pub(super) inner: ql_rpc::duplex::DuplexSender, +} + +pub struct DuplexReceiver +where + T: ql_rpc::RpcCodec, +{ + pub(super) inner: ql_rpc::duplex::DuplexReceiver, +} + +impl DuplexSender +where + T: ql_rpc::RpcCodec, +{ + pub async fn send(&mut self, event: &T) -> Result<(), QlStreamError> { + self.inner.send(event).await + } + + pub async fn finish(self) -> Result<(), QlStreamError> { + self.inner.finish().await + } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} + +impl DuplexReceiver +where + T: ql_rpc::RpcCodec, +{ + pub async fn next_event(&mut self) -> Option>> { + poll_fn(|cx| { + self.inner + .poll_next_event(cx) + .map(|item| item.map(|result| Ok(result?))) + }) + .await + } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} diff --git a/ql-runtime/src/rpc/error.rs b/ql-runtime/src/rpc/error.rs new file mode 100644 index 0000000..4cc9e17 --- /dev/null +++ b/ql-runtime/src/rpc/error.rs @@ -0,0 +1,79 @@ +use ql_fsm::NoSessionError; + +use crate::QlStreamError; + +#[derive(Debug)] +pub enum RpcError { + NoSession, + Closed(ql_rpc::StreamCloseCode), + Protocol(ql_rpc::Error), + Codec(E), +} + +impl From for RpcError { + fn from(_: NoSessionError) -> Self { + Self::NoSession + } +} + +impl From for RpcError { + fn from(error: QlStreamError) -> Self { + match error { + QlStreamError::StreamClosed { code } => Self::Closed(ql_rpc::StreamCloseCode(code.0)), + QlStreamError::NoSession => Self::NoSession, + } + } +} + +impl From for RpcError { + fn from(error: ql_rpc::Error) -> Self { + Self::Protocol(error) + } +} + +impl From> for RpcError { + fn from(error: ql_rpc::CodecError) -> Self { + match error { + ql_rpc::CodecError::Rpc(error) => Self::Protocol(error), + ql_rpc::CodecError::Codec(error) => Self::Codec(error), + } + } +} + +impl From> for RpcError { + fn from(error: ql_rpc::CallError) -> Self { + match error { + ql_rpc::CallError::Protocol(error) => Self::Protocol(error), + ql_rpc::CallError::Codec(error) => Self::Codec(error), + ql_rpc::CallError::Transport(error) => error.into(), + } + } +} + +impl std::fmt::Display for RpcError +where + E: std::fmt::Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::NoSession => write!(f, "no session"), + Self::Closed(code) => write!(f, "stream closed {code:?}"), + Self::Protocol(error) => write!(f, "{error}"), + Self::Codec(error) => write!(f, "{error}"), + } + } +} + +impl std::error::Error for RpcError +where + E: std::error::Error + 'static, +{ + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Protocol(error) => Some(error), + Self::Codec(error) => Some(error), + RpcError::NoSession => None, + RpcError::Closed(_) => None, + } + } +} diff --git a/ql-runtime/src/rpc/mod.rs b/ql-runtime/src/rpc/mod.rs new file mode 100644 index 0000000..d8be02c --- /dev/null +++ b/ql-runtime/src/rpc/mod.rs @@ -0,0 +1,154 @@ +pub use self::{download::*, duplex::*, error::*, progress::*, subscription::*, upload::*}; + +mod adapter; +mod download; +mod duplex; +mod error; +mod progress; +mod subscription; +mod upload; + +use bytes::Bytes; +use ql_rpc::{ + download::{self as rpc_download, Download as DownloadRpc}, + duplex::{self as rpc_duplex, Duplex as DuplexRpc}, + notification::{self, Notification}, + progress::{self as rpc_progress, Progress}, + request::{self, Request as RequestRpc}, + subscription::{self as rpc_subscription, Subscription as SubscriptionRpc}, + upload::{self as rpc_upload, Upload as UploadRpc}, +}; + +use crate::{RuntimeHandle, StreamReader}; + +#[derive(Clone)] +pub struct RpcHandle { + inner: RuntimeHandle, +} + +impl RpcHandle { + pub async fn notification(&self, event: &M::Payload) -> Result<(), RpcError> + where + M: Notification, + { + let mut payload = Vec::new(); + notification::encode_notification::(event, &mut payload); + let mut stream = self + .inner + .open_stream(adapter::to_wire_route_id(M::ROUTE)) + .await?; + stream.reader.close(ql_wire::StreamCloseCode::CANCELLED); + stream.writer.write(Bytes::from(payload)).await?; + stream.writer.finish().await?; + Ok(()) + } + + pub async fn request(&self, request: &M::Request) -> Result> + where + M: RequestRpc, + { + let mut payload = Vec::new(); + request::encode_request::(request, &mut payload); + let response = self.start_request(M::ROUTE, payload).await?; + Ok(request::read_response::(response).await?) + } + + pub async fn subscribe( + &self, + request: &M::Request, + ) -> Result, RpcError> + where + M: SubscriptionRpc, + { + let mut payload = Vec::new(); + rpc_subscription::encode_request::(request, &mut payload); + let response = self.start_request(M::ROUTE, payload).await?; + Ok(Subscription { + inner: rpc_subscription::SubscriptionCall::new(response), + }) + } + + pub async fn download( + &self, + request: &M::Request, + ) -> Result, RpcError> + where + M: DownloadRpc, + { + let mut payload = Vec::new(); + rpc_download::encode_request::(request, &mut payload); + let response = self.start_request(M::ROUTE, payload).await?; + Ok(DownloadCall { + inner: rpc_download::DownloadCall::new(response), + }) + } + + pub async fn progress( + &self, + request: &M::Request, + ) -> Result, RpcError> + where + M: Progress, + { + let mut payload = Vec::new(); + rpc_progress::encode_request::(request, &mut payload); + let response = self.start_request(M::ROUTE, payload).await?; + Ok(ProgressCall { + inner: rpc_progress::ProgressCall::new(response), + }) + } + + pub async fn upload(&self, request: &M::Request) -> Result, RpcError> + where + M: UploadRpc, + { + let mut payload = Vec::new(); + rpc_upload::encode_request::(request, &mut payload); + let mut stream = self + .inner + .open_stream(adapter::to_wire_route_id(M::ROUTE)) + .await?; + stream.writer.write(Bytes::from(payload)).await?; + Ok(UploadCall { + inner: rpc_upload::UploadCall::new(stream.writer, stream.reader), + }) + } + + pub async fn duplex(&self) -> Result, RpcError> + where + M: DuplexRpc, + { + let stream = self + .inner + .open_stream(adapter::to_wire_route_id(M::ROUTE)) + .await?; + Ok(DuplexCall { + sender: DuplexSender { + inner: rpc_duplex::DuplexSender::new(stream.writer), + }, + receiver: DuplexReceiver { + inner: rpc_duplex::DuplexReceiver::new(stream.reader), + }, + }) + } +} + +impl RpcHandle { + pub(super) fn new(inner: RuntimeHandle) -> Self { + Self { inner } + } + + async fn start_request( + &self, + route_id: ql_rpc::RouteId, + payload: Vec, + ) -> Result> { + let mut stream = self + .inner + .open_stream(adapter::to_wire_route_id(route_id)) + .await?; + stream.writer.write(Bytes::from(payload)).await?; + stream.writer.finish().await?; + Ok(stream.reader) + } +} diff --git a/ql-runtime/src/rpc/progress.rs b/ql-runtime/src/rpc/progress.rs new file mode 100644 index 0000000..a22da20 --- /dev/null +++ b/ql-runtime/src/rpc/progress.rs @@ -0,0 +1,50 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_lite::Stream; +use ql_rpc::progress::Progress; + +use super::RpcError; +use crate::StreamReader; + +pub struct ProgressCall { + pub(super) inner: ql_rpc::progress::ProgressCall, +} + +impl Unpin for ProgressCall where M: Progress {} + +impl ProgressCall +where + M: Progress, +{ + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} + +impl Stream for ProgressCall +where + M: Progress, +{ + type Item = M::Progress; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().inner.poll_next_progress(cx) + } +} + +impl Future for ProgressCall +where + M: Progress, +{ + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.get_mut().inner) + .poll(cx) + .map(|result| result.map_err(RpcError::from)) + } +} diff --git a/ql-runtime/src/rpc/subscription.rs b/ql-runtime/src/rpc/subscription.rs new file mode 100644 index 0000000..45a08a6 --- /dev/null +++ b/ql-runtime/src/rpc/subscription.rs @@ -0,0 +1,43 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use futures_lite::{future::poll_fn, Stream}; +use ql_rpc::subscription::Subscription as SubscriptionRpc; + +use super::RpcError; +use crate::StreamReader; + +pub struct Subscription { + pub(super) inner: ql_rpc::subscription::SubscriptionCall, +} + +impl Unpin for Subscription where M: SubscriptionRpc {} + +impl Subscription +where + M: SubscriptionRpc, +{ + pub async fn next_event(&mut self) -> Option>> { + poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await + } + + pub fn close(self, code: ql_wire::StreamCloseCode) { + self.inner.close(ql_rpc::StreamCloseCode(code.0)); + } +} + +impl Stream for Subscription +where + M: SubscriptionRpc, +{ + type Item = Result>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut() + .inner + .poll_next_event(cx) + .map(|item| item.map(|result| Ok(result?))) + } +} diff --git a/ql-runtime/src/rpc/upload.rs b/ql-runtime/src/rpc/upload.rs new file mode 100644 index 0000000..33ee366 --- /dev/null +++ b/ql-runtime/src/rpc/upload.rs @@ -0,0 +1,44 @@ +use bytes::Bytes; +use ql_rpc::upload::Upload as UploadRpc; + +use super::RpcError; +use crate::QlStreamError; + +pub struct UploadCall { + pub(super) inner: ql_rpc::upload::UploadCall, +} + +pub struct UploadPartWriter<'a, M: UploadRpc> { + inner: ql_rpc::upload::UploadPartWriter<'a, M, crate::StreamWriter, crate::StreamReader>, +} + +impl UploadCall +where + M: UploadRpc, +{ + pub async fn start_part( + &mut self, + part_header: M::PartHeader, + ) -> Result, QlStreamError> { + Ok(UploadPartWriter { + inner: self.inner.start_part(part_header).await?, + }) + } + + pub async fn finish(self) -> Result> { + self.inner.finish().await.map_err(RpcError::from) + } +} + +impl UploadPartWriter<'_, M> +where + M: UploadRpc, +{ + pub async fn send(&mut self, bytes: Bytes) -> Result<(), QlStreamError> { + self.inner.send(bytes).await + } + + pub async fn finish(self) -> Result<(), QlStreamError> { + self.inner.finish().await + } +} diff --git a/ql-runtime/src/tests/handshake.rs b/ql-runtime/src/tests/handshake.rs new file mode 100644 index 0000000..65731bb --- /dev/null +++ b/ql-runtime/src/tests/handshake.rs @@ -0,0 +1,178 @@ +use std::time::Duration; + +use bytes::Bytes; + +use super::*; + +#[tokio::test(flavor = "current_thread")] +async fn connect_round_trip_changes_peer_status() { + run_local_test(async { + let pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn opening_stream_requires_connection() { + run_local_test(async { + let pair = TestPair::new(default_runtime_config()); + assert!(matches!( + pair.side(Side::A).handle.open_stream(test_route_id()).await, + Err(NoSessionError) + )); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn handshake_timeout_disconnects() { + run_local_test(async { + let config = RuntimeConfig { + fsm: QlFsmConfig { + handshake_timeout: Duration::from_millis(60), + ..default_runtime_config().fsm + }, + ..default_runtime_config() + }; + let (platform_a, _outbound_a, _inbound_a, status_a) = TestPlatform::new(); + let (platform_b, _outbound_b, _inbound_b, _status_b) = TestPlatform::new(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect(); + + await_status(&status_a, Some(identity_b.qid), PeerStatus::Disconnected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rejected_session_write_is_reissued() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, inbound_a_tx, status_a) = + TestPlatform::new_with_session_write_failure(1); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_forwarder(outbound_b, inbound_a_tx); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect(); + + await_status(&status_a, Some(identity_b.qid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.qid), PeerStatus::Connected).await; + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let request = read_all(stream.reader).await.unwrap(); + stream.writer.finish().await.unwrap(); + request + }); + + let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); + stream + .writer + .write(Bytes::from_static(b"retry")) + .await + .unwrap(); + stream.writer.finish().await.unwrap(); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); + + assert_eq!( + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(), + b"retry".to_vec() + ); + + assert_no_status_for( + &status_a, + Some(identity_b.qid), + PeerStatus::Disconnected, + Duration::from_millis(150), + ) + .await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn start_pairing_round_trip_connects_when_armed() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, inbound_b_tx, status_b) = TestPlatform::new(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + let token = pairing_token(7); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_forwarder(outbound_b, inbound_a_tx); + + handle_b.arm_pairing(token); + handle_a.start_pairing(PairingInvite { + qid: identity_b.qid, + token, + }); + + await_status(&status_a, Some(identity_b.qid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.qid), PeerStatus::Connected).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn start_pairing_does_not_connect_when_unarmed() { + run_local_test(async { + let config = default_runtime_config(); + let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, inbound_b_tx, _status_b) = TestPlatform::new(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + let token = pairing_token(8); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, _handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_forwarder(outbound_b, inbound_a_tx); + + handle_a.start_pairing(PairingInvite { + qid: identity_b.qid, + token, + }); + + assert_no_status_for( + &status_a, + Some(identity_b.qid), + PeerStatus::Connected, + Duration::from_millis(150), + ) + .await; + }) + .await; +} diff --git a/ql-runtime/src/tests/mod.rs b/ql-runtime/src/tests/mod.rs new file mode 100644 index 0000000..af36873 --- /dev/null +++ b/ql-runtime/src/tests/mod.rs @@ -0,0 +1,710 @@ +use std::{ + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, Once, + }, + task::{Context, Poll}, + time::Duration, +}; + +use async_channel::{Receiver, Sender}; +use futures_lite::Stream; +use ql_fsm::PeerStatus; +use ql_wire::{ + generate_identity, test_identities, MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, + MlKemPublicKey, Nonce, PairingToken, PeerBundle, QlAead, QlHash, QlIdentity, QlKem, QlRandom, + RecordHeader, RecordType, RouteId, SessionKey, SoftwareCrypto, WireDecode, QID, +}; +use tokio::{task::LocalSet, time::Sleep}; + +use crate::{ + new_runtime, platform::QlTimer, NoSessionError, PairingInvite, QlFsmConfig, QlStream, + QlStreamError, RuntimeConfig, RuntimeHandle, +}; + +mod handshake; +#[cfg(feature = "rpc")] +mod rpc; +mod session; +mod stream; + +fn init_test_logger() { + static INIT: Once = Once::new(); + + INIT.call_once(|| { + let env = env_logger::Env::default().default_filter_or("ql_runtime=info"); + let mut builder = env_logger::Builder::from_env(env); + builder.is_test(true); + let _ = builder.try_init(); + }); +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct StatusEvent { + peer: Option, + status: PeerStatus, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Side { + A, + B, +} + +impl Side { + fn opposite(self) -> Self { + match self { + Self::A => Self::B, + Self::B => Self::A, + } + } +} + +fn test_route_id() -> RouteId { + RouteId::from_u32(1) +} + +#[derive(Debug, Clone)] +struct WriteStats { + active: Arc, + max_active: Arc, +} + +impl WriteStats { + fn new() -> Self { + Self { + active: Arc::new(AtomicUsize::new(0)), + max_active: Arc::new(AtomicUsize::new(0)), + } + } + + fn max_active(&self) -> usize { + self.max_active.load(Ordering::Relaxed) + } +} + +struct TestPlatform { + outbound: Sender>, + _inbound_messages_tx: Sender>, + inbound_messages: Option>>, + status: Sender, + inbound: Option>, + crypto: SoftwareCrypto, + encrypted_write_counter: AtomicUsize, + fail_encrypted_write_at: Option, + write_delay: Duration, + write_stats: Option, +} + +struct TestInbound { + receiver: Receiver>, +} + +type TestPlatformParts = ( + TestPlatform, + Receiver>, + Sender>, + Receiver, +); + +type TestPlatformPartsWithInbound = ( + TestPlatform, + Receiver>, + Sender>, + Receiver, + Receiver, +); + +impl TestPlatform { + fn new() -> TestPlatformParts { + Self::new_inner(None, None, Duration::ZERO, None) + } + + fn new_with_inbound() -> TestPlatformPartsWithInbound { + let (inbound_tx, inbound_rx) = async_channel::unbounded(); + let (platform, outbound_rx, inbound_messages_tx, status_rx) = + Self::new_inner(Some(inbound_tx), None, Duration::ZERO, None); + ( + platform, + outbound_rx, + inbound_messages_tx, + status_rx, + inbound_rx, + ) + } + + fn new_with_session_write_failure(fail_encrypted_write_at: usize) -> TestPlatformParts { + Self::new_inner(None, Some(fail_encrypted_write_at), Duration::ZERO, None) + } + + fn new_with_delayed_writes(delay: Duration, write_stats: WriteStats) -> TestPlatformParts { + Self::new_inner(None, None, delay, Some(write_stats)) + } + + fn new_inner( + inbound: Option>, + fail_encrypted_write_at: Option, + write_delay: Duration, + write_stats: Option, + ) -> TestPlatformParts { + let (outbound, outbound_rx) = async_channel::unbounded(); + let (inbound_messages_tx, inbound_messages_rx) = async_channel::unbounded(); + let (status, status_rx) = async_channel::unbounded(); + ( + Self { + outbound, + _inbound_messages_tx: inbound_messages_tx.clone(), + inbound_messages: Some(inbound_messages_rx), + status, + inbound, + crypto: SoftwareCrypto, + encrypted_write_counter: AtomicUsize::new(0), + fail_encrypted_write_at, + write_delay, + write_stats, + }, + outbound_rx, + inbound_messages_tx, + status_rx, + ) + } +} + +struct TestSide { + handle: RuntimeHandle, + status: Receiver, + peer: QID, + inbound: Receiver, +} + +struct TestPair { + a: TestSide, + b: TestSide, +} + +#[derive(Debug, Clone, Copy, Default)] +struct LinkBehavior { + base_delay: Duration, + drop_encrypted_every: Option, + duplicate_encrypted_every: Option, + delay_encrypted_every: Option<(usize, Duration)>, +} + +#[derive(Clone, Default)] +struct LinkController { + behavior: Arc>, +} + +impl LinkController { + fn new(behavior: LinkBehavior) -> Self { + Self { + behavior: Arc::new(Mutex::new(behavior)), + } + } + + fn load(&self) -> LinkBehavior { + *self.behavior.lock().unwrap() + } + + fn store(&self, behavior: LinkBehavior) { + *self.behavior.lock().unwrap() = behavior; + } +} + +#[derive(Clone)] +struct ControlledLinks { + a_to_b: LinkController, + b_to_a: LinkController, +} + +impl TestPair { + fn new(config: RuntimeConfig) -> Self { + Self::new_with_links(config, LinkBehavior::default(), LinkBehavior::default()) + } + + fn new_with_links(config: RuntimeConfig, a_to_b: LinkBehavior, b_to_a: LinkBehavior) -> Self { + let (pair, _links) = Self::new_with_controlled_links(config, a_to_b, b_to_a); + pair + } + + fn new_with_controlled_links( + config: RuntimeConfig, + a_to_b: LinkBehavior, + b_to_a: LinkBehavior, + ) -> (Self, ControlledLinks) { + let (platform_a, outbound_a, inbound_a_tx, status_a, inbound_a) = + TestPlatform::new_with_inbound(); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + let links = ControlledLinks { + a_to_b: LinkController::new(a_to_b), + b_to_a: LinkController::new(b_to_a), + }; + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_simulated_forwarder(outbound_a, inbound_b_tx, links.a_to_b.clone()); + spawn_simulated_forwarder(outbound_b, inbound_a_tx, links.b_to_a.clone()); + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + + ( + Self { + a: TestSide { + handle: handle_a, + status: status_a, + peer: identity_a.qid, + inbound: inbound_a, + }, + b: TestSide { + handle: handle_b, + status: status_b, + peer: identity_b.qid, + inbound: inbound_b, + }, + }, + links, + ) + } + + fn side(&self, side: Side) -> &TestSide { + match side { + Side::A => &self.a, + Side::B => &self.b, + } + } + + fn side_mut(&mut self, side: Side) -> &mut TestSide { + match side { + Side::A => &mut self.a, + Side::B => &mut self.b, + } + } + + async fn connect_and_wait(&self, initiator: Side) { + self.side(initiator).handle.connect(); + await_status( + &self.side(initiator).status, + Some(self.side(initiator.opposite()).peer), + PeerStatus::Connected, + ) + .await; + await_status( + &self.side(initiator.opposite()).status, + Some(self.side(initiator).peer), + PeerStatus::Connected, + ) + .await; + } + + fn take_inbound(&mut self, side: Side) -> Receiver { + let replacement = async_channel::unbounded().1; + std::mem::replace(&mut self.side_mut(side).inbound, replacement) + } +} + +struct TokioTimer { + sleep: Pin>, +} + +impl TokioTimer { + fn new() -> Self { + Self { + sleep: Box::pin(tokio::time::sleep_until(parked_deadline())), + } + } +} + +impl QlTimer for TokioTimer { + fn set_deadline(mut self: Pin<&mut Self>, deadline: Option) { + let deadline = deadline.map_or_else(parked_deadline, tokio::time::Instant::from_std); + self.as_mut().get_mut().sleep.as_mut().reset(deadline); + } + + fn poll_wait(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + self.as_mut().get_mut().sleep.as_mut().poll(cx) + } +} + +impl QlRandom for TestPlatform { + fn fill_random_bytes(&self, data: &mut [u8]) { + self.crypto.fill_random_bytes(data); + } +} + +impl QlHash for TestPlatform { + fn sha256(&self, parts: &[&[u8]]) -> [u8; 32] { + self.crypto.sha256(parts) + } +} + +impl QlAead for TestPlatform { + fn aes256_gcm_encrypt( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + ) -> [u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE] { + self.crypto.aes256_gcm_encrypt(key, nonce, aad, buffer) + } + + fn aes256_gcm_decrypt( + &self, + key: &SessionKey, + nonce: &Nonce, + aad: &[u8], + buffer: &mut [u8], + auth_tag: &[u8; ql_wire::ENCRYPTED_MESSAGE_AUTH_SIZE], + ) -> bool { + self.crypto + .aes256_gcm_decrypt(key, nonce, aad, buffer, auth_tag) + } +} + +impl QlKem for TestPlatform { + fn mlkem_generate_keypair(&self) -> MlKemKeyPair { + self.crypto.mlkem_generate_keypair() + } + + fn mlkem_encapsulate(&self, public_key: &MlKemPublicKey) -> (MlKemCiphertext, SessionKey) { + self.crypto.mlkem_encapsulate(public_key) + } + + fn mlkem_decapsulate(&self, pk: &MlKemPrivateKey, cipher: &MlKemCiphertext) -> SessionKey { + self.crypto.mlkem_decapsulate(pk, cipher) + } +} + +impl crate::platform::QlPlatform for TestPlatform { + type Timer = TokioTimer; + type WriteMessageFut<'a> = Pin + Send + 'a>>; + type Inbound = TestInbound; + + fn write_message(&self, message: Vec) -> Self::WriteMessageFut<'_> { + let outbound = self.outbound.clone(); + let write_delay = self.write_delay; + let fail_encrypted_write_at = self.fail_encrypted_write_at; + let write_stats = self.write_stats.clone(); + + Box::pin(async move { + if let Some(stats) = write_stats.as_ref() { + let active = stats.active.fetch_add(1, Ordering::Relaxed) + 1; + stats.max_active.fetch_max(active, Ordering::Relaxed); + } + + if !write_delay.is_zero() { + tokio::time::sleep(write_delay).await; + } + + let should_fail = if is_encrypted_payload(&message) { + let count = self.encrypted_write_counter.fetch_add(1, Ordering::Relaxed) + 1; + fail_encrypted_write_at == Some(count) + } else { + false + }; + + let success = if should_fail { + false + } else { + outbound.send(message).await.is_ok() + }; + + if let Some(stats) = write_stats.as_ref() { + stats.active.fetch_sub(1, Ordering::Relaxed); + } + + success + }) + } + + fn inbound(&mut self) -> Self::Inbound { + TestInbound { + receiver: self + .inbound_messages + .take() + .expect("TestPlatform::inbound may only be called once"), + } + } + + fn timer(&self) -> Self::Timer { + TokioTimer::new() + } + + fn persist_peer(&self, _peer: PeerBundle) {} + + fn handle_peer_status(&self, peer: Option, status: PeerStatus) { + let _ = self.status.try_send(StatusEvent { peer, status }); + } + + fn handle_inbound(&self, event: QlStream) { + if let Some(tx) = &self.inbound { + let _ = tx.try_send(event); + } + } +} + +impl crate::platform::QlInbound for TestInbound { + fn poll_recv(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match unsafe { self.as_mut().map_unchecked_mut(|this| &mut this.receiver) }.poll_next(cx) { + Poll::Ready(Some(bytes)) => Poll::Ready(bytes), + Poll::Ready(None) => panic!("TestInbound channel closed"), + Poll::Pending => Poll::Pending, + } + } +} + +fn parked_deadline() -> tokio::time::Instant { + tokio::time::Instant::now() + Duration::from_secs(60 * 60 * 24 * 365 * 100) +} + +fn is_encrypted_payload(bytes: &[u8]) -> bool { + RecordHeader::decode_bytes(bytes) + .ok() + .is_some_and(|header| header.record_type == RecordType::Session) +} + +fn pairing_token(byte: u8) -> PairingToken { + PairingToken([byte; PairingToken::SIZE]) +} + +fn register_peers( + handle_a: &RuntimeHandle, + handle_b: &RuntimeHandle, + id_a: &QlIdentity, + id_b: &QlIdentity, +) { + handle_a.bind_peer(id_b.bundle()); + handle_b.bind_peer(id_a.bundle()); +} + +fn spawn_forwarder(outbound: Receiver>, inbound: Sender>) { + spawn_simulated_forwarder( + outbound, + inbound, + LinkController::new(LinkBehavior::default()), + ); +} + +fn spawn_simulated_forwarder( + outbound: Receiver>, + inbound: Sender>, + controller: LinkController, +) { + tokio::task::spawn_local(async move { + let mut encrypted_count = 0usize; + while let Ok(bytes) = outbound.recv().await { + let behavior = controller.load(); + let encrypted = is_encrypted_payload(&bytes); + let ordinal = if encrypted { + encrypted_count = encrypted_count.saturating_add(1); + Some(encrypted_count) + } else { + None + }; + + if ordinal.is_some_and(|count| { + behavior + .drop_encrypted_every + .is_some_and(|nth| nth != 0 && count % nth == 0) + }) { + continue; + } + + let mut delay = behavior.base_delay; + if let Some(count) = ordinal { + if let Some((nth, extra_delay)) = behavior.delay_encrypted_every { + if nth != 0 && count % nth == 0 { + delay += extra_delay; + } + } + } + + let primary = bytes.clone(); + let primary_inbound = inbound.clone(); + tokio::task::spawn_local(async move { + if !delay.is_zero() { + tokio::time::sleep(delay).await; + } + let _ = primary_inbound.try_send(primary); + }); + + if ordinal.is_some_and(|count| { + behavior + .duplicate_encrypted_every + .is_some_and(|nth| nth != 0 && count % nth == 0) + }) { + let duplicate_inbound = inbound.clone(); + tokio::task::spawn_local(async move { + let duplicate_delay = delay + Duration::from_millis(1); + if !duplicate_delay.is_zero() { + tokio::time::sleep(duplicate_delay).await; + } + let _ = duplicate_inbound.try_send(bytes); + }); + } + } + }); +} + +fn spawn_drop_every_nth_encrypted_forwarder( + outbound: Receiver>, + inbound: Sender>, + nth: usize, +) { + tokio::task::spawn_local(async move { + let mut encrypted_count = 0usize; + while let Ok(bytes) = outbound.recv().await { + if nth > 0 && is_encrypted_payload(&bytes) { + encrypted_count = encrypted_count.saturating_add(1); + if encrypted_count % nth == 0 { + continue; + } + } + let _ = inbound.try_send(bytes); + } + }); +} + +fn spawn_gated_forwarder( + outbound: Receiver>, + inbound: Sender>, + drop_flag: Arc, +) { + tokio::task::spawn_local(async move { + while let Ok(bytes) = outbound.recv().await { + if drop_flag.load(Ordering::Relaxed) { + continue; + } + let _ = inbound.try_send(bytes); + } + }); +} + +#[allow(clippy::future_not_send)] +async fn run_local_test(future: F) +where + F: Future, +{ + run_local_test_timeout(Duration::from_secs(5), future).await; +} + +#[allow(clippy::future_not_send)] +async fn run_local_test_timeout(duration: Duration, future: F) +where + F: Future, +{ + init_test_logger(); + let local = LocalSet::new(); + let future = local.run_until(future); + tokio::time::timeout(duration, future) + .await + .unwrap_or_else(|_| panic!("local runtime test exceeded {duration:?}")); +} + +async fn await_status(receiver: &Receiver, peer: Option, stage: PeerStatus) { + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if let Ok(event) = receiver.recv().await { + if event.peer == peer && event.status == stage { + return; + } + } + } + }) + .await + .unwrap(); +} + +async fn assert_no_status_for( + receiver: &Receiver, + peer: Option, + status: PeerStatus, + window: Duration, +) { + let res = tokio::time::timeout(window, async { + loop { + let event = receiver.recv().await.unwrap(); + if event.peer == peer && event.status == status { + return; + } + } + }) + .await; + assert!(res.is_err(), "unexpected status event: {status:?}"); +} + +async fn read_all(mut stream: crate::StreamReader) -> Result, QlStreamError> { + let mut data = Vec::new(); + while let Some(chunk) = next_chunk(&mut stream).await? { + data.extend_from_slice(&chunk); + } + Ok(data) +} + +async fn next_chunk_max( + stream: &mut crate::StreamReader, + max_len: usize, +) -> Result>, crate::QlStreamError> { + stream + .read(max_len) + .await + .map(|chunk| chunk.map(|bytes| bytes.to_vec())) +} + +async fn next_chunk(stream: &mut crate::StreamReader) -> Result>, QlStreamError> { + next_chunk_max(stream, usize::MAX).await +} + +fn default_runtime_config() -> RuntimeConfig { + RuntimeConfig { + fsm: QlFsmConfig { + handshake_timeout: Duration::from_millis(300), + session_record_retransmit_timeout: Duration::from_millis(30), + session_keepalive_interval: Duration::ZERO, + session_peer_timeout: Duration::ZERO, + ..Default::default() + }, + ..Default::default() + } +} + +// runtime is send, if platform is send +#[test] +fn runtime_is_send() { + let config = default_runtime_config(); + let identity = generate_identity(&SoftwareCrypto, "runtime").unwrap(); + let (platform, _, _, _) = TestPlatform::new(); + let (runtime, _handle) = new_runtime(identity, platform, config); + let _run: Box + Send> = Box::new(runtime.run()); +} + +#[test] +fn runtime_exits_when_last_handle_drops() { + let config = default_runtime_config(); + let identity = generate_identity(&SoftwareCrypto, "runtime").unwrap(); + let (platform, _, _, _) = TestPlatform::new(); + let (runtime, handle) = new_runtime(identity, platform, config); + let (done_tx, done_rx) = oneshot::channel(); + + std::thread::spawn(move || { + tokio::runtime::Builder::new_current_thread() + .enable_time() + .build() + .unwrap() + .block_on(runtime.run()); + done_tx.send(()).unwrap(); + }); + + drop(handle); + + done_rx + .recv_timeout(Duration::from_secs(1)) + .expect("runtime should stop once the last sender is dropped"); +} diff --git a/ql-runtime/src/tests/rpc.rs b/ql-runtime/src/tests/rpc.rs new file mode 100644 index 0000000..3244c58 --- /dev/null +++ b/ql-runtime/src/tests/rpc.rs @@ -0,0 +1,677 @@ +use std::{ + cell::RefCell, + future::Future, + rc::Rc, + str::Utf8Error, + sync::{Arc, Mutex}, + time::Duration, +}; + +use bytes::Bytes; +use futures_lite::StreamExt; +use ql_rpc::{ + DownloadHandlerLocal, DownloadStart, DuplexHandlerLocal, DuplexPeer, LocalSpawner, + NotificationHandlerLocal, ProgressHandlerLocal, ProgressResponder, RequestHandler, + RequestHandlerLocal, Response, RouteId, SendSpawner, Spawner, StreamCloseCode, + SubscriptionHandlerLocal, SubscriptionResponder, UploadHandlerLocal, UploadReader, + UploadResponder, +}; + +use super::*; +use crate::{rpc::RpcError, QlStream, StreamWriter}; + +#[derive(Debug, Clone, Copy)] +struct TokioLocalSpawner; + +impl Spawner for TokioLocalSpawner { + type Handle = tokio::task::JoinHandle<()>; +} + +impl LocalSpawner for TokioLocalSpawner { + fn spawn(&self, fut: F) -> Self::Handle + where + F: Future + 'static, + { + tokio::task::spawn_local(fut) + } +} + +#[derive(Debug, Clone, Copy)] +struct TokioSendSpawner; + +impl Spawner for TokioSendSpawner { + type Handle = tokio::task::JoinHandle<()>; +} + +impl SendSpawner for TokioSendSpawner { + fn spawn(&self, fut: F) -> Self::Handle + where + F: Future + Send + 'static, + { + tokio::task::spawn(fut) + } +} + +struct Echo; + +impl ql_rpc::Route for Echo { + const ROUTE: RouteId = RouteId::from_u32(51); +} + +impl ql_rpc::request::Request for Echo { + type Error = Utf8Error; + + type Request = String; + type Response = String; +} + +struct Feed; + +impl ql_rpc::Route for Feed { + const ROUTE: RouteId = RouteId::from_u32(52); +} + +impl ql_rpc::subscription::Subscription for Feed { + type Error = core::convert::Infallible; + type Request = Vec; + type Event = Vec; +} + +struct Notice; + +impl ql_rpc::Route for Notice { + const ROUTE: RouteId = RouteId::from_u32(521); +} + +impl ql_rpc::notification::Notification for Notice { + type Error = core::convert::Infallible; + type Payload = Vec; +} + +struct Download; + +impl ql_rpc::Route for Download { + const ROUTE: RouteId = RouteId::from_u32(53); +} + +impl ql_rpc::progress::Progress for Download { + type Error = core::convert::Infallible; + type Request = Vec; + type Progress = Vec; + type Response = Vec; +} + +struct BlobDownload; + +impl ql_rpc::Route for BlobDownload { + const ROUTE: RouteId = RouteId::from_u32(54); +} + +impl ql_rpc::download::Download for BlobDownload { + type Error = core::convert::Infallible; + type Request = Vec; + type ResponseHeader = Vec; + type PartHeader = Vec; +} + +struct BlobUpload; + +impl ql_rpc::Route for BlobUpload { + const ROUTE: RouteId = RouteId::from_u32(55); +} + +impl ql_rpc::upload::Upload for BlobUpload { + type Error = core::convert::Infallible; + type Request = Vec; + type PartHeader = Vec; + type Response = Vec; +} + +struct Chat; + +impl ql_rpc::Route for Chat { + const ROUTE: RouteId = RouteId::from_u32(56); +} + +impl ql_rpc::duplex::Duplex for Chat { + type Error = core::convert::Infallible; + type InitiatorEvent = Vec; + type ResponderEvent = Vec; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_request() { + #[derive(Clone)] + struct RouterState { + seen: Arc>>, + } + + impl RequestHandler for RouterState { + async fn handle(self, request: String, response: Response) { + let seen = self.seen.clone(); + seen.lock().unwrap().push(request); + let _ = response.respond("world".into()).await; + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Arc::new(Mutex::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, TokioSendSpawner>::builder_send(TokioSendSpawner) + .request::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + let fut = assert_send(fut); + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let response = rpc.request::(&"hello".into()).await.unwrap(); + assert_eq!(response, "world"); + assert_eq!(&*seen.lock().unwrap(), &["hello".to_string()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +fn assert_send(value: T) -> T { + value +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_notification() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl NotificationHandlerLocal for RouterState { + async fn handle(self, payload: Vec) { + self.seen.borrow_mut().push(payload); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Rc::new(RefCell::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .notification::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + rpc.notification::(&b"hello".to_vec()) + .await + .unwrap(); + assert_eq!(seen.borrow().as_slice(), &[b"hello".to_vec()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_subscrption() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl SubscriptionHandlerLocal for RouterState { + async fn handle( + self, + request: Vec, + mut response: SubscriptionResponder, StreamWriter>, + ) { + let seen = self.seen.clone(); + seen.borrow_mut().push(request); + let _ = response.send(b"one".to_vec()).await; + let _ = response.send(b"two".to_vec()).await; + let _ = response.finish().await; + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + + let seen = Rc::new(RefCell::new(Vec::new())); + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .subscription::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let mut subscription = rpc.subscribe::(&b"watch".to_vec()).await.unwrap(); + assert_eq!(subscription.next().await.unwrap().unwrap(), b"one".to_vec()); + assert_eq!(subscription.next().await.unwrap().unwrap(), b"two".to_vec()); + assert!(subscription.next().await.is_none()); + assert_eq!(seen.borrow().as_slice(), &[b"watch".to_vec()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_router_enforces_max_request_bytes() { + #[derive(Clone)] + struct LimitedState; + + impl RequestHandlerLocal for LimitedState { + async fn handle(self, request: String, response: Response) { + let _ = response.respond(request).await; + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .max_request_bytes(4) + .request::() + .build(LimitedState); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let response = rpc.request::(&"hello".to_string()).await; + assert!(matches!( + response, + Err(RpcError::Closed(code)) if code == StreamCloseCode::LIMIT + )); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_progress() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl ProgressHandlerLocal for RouterState { + async fn handle( + self, + request: Vec, + mut responder: ProgressResponder, + ) { + let seen = self.seen.clone(); + seen.borrow_mut().push(request); + responder.send(b"10".to_vec()).await.unwrap(); + responder.send(b"90".to_vec()).await.unwrap(); + responder.finish(b"done".to_vec()).await.unwrap(); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Rc::new(RefCell::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .progress::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let mut download = rpc.progress::(&b"logo".to_vec()).await.unwrap(); + + assert_eq!(download.next().await, Some(b"10".to_vec())); + assert_eq!(download.next().await, Some(b"90".to_vec())); + assert_eq!(download.next().await, None); + assert_eq!(download.await.unwrap(), b"done".to_vec()); + assert_eq!(seen.borrow().as_slice(), &[b"logo".to_vec()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_download() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl DownloadHandlerLocal for RouterState { + async fn handle( + self, + request: Vec, + download: DownloadStart, + ) { + let seen = self.seen.clone(); + seen.borrow_mut().push(request); + let mut writer = download.start(b"image/png".to_vec()).await.unwrap(); + let mut part = writer.start_part(b"icon".to_vec()).await.unwrap(); + part.send(Bytes::from_static(b"abc")).await.unwrap(); + part.send(Bytes::from_static(b"def")).await.unwrap(); + part.finish().await.unwrap(); + let mut part = writer.start_part(b"manifest".to_vec()).await.unwrap(); + part.send(Bytes::from_static(b"{}")).await.unwrap(); + part.finish().await.unwrap(); + writer.finish().await.unwrap(); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Rc::new(RefCell::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .download::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let download = rpc + .download::(&b"logo".to_vec()) + .await + .unwrap(); + let (header, mut reader) = download.start().await.unwrap(); + assert_eq!(header, b"image/png".to_vec()); + { + let (part_header, mut part) = reader.next_part().await.unwrap().unwrap(); + assert_eq!(part_header, b"icon".to_vec()); + assert_eq!( + part.read_chunk().await.unwrap(), + Some(Bytes::from_static(b"abc")) + ); + assert_eq!( + part.read_chunk().await.unwrap(), + Some(Bytes::from_static(b"def")) + ); + assert_eq!(part.read_chunk().await.unwrap(), None); + } + { + let (part_header, mut part) = reader.next_part().await.unwrap().unwrap(); + assert_eq!(part_header, b"manifest".to_vec()); + assert_eq!( + part.read_chunk().await.unwrap(), + Some(Bytes::from_static(b"{}")) + ); + assert_eq!(part.read_chunk().await.unwrap(), None); + } + assert!(reader.next_part().await.unwrap().is_none()); + assert_eq!(seen.borrow().as_slice(), &[b"logo".to_vec()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_download_complete() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl DownloadHandlerLocal for RouterState { + async fn handle( + self, + request: Vec, + download: DownloadStart, + ) { + self.seen.borrow_mut().push(request); + download.complete(b"not found".to_vec()).await.unwrap(); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Rc::new(RefCell::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .download::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let download = rpc + .download::(&b"logo".to_vec()) + .await + .unwrap(); + let (header, reader) = download.start().await.unwrap(); + assert_eq!(header, b"not found".to_vec()); + reader.complete().await.unwrap(); + assert_eq!(seen.borrow().as_slice(), &[b"logo".to_vec()]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_upload() { + #[derive(Clone)] + struct RouterState { + requests: Rc>>>, + uploads: Rc>>>, + } + + impl UploadHandlerLocal for RouterState { + async fn handle( + self, + request: Vec, + mut upload: UploadReader, + responder: UploadResponder, StreamWriter>, + ) { + let requests = self.requests.clone(); + let uploads = self.uploads.clone(); + requests.borrow_mut().push(request); + + let mut body = Vec::new(); + while let Some((part_header, mut part)) = upload.next_part().await.unwrap() { + body.extend_from_slice(&part_header); + body.push(b':'); + while let Some(chunk) = part.read_chunk().await.unwrap() { + body.extend_from_slice(&chunk); + } + body.push(b';'); + } + uploads.borrow_mut().push(body.clone()); + + responder.respond(body).await.unwrap(); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let requests = Rc::new(RefCell::new(Vec::new())); + let uploads = Rc::new(RefCell::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .upload::() + .build(RouterState { + requests: requests.clone(), + uploads: uploads.clone(), + }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let mut upload = rpc.upload::(&b"logo".to_vec()).await.unwrap(); + let mut part = upload.start_part(b"icon".to_vec()).await.unwrap(); + part.send(Bytes::from_static(b"abc")).await.unwrap(); + part.send(Bytes::from_static(b"def")).await.unwrap(); + part.finish().await.unwrap(); + let mut part = upload.start_part(b"manifest".to_vec()).await.unwrap(); + part.send(Bytes::from_static(b"{}")).await.unwrap(); + part.finish().await.unwrap(); + let response = upload.finish().await.unwrap(); + + assert_eq!(response, b"icon:abcdef;manifest:{};".to_vec()); + assert_eq!(requests.borrow().as_slice(), &[b"logo".to_vec()]); + assert_eq!( + uploads.borrow().as_slice(), + &[b"icon:abcdef;manifest:{};".to_vec()] + ); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn rpc_duplex() { + #[derive(Clone)] + struct RouterState { + seen: Rc>>>, + } + + impl DuplexHandlerLocal for RouterState { + async fn handle(self, mut peer: DuplexPeer) { + let seen = self.seen.clone(); + let first = peer.receiver.next_event().await.unwrap().unwrap(); + seen.borrow_mut().push(first); + + peer.sender + .send(&b"challenge-response".to_vec()) + .await + .unwrap(); + + let second = peer.receiver.next_event().await.unwrap().unwrap(); + seen.borrow_mut().push(second); + + peer.sender.finish().await.unwrap(); + } + } + + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let seen = Rc::new(RefCell::new(Vec::new())); + + let router = + ql_rpc::Router::<_, QlStream, TokioLocalSpawner>::builder_local(TokioLocalSpawner) + .duplex::() + .build(RouterState { seen: seen.clone() }); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + if let Some((_, fut)) = router.handle(inbound) { + fut.await.unwrap(); + } + }); + + let rpc = pair.side_mut(Side::A).handle.rpc(); + let mut chat = rpc.duplex::().await.unwrap(); + chat.sender.send(&b"challenge".to_vec()).await.unwrap(); + assert_eq!( + chat.receiver.next_event().await.unwrap().unwrap(), + b"challenge-response".to_vec() + ); + chat.sender.send(&b"verification".to_vec()).await.unwrap(); + chat.sender.finish().await.unwrap(); + assert!(chat.receiver.next_event().await.is_none()); + + assert_eq!( + seen.borrow().as_slice(), + &[b"challenge".to_vec(), b"verification".to_vec()] + ); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} diff --git a/ql-runtime/src/tests/session.rs b/ql-runtime/src/tests/session.rs new file mode 100644 index 0000000..ec351e3 --- /dev/null +++ b/ql-runtime/src/tests/session.rs @@ -0,0 +1,213 @@ +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Duration, +}; + +use bytes::Bytes; +use ql_wire::SessionCloseCode; + +use super::*; +use crate::QlStreamError; + +#[tokio::test(flavor = "current_thread")] +async fn close_session_aborts_active_streams_and_allows_reconnect() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + let inbound_b = pair.take_inbound(Side::B); + let (received_tx, received_rx) = async_channel::bounded(1); + pair.connect_and_wait(Side::A).await; + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let mut reader = stream.reader; + + assert_eq!( + next_chunk(&mut reader).await.unwrap(), + Some(vec![1, 2, 3, 4]) + ); + received_tx.send(()).await.unwrap(); + + let err = next_chunk(&mut reader).await.unwrap_err(); + assert_eq!(err, QlStreamError::NoSession); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream + .writer + .write(Bytes::from_static(&[1, 2, 3, 4])) + .await + .unwrap(); + received_rx.recv().await.unwrap(); + + pair.side(Side::A) + .handle + .close_session(SessionCloseCode::CANCELLED); + + let err = stream.writer.finish().await.unwrap_err(); + assert_eq!(err, QlStreamError::NoSession); + + await_status( + &pair.side(Side::A).status, + Some(pair.side(Side::B).peer), + PeerStatus::Disconnected, + ) + .await; + await_status( + &pair.side(Side::B).status, + Some(pair.side(Side::A).peer), + PeerStatus::Disconnected, + ) + .await; + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + + pair.connect_and_wait(Side::A).await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn unpair_aborts_active_streams_and_prevents_reconnect() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + let inbound_b = pair.take_inbound(Side::B); + let (received_tx, received_rx) = async_channel::bounded(1); + pair.connect_and_wait(Side::A).await; + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let mut reader = stream.reader; + + assert_eq!( + next_chunk(&mut reader).await.unwrap(), + Some(vec![5, 6, 7, 8]) + ); + received_tx.send(()).await.unwrap(); + + let err = next_chunk(&mut reader).await.unwrap_err(); + assert_eq!(err, QlStreamError::NoSession); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream + .writer + .write(Bytes::from_static(&[5, 6, 7, 8])) + .await + .unwrap(); + received_rx.recv().await.unwrap(); + + pair.side(Side::A).handle.unpair(); + + let err = stream.writer.finish().await.unwrap_err(); + assert_eq!(err, QlStreamError::NoSession); + + await_status(&pair.side(Side::A).status, None, PeerStatus::Unpaired).await; + await_status(&pair.side(Side::B).status, None, PeerStatus::Unpaired).await; + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + + assert!(matches!( + pair.side(Side::A).handle.open_stream(test_route_id()).await, + Err(NoSessionError) + )); + assert!(matches!( + pair.side(Side::B).handle.open_stream(test_route_id()).await, + Err(NoSessionError) + )); + + pair.side(Side::B).handle.connect(); + assert_no_status_for( + &pair.side(Side::B).status, + None, + PeerStatus::Initiator, + Duration::from_millis(150), + ) + .await; + assert_no_status_for( + &pair.side(Side::B).status, + None, + PeerStatus::Connected, + Duration::from_millis(150), + ) + .await; + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn session_timeout_disconnects_and_fails_pending_open() { + run_local_test(async { + let config_a = RuntimeConfig { + fsm: QlFsmConfig { + session_keepalive_interval: Duration::from_millis(40), + session_peer_timeout: Duration::from_millis(60), + ..default_runtime_config().fsm + }, + ..default_runtime_config() + }; + let config_b = default_runtime_config(); + let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config_a); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config_b); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + let drop_flag = Arc::new(AtomicBool::new(false)); + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_gated_forwarder(outbound_b, inbound_a_tx, drop_flag.clone()); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect(); + + await_status(&status_a, Some(identity_b.qid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.qid), PeerStatus::Connected).await; + + let responder_task = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let _ = read_all(stream.reader).await; + let err = stream.writer.finish().await.unwrap_err(); + assert!(matches!(err, QlStreamError::NoSession)); + }); + + drop_flag.store(true, Ordering::Relaxed); + + let mut pending = handle_a.open_stream(test_route_id()).await.unwrap(); + let err = pending.writer.finish().await.unwrap_err(); + assert!(matches!(err, QlStreamError::NoSession)); + + await_status(&status_a, Some(identity_b.qid), PeerStatus::Disconnected).await; + + let result = + tokio::time::timeout(Duration::from_millis(300), next_chunk(&mut pending.reader)) + .await + .unwrap(); + assert!(matches!(result, Err(QlStreamError::NoSession))); + + responder_task.abort(); + }) + .await; +} diff --git a/ql-runtime/src/tests/stream.rs b/ql-runtime/src/tests/stream.rs new file mode 100644 index 0000000..176711c --- /dev/null +++ b/ql-runtime/src/tests/stream.rs @@ -0,0 +1,673 @@ +use std::time::Duration; + +use bytes::Bytes; +use ql_wire::StreamCloseCode; + +use super::*; +use crate::QlStreamError; + +#[tokio::test(flavor = "current_thread")] +async fn open_stream_duplex_happy_path() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + + let mut writer = inbound.writer; + let mut reader = inbound.reader; + + assert_eq!(next_chunk(&mut reader).await.unwrap(), Some(vec![1, 2])); + writer.write(Bytes::from_static(&[9])).await.unwrap(); + assert_eq!(next_chunk(&mut reader).await.unwrap(), Some(vec![3, 4])); + writer.write(Bytes::from_static(&[8, 7])).await.unwrap(); + assert_eq!(next_chunk(&mut reader).await.unwrap(), None); + writer.finish().await.unwrap(); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream + .writer + .write(Bytes::from_static(&[1, 2])) + .await + .unwrap(); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), Some(vec![9])); + stream + .writer + .write(Bytes::from_static(&[3, 4])) + .await + .unwrap(); + stream.writer.finish().await.unwrap(); + assert_eq!( + next_chunk(&mut stream.reader).await.unwrap(), + Some(vec![8, 7]) + ); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn reader_respects_max_len() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + let mut reader = inbound.reader; + + assert_eq!( + next_chunk_max(&mut reader, 2).await.unwrap(), + Some(vec![1, 2]) + ); + assert_eq!( + next_chunk_max(&mut reader, 2).await.unwrap(), + Some(vec![3, 4]) + ); + assert_eq!( + next_chunk_max(&mut reader, 2).await.unwrap(), + Some(vec![5, 6]) + ); + assert_eq!(next_chunk(&mut reader).await.unwrap(), None); + + inbound.writer.finish().await.unwrap(); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream + .writer + .write(Bytes::from_static(&[1, 2, 3, 4, 5, 6])) + .await + .unwrap(); + stream.writer.finish().await.unwrap(); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn large_stream_payload_round_trips() { + run_local_test(async { + let payload: Vec = (0..40).collect(); + let mut pair = TestPair::new(default_runtime_config()); + let (done_tx, done_rx) = async_channel::bounded(1); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let request_data = read_all(stream.reader).await.unwrap(); + stream.writer.finish().await.unwrap(); + done_tx.send(request_data).await.unwrap(); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream + .writer + .write(Bytes::from(payload.clone())) + .await + .unwrap(); + stream.writer.finish().await.unwrap(); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); + + let received = tokio::time::timeout(Duration::from_secs(2), done_rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(received, payload); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn dropping_responder_closes_initiator_response() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + drop(stream.reader); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + let err = stream.writer.finish().await.unwrap_err(); + assert!(matches!( + err, + QlStreamError::StreamClosed { code } if code == StreamCloseCode::CANCELLED + )); + + let err = next_chunk(&mut stream.reader).await.unwrap_err(); + assert!(matches!( + err, + QlStreamError::StreamClosed { code } if code == StreamCloseCode::CANCELLED + )); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn dropping_inbound_reader_cancels_remote_writer() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + let inbound_b = pair.take_inbound(Side::B); + let (go_tx, go_rx) = async_channel::bounded(1); + pair.connect_and_wait(Side::A).await; + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let mut writer = stream.writer; + let mut reader = stream.reader; + assert_eq!(next_chunk(&mut reader).await.unwrap(), None); + writer + .write(Bytes::from_static(&[1, 2, 3, 4])) + .await + .unwrap(); + go_rx.recv().await.unwrap(); + let _ = writer.write(Bytes::from(vec![5; 64])).await; + let err = writer.finish().await.unwrap_err(); + assert!(matches!( + err, + QlStreamError::StreamClosed { code } if code == StreamCloseCode::CANCELLED + )); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream.writer.finish().await.unwrap(); + assert_eq!( + next_chunk(&mut stream.reader).await.unwrap(), + Some(vec![1, 2, 3, 4]) + ); + drop(stream.reader); + go_tx.send(()).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn closing_initiator_reader_preserves_initiator_writer() { + run_local_test(async { + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + let (done_tx, done_rx) = async_channel::bounded(1); + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let request = read_all(stream.reader).await.unwrap(); + done_tx.send(request).await.unwrap(); + }); + + let stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + let mut writer = stream.writer; + stream.reader.close(StreamCloseCode::CANCELLED); + + writer.write(Bytes::from_static(&[1, 2])).await.unwrap(); + writer.write(Bytes::from_static(&[3, 4])).await.unwrap(); + writer.finish().await.unwrap(); + + let request = tokio::time::timeout(Duration::from_secs(2), done_rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(request, vec![1, 2, 3, 4]); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn max_concurrent_message_writes_is_respected() { + run_local_test(async { + let stats = WriteStats::new(); + let config = RuntimeConfig { + max_concurrent_message_writes: 2, + ..default_runtime_config() + }; + let (platform_a, outbound_a, inbound_a_tx, status_a) = + TestPlatform::new_with_delayed_writes(Duration::from_millis(40), stats.clone()); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_forwarder(outbound_a, inbound_b_tx); + spawn_forwarder(outbound_b, inbound_a_tx); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect(); + + await_status(&status_a, Some(identity_b.qid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.qid), PeerStatus::Connected).await; + + let responder = tokio::task::spawn_local(async move { + for _ in 0..4 { + let stream = inbound_b.recv().await.unwrap(); + let _ = read_all(stream.reader).await; + let mut writer = stream.writer; + writer.queue_finish(); + } + }); + + let mut tasks = Vec::new(); + for i in 0..4u8 { + let handle = handle_a.clone(); + tasks.push(tokio::task::spawn_local(async move { + let mut stream = handle.open_stream(test_route_id()).await.unwrap(); + stream.writer.write(Bytes::from(vec![i; 8])).await.unwrap(); + stream.writer.finish().await.unwrap(); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); + })); + } + + for task in tasks { + tokio::time::timeout(Duration::from_secs(4), task) + .await + .unwrap() + .unwrap(); + } + + tokio::time::timeout(Duration::from_secs(4), responder) + .await + .unwrap() + .unwrap(); + + assert!( + stats.max_active() <= 2, + "max active writes exceeded: {}", + stats.max_active() + ); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn stream_round_trip_survives_encrypted_packet_drops() { + run_local_test(async { + let config = RuntimeConfig { + fsm: QlFsmConfig { + session_record_retransmit_timeout: Duration::from_millis(20), + ..default_runtime_config().fsm + }, + ..default_runtime_config() + }; + let (platform_a, outbound_a, inbound_a_tx, status_a) = TestPlatform::new(); + let (platform_b, outbound_b, inbound_b_tx, status_b, inbound_b) = + TestPlatform::new_with_inbound(); + let (identity_a, identity_b) = test_identities(&SoftwareCrypto); + + let request_payload: Vec = (0..32).collect(); + let response_payload: Vec = (100..132).collect(); + let expected_response = response_payload.clone(); + + let (runtime_a, handle_a) = new_runtime(identity_a.clone(), platform_a, config); + let (runtime_b, handle_b) = new_runtime(identity_b.clone(), platform_b, config); + + tokio::task::spawn_local(async move { runtime_a.run().await }); + tokio::task::spawn_local(async move { runtime_b.run().await }); + + spawn_drop_every_nth_encrypted_forwarder(outbound_a, inbound_b_tx, 3); + spawn_drop_every_nth_encrypted_forwarder(outbound_b, inbound_a_tx, 3); + + register_peers(&handle_a, &handle_b, &identity_a, &identity_b); + handle_a.connect(); + + await_status(&status_a, Some(identity_b.qid), PeerStatus::Connected).await; + await_status(&status_b, Some(identity_a.qid), PeerStatus::Connected).await; + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let received_request = read_all(stream.reader).await.unwrap(); + let mut writer = stream.writer; + writer + .write(Bytes::from(response_payload.clone())) + .await + .unwrap(); + writer.finish().await.unwrap(); + received_request + }); + + let mut stream = handle_a.open_stream(test_route_id()).await.unwrap(); + stream + .writer + .write(Bytes::from(request_payload.clone())) + .await + .unwrap(); + stream.writer.finish().await.unwrap(); + + let mut received_response = Vec::new(); + while let Some(chunk) = next_chunk(&mut stream.reader).await.unwrap() { + received_response.extend_from_slice(&chunk); + } + assert_eq!(received_response, expected_response); + + let received_request = tokio::time::timeout(Duration::from_secs(4), responder) + .await + .unwrap() + .unwrap(); + assert_eq!(received_request, request_payload); + }) + .await; +} + +#[allow(clippy::too_many_lines)] +#[tokio::test(flavor = "current_thread")] +async fn multi_megabyte_stream_survives_asymmetric_loss_and_delay() { + run_local_test_timeout(Duration::from_secs(10), async { + let payload_len = 2 * 1024 * 1024; + let chunk_len = 16 * 1024; + let payload: Vec = (0..payload_len) + .map(|i| u8::try_from(i % 251).unwrap()) + .collect(); + let expected = payload.clone(); + let config = RuntimeConfig { + fsm: QlFsmConfig { + session_record_max_size: 16 * 1024, + session_record_ack_delay: Duration::from_millis(2), + session_record_retransmit_timeout: Duration::from_millis(25), + session_stream_send_buffer_size: 4 * 1024 * 1024, + session_stream_receive_buffer_size: 4 * 1024 * 1024, + session_accepted_record_window: 16 * 1024, + session_pending_ack_range_limit: 4 * 1024, + ..default_runtime_config().fsm + }, + ..default_runtime_config() + }; + let (mut pair, links) = TestPair::new_with_controlled_links( + config, + LinkBehavior { + base_delay: Duration::from_millis(1), + drop_encrypted_every: Some(41), + delay_encrypted_every: Some((13, Duration::from_millis(12))), + ..LinkBehavior::default() + }, + LinkBehavior { + base_delay: Duration::from_millis(1), + ..LinkBehavior::default() + }, + ); + pair.connect_and_wait(Side::A).await; + links.b_to_a.store(LinkBehavior { + base_delay: Duration::from_millis(3), + drop_encrypted_every: Some(7), + duplicate_encrypted_every: Some(19), + delay_encrypted_every: Some((3, Duration::from_millis(25))), + }); + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + eprintln!("responder accepted inbound stream"); + let mut reader = stream.reader; + let mut received = Vec::new(); + while let Some(chunk) = next_chunk(&mut reader).await.unwrap() { + if received.len() >= 36 * chunk_len { + eprintln!("responder received chunk of {} bytes", chunk.len()); + } + received.extend_from_slice(&chunk); + if received.len() % (256 * 1024) == 0 { + eprintln!("responder received {} bytes", received.len()); + } + } + stream.writer.finish().await.unwrap(); + received + }); + + let recovery_links = links.clone(); + let recovery = tokio::task::spawn_local(async move { + tokio::time::sleep(Duration::from_millis(300)).await; + eprintln!("restoring reverse path"); + recovery_links.b_to_a.store(LinkBehavior { + base_delay: Duration::from_millis(1), + delay_encrypted_every: Some((17, Duration::from_millis(8))), + ..LinkBehavior::default() + }); + }); + + let writer = tokio::task::spawn_local(async move { + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + for (index, chunk) in payload.chunks(chunk_len).enumerate() { + if index + 1 >= 40 { + eprintln!("writer attempting chunk {}", index + 1); + } + stream + .writer + .write(Bytes::copy_from_slice(chunk)) + .await + .unwrap(); + if index + 1 >= 40 { + eprintln!("writer queued chunk {}", index + 1); + } + if index % 16 == 15 { + eprintln!("writer queued {} chunks", index + 1); + } + } + eprintln!("writer finished queueing"); + stream.writer.finish().await.unwrap(); + eprintln!("writer waiting for eof"); + assert_eq!(next_chunk(&mut stream.reader).await.unwrap(), None); + eprintln!("writer observed eof"); + }); + + tokio::time::timeout(Duration::from_secs(30), writer) + .await + .unwrap() + .unwrap(); + tokio::time::timeout(Duration::from_secs(2), recovery) + .await + .unwrap() + .unwrap(); + let received = tokio::time::timeout(Duration::from_secs(30), responder) + .await + .unwrap() + .unwrap(); + assert_eq!(received, expected); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn reproducer_writer_stalls_after_reverse_path_impairment() { + run_local_test_timeout(Duration::from_secs(10), async { + let payload_len = 2 * 1024 * 1024; + let chunk_len = 16 * 1024; + let payload: Vec = (0..payload_len) + .map(|i| u8::try_from(i % 251).unwrap()) + .collect(); + let config = RuntimeConfig { + fsm: QlFsmConfig { + session_record_max_size: 16 * 1024, + session_record_ack_delay: Duration::from_millis(2), + session_record_retransmit_timeout: Duration::from_millis(25), + session_stream_send_buffer_size: 4 * 1024 * 1024, + session_stream_receive_buffer_size: 4 * 1024 * 1024, + session_accepted_record_window: 16 * 1024, + session_pending_ack_range_limit: 4 * 1024, + ..default_runtime_config().fsm + }, + ..default_runtime_config() + }; + let (mut pair, links) = TestPair::new_with_controlled_links( + config, + LinkBehavior { + base_delay: Duration::from_millis(1), + drop_encrypted_every: Some(41), + delay_encrypted_every: Some((13, Duration::from_millis(12))), + ..LinkBehavior::default() + }, + LinkBehavior { + base_delay: Duration::from_millis(1), + ..LinkBehavior::default() + }, + ); + pair.connect_and_wait(Side::A).await; + links.b_to_a.store(LinkBehavior { + base_delay: Duration::from_millis(3), + drop_encrypted_every: Some(7), + duplicate_encrypted_every: Some(19), + delay_encrypted_every: Some((3, Duration::from_millis(25))), + }); + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let stream = inbound_b.recv().await.unwrap(); + let mut reader = stream.reader; + while next_chunk(&mut reader).await.unwrap().is_some() {} + }); + + let recovery_links = links.clone(); + let recovery = tokio::task::spawn_local(async move { + tokio::time::sleep(Duration::from_millis(300)).await; + recovery_links.b_to_a.store(LinkBehavior { + base_delay: Duration::from_millis(1), + delay_encrypted_every: Some((17, Duration::from_millis(8))), + ..LinkBehavior::default() + }); + }); + + let writer = tokio::task::spawn_local(async move { + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + for chunk in payload.chunks(chunk_len) { + stream + .writer + .write(Bytes::copy_from_slice(chunk)) + .await + .unwrap(); + } + stream.writer.queue_finish(); + let _ = next_chunk(&mut stream.reader).await; + }); + + let _ = tokio::time::timeout(Duration::from_secs(15), writer).await; + recovery.abort(); + responder.abort(); + }) + .await; +} + +#[tokio::test(flavor = "current_thread")] +async fn responder_drains_multiple_local_chunks_per_writable_wake() { + run_local_test(async { + let chunk_len = 4104usize; + let chunk_count = 5usize; + let expected = vec![0x5a; chunk_len * chunk_count]; + let mut pair = TestPair::new(default_runtime_config()); + pair.connect_and_wait(Side::A).await; + let inbound_b = pair.take_inbound(Side::B); + + let responder = tokio::task::spawn_local(async move { + let inbound = inbound_b.recv().await.unwrap(); + let _ = read_all(inbound.reader).await.unwrap(); + + let mut writer = inbound.writer; + for _ in 0..chunk_count { + writer + .write(Bytes::from(vec![0x5a; chunk_len])) + .await + .unwrap(); + } + writer.finish().await.unwrap(); + }); + + let mut stream = pair + .side(Side::A) + .handle + .open_stream(test_route_id()) + .await + .unwrap(); + stream + .writer + .write(Bytes::from_static(b"request")) + .await + .unwrap(); + stream.writer.finish().await.unwrap(); + + let received = read_all(stream.reader).await.unwrap(); + assert_eq!(received, expected); + + tokio::time::timeout(Duration::from_secs(2), responder) + .await + .unwrap() + .unwrap(); + }) + .await; +} From 6c0bb2f7e89d181adc5b911a8d213e710897fbc4 Mon Sep 17 00:00:00 2001 From: Nico Burniske Date: Thu, 4 Jun 2026 09:38:37 -0400 Subject: [PATCH 6/6] workspace: remove legacy QLV1 crates --- Cargo.lock | 1845 +---------------- Cargo.toml | 17 - Justfile | 12 - README.md | 11 +- api/.gitignore | 2 - api/Cargo.toml | 31 - api/src/api/backup.rs | 264 --- api/src/api/bitcoin.rs | 32 - api/src/api/firmware.rs | 116 -- api/src/api/fx.rs | 34 - api/src/api/message.rs | 149 -- api/src/api/mod.rs | 13 - api/src/api/onboarding.rs | 47 - api/src/api/pairing.rs | 39 - api/src/api/passport.rs | 25 - api/src/api/quantum_link.rs | 431 ---- api/src/api/scv.rs | 52 - api/src/api/status.rs | 35 - api/src/api/tests.rs | 414 ---- api/src/lib.rs | 11 - api/tests/golden_tests.rs | 550 ----- .../golden_tests__golden_account_update.snap | 5 - ...n_tests__golden_apply_passphrase_none.snap | 5 - ...n_tests__golden_apply_passphrase_some.snap | 5 - ...en_tests__golden_backup_shard_request.snap | 5 - ...s__golden_backup_shard_response_error.snap | 5 - ..._golden_backup_shard_response_success.snap | 5 - ...n_tests__golden_broadcast_transaction.snap | 5 - ...olden_create_magic_backup_event_chunk.snap | 5 - ...olden_create_magic_backup_event_start.snap | 5 - ...lden_create_magic_backup_result_error.snap | 5 - ...en_create_magic_backup_result_success.snap | 5 - .../golden_tests__golden_device_status.snap | 5 - ..._tests__golden_device_status_updating.snap | 5 - ...en_envoy_magic_backup_enabled_request.snap | 5 - ...n_envoy_magic_backup_enabled_response.snap | 5 - .../golden_tests__golden_envoy_status.snap | 5 - .../golden_tests__golden_exchange_rate.snap | 5 - ...n_tests__golden_exchange_rate_history.snap | 5 - ...ts__golden_firmware_fetch_event_chunk.snap | 5 - ...lden_firmware_fetch_event_downloading.snap | 5 - ...ts__golden_firmware_fetch_event_error.snap | 5 - ...en_firmware_fetch_event_not_available.snap | 5 - ..._golden_firmware_fetch_event_starting.snap | 5 - ..._tests__golden_firmware_fetch_request.snap | 5 - ..._golden_firmware_update_check_request.snap | 5 - ...mware_update_check_response_available.snap | 5 - ...e_update_check_response_not_available.snap | 5 - ...__golden_firmware_update_result_error.snap | 5 - ..._firmware_update_result_error_install.snap | 5 - ...n_firmware_update_result_error_verify.snap | 5 - ...den_firmware_update_result_installing.snap | 5 - ...lden_firmware_update_result_rebooting.snap | 5 - ...golden_firmware_update_result_success.snap | 5 - ...irmware_update_result_update_verified.snap | 5 - .../golden_tests__golden_heartbeat.snap | 5 - ...ts__golden_onboarding_state_completed.snap | 5 - ...boarding_state_firmware_update_screen.snap | 5 - .../golden_tests__golden_pairing_request.snap | 5 - ...golden_tests__golden_pairing_response.snap | 6 - ...ts__golden_prime_magic_backup_enabled.snap | 5 - ...den_prime_magic_backup_status_request.snap | 5 - ...en_prime_magic_backup_status_response.snap | 5 - .../golden_tests__golden_raw_data.snap | 5 - ...lden_restore_magic_backup_event_chunk.snap | 5 - ...lden_restore_magic_backup_event_error.snap | 5 - ..._restore_magic_backup_event_no_backup.snap | 5 - ...n_restore_magic_backup_event_starting.snap | 5 - ...__golden_restore_magic_backup_request.snap | 5 - ...den_restore_magic_backup_result_error.snap | 5 - ...n_restore_magic_backup_result_success.snap | 5 - ...n_tests__golden_restore_shard_request.snap | 5 - ...__golden_restore_shard_response_error.snap | 5 - ...lden_restore_shard_response_not_found.snap | 5 - ...golden_restore_shard_response_success.snap | 5 - ...lden_security_check_challenge_request.snap | 5 - ...curity_check_challenge_response_error.snap | 5 - ...rity_check_challenge_response_success.snap | 5 - ...den_security_check_verification_error.snap | 5 - ...n_security_check_verification_success.snap | 5 - .../golden_tests__golden_sign_psbt.snap | 5 - quantum-link-macros/Cargo.toml | 14 - quantum-link-macros/src/lib.rs | 632 ------ 83 files changed, 45 insertions(+), 5032 deletions(-) delete mode 100644 api/.gitignore delete mode 100644 api/Cargo.toml delete mode 100644 api/src/api/backup.rs delete mode 100644 api/src/api/bitcoin.rs delete mode 100644 api/src/api/firmware.rs delete mode 100644 api/src/api/fx.rs delete mode 100644 api/src/api/message.rs delete mode 100644 api/src/api/mod.rs delete mode 100644 api/src/api/onboarding.rs delete mode 100644 api/src/api/pairing.rs delete mode 100644 api/src/api/passport.rs delete mode 100644 api/src/api/quantum_link.rs delete mode 100644 api/src/api/scv.rs delete mode 100644 api/src/api/status.rs delete mode 100644 api/src/api/tests.rs delete mode 100644 api/src/lib.rs delete mode 100644 api/tests/golden_tests.rs delete mode 100644 api/tests/snapshots/golden_tests__golden_account_update.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_apply_passphrase_none.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_apply_passphrase_some.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_backup_shard_request.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_backup_shard_response_error.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_backup_shard_response_success.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_broadcast_transaction.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_create_magic_backup_event_chunk.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_create_magic_backup_event_start.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_create_magic_backup_result_error.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_create_magic_backup_result_success.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_device_status.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_device_status_updating.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_envoy_magic_backup_enabled_request.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_envoy_magic_backup_enabled_response.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_envoy_status.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_exchange_rate.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_exchange_rate_history.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_firmware_fetch_event_chunk.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_firmware_fetch_event_downloading.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_firmware_fetch_event_error.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_firmware_fetch_event_not_available.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_firmware_fetch_event_starting.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_firmware_fetch_request.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_firmware_update_check_request.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_firmware_update_check_response_available.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_firmware_update_check_response_not_available.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_firmware_update_result_error.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_firmware_update_result_error_install.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_firmware_update_result_error_verify.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_firmware_update_result_installing.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_firmware_update_result_rebooting.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_firmware_update_result_success.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_firmware_update_result_update_verified.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_heartbeat.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_onboarding_state_completed.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_onboarding_state_firmware_update_screen.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_pairing_request.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_pairing_response.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_prime_magic_backup_enabled.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_prime_magic_backup_status_request.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_prime_magic_backup_status_response.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_raw_data.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_restore_magic_backup_event_chunk.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_restore_magic_backup_event_error.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_restore_magic_backup_event_no_backup.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_restore_magic_backup_event_starting.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_restore_magic_backup_request.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_restore_magic_backup_result_error.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_restore_magic_backup_result_success.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_restore_shard_request.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_restore_shard_response_error.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_restore_shard_response_not_found.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_restore_shard_response_success.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_security_check_challenge_request.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_security_check_challenge_response_error.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_security_check_challenge_response_success.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_security_check_verification_error.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_security_check_verification_success.snap delete mode 100644 api/tests/snapshots/golden_tests__golden_sign_psbt.snap delete mode 100644 quantum-link-macros/Cargo.toml delete mode 100644 quantum-link-macros/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 123d0e5..c1e30d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,28 +11,12 @@ dependencies = [ "gimli", ] -[[package]] -name = "adler" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" - [[package]] name = "adler2" version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" -[[package]] -name = "aead" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" -dependencies = [ - "crypto-common", - "generic-array", -] - [[package]] name = "aho-corasick" version = "1.1.3" @@ -42,40 +26,12 @@ dependencies = [ "memchr", ] -[[package]] -name = "allo-isolate" -version = "0.1.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "449e356a4864c017286dbbec0e12767ea07efba29e3b7d984194c2a7ff3c4550" -dependencies = [ - "anyhow", - "atomic", - "backtrace", -] - [[package]] name = "android-tzdata" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" -[[package]] -name = "android_log-sys" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84521a3cf562bc62942e294181d9eef17eb38ceb8c68677bc49f144e4c3d4f8d" - -[[package]] -name = "android_logger" -version = "0.15.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbb4e440d04be07da1f1bf44fb4495ebd58669372fe0cffa6e48595ac5bd88a3" -dependencies = [ - "android_log-sys", - "env_filter 0.1.3", - "log", -] - [[package]] name = "android_system_properties" version = "0.1.5" @@ -135,30 +91,6 @@ dependencies = [ "windows-sys 0.61.2", ] -[[package]] -name = "anyhow" -version = "1.0.99" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100" - -[[package]] -name = "argon2" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072" -dependencies = [ - "base64ct", - "blake2", - "cpufeatures", - "password-hash", -] - -[[package]] -name = "arrayvec" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" - [[package]] name = "async-channel" version = "2.5.0" @@ -171,12 +103,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "atomic" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c59bdb34bc650a32731b31bd8f0829cc15d24a708ee31559e0bb34f2bc320cba" - [[package]] name = "autocfg" version = "1.5.0" @@ -192,7 +118,7 @@ dependencies = [ "addr2line", "cfg-if", "libc", - "miniz_oxide 0.8.9", + "miniz_oxide", "object", "rustc-demangle", "windows-targets", @@ -208,153 +134,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "base16ct" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" - -[[package]] -name = "base64" -version = "0.22.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" - -[[package]] -name = "base64ct" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" - -[[package]] -name = "bc-components" -version = "0.28.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64fc6326f9838e1332cb767fba7ce6a31fa8f14912f13dd42427125263722b9f" -dependencies = [ - "bc-crypto", - "bc-rand", - "bc-tags", - "bc-ur", - "dcbor", - "hex", - "miniz_oxide 0.7.4", - "pqcrypto-mldsa", - "pqcrypto-mlkem", - "pqcrypto-traits", - "rand_core 0.6.4", - "ssh-key", - "sskr", - "thiserror", - "url", - "zeroize", -] - -[[package]] -name = "bc-crypto" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9644245d48f4ab1bfa8c7eebfbd20d2bea7895e220de766f66876c5a71b14712" -dependencies = [ - "argon2", - "bc-rand", - "chacha20poly1305", - "crc32fast", - "ed25519-dalek", - "hex", - "hkdf", - "hmac", - "pbkdf2", - "rand 0.8.5", - "scrypt", - "secp256k1", - "sha2", - "thiserror", - "x25519-dalek", -] - -[[package]] -name = "bc-envelope" -version = "0.37.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "515acbccbbbc35f5ac024b890fdeec084607c73f4f39c0fb231a356823ef272c" -dependencies = [ - "bc-components", - "bc-crypto", - "bc-rand", - "bc-ur", - "bytes", - "dcbor", - "hex", - "itertools", - "known-values", - "paste", - "ssh-key", - "thiserror", -] - -[[package]] -name = "bc-rand" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdade83e92b8dfb9acbccd68e09f9e9555dbaf64c7bc2e5fbb894fcc9b53b413" -dependencies = [ - "getrandom 0.2.16", - "lazy_static", - "num-traits", - "rand 0.8.5", - "rand_core 0.6.4", - "rand_xoshiro", -] - -[[package]] -name = "bc-shamir" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cc9e00fbb348a889d951b0a57b04cb609ebd5b123231a6a7c18b4d057825823" -dependencies = [ - "bc-crypto", - "bc-rand", - "thiserror", -] - -[[package]] -name = "bc-tags" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "947dee941635701788b56a6557cf9f89b9750bcc6aae76cf10863759b9964c4a" -dependencies = [ - "dcbor", - "paste", -] - -[[package]] -name = "bc-ur" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac0af650d34ec93be355e81f22df87b108a419d7bc775a5ee1fb00f5daeb9376" -dependencies = [ - "dcbor", - "thiserror", - "ur", -] - -[[package]] -name = "bc-xid" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2206b65d39a6057ad75301cb691a6ea0e2f09556551936640b436c6f269f260e" -dependencies = [ - "bc-components", - "bc-envelope", - "bc-rand", - "bc-ur", - "dcbor", - "hex", - "provenance-mark", - "thiserror", -] - [[package]] name = "bit-set" version = "0.8.0" @@ -370,52 +149,12 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" -[[package]] -name = "bitcoin-io" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b47c4ab7a93edb0c7198c5535ed9b52b63095f4e9b45279c6736cec4b856baf" - -[[package]] -name = "bitcoin-private" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73290177011694f38ec25e165d0387ab7ea749a4b81cd4c80dae5988229f7a57" - -[[package]] -name = "bitcoin_hashes" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d7066118b13d4b20b23645932dfb3a81ce7e29f95726c2036fa33cd7b092501" -dependencies = [ - "bitcoin-private", -] - -[[package]] -name = "bitcoin_hashes" -version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb18c03d0db0247e147a21a6faafd5a7eb851c743db062de72018b6b7e8e4d16" -dependencies = [ - "bitcoin-io", - "hex-conservative", -] - [[package]] name = "bitflags" version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" -[[package]] -name = "blake2" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" -dependencies = [ - "digest", -] - [[package]] name = "block-buffer" version = "0.10.4" @@ -432,16 +171,10 @@ dependencies = [ "bytemuck", "consts", "getrandom 0.2.16", - "rand 0.9.2", + "rand", "thiserror", ] -[[package]] -name = "build-target" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "832133bbabbbaa9fbdba793456a2827627a7d2b8fb96032fa1e7666d7895832b" - [[package]] name = "bumpalo" version = "3.19.0" @@ -468,7 +201,7 @@ checksum = "89385e82b5d1821d2219e0b095efa2cc1f246cbf99080f3be46a1a85c0d392d9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn", ] [[package]] @@ -488,15 +221,9 @@ checksum = "4f154e572231cb6ba2bd1176980827e3d5dc04cc183a75dea38109fbdd672d29" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn", ] -[[package]] -name = "byteorder" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" - [[package]] name = "bytes" version = "1.10.1" @@ -509,8 +236,6 @@ version = "1.2.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42bc4aea80032b7bf409b0bc7ccad88853858911b7713a8062fdc0623867bedc" dependencies = [ - "jobserver", - "libc", "shlex", ] @@ -520,30 +245,6 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" -[[package]] -name = "chacha20" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818" -dependencies = [ - "cfg-if", - "cipher", - "cpufeatures", -] - -[[package]] -name = "chacha20poly1305" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35" -dependencies = [ - "aead", - "chacha20", - "cipher", - "poly1305", - "zeroize", -] - [[package]] name = "chrono" version = "0.4.41" @@ -554,22 +255,10 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", - "serde", "wasm-bindgen", "windows-link 0.1.3", ] -[[package]] -name = "cipher" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" -dependencies = [ - "crypto-common", - "inout", - "zeroize", -] - [[package]] name = "colorchoice" version = "1.0.5" @@ -598,22 +287,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "console_error_panic_hook" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a06aeb73f470f66dcdbf7223caeebb85984942f22f1adb2a088cf9668146bbbc" -dependencies = [ - "cfg-if", - "wasm-bindgen", -] - -[[package]] -name = "const-oid" -version = "0.9.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" - [[package]] name = "consts" version = "1.0.0" @@ -633,7 +306,7 @@ checksum = "657f625ff361906f779745d08375ae3cc9fef87a35fba5f22874cf773010daf4" dependencies = [ "hax-lib", "pastey", - "rand 0.9.2", + "rand", ] [[package]] @@ -645,30 +318,6 @@ dependencies = [ "libc", ] -[[package]] -name = "crc" -version = "3.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9710d3b3739c2e349eb44fe848ad0b7c8cb1e42bd87ee49371df2f7acaf3e675" -dependencies = [ - "crc-catalog", -] - -[[package]] -name = "crc-catalog" -version = "2.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" - -[[package]] -name = "crc32fast" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" -dependencies = [ - "cfg-if", -] - [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -681,18 +330,6 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" -[[package]] -name = "crypto-bigint" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" -dependencies = [ - "generic-array", - "rand_core 0.6.4", - "subtle", - "zeroize", -] - [[package]] name = "crypto-common" version = "0.1.6" @@ -700,59 +337,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", - "rand_core 0.6.4", "typenum", ] -[[package]] -name = "curve25519-dalek" -version = "4.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" -dependencies = [ - "cfg-if", - "cpufeatures", - "curve25519-dalek-derive", - "digest", - "fiat-crypto", - "rustc_version", - "subtle", - "zeroize", -] - -[[package]] -name = "curve25519-dalek-derive" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", -] - -[[package]] -name = "dart-sys" -version = "4.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57967e4b200d767d091b961d6ab42cc7d0cc14fe9e052e75d0d3cf9eb732d895" -dependencies = [ - "cc", -] - -[[package]] -name = "dashmap" -version = "5.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" -dependencies = [ - "cfg-if", - "hashbrown 0.14.5", - "lock_api", - "once_cell", - "parking_lot_core", -] - [[package]] name = "dcbor" version = "0.23.3" @@ -767,27 +354,6 @@ dependencies = [ "unicode-normalization", ] -[[package]] -name = "delegate-attr" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51aac4c99b2e6775164b412ea33ae8441b2fde2dbf05a20bc0052a63d08c475b" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", -] - -[[package]] -name = "der" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" -dependencies = [ - "const-oid", - "zeroize", -] - [[package]] name = "diatomic-waker" version = "0.2.3" @@ -801,124 +367,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", - "const-oid", "crypto-common", - "subtle", ] [[package]] -name = "displaydoc" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", -] - -[[package]] -name = "dsa" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48bc224a9084ad760195584ce5abb3c2c34a225fa312a128ad245a6b412b7689" -dependencies = [ - "digest", - "num-bigint-dig", - "num-traits", - "pkcs8", - "rfc6979", - "sha2", - "signature", - "zeroize", -] - -[[package]] -name = "dunce" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" - -[[package]] -name = "ecdsa" -version = "0.16.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" -dependencies = [ - "der", - "digest", - "elliptic-curve", - "rfc6979", - "signature", - "spki", -] - -[[package]] -name = "ed25519" -version = "2.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "115531babc129696a58c64a4fef0a8bf9e9698629fb97e9e40767d235cfbcd53" -dependencies = [ - "pkcs8", - "signature", -] - -[[package]] -name = "ed25519-dalek" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70e796c081cee67dc755e1a36a0a172b897fab85fc3f6bc48307991f64e4eca9" -dependencies = [ - "curve25519-dalek", - "ed25519", - "rand_core 0.6.4", - "serde", - "sha2", - "subtle", - "zeroize", -] - -[[package]] -name = "either" -version = "1.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" - -[[package]] -name = "elliptic-curve" -version = "0.13.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" -dependencies = [ - "base16ct", - "crypto-bigint", - "digest", - "ff", - "generic-array", - "group", - "pkcs8", - "rand_core 0.6.4", - "sec1", - "subtle", - "zeroize", -] - -[[package]] -name = "encode_unicode" -version = "1.0.0" +name = "encode_unicode" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" -[[package]] -name = "env_filter" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" -dependencies = [ - "log", - "regex", -] - [[package]] name = "env_filter" version = "1.0.1" @@ -937,7 +394,7 @@ checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a" dependencies = [ "anstream", "anstyle", - "env_filter 1.0.1", + "env_filter", "jiff", "log", ] @@ -986,138 +443,18 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" -[[package]] -name = "ff" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0b50bfb653653f9ca9095b427bed08ab8d75a137839d9ad64eb11810d5b6393" -dependencies = [ - "rand_core 0.6.4", - "subtle", -] - -[[package]] -name = "fiat-crypto" -version = "0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" - -[[package]] -name = "flutter_rust_bridge" -version = "2.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dde126295b2acc5f0a712e265e91b6fdc0ed38767496483e592ae7134db83725" -dependencies = [ - "allo-isolate", - "android_logger", - "anyhow", - "build-target", - "bytemuck", - "byteorder", - "console_error_panic_hook", - "dart-sys", - "delegate-attr", - "flutter_rust_bridge_macros", - "futures", - "js-sys", - "lazy_static", - "log", - "oslog", - "portable-atomic", - "threadpool", - "tokio", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", -] - -[[package]] -name = "flutter_rust_bridge_macros" -version = "2.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5f0420326b13675321b194928bb7830043b68cf8b810e1c651285c747abb080" -dependencies = [ - "hex", - "md-5", - "proc-macro2", - "quote", - "syn 2.0.106", -] - [[package]] name = "fnv" version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" -[[package]] -name = "form_urlencoded" -version = "1.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" -dependencies = [ - "percent-encoding", -] - -[[package]] -name = "foundation-api" -version = "2.0.0" -dependencies = [ - "bc-components", - "bc-envelope", - "bc-xid", - "chrono", - "dcbor", - "flutter_rust_bridge", - "gstp", - "insta", - "quantum-link-macros", - "rkyv", - "thiserror", -] - -[[package]] -name = "futures" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" -dependencies = [ - "futures-channel", - "futures-core", - "futures-executor", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-channel" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" -dependencies = [ - "futures-core", - "futures-sink", -] - [[package]] name = "futures-core" version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" -[[package]] -name = "futures-executor" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" -dependencies = [ - "futures-core", - "futures-task", - "futures-util", -] - [[package]] name = "futures-io" version = "0.3.31" @@ -1137,47 +474,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "futures-macro" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", -] - -[[package]] -name = "futures-sink" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" - -[[package]] -name = "futures-task" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" - -[[package]] -name = "futures-util" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" -dependencies = [ - "futures-channel", - "futures-core", - "futures-io", - "futures-macro", - "futures-sink", - "futures-task", - "memchr", - "pin-project-lite", - "pin-utils", - "slab", -] - [[package]] name = "generator" version = "0.8.8" @@ -1201,7 +497,6 @@ checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", - "zeroize", ] [[package]] @@ -1233,37 +528,6 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" -[[package]] -name = "glob" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" - -[[package]] -name = "group" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" -dependencies = [ - "ff", - "rand_core 0.6.4", - "subtle", -] - -[[package]] -name = "gstp" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dd8214e6a70abd783f45565cba634b58e8afca35dd374a155517a7f4c437774" -dependencies = [ - "bc-components", - "bc-envelope", - "bc-rand", - "bc-xid", - "dcbor", - "thiserror", -] - [[package]] name = "half" version = "2.6.0" @@ -1274,12 +538,6 @@ dependencies = [ "crunchy", ] -[[package]] -name = "hashbrown" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" - [[package]] name = "hashbrown" version = "0.15.5" @@ -1313,7 +571,7 @@ dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.106", + "syn", ] [[package]] @@ -1329,47 +587,11 @@ dependencies = [ "uuid", ] -[[package]] -name = "hermit-abi" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" - [[package]] name = "hex" version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" -dependencies = [ - "serde", -] - -[[package]] -name = "hex-conservative" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5313b072ce3c597065a808dbf612c4c8e8590bdbf8b579508bf7a762c5eae6cd" -dependencies = [ - "arrayvec", -] - -[[package]] -name = "hkdf" -version = "0.12.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" -dependencies = [ - "hmac", -] - -[[package]] -name = "hmac" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" -dependencies = [ - "digest", -] [[package]] name = "iana-time-zone" @@ -1395,113 +617,6 @@ dependencies = [ "cc", ] -[[package]] -name = "icu_collections" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" -dependencies = [ - "displaydoc", - "potential_utf", - "yoke", - "zerofrom", - "zerovec", -] - -[[package]] -name = "icu_locale_core" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" -dependencies = [ - "displaydoc", - "litemap", - "tinystr", - "writeable", - "zerovec", -] - -[[package]] -name = "icu_normalizer" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" -dependencies = [ - "displaydoc", - "icu_collections", - "icu_normalizer_data", - "icu_properties", - "icu_provider", - "smallvec", - "zerovec", -] - -[[package]] -name = "icu_normalizer_data" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" - -[[package]] -name = "icu_properties" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" -dependencies = [ - "displaydoc", - "icu_collections", - "icu_locale_core", - "icu_properties_data", - "icu_provider", - "potential_utf", - "zerotrie", - "zerovec", -] - -[[package]] -name = "icu_properties_data" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" - -[[package]] -name = "icu_provider" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af" -dependencies = [ - "displaydoc", - "icu_locale_core", - "stable_deref_trait", - "tinystr", - "writeable", - "yoke", - "zerofrom", - "zerotrie", - "zerovec", -] - -[[package]] -name = "idna" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" -dependencies = [ - "idna_adapter", - "smallvec", - "utf8_iter", -] - -[[package]] -name = "idna_adapter" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" -dependencies = [ - "icu_normalizer", - "icu_properties", -] - [[package]] name = "indexmap" version = "2.12.1" @@ -1512,15 +627,6 @@ dependencies = [ "hashbrown 0.16.1", ] -[[package]] -name = "inout" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" -dependencies = [ - "generic-array", -] - [[package]] name = "insta" version = "1.44.1" @@ -1549,15 +655,6 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" -[[package]] -name = "itertools" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" -dependencies = [ - "either", -] - [[package]] name = "itoa" version = "1.0.15" @@ -1583,19 +680,9 @@ version = "0.2.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a8c8b344124222efd714b73bb41f8b5120b27a7cc1c75593a6ff768d9d05aa4" dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", -] - -[[package]] -name = "jobserver" -version = "0.1.33" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" -dependencies = [ - "getrandom 0.3.3", - "libc", + "proc-macro2", + "quote", + "syn", ] [[package]] @@ -1608,25 +695,11 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "known-values" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efadaa833480ac053954ea1bf019eee3b3ed0123417434158d1d9ca46162dfed" -dependencies = [ - "bc-components", - "dcbor", - "paste", -] - [[package]] name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" -dependencies = [ - "spin", -] [[package]] name = "libc" @@ -1668,7 +741,7 @@ dependencies = [ "libcrux-secrets", "libcrux-sha3", "libcrux-traits", - "rand 0.9.2", + "rand", "tls_codec", ] @@ -1709,37 +782,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "812e4fa89f3f5e34b47f928b22b1b78395a0d4ec23b1f583db635f128159d65f" dependencies = [ "libcrux-secrets", - "rand 0.9.2", + "rand", ] -[[package]] -name = "libm" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" - [[package]] name = "linux-raw-sys" version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" -[[package]] -name = "litemap" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" - -[[package]] -name = "lock_api" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" -dependencies = [ - "autocfg", - "scopeguard", -] - [[package]] name = "log" version = "0.4.29" @@ -1768,51 +819,12 @@ dependencies = [ "regex-automata", ] -[[package]] -name = "md-5" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" -dependencies = [ - "cfg-if", - "digest", -] - [[package]] name = "memchr" version = "2.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" -[[package]] -name = "minicbor" -version = "0.19.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7005aaf257a59ff4de471a9d5538ec868a21586534fff7f85dd97d4043a6139" -dependencies = [ - "minicbor-derive", -] - -[[package]] -name = "minicbor-derive" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1154809406efdb7982841adb6311b3d095b46f78342dd646736122fe6b19e267" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "miniz_oxide" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" -dependencies = [ - "adler", -] - [[package]] name = "miniz_oxide" version = "0.8.9" @@ -1850,7 +862,7 @@ checksum = "4568f25ccbd45ab5d5603dc34318c1ec56b117531781260002151b8530a9f931" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn", ] [[package]] @@ -1872,22 +884,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "num-bigint-dig" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e661dda6640fad38e827a6d4a310ff4763082116fe217f279885c97f511bb0b7" -dependencies = [ - "lazy_static", - "libm", - "num-integer", - "num-iter", - "num-traits", - "rand 0.8.5", - "smallvec", - "zeroize", -] - [[package]] name = "num-integer" version = "0.1.46" @@ -1897,17 +893,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "num-iter" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - [[package]] name = "num-traits" version = "0.2.19" @@ -1915,17 +900,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", - "libm", -] - -[[package]] -name = "num_cpus" -version = "1.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" -dependencies = [ - "hermit-abi", - "libc", ] [[package]] @@ -1955,61 +929,6 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea" -[[package]] -name = "opaque-debug" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" - -[[package]] -name = "oslog" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80d2043d1f61d77cb2f4b1f7b7b2295f40507f5f8e9d1c8bf10a1ca5f97a3969" -dependencies = [ - "cc", - "dashmap", - "log", -] - -[[package]] -name = "p256" -version = "0.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9863ad85fa8f4460f9c48cb909d38a0d689dba1f6f6988a5e3e0d31071bcd4b" -dependencies = [ - "ecdsa", - "elliptic-curve", - "primeorder", - "sha2", -] - -[[package]] -name = "p384" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe42f1670a52a47d448f14b6a5c61dd78fce51856e68edaa38f7ae3a46b8d6b6" -dependencies = [ - "ecdsa", - "elliptic-curve", - "primeorder", - "sha2", -] - -[[package]] -name = "p521" -version = "0.13.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fc9e2161f1f215afdfce23677034ae137bbd45016a880c2eb3ba8eb95f085b2" -dependencies = [ - "base16ct", - "ecdsa", - "elliptic-curve", - "primeorder", - "rand_core 0.6.4", - "sha2", -] - [[package]] name = "parking" version = "2.2.1" @@ -2019,30 +938,6 @@ dependencies = [ "loom", ] -[[package]] -name = "parking_lot_core" -version = "0.9.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets", -] - -[[package]] -name = "password-hash" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" -dependencies = [ - "base64ct", - "rand_core 0.6.4", - "subtle", -] - [[package]] name = "paste" version = "1.0.15" @@ -2055,117 +950,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b867cad97c0791bbd3aaa6472142568c6c9e8f71937e98379f584cfb0cf35bec" -[[package]] -name = "pbkdf2" -version = "0.12.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2" -dependencies = [ - "digest", - "hmac", -] - -[[package]] -name = "pem-rfc7468" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" -dependencies = [ - "base64ct", -] - -[[package]] -name = "percent-encoding" -version = "2.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" - -[[package]] -name = "phf" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" -dependencies = [ - "phf_macros", - "phf_shared", -] - -[[package]] -name = "phf_generator" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" -dependencies = [ - "phf_shared", - "rand 0.8.5", -] - -[[package]] -name = "phf_macros" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216" -dependencies = [ - "phf_generator", - "phf_shared", - "proc-macro2", - "quote", - "syn 2.0.106", -] - -[[package]] -name = "phf_shared" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" -dependencies = [ - "siphasher", -] - [[package]] name = "pin-project-lite" version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" -[[package]] -name = "pin-utils" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" - -[[package]] -name = "pkcs1" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" -dependencies = [ - "der", - "pkcs8", - "spki", -] - -[[package]] -name = "pkcs8" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" -dependencies = [ - "der", - "spki", -] - -[[package]] -name = "poly1305" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf" -dependencies = [ - "cpufeatures", - "opaque-debug", - "universal-hash", -] - [[package]] name = "portable-atomic" version = "1.11.1" @@ -2181,15 +971,6 @@ dependencies = [ "portable-atomic", ] -[[package]] -name = "potential_utf" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585" -dependencies = [ - "zerovec", -] - [[package]] name = "ppv-lite86" version = "0.2.21" @@ -2199,56 +980,6 @@ dependencies = [ "zerocopy", ] -[[package]] -name = "pqcrypto-internals" -version = "0.2.10" -source = "git+https://github.com/Foundation-Devices/pqcrypto?rev=ebadf71214f67cb970242fa1053b4acb65767737#ebadf71214f67cb970242fa1053b4acb65767737" -dependencies = [ - "cc", - "dunce", - "getrandom 0.2.16", - "libc", -] - -[[package]] -name = "pqcrypto-mldsa" -version = "0.1.1" -source = "git+https://github.com/Foundation-Devices/pqcrypto?rev=ebadf71214f67cb970242fa1053b4acb65767737#ebadf71214f67cb970242fa1053b4acb65767737" -dependencies = [ - "cc", - "glob", - "libc", - "paste", - "pqcrypto-internals", - "pqcrypto-traits", -] - -[[package]] -name = "pqcrypto-mlkem" -version = "0.1.0" -source = "git+https://github.com/Foundation-Devices/pqcrypto?rev=ebadf71214f67cb970242fa1053b4acb65767737#ebadf71214f67cb970242fa1053b4acb65767737" -dependencies = [ - "cc", - "glob", - "libc", - "pqcrypto-internals", - "pqcrypto-traits", -] - -[[package]] -name = "pqcrypto-traits" -version = "0.3.5" -source = "git+https://github.com/Foundation-Devices/pqcrypto?rev=ebadf71214f67cb970242fa1053b4acb65767737#ebadf71214f67cb970242fa1053b4acb65767737" - -[[package]] -name = "primeorder" -version = "0.13.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "353e1ca18966c16d9deb1c69278edbc5f194139612772bd9537af60ac231e1e6" -dependencies = [ - "elliptic-curve", -] - [[package]] name = "proc-macro-error-attr2" version = "2.0.0" @@ -2268,7 +999,7 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", - "syn 2.0.106", + "syn", ] [[package]] @@ -2290,8 +1021,8 @@ dependencies = [ "bit-vec", "bitflags", "num-traits", - "rand 0.9.2", - "rand_chacha 0.9.0", + "rand", + "rand_chacha", "rand_xorshift", "regex-syntax", "rusty-fork", @@ -2299,30 +1030,6 @@ dependencies = [ "unarray", ] -[[package]] -name = "provenance-mark" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8a2078ef3c515d873099557bdbcb4b1ad8e67af1d2b398b5a7aa224baca970d" -dependencies = [ - "base64", - "bc-envelope", - "bc-rand", - "bc-tags", - "bc-ur", - "chacha20", - "chrono", - "dcbor", - "hex", - "hkdf", - "rand_core 0.6.4", - "serde", - "serde_json", - "sha2", - "thiserror", - "url", -] - [[package]] name = "ptr_meta" version = "0.3.1" @@ -2340,7 +1047,7 @@ checksum = "7347867d0a7e1208d93b46767be83e2b8f978c3dad35f775ac8d8847551d6fe1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn", ] [[package]] @@ -2391,15 +1098,6 @@ dependencies = [ "sha2", ] -[[package]] -name = "quantum-link-macros" -version = "0.1.0" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", -] - [[package]] name = "quick-error" version = "1.2.3" @@ -2430,54 +1128,24 @@ dependencies = [ "ptr_meta", ] -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "libc", - "rand_chacha 0.3.1", - "rand_core 0.6.4", -] - [[package]] name = "rand" version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ - "rand_chacha 0.9.0", - "rand_core 0.9.3", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core 0.6.4", -] - -[[package]] -name = "rand_chacha" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" -dependencies = [ - "ppv-lite86", - "rand_core 0.9.3", + "rand_chacha", + "rand_core", ] [[package]] -name = "rand_core" -version = "0.6.4" +name = "rand_chacha" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ - "getrandom 0.2.16", + "ppv-lite86", + "rand_core", ] [[package]] @@ -2495,25 +1163,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a" dependencies = [ - "rand_core 0.9.3", -] - -[[package]] -name = "rand_xoshiro" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" -dependencies = [ - "rand_core 0.6.4", -] - -[[package]] -name = "redox_syscall" -version = "0.5.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5407465600fb0548f1442edf71dd20683c6ed326200ace4b1ef0763521bb3b77" -dependencies = [ - "bitflags", + "rand_core", ] [[package]] @@ -2554,16 +1204,6 @@ dependencies = [ "bytecheck", ] -[[package]] -name = "rfc6979" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" -dependencies = [ - "hmac", - "subtle", -] - [[package]] name = "rkyv" version = "0.8.12" @@ -2591,28 +1231,7 @@ checksum = "bd83f5f173ff41e00337d97f6572e416d022ef8a19f371817259ae960324c482" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", -] - -[[package]] -name = "rsa" -version = "0.9.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8573f03f5883dcaebdfcf4725caa1ecb9c15b2ef50c43a07b816e06799bb12d" -dependencies = [ - "const-oid", - "digest", - "num-bigint-dig", - "num-integer", - "num-traits", - "pkcs1", - "pkcs8", - "rand_core 0.6.4", - "sha2", - "signature", - "spki", - "subtle", - "zeroize", + "syn", ] [[package]] @@ -2621,15 +1240,6 @@ version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" -[[package]] -name = "rustc_version" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" -dependencies = [ - "semver", -] - [[package]] name = "rustix" version = "1.1.2" @@ -2667,78 +1277,12 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" -[[package]] -name = "salsa20" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97a22f5af31f73a954c10289c93e8a50cc23d971e80ee446f1f6f7137a088213" -dependencies = [ - "cipher", -] - [[package]] name = "scoped-tls" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - -[[package]] -name = "scrypt" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0516a385866c09368f0b5bcd1caff3366aace790fcd46e2bb032697bb172fd1f" -dependencies = [ - "pbkdf2", - "salsa20", - "sha2", -] - -[[package]] -name = "sec1" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" -dependencies = [ - "base16ct", - "der", - "generic-array", - "pkcs8", - "subtle", - "zeroize", -] - -[[package]] -name = "secp256k1" -version = "0.30.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b50c5943d326858130af85e049f2661ba3c78b26589b8ab98e65e80ae44a1252" -dependencies = [ - "bitcoin_hashes 0.14.0", - "rand 0.8.5", - "secp256k1-sys", -] - -[[package]] -name = "secp256k1-sys" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4387882333d3aa8cb20530a17c69a3752e97837832f34f6dccc760e715001d9" -dependencies = [ - "cc", -] - -[[package]] -name = "semver" -version = "1.0.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" - [[package]] name = "serde" version = "1.0.228" @@ -2766,7 +1310,7 @@ checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn", ] [[package]] @@ -2781,17 +1325,6 @@ dependencies = [ "serde", ] -[[package]] -name = "sha1" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - [[package]] name = "sha2" version = "0.10.9" @@ -2818,16 +1351,6 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" -[[package]] -name = "signature" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" -dependencies = [ - "digest", - "rand_core 0.6.4", -] - [[package]] name = "simdutf8" version = "0.1.5" @@ -2840,12 +1363,6 @@ version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" -[[package]] -name = "siphasher" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" - [[package]] name = "slab" version = "0.4.11" @@ -2858,101 +1375,6 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" - -[[package]] -name = "spki" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" -dependencies = [ - "base64ct", - "der", -] - -[[package]] -name = "ssh-cipher" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "caac132742f0d33c3af65bfcde7f6aa8f62f0e991d80db99149eb9d44708784f" -dependencies = [ - "cipher", - "ssh-encoding", -] - -[[package]] -name = "ssh-encoding" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb9242b9ef4108a78e8cd1a2c98e193ef372437f8c22be363075233321dd4a15" -dependencies = [ - "base64ct", - "pem-rfc7468", - "sha2", -] - -[[package]] -name = "ssh-key" -version = "0.6.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b86f5297f0f04d08cabaa0f6bff7cb6aec4d9c3b49d87990d63da9d9156a8c3" -dependencies = [ - "dsa", - "ed25519-dalek", - "num-bigint-dig", - "p256", - "p384", - "p521", - "rand_core 0.6.4", - "rsa", - "sec1", - "sha1", - "sha2", - "signature", - "ssh-cipher", - "ssh-encoding", - "subtle", - "zeroize", -] - -[[package]] -name = "sskr" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7228e0234fae61785706c7f2b2bc5e47b6b34397e8c1052e1cfcba8030536234" -dependencies = [ - "bc-rand", - "bc-shamir", - "thiserror", -] - -[[package]] -name = "stable_deref_trait" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" - -[[package]] -name = "subtle" -version = "2.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" - -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - [[package]] name = "syn" version = "2.0.106" @@ -2964,17 +1386,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "synstructure" -version = "0.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", -] - [[package]] name = "tempfile" version = "3.23.0" @@ -3005,7 +1416,7 @@ checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn", ] [[package]] @@ -3017,25 +1428,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "threadpool" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa" -dependencies = [ - "num_cpus", -] - -[[package]] -name = "tinystr" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" -dependencies = [ - "displaydoc", - "zerovec", -] - [[package]] name = "tinyvec" version = "1.10.0" @@ -3069,7 +1461,7 @@ checksum = "2d2e76690929402faae40aebdda620a2c0e25dd6d3b9afe48867dfd95991f4bd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn", ] [[package]] @@ -3095,7 +1487,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn", ] [[package]] @@ -3155,7 +1547,7 @@ checksum = "70977707304198400eb4835a78f6a9f928bf41bba420deb8fdb175cd965d77a7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn", ] [[package]] @@ -3185,46 +1577,6 @@ dependencies = [ "tinyvec", ] -[[package]] -name = "universal-hash" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" -dependencies = [ - "crypto-common", - "subtle", -] - -[[package]] -name = "ur" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "010f24a953db5d22d0010969ca3bbf40b3857b89f47c0f7be0da4c2d7ded0760" -dependencies = [ - "bitcoin_hashes 0.12.0", - "crc", - "minicbor", - "phf", - "rand_xoshiro", -] - -[[package]] -name = "url" -version = "2.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "137a3c834eaf7139b73688502f3f1141a0337c5d8e4d9b536f9b8c796e26a7c4" -dependencies = [ - "form_urlencoded", - "idna", - "percent-encoding", -] - -[[package]] -name = "utf8_iter" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" - [[package]] name = "utf8parse" version = "0.2.2" @@ -3300,23 +1652,10 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.106", + "syn", "wasm-bindgen-shared", ] -[[package]] -name = "wasm-bindgen-futures" -version = "0.4.50" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61" -dependencies = [ - "cfg-if", - "js-sys", - "once_cell", - "wasm-bindgen", - "web-sys", -] - [[package]] name = "wasm-bindgen-macro" version = "0.2.100" @@ -3335,7 +1674,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3349,16 +1688,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "web-sys" -version = "0.3.77" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - [[package]] name = "windows-core" version = "0.61.2" @@ -3380,7 +1709,7 @@ checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn", ] [[package]] @@ -3391,7 +1720,7 @@ checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn", ] [[package]] @@ -3515,48 +1844,6 @@ dependencies = [ "bitflags", ] -[[package]] -name = "writeable" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" - -[[package]] -name = "x25519-dalek" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277" -dependencies = [ - "curve25519-dalek", - "rand_core 0.6.4", - "serde", - "zeroize", -] - -[[package]] -name = "yoke" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" -dependencies = [ - "serde", - "stable_deref_trait", - "yoke-derive", - "zerofrom", -] - -[[package]] -name = "yoke-derive" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", - "synstructure", -] - [[package]] name = "zerocopy" version = "0.8.26" @@ -3574,28 +1861,7 @@ checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", -] - -[[package]] -name = "zerofrom" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" -dependencies = [ - "zerofrom-derive", -] - -[[package]] -name = "zerofrom-derive" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", - "synstructure", + "syn", ] [[package]] @@ -3615,38 +1881,5 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", -] - -[[package]] -name = "zerotrie" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" -dependencies = [ - "displaydoc", - "yoke", - "zerofrom", -] - -[[package]] -name = "zerovec" -version = "0.11.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7aa2bd55086f1ab526693ecbe444205da57e25f4489879da80635a46d90e73b" -dependencies = [ - "yoke", - "zerofrom", - "zerovec-derive", -] - -[[package]] -name = "zerovec-derive" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", + "syn", ] diff --git a/Cargo.toml b/Cargo.toml index b2492c4..83ac135 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,44 +1,27 @@ [workspace] resolver = "2" members = [ - "api", "backup-shard", "btp", "ql-fsm", "ql-rpc", "ql-runtime", "ql-wire", - "quantum-link-macros", ] [workspace.package] homepage = "https://github.com/Foundation-Devices/foundation-api" [workspace.dependencies] -# blockchain commons -bc-components = { version = "0.28.0" } -bc-envelope = { version = "0.37.0" } -bc-xid = { version = "0.16.0" } dcbor = { version = "0.23.3" } -gstp = { version = "0.11.0" } - -chrono = "0.4" bytes = "1" getrandom = { version = "0.2" } insta = { version = "1.43.2" } -thiserror = { version = "2" } rkyv = { version = "0.8" } # workspace crates backup-shard = { path = "backup-shard" } btp = { path = "btp" } -foundation-api = { path = "api" } -quantum-link-macros = { path = "quantum-link-macros" } ql-fsm = { path = "ql-fsm" } ql-rpc = { path = "ql-rpc" } ql-wire = { path = "ql-wire" } - -[patch.crates-io] -pqcrypto-traits = { git = "https://github.com/Foundation-Devices/pqcrypto", rev = "ebadf71214f67cb970242fa1053b4acb65767737" } -pqcrypto-mldsa = { git = "https://github.com/Foundation-Devices/pqcrypto", rev = "ebadf71214f67cb970242fa1053b4acb65767737" } -pqcrypto-mlkem = { git = "https://github.com/Foundation-Devices/pqcrypto", rev = "ebadf71214f67cb970242fa1053b4acb65767737" } diff --git a/Justfile b/Justfile index 45492c2..71889c0 100644 --- a/Justfile +++ b/Justfile @@ -1,15 +1,3 @@ # Run clippy on all targets and features, treating warnings as errors clippy: cargo clippy --all-targets --all-features -- -D warnings - -# Run golden/snapshot tests -golden: - cargo test -p foundation-api --test golden_tests - -# Update golden/snapshot tests (accept all new snapshots) -golden-update: - INSTA_UPDATE=always cargo test -p foundation-api --test golden_tests - -# Review pending golden/snapshot changes interactively (requires cargo-insta) -golden-review: - cargo insta review \ No newline at end of file diff --git a/README.md b/README.md index 2e78690..0d280b2 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,15 @@ # Foundation API -This monorepo contains the core crates for a device-to-device API using Blockchain Commons' GSTP +This monorepo contains the core crates for Foundation device-to-device protocols. ## Crates -- **abstracted**: Abstractions of the BLE and SE chips -- **api**: The API - contains predefined QL messages -- **api-demo**: Tokio-based demo of device-to-device communication - **btp**: Beefcake Transfer Protocol for splitting messages into MTU sized chunks -- **quantum-link-macros**: Macros to easily turn Rust Structs and Enums into valid QL messages +- **backup-shard**: Magic backup shard encoding +- **ql-wire**: QuantumLink wire-format definitions +- **ql-fsm**: QuantumLink Sans-IO protocol finite state machine +- **ql-runtime**: QuantumLink async runtime +- **ql-rpc**: RPC modality layer over QuantumLink streams ## Development diff --git a/api/.gitignore b/api/.gitignore deleted file mode 100644 index 96ef6c0..0000000 --- a/api/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -/target -Cargo.lock diff --git a/api/Cargo.toml b/api/Cargo.toml deleted file mode 100644 index 13e7ec3..0000000 --- a/api/Cargo.toml +++ /dev/null @@ -1,31 +0,0 @@ -[package] -name = "foundation-api" -version = "2.0.0" -edition = "2021" -description = "Foundation API using Gordian Sealed Transaction Protocol (GSTP)." -authors = ["Wolf McNally, Blockchain Commons, Foundation Devices"] -repository = "https://github.com/Foundation-Devices/foundation-api" -readme = "README.md" -license = "Proprietary" - -[dependencies] -bc-envelope = { workspace = true } -bc-xid = { workspace = true } -rkyv = { workspace = true, optional = true } -flutter_rust_bridge = { version = "=2.11.1", optional = true } -quantum-link-macros = { workspace = true } -gstp = { workspace = true } -bc-components = { workspace = true } -dcbor = { workspace = true } -chrono = { workspace = true } -thiserror = { workspace = true } - -[dev-dependencies] -insta = { workspace = true } - -[features] -keyos = ["rkyv"] -envoy = ["flutter_rust_bridge"] - -[lints.rust] -unexpected_cfgs = { level = "warn", check-cfg = ['cfg(frb_expand)'] } diff --git a/api/src/api/backup.rs b/api/src/api/backup.rs deleted file mode 100644 index ec27c35..0000000 --- a/api/src/api/backup.rs +++ /dev/null @@ -1,264 +0,0 @@ -use quantum_link_macros::quantum_link; - -#[quantum_link] -#[repr(transparent)] -pub struct Shard(pub Vec); - -#[quantum_link] -#[repr(transparent)] -pub struct SeedFingerprint(pub [u8; 32]); - -#[quantum_link] -pub struct BackupShardRequest { - #[n(0)] - pub shard: Shard, -} - -#[quantum_link] -pub enum BackupShardResponse { - #[n(0)] - Success, - #[n(1)] - Error { - #[n(0)] - error: String, - }, -} - -#[quantum_link] -pub struct RestoreShardRequest { - #[n(0)] - pub seed_fingerprint: SeedFingerprint, - #[n(1)] - pub timestamp: Option, -} - -#[quantum_link] -pub enum RestoreShardResponse { - #[n(0)] - Success { - #[n(0)] - shard: Shard, - }, - #[n(1)] - Error { - #[n(0)] - error: String, - }, - #[n(2)] - NotFound, -} - -#[quantum_link] -pub struct EnvoyMagicBackupEnabledRequest {} - -#[quantum_link] -pub struct EnvoyMagicBackupEnabledResponse { - #[n(0)] - pub enabled: bool, -} - -#[quantum_link] -pub struct PrimeMagicBackupEnabled { - #[n(0)] - pub enabled: bool, - #[n(1)] - pub seed_fingerprint: SeedFingerprint, -} - -#[quantum_link] -pub struct PrimeMagicBackupStatusRequest { - #[n(0)] - pub seed_fingerprint: SeedFingerprint, - #[n(1)] - pub timestamp: Option, -} - -#[quantum_link] -pub struct PrimeMagicBackupStatusResponse { - #[n(0)] - pub shard_backup_found: bool, -} - -// -// MAGIC BACKUPS -// - -#[quantum_link] -#[derive(Eq)] -pub struct BackupChunk { - #[n(0)] - pub chunk_index: u32, - #[n(1)] - pub total_chunks: u32, - #[n(2)] - pub data: Vec, -} - -impl BackupChunk { - pub fn is_last(&self) -> bool { - self.chunk_index == self.total_chunks - 1 - } -} - -// -// CREATING BACKUP -// - -// from prime -> envoy -#[quantum_link] -pub enum CreateMagicBackupEvent { - #[n(0)] - Start(StartMagicBackup), - #[n(1)] - Chunk(BackupChunk), -} - -#[quantum_link] -pub struct StartMagicBackup { - #[n(0)] - pub seed_fingerprint: SeedFingerprint, - #[n(1)] - pub total_chunks: u32, - #[n(2)] - pub hash: [u8; 32], -} - -// envoy -> prime -// error can be sent at any time -// success is expected at the end of the flow -#[quantum_link] -pub enum CreateMagicBackupResult { - #[n(0)] - Success, - #[n(1)] - Error { - #[n(0)] - error: String, - }, -} - -// -// RESTORING BACKUP -// - -#[quantum_link] -pub struct RestoreMagicBackupRequest { - #[n(0)] - pub seed_fingerprint: SeedFingerprint, - /// if 0, then go from start - #[n(1)] - pub resume_from_chunk: u32, -} - -#[quantum_link] -pub enum RestoreMagicBackupEvent { - // there is no backup found from the provided fingerprint - #[n(0)] - NotFound, - // envoy found a backup and is beginning transmission - #[n(1)] - Starting(BackupMetadata), - // a backup chunk - #[n(2)] - Chunk(BackupChunk), - // envoy failed - #[n(3)] - Error { - #[n(0)] - error: String, - }, -} - -#[quantum_link] -#[derive(Eq)] -pub struct BackupMetadata { - #[n(0)] - pub total_chunks: u32, -} - -// sent from prime -> envoy -#[quantum_link] -pub enum RestoreMagicBackupResult { - #[n(0)] - Success, - #[n(1)] - Error { - #[n(0)] - error: String, - }, -} - -// -// MAGIC BACKUPS V2 -// - -#[quantum_link] -pub struct CreateMagicBackupV2 { - #[n(0)] - pub timestamp: u64, - /// Backup identifier (SHA-256 hash). - #[n(1)] - pub hash: Vec, - /// ML-DSA-44 public key. - #[n(2)] - pub pubkey: Vec, - /// Encrypted backup payload. - #[n(3)] - pub data: Vec, - /// ML-DSA-44 client signature. - #[n(4)] - pub client_signature: Vec, -} - -#[quantum_link] -pub struct GetMagicBackupV2 { - #[n(0)] - pub key: Vec, - #[n(1)] - pub timestamp: u64, - /// ML-DSA-44 signature. - #[n(2)] - pub signature: Vec, -} - -#[quantum_link] -pub struct DeleteMagicBackupV2 { - #[n(0)] - pub key: Vec, - #[n(1)] - pub timestamp: u64, - /// ML-DSA-44 signature. - #[n(2)] - pub signature: Vec, -} - -// prime -> envoy -#[quantum_link] -pub enum MagicBackupRequestV2 { - #[n(0)] - Create(CreateMagicBackupV2), - #[n(1)] - Get(GetMagicBackupV2), - #[n(2)] - Delete(DeleteMagicBackupV2), -} - -// envoy -> prime -#[quantum_link] -pub enum MagicBackupResponseV2 { - #[n(0)] - Created, - #[n(1)] - Backup { - #[n(0)] - data: Vec, - }, - #[n(2)] - Deleted, - #[n(3)] - Error { - #[n(0)] - error: String, - }, -} diff --git a/api/src/api/bitcoin.rs b/api/src/api/bitcoin.rs deleted file mode 100644 index 821f280..0000000 --- a/api/src/api/bitcoin.rs +++ /dev/null @@ -1,32 +0,0 @@ -use quantum_link_macros::quantum_link; - -#[quantum_link] -pub struct SignPsbt { - #[n(0)] - pub account_id: String, - #[n(1)] - pub psbt: Vec, -} - -#[quantum_link] -pub struct AccountUpdate { - #[n(0)] - pub account_id: String, - #[n(1)] - pub update: Vec, -} - -#[quantum_link] -pub struct BroadcastTransaction { - #[n(0)] - pub account_id: String, - #[n(1)] - pub psbt: Vec, -} - -// If None, there's no passphrase, hide passphrased accounts -#[quantum_link] -pub struct ApplyPassphrase { - #[n(0)] - pub fingerprint: Option, -} diff --git a/api/src/api/firmware.rs b/api/src/api/firmware.rs deleted file mode 100644 index e3ad928..0000000 --- a/api/src/api/firmware.rs +++ /dev/null @@ -1,116 +0,0 @@ -use quantum_link_macros::quantum_link; - -// From Prime to Envoy -#[quantum_link] -pub struct FirmwareUpdateCheckRequest { - #[n(0)] - pub current_version: String, -} - -// From Envoy to Prime -#[quantum_link] -pub enum FirmwareUpdateCheckResponse { - #[n(0)] - Available(FirmwareUpdateAvailable), - #[n(1)] - NotAvailable, -} - -#[quantum_link] -pub struct FirmwareUpdateAvailable { - #[n(0)] - pub version: String, - #[n(1)] - pub changelog: String, - #[n(2)] - pub timestamp: u32, - #[n(3)] - pub total_size: u32, - #[n(4)] - pub patch_count: u8, -} - -// From Prime to Envoy -#[quantum_link] -pub struct FirmwareFetchRequest { - #[n(0)] - pub current_version: String, - #[n(1)] - pub chunk_offset: Option, -} - -// From Envoy to Prime -#[quantum_link] -pub enum FirmwareFetchEvent { - // there is no update available from the provided prime version - #[n(0)] - UpdateNotAvailable, - // envoy has found an update, and will begin transmission - #[n(1)] - Starting(FirmwareUpdateAvailable), - // envoy is downloading the update - #[n(2)] - Downloading, - // envoy is sending a chunk for an update patch - #[n(3)] - Chunk(FirmwareChunk), - // envoy failed - #[n(5)] - Error { - #[n(0)] - error: String, - }, -} - -#[quantum_link] -#[derive(Eq)] -pub struct FirmwareChunk { - #[n(0)] - pub patch_index: u8, - #[n(1)] - pub total_patches: u8, - #[n(2)] - pub chunk_index: u16, - #[n(3)] - pub total_chunks: u16, - #[n(4)] - pub data: Vec, -} - -impl FirmwareChunk { - pub fn is_last(&self) -> bool { - self.patch_index == self.total_patches - 1 && self.chunk_index == self.total_chunks - 1 - } -} - -#[quantum_link] -pub enum FirmwareInstallEvent { - #[n(0)] - UpdateVerified, - #[n(1)] - Installing, - #[n(2)] - Rebooting, - #[n(3)] - Success { - #[n(0)] - installed_version: String, - }, - #[n(4)] - Error { - #[n(0)] - error: String, - #[n(1)] - stage: InstallErrorStage, - }, -} - -#[quantum_link] -pub enum InstallErrorStage { - #[n(0)] - Download, - #[n(1)] - Verify, - #[n(2)] - Install, -} diff --git a/api/src/api/fx.rs b/api/src/api/fx.rs deleted file mode 100644 index 3f2eb81..0000000 --- a/api/src/api/fx.rs +++ /dev/null @@ -1,34 +0,0 @@ -use quantum_link_macros::quantum_link; - -#[quantum_link] -pub struct ExchangeRate { - #[n(0)] - pub currency_code: String, - #[n(1)] - pub rate: f32, - #[n(2)] - pub timestamp: u64, -} - -#[quantum_link] -pub struct ExchangeRateHistory { - #[n(0)] - pub history: Vec, - #[n(1)] - pub currency_code: String, -} - -#[quantum_link] -pub struct PricePoint { - #[n(0)] - pub rate: f32, - #[n(1)] - pub timestamp: u64, -} - -/// Prime → Envoy. ISO-4217 code; sent on settings change and on every reconnect. -#[quantum_link] -pub struct PrimeFiatPreference { - #[n(0)] - pub currency_code: String, -} diff --git a/api/src/api/message.rs b/api/src/api/message.rs deleted file mode 100644 index cfe8fd7..0000000 --- a/api/src/api/message.rs +++ /dev/null @@ -1,149 +0,0 @@ -use quantum_link_macros::quantum_link; - -use super::onboarding::OnboardingState; -use crate::{ - backup::{ - BackupShardRequest, BackupShardResponse, CreateMagicBackupEvent, CreateMagicBackupResult, - EnvoyMagicBackupEnabledRequest, EnvoyMagicBackupEnabledResponse, MagicBackupRequestV2, - MagicBackupResponseV2, PrimeMagicBackupEnabled, PrimeMagicBackupStatusRequest, - PrimeMagicBackupStatusResponse, RestoreMagicBackupEvent, RestoreMagicBackupRequest, - RestoreMagicBackupResult, RestoreShardRequest, RestoreShardResponse, - }, - bitcoin::*, - firmware::{ - FirmwareFetchEvent, FirmwareFetchRequest, FirmwareInstallEvent, FirmwareUpdateCheckRequest, - FirmwareUpdateCheckResponse, - }, - fx::{ExchangeRate, ExchangeRateHistory, PrimeFiatPreference}, - pairing::{PairingRequest, PairingResponse, UnpairingRequest, UnpairingResponse}, - scv::SecurityCheck, - status::{ - DeviceNameUpdate, DeviceStatus, EnvoyStatus, Heartbeat, TimezoneRequest, TimezoneResponse, - }, -}; - -// Bump this every time there is a significant change -pub const PROTOCOL_VERSION: u8 = 1; - -#[quantum_link] -pub struct EnvoyMessage { - #[n(0)] - pub message: QuantumLinkMessage, - #[n(1)] - pub timestamp: u32, - #[n(2)] - pub protocol_version: Option, // This being None is implicit v0 -} - -#[quantum_link] -pub struct PassportMessage { - #[n(0)] - pub message: QuantumLinkMessage, - #[n(1)] - pub status: DeviceStatus, - #[n(2)] - pub protocol_version: Option, -} - -#[quantum_link] -pub enum QuantumLinkMessage { - #[n(0)] - ExchangeRate(ExchangeRate), - #[n(1)] - ExchangeRateHistory(ExchangeRateHistory), - - #[n(2)] - FirmwareUpdateCheckRequest(FirmwareUpdateCheckRequest), - #[n(3)] - FirmwareUpdateCheckResponse(FirmwareUpdateCheckResponse), - #[n(4)] - FirmwareFetchRequest(FirmwareFetchRequest), - #[n(5)] - FirmwareFetchEvent(FirmwareFetchEvent), - #[n(6)] - FirmwareInstallEvent(FirmwareInstallEvent), - - #[n(7)] - DeviceStatus(DeviceStatus), - #[n(8)] - EnvoyStatus(EnvoyStatus), - - #[n(9)] - PairingRequest(PairingRequest), - #[n(10)] - PairingResponse(PairingResponse), - - #[n(11)] - SecurityCheck(SecurityCheck), - #[n(12)] - OnboardingState(OnboardingState), - - #[n(13)] - SignPsbt(SignPsbt), - #[n(14)] - BroadcastTransaction(BroadcastTransaction), - #[n(15)] - AccountUpdate(AccountUpdate), - #[n(16)] - ApplyPassphrase(ApplyPassphrase), - - #[n(17)] - EnvoyMagicBackupEnabledRequest(EnvoyMagicBackupEnabledRequest), - #[n(18)] - EnvoyMagicBackupEnabledResponse(EnvoyMagicBackupEnabledResponse), - - #[n(19)] - PrimeMagicBackupEnabled(PrimeMagicBackupEnabled), - - #[n(20)] - PrimeMagicBackupStatusRequest(PrimeMagicBackupStatusRequest), - #[n(21)] - PrimeMagicBackupStatusResponse(PrimeMagicBackupStatusResponse), - - #[n(22)] - BackupShardRequest(BackupShardRequest), - #[n(23)] - BackupShardResponse(BackupShardResponse), - - #[n(24)] - RestoreShardRequest(RestoreShardRequest), - #[n(25)] - RestoreShardResponse(RestoreShardResponse), - - #[n(26)] - CreateMagicBackupEvent(CreateMagicBackupEvent), - #[n(27)] - CreateMagicBackupResult(CreateMagicBackupResult), - - #[n(28)] - RestoreMagicBackupRequest(RestoreMagicBackupRequest), - #[n(29)] - RestoreMagicBackupEvent(RestoreMagicBackupEvent), - #[n(30)] - RestoreMagicBackupResult(RestoreMagicBackupResult), - - #[n(31)] - Heartbeat(Heartbeat), - - #[n(33)] - TimezoneRequest(TimezoneRequest), - #[n(34)] - TimezoneResponse(TimezoneResponse), - - #[n(35)] - UnpairingRequest(UnpairingRequest), - #[n(36)] - UnpairingResponse(UnpairingResponse), - - #[n(37)] - DeviceNameUpdate(DeviceNameUpdate), - - #[n(38)] - MagicBackupRequestV2(MagicBackupRequestV2), - #[n(39)] - MagicBackupResponseV2(MagicBackupResponseV2), - - // Skipped tags (e.g. #[n(32)]) are intentional and must not be reused. - #[n(40)] - PrimeFiatPreference(PrimeFiatPreference), -} diff --git a/api/src/api/mod.rs b/api/src/api/mod.rs deleted file mode 100644 index a362499..0000000 --- a/api/src/api/mod.rs +++ /dev/null @@ -1,13 +0,0 @@ -pub mod backup; -pub mod bitcoin; -pub mod firmware; -pub mod fx; -pub mod message; -pub mod onboarding; -pub mod pairing; -pub mod passport; -pub mod quantum_link; -pub mod scv; -pub mod status; -#[cfg(test)] -pub mod tests; diff --git a/api/src/api/onboarding.rs b/api/src/api/onboarding.rs deleted file mode 100644 index 8e85fb0..0000000 --- a/api/src/api/onboarding.rs +++ /dev/null @@ -1,47 +0,0 @@ -use quantum_link_macros::quantum_link; - -#[quantum_link] -pub enum OnboardingState { - #[n(0)] - SecurityChecked, - #[n(1)] - SecurityCheckFailed, - - #[n(2)] - FirmwareUpdateScreen, - - /// pin - #[n(3)] - SecuringDevice, - /// pin - #[n(4)] - DeviceSecured, - - #[n(5)] - WalletCreationScreen, - #[n(6)] - CreatingWallet, - #[n(7)] - WalletCreated, - - #[n(8)] - MagicBackupScreen, - #[n(9)] - CreatingMagicBackup, - #[n(10)] - MagicBackupCreated, - - #[n(11)] - CreatingManualBackup, - #[n(12)] - CreatingKeycardBackup, - - #[n(13)] - WritingDownSeedWords, - #[n(14)] - ConnectingWallet, - #[n(15)] - WalletConected, - #[n(16)] - Completed, -} diff --git a/api/src/api/pairing.rs b/api/src/api/pairing.rs deleted file mode 100644 index ccc94ea..0000000 --- a/api/src/api/pairing.rs +++ /dev/null @@ -1,39 +0,0 @@ -use quantum_link_macros::quantum_link; - -use crate::{ - api::passport::{PassportFirmwareVersion, PassportModel, PassportSerial}, - passport::PassportColor, -}; - -#[quantum_link] -pub struct PairingResponse { - #[n(0)] - pub passport_model: PassportModel, - #[n(1)] - pub passport_firmware_version: PassportFirmwareVersion, - #[n(2)] - pub passport_serial: PassportSerial, - #[n(3)] - pub passport_color: PassportColor, - #[n(4)] - pub onboarding_complete: bool, - #[n(5)] - pub device_name: Option, -} - -#[quantum_link] -pub struct PairingRequest { - #[n(0)] - pub xid_document: Vec, - #[n(1)] - pub device_name: String, -} - -#[quantum_link] -pub struct UnpairingRequest {} - -#[quantum_link] -pub struct UnpairingResponse { - #[n(0)] - pub success: bool, -} diff --git a/api/src/api/passport.rs b/api/src/api/passport.rs deleted file mode 100644 index 84c2a78..0000000 --- a/api/src/api/passport.rs +++ /dev/null @@ -1,25 +0,0 @@ -use quantum_link_macros::quantum_link; - -#[quantum_link] -pub enum PassportModel { - #[n(0)] - Gen1, - #[n(1)] - Gen2, - #[n(2)] - Prime, -} - -#[quantum_link] -pub struct PassportFirmwareVersion(pub String); - -#[quantum_link] -pub struct PassportSerial(pub String); - -#[quantum_link] -pub enum PassportColor { - #[n(0)] - Light, - #[n(1)] - Dark, -} diff --git a/api/src/api/quantum_link.rs b/api/src/api/quantum_link.rs deleted file mode 100644 index 21b187f..0000000 --- a/api/src/api/quantum_link.rs +++ /dev/null @@ -1,431 +0,0 @@ -use std::time::Duration; - -use bc_components::{EncapsulationScheme, PrivateKeys, PublicKeys, SignatureScheme, ARID}; -use bc_envelope::{ - prelude::{CBORCase, CBOR}, - Envelope, EventBehavior, Expression, ExpressionBehavior, Function, -}; -use bc_xid::XIDDocument; -use chrono::{DateTime, Utc}; -use dcbor::Date; -use gstp::{SealedEvent, SealedEventBehavior}; - -use crate::message::{EnvoyMessage, PassportMessage}; - -pub const QUANTUM_LINK: Function = Function::new_static_named("quantumLink"); -pub const EXPIRATION_DURATION: Duration = Duration::from_secs(60); - -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum ReplayCheck { - Fresh, - Replay, - Expired, -} - -/// Storage for tracking received ARIDs to prevent replay attacks -#[derive(Debug, Default, Clone)] -pub struct ARIDCache { - cache: Vec<(ARID, DateTime)>, -} - -impl ARIDCache { - pub fn new() -> Self { - Self { cache: Vec::new() } - } - - /// Check if ARID has been seen before and store it until the event expires. - pub fn check_and_store( - &mut self, - arid: &ARID, - expires_at: DateTime, - now: DateTime, - ) -> ReplayCheck { - // Clean up expired entries first - self.cache.retain(|(_, expires_at)| now < *expires_at); - - if now >= expires_at { - return ReplayCheck::Expired; - } - - // Check if ARID already exists (replay attack) - if self.cache.iter().any(|(id, _)| id == arid) { - return ReplayCheck::Replay; - } - - self.cache.push((*arid, expires_at)); - ReplayCheck::Fresh - } - - /// Get the number of stored ARIDs - pub fn len(&self) -> usize { - self.cache.len() - } - - pub fn is_empty(&self) -> bool { - self.cache.len() == 0 - } - - /// Clear all stored ARIDs - pub fn clear(&mut self) { - self.cache.clear(); - } -} - -#[derive(Debug, thiserror::Error)] -pub enum QlError { - #[error(transparent)] - Cbor(#[from] dcbor::Error), - #[error(transparent)] - Envelope(#[from] bc_envelope::Error), - #[error(transparent)] - Gstp(#[from] gstp::Error), - - #[error("envelope did not contain leaf")] - NotLeaf, - #[error("missing date")] - MissingDate, - #[error("invalid function")] - InvalidFunction, - #[error("replay attack")] - ReplayAttack, - #[error("expired")] - Expired, - #[error("date too far in the future")] - FutureDated, -} - -pub trait QuantumLink: Into + TryFrom { - fn encode(self) -> Expression { - let cbor: CBOR = self.into(); - let envelope = Envelope::new(cbor); - Expression::new(QUANTUM_LINK).with_parameter("ql", envelope) - } - - fn decode(expression: &Expression) -> Result { - if expression.function() != &QUANTUM_LINK { - return Err(QlError::InvalidFunction); - } - let envelope = expression.object_for_parameter("ql")?; - let cbor = envelope.as_leaf().ok_or(QlError::NotLeaf)?; - - let message = Self::try_from(cbor)?; - Ok(message) - } - - fn seal( - self, - (sender_pk, sender_xid): (&PrivateKeys, &XIDDocument), - recipient: &XIDDocument, - ) -> Envelope { - let valid_until = Date::with_duration_from_now(EXPIRATION_DURATION); - - let event: SealedEvent = - SealedEvent::new(QuantumLink::encode(self), ARID::new(), sender_xid) - .with_date(&valid_until); - event - .to_envelope(Some(&valid_until), Some(sender_pk), Some(recipient)) - .unwrap() - } - - fn unseal( - envelope: &Envelope, - private_keys: &PrivateKeys, - ) -> Result<(Expression, XIDDocument), QlError> { - let now = Utc::now(); - let event: SealedEvent = - SealedEvent::try_from_envelope(envelope, None, Some(&Date::from(now)), private_keys)?; - let expires_at = event.date().ok_or(QlError::MissingDate)?.datetime(); - validate_expires_at(expires_at, now)?; - - let expression = event.content().clone(); - Ok((expression, event.sender().clone())) - } - - fn unseal_with_replay_check( - envelope: &Envelope, - private_keys: &PrivateKeys, - arid_cache: &mut ARIDCache, - ) -> Result<(Expression, XIDDocument), QlError> { - let now = Utc::now(); - let event: SealedEvent = - SealedEvent::try_from_envelope(envelope, None, Some(&Date::from(now)), private_keys)?; - - let arid = event.id(); - let expires_at = event.date().ok_or(QlError::MissingDate)?.datetime(); - validate_expires_at(expires_at, now)?; - - match arid_cache.check_and_store(&arid, expires_at, now) { - ReplayCheck::Fresh => {} - ReplayCheck::Replay => return Err(QlError::ReplayAttack), - ReplayCheck::Expired => return Err(QlError::Expired), - } - - let expression = event.content().clone(); - Ok((expression, event.sender().clone())) - } - - fn unseal_passport_message_with_replay_check( - envelope: &Envelope, - private_keys: &PrivateKeys, - arid_cache: &mut ARIDCache, - ) -> Result<(PassportMessage, XIDDocument), QlError> { - let (expression, sender) = - PassportMessage::unseal_with_replay_check(envelope, private_keys, arid_cache)?; - Ok((PassportMessage::decode(&expression)?, sender)) - } - - fn unseal_envoy_message_with_replay_check( - envelope: &Envelope, - private_keys: &PrivateKeys, - arid_cache: &mut ARIDCache, - ) -> Result<(EnvoyMessage, XIDDocument), QlError> { - let (expression, sender) = - EnvoyMessage::unseal_with_replay_check(envelope, private_keys, arid_cache)?; - Ok((EnvoyMessage::decode(&expression)?, sender)) - } -} - -impl QuantumLink for T where T: Into + TryFrom {} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "envoy", flutter_rust_bridge::frb(opaque))] -pub struct QuantumLinkIdentity { - pub private_keys: Option, - pub xid_document: XIDDocument, -} - -impl QuantumLinkIdentity { - pub fn generate() -> Self { - let (signing_private_key, signing_public_key) = SignatureScheme::MLDSA44.keypair(); - let (encapsulation_private_key, encapsulation_public_key) = - EncapsulationScheme::MLKEM512.keypair(); - - let private_keys = PrivateKeys::with_keys(signing_private_key, encapsulation_private_key); - let public_keys = PublicKeys::new(signing_public_key, encapsulation_public_key); - - let xid_document = XIDDocument::from(public_keys); - - QuantumLinkIdentity { - private_keys: Some(private_keys), - xid_document, - } - } - - pub fn to_bytes(&self) -> Vec { - let mut map = bc_envelope::prelude::Map::new(); - map.insert(CBOR::from("xid_document"), self.clone().xid_document); - if self.private_keys.is_some() { - map.insert( - CBOR::from("private_keys"), - self.clone().private_keys.unwrap(), - ); - } - - CBOR::from(map).to_cbor_data() - } - - pub fn from_bytes(bytes: &[u8]) -> dcbor::Result { - let cbor = CBOR::try_from_data(bytes)?; - let case = cbor.into_case(); - - let CBORCase::Map(map) = case else { - return Err(dcbor::Error::WrongType); - }; - - Ok(QuantumLinkIdentity { - xid_document: map.get("xid_document").ok_or(dcbor::Error::MissingMapKey)?, - private_keys: map.get("private_keys"), - }) - } -} - -fn expiration_duration() -> chrono::Duration { - chrono::Duration::from_std(EXPIRATION_DURATION).expect("expiration duration must fit chrono") -} - -fn validate_expires_at(expires_at: DateTime, now: DateTime) -> Result<(), QlError> { - if now >= expires_at { - return Err(QlError::Expired); - } - - if expires_at > now + expiration_duration() { - return Err(QlError::FutureDated); - } - - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - api::{ - message::{QuantumLinkMessage, PROTOCOL_VERSION}, - quantum_link::QuantumLink, - }, - fx::ExchangeRate, - message::EnvoyMessage, - quantum_link::{ARIDCache, QlError, QuantumLinkIdentity}, - }; - - #[test] - fn accepts_fresh_and_rejects_immediate_replay() { - let envoy = QuantumLinkIdentity::generate(); - let passport = QuantumLinkIdentity::generate(); - let mut arid_cache = ARIDCache::new(); - - let original_message = exchange_rate_envoy_message(); - let envelope = QuantumLink::seal( - original_message.clone(), - (envoy.private_keys.as_ref().unwrap(), &envoy.xid_document), - &passport.xid_document, - ); - - let (decoded, _sender) = EnvoyMessage::unseal_envoy_message_with_replay_check( - &envelope, - &passport.private_keys.clone().unwrap(), - &mut arid_cache, - ) - .unwrap(); - assert_exchange_rate_matches(&original_message, &decoded); - - let result2 = EnvoyMessage::unseal_envoy_message_with_replay_check( - &envelope, - &passport.private_keys.unwrap(), - &mut arid_cache, - ); - assert!(matches!(result2, Err(QlError::ReplayAttack))); - } - - #[test] - fn rejects_expired_envelope() { - let envoy = QuantumLinkIdentity::generate(); - let passport = QuantumLinkIdentity::generate(); - let mut arid_cache = ARIDCache::new(); - let expired_at = Utc::now() - chrono::Duration::seconds(1); - - let envelope = seal_envoy_message_with_expiration( - exchange_rate_envoy_message(), - &envoy, - &passport, - expired_at, - ); - - let replay_checked_result = EnvoyMessage::unseal_envoy_message_with_replay_check( - &envelope, - &passport.private_keys.unwrap(), - &mut arid_cache, - ); - assert!(matches!(replay_checked_result, Err(QlError::Expired))); - } - - #[test] - fn rejects_future_dated_envelope() { - let envoy = QuantumLinkIdentity::generate(); - let passport = QuantumLinkIdentity::generate(); - let mut arid_cache = ARIDCache::new(); - let expires_at = Utc::now() + expiration_duration() + chrono::Duration::seconds(1); - - let envelope = seal_envoy_message_with_expiration( - exchange_rate_envoy_message(), - &envoy, - &passport, - expires_at, - ); - - let result = EnvoyMessage::unseal_envoy_message_with_replay_check( - &envelope, - &passport.private_keys.unwrap(), - &mut arid_cache, - ); - assert!(matches!(result, Err(QlError::FutureDated))); - } - - #[test] - fn arid_cache_reports_replay_and_expiration() { - let mut cache = ARIDCache::new(); - let arid1 = ARID::new(); - let arid2 = ARID::new(); - - let start = chrono::Utc::now(); - let expires_at = start + expiration_duration(); - - assert_eq!( - cache.check_and_store(&arid1, expires_at, start), - ReplayCheck::Fresh - ); - assert_eq!( - cache.check_and_store(&arid1, expires_at, start), - ReplayCheck::Replay - ); - - let after_expiration = expires_at + chrono::Duration::seconds(1); - assert_eq!( - cache.check_and_store(&arid1, expires_at, after_expiration), - ReplayCheck::Expired - ); - - assert_eq!( - cache.check_and_store( - &arid2, - after_expiration + expiration_duration(), - after_expiration, - ), - ReplayCheck::Fresh - ); - assert_eq!(cache.len(), 1); - assert!(!cache.cache.iter().any(|(id, _)| id == &arid1)); - } - - fn exchange_rate_envoy_message() -> EnvoyMessage { - let fx_rate = ExchangeRate { - currency_code: String::from("USD"), - rate: 0.85, - timestamp: 0, - }; - - EnvoyMessage { - message: QuantumLinkMessage::ExchangeRate(fx_rate), - timestamp: 123456, - protocol_version: Some(PROTOCOL_VERSION), - } - } - - fn assert_exchange_rate_matches(expected: &EnvoyMessage, actual: &EnvoyMessage) { - let expected_rate = match &expected.message { - QuantumLinkMessage::ExchangeRate(rate) => rate, - _ => panic!("Expected ExchangeRate message"), - }; - let actual_rate = match &actual.message { - QuantumLinkMessage::ExchangeRate(rate) => rate, - _ => panic!("Expected ExchangeRate message"), - }; - - assert_eq!(actual.timestamp, expected.timestamp); - assert_eq!(actual.protocol_version, expected.protocol_version); - assert_eq!(actual_rate.rate, expected_rate.rate); - } - - fn seal_envoy_message_with_expiration( - message: EnvoyMessage, - sender: &QuantumLinkIdentity, - recipient: &QuantumLinkIdentity, - expires_at: DateTime, - ) -> Envelope { - let valid_until = Date::from(expires_at); - - let event: SealedEvent = SealedEvent::new( - QuantumLink::encode(message), - ARID::new(), - &sender.xid_document, - ) - .with_date(&valid_until); - event - .to_envelope( - Some(&valid_until), - Some(sender.private_keys.as_ref().unwrap()), - Some(&recipient.xid_document), - ) - .unwrap() - } -} diff --git a/api/src/api/scv.rs b/api/src/api/scv.rs deleted file mode 100644 index 0cec662..0000000 --- a/api/src/api/scv.rs +++ /dev/null @@ -1,52 +0,0 @@ -use quantum_link_macros::quantum_link; - -#[quantum_link] -pub enum SecurityCheck { - // Envoy to Prime: Initial challenge - #[n(0)] - ChallengeRequest(ChallengeRequest), - - // Prime to Envoy: Response to the challenge - #[n(1)] - ChallengeResponse(ChallengeResponseResult), - - // Envoy to Prime: Verification result - // only send if ChallengeResponse was successful - #[n(2)] - VerificationResult(VerificationResult), -} - -#[quantum_link] -pub struct ChallengeRequest { - #[n(0)] - pub data: Vec, -} - -#[quantum_link] -pub enum ChallengeResponseResult { - #[n(0)] - Success { - #[n(0)] - data: Vec, - }, - #[n(1)] - Error { - #[n(0)] - error: String, - }, -} - -#[quantum_link] -pub enum VerificationResult { - #[n(0)] - Success, - // Error due to Envoy not being able to perform the verification - #[n(1)] - Error { - #[n(0)] - error: String, - }, - // Actual failure indicating device has been tampered with - #[n(2)] - Failure, -} diff --git a/api/src/api/status.rs b/api/src/api/status.rs deleted file mode 100644 index f564cec..0000000 --- a/api/src/api/status.rs +++ /dev/null @@ -1,35 +0,0 @@ -use quantum_link_macros::quantum_link; - -#[quantum_link] -pub struct DeviceStatus { - #[n(0)] - pub version: String, - #[n(1)] - pub battery_level: u8, -} - -#[quantum_link] -pub struct EnvoyStatus { - #[n(0)] - pub version: String, -} - -#[quantum_link] -pub struct Heartbeat {} - -#[quantum_link] -pub struct TimezoneRequest {} - -#[quantum_link] -pub struct TimezoneResponse { - #[n(0)] - pub offset_minutes: i32, - #[n(1)] - pub zone: String, -} - -#[quantum_link] -pub struct DeviceNameUpdate { - #[n(0)] - pub device_name: String, -} diff --git a/api/src/api/tests.rs b/api/src/api/tests.rs deleted file mode 100644 index b8e635b..0000000 --- a/api/src/api/tests.rs +++ /dev/null @@ -1,414 +0,0 @@ -use dcbor::{CBORCase, CBOR}; -use quantum_link_macros::Cbor; - -#[derive(Debug, Clone, PartialEq, Cbor)] -pub struct TestStruct { - #[n(0)] - pub name: String, - #[n(1)] - pub value: u64, - #[n(2)] - pub enabled: bool, -} - -#[derive(Debug, Clone, PartialEq, Cbor)] -pub struct TestWithVec { - #[n(0)] - pub items: Vec, - #[n(1)] - pub label: String, -} - -#[derive(Debug, Clone, PartialEq, Cbor)] -pub struct TestWithArray { - #[n(0)] - pub hash: [u8; 32], - #[n(1)] - pub id: u64, -} - -#[derive(Debug, Clone, PartialEq, Cbor)] -pub enum TestEnumTuple { - #[n(0)] - First(TestStruct), - #[n(1)] - Second(TestWithVec), -} - -#[derive(Debug, Clone, PartialEq, Cbor)] -pub enum TestEnumStruct { - #[n(0)] - VariantA { - #[n(0)] - count: u64, - #[n(1)] - active: bool, - }, - #[n(1)] - VariantB { - #[n(0)] - message: String, - }, -} - -#[derive(Debug, Clone, PartialEq, Cbor)] -pub enum TestEnumUnit { - #[n(0)] - Empty, - #[n(1)] - WithData(TestStruct), -} - -#[derive(Debug, Clone, PartialEq, Cbor)] -pub enum TestEnumMixed { - #[n(0)] - Unit, - #[n(1)] - Tuple(TestStruct), - #[n(2)] - Struct { - #[n(0)] - field1: String, - #[n(1)] - field2: u64, - }, -} - -#[test] -fn struct_roundtrip() { - let original = TestStruct { - name: "test".to_string(), - value: 42, - enabled: true, - }; - - let cbor: CBOR = original.clone().into(); - let recovered: TestStruct = cbor.try_into().unwrap(); - - assert_eq!(original, recovered); -} - -#[test] -fn struct_with_vec_roundtrip() { - let original = TestWithVec { - items: vec![1, 2, 3, 4, 5], - label: "data".to_string(), - }; - - let cbor: CBOR = original.clone().into(); - let recovered: TestWithVec = cbor.try_into().unwrap(); - - assert_eq!(original, recovered); -} - -#[test] -fn struct_with_array_roundtrip() { - let original = TestWithArray { - hash: [ - 0xde, 0xad, 0xbe, 0xef, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, - 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, - 0x09, 0x0a, 0x0b, 0x0c, - ], - id: 12345, - }; - - let cbor: CBOR = original.clone().into(); - let recovered: TestWithArray = cbor.try_into().unwrap(); - - assert_eq!(original, recovered); -} - -#[test] -fn byte_string_encoding() { - let test = TestWithVec { - items: vec![1, 2, 3], - label: "test".to_string(), - }; - - let cbor: CBOR = test.into(); - let case = cbor.into_case(); - - match case { - CBORCase::Map(map) => { - let items_cbor: CBOR = map.get(0).unwrap(); - let items_case = items_cbor.into_case(); - assert!( - matches!(items_case, CBORCase::ByteString(_)), - "Vec should be encoded as byte string" - ); - } - _ => panic!("Expected CBOR map"), - } -} - -#[test] -fn array_byte_string_encoding() { - let test = TestWithArray { - hash: [0u8; 32], - id: 1, - }; - - let cbor: CBOR = test.into(); - let case = cbor.into_case(); - - match case { - CBORCase::Map(map) => { - let hash_cbor: CBOR = map.get(0).unwrap(); - let hash_case = hash_cbor.into_case(); - assert!( - matches!(hash_case, CBORCase::ByteString(_)), - "[u8; N] should be encoded as byte string" - ); - } - _ => panic!("Expected CBOR map"), - } -} - -#[test] -fn enum_tuple_roundtrip() { - let test_struct = TestStruct { - name: "inner".to_string(), - value: 100, - enabled: false, - }; - - let original = TestEnumTuple::First(test_struct); - let cbor: CBOR = original.clone().into(); - let recovered: TestEnumTuple = cbor.try_into().unwrap(); - - assert_eq!(original, recovered); - - let test_vec = TestWithVec { - items: vec![255, 128, 0], - label: "bytes".to_string(), - }; - - let original2 = TestEnumTuple::Second(test_vec); - let cbor2: CBOR = original2.clone().into(); - let recovered2: TestEnumTuple = cbor2.try_into().unwrap(); - - assert_eq!(original2, recovered2); -} - -#[test] -fn enum_struct_roundtrip() { - let original_a = TestEnumStruct::VariantA { - count: 999, - active: true, - }; - - let cbor_a: CBOR = original_a.clone().into(); - let recovered_a: TestEnumStruct = cbor_a.try_into().unwrap(); - - assert_eq!(original_a, recovered_a); - - let original_b = TestEnumStruct::VariantB { - message: "hello world".to_string(), - }; - - let cbor_b: CBOR = original_b.clone().into(); - let recovered_b: TestEnumStruct = cbor_b.try_into().unwrap(); - - assert_eq!(original_b, recovered_b); -} - -#[test] -fn enum_unit_roundtrip() { - let original_empty = TestEnumUnit::Empty; - let cbor: CBOR = original_empty.clone().into(); - let recovered: TestEnumUnit = cbor.try_into().unwrap(); - - assert_eq!(original_empty, recovered); - - let test_struct = TestStruct { - name: "with data".to_string(), - value: 123, - enabled: true, - }; - - let original_with_data = TestEnumUnit::WithData(test_struct); - let cbor2: CBOR = original_with_data.clone().into(); - let recovered2: TestEnumUnit = cbor2.try_into().unwrap(); - - assert_eq!(original_with_data, recovered2); -} - -#[test] -fn enum_mixed_roundtrip() { - let unit = TestEnumMixed::Unit; - let cbor: CBOR = unit.clone().into(); - let recovered: TestEnumMixed = cbor.try_into().unwrap(); - assert_eq!(unit, recovered); - - let tuple = TestEnumMixed::Tuple(TestStruct { - name: "tuple".to_string(), - value: 50, - enabled: false, - }); - let cbor: CBOR = tuple.clone().into(); - let recovered: TestEnumMixed = cbor.try_into().unwrap(); - assert_eq!(tuple, recovered); - - let struct_var = TestEnumMixed::Struct { - field1: "struct variant".to_string(), - field2: 9999, - }; - let cbor: CBOR = struct_var.clone().into(); - let recovered: TestEnumMixed = cbor.try_into().unwrap(); - assert_eq!(struct_var, recovered); -} - -#[test] -fn cbor_structure() { - let test = TestStruct { - name: "check".to_string(), - value: 7, - enabled: true, - }; - - let cbor: CBOR = test.into(); - let case = cbor.into_case(); - - match case { - CBORCase::Map(map) => { - assert_eq!(map.len(), 3); - - assert!(map.get::(0).is_some()); - assert!(map.get::(1).is_some()); - assert!(map.get::(2).is_some()); - } - _ => panic!("Expected CBOR map"), - } -} - -#[test] -fn enum_cbor_structure() { - let test_struct = TestStruct { - name: "test".to_string(), - value: 1, - enabled: true, - }; - - let variant = TestEnumTuple::First(test_struct); - let cbor: CBOR = variant.into(); - let case = cbor.into_case(); - - match case { - CBORCase::Array(arr) => { - assert_eq!(arr.len(), 2); - - let index: u64 = arr.first().unwrap().clone().try_into().unwrap(); - assert_eq!(index, 0); - } - _ => panic!("Expected CBOR array for enum"), - } -} - -#[test] -fn enum_tuple_vs_struct_encoding() { - #[derive(Debug, Clone, PartialEq, Cbor)] - pub struct InnerData { - #[n(0)] - pub count: u64, - #[n(1)] - pub active: bool, - } - - #[derive(Debug, Clone, PartialEq, Cbor)] - pub enum EnumWithTupleStruct { - #[n(0)] - Variant(InnerData), - } - - #[derive(Debug, Clone, PartialEq, Cbor)] - pub enum EnumWithStructFields { - #[n(0)] - Variant { - #[n(0)] - count: u64, - #[n(1)] - active: bool, - }, - } - - let tuple_enum = EnumWithTupleStruct::Variant(InnerData { - count: 42, - active: true, - }); - - let struct_enum = EnumWithStructFields::Variant { - count: 42, - active: true, - }; - - let tuple_cbor: CBOR = tuple_enum.into(); - let struct_cbor: CBOR = struct_enum.into(); - - let tuple_bytes = tuple_cbor.to_cbor_data(); - let struct_bytes = struct_cbor.to_cbor_data(); - - assert_eq!( - tuple_bytes, struct_bytes, - "Enum with tuple(struct) should serialize the same as enum with struct fields" - ); -} - -#[test] -fn newtype() { - #[derive(Debug, Clone, PartialEq, Cbor)] - struct NewType(String); - - let value = NewType(String::from("yes")); - let cbor: CBOR = value.clone().into(); - let case = cbor.clone().into_case(); - - match case { - CBORCase::Text(_) => {} - _ => panic!("invalid case"), - } - - assert_eq!(value, NewType::try_from(cbor).unwrap()) -} - -#[test] -fn option_array() { - #[derive(Debug, Clone, PartialEq, Cbor)] - struct OptionArray { - #[n(0)] - arr: Option<[u8; 10]>, - #[n(1)] - vec: Option>, - } - - let a = [10; 10]; - let b = vec![12; 4]; - let value = OptionArray { - arr: Some(a), - vec: Some(b.clone()), - }; - let cbor: CBOR = value.clone().into(); - let case = cbor.clone().into_case(); - - match case { - CBORCase::Map(map) => { - assert_eq!(map.len(), 2); - let arr: CBOR = map.get(0).unwrap(); - match arr.into_case() { - CBORCase::ByteString(bytes) => { - assert_eq!(bytes.data(), &a) - } - _ => panic!("expected bytestring"), - } - let vec: CBOR = map.get(1).unwrap(); - match vec.into_case() { - CBORCase::ByteString(bytes) => { - assert_eq!(bytes.data(), &b) - } - _ => panic!("expected bytestring"), - } - } - _ => panic!("Expected CBOR array for enum"), - } - - assert_eq!(value, OptionArray::try_from(cbor).unwrap()) -} diff --git a/api/src/lib.rs b/api/src/lib.rs deleted file mode 100644 index e6ecd81..0000000 --- a/api/src/lib.rs +++ /dev/null @@ -1,11 +0,0 @@ -pub mod api; -pub use api::*; - -/// Marker trait for types that have a Cbor derive (structs and enums, not primitives). -/// This is used to enforce that enum tuple variants wrap Cbor-derived types. -pub(crate) trait CborMarker {} - -pub use bc_components; -pub use bc_envelope; -pub use bc_xid; -pub use dcbor; diff --git a/api/tests/golden_tests.rs b/api/tests/golden_tests.rs deleted file mode 100644 index e240c9e..0000000 --- a/api/tests/golden_tests.rs +++ /dev/null @@ -1,550 +0,0 @@ -//! golden/snapshot tests for QuantumLinkMessage codec -//! -//! to update snapshots when serialization intentionally changes: -//! ``` -//! INSTA_UPDATE=always cargo test -//! ``` - -use dcbor::CBOR; -use foundation_api::{ - backup::*, bitcoin::*, firmware::*, fx::*, message::*, onboarding::*, pairing::*, passport::*, - scv::*, status::*, -}; - -/// convert a message to hex-encoded CBOR bytes -fn to_hex(message: &QuantumLinkMessage) -> String { - let cbor: CBOR = message.clone().into(); - let bytes = cbor.to_cbor_data(); - bytes.iter().map(|b| format!("{b:02x}")).collect::() -} - -/// decode hex-encoded CBOR bytes back to a message -fn from_hex(hex: &str) -> QuantumLinkMessage { - let bytes: Vec = (0..hex.len()) - .step_by(2) - .map(|i| u8::from_str_radix(&hex[i..i + 2], 16).unwrap()) - .collect(); - let cbor = CBOR::try_from_data(&bytes).unwrap(); - QuantumLinkMessage::try_from(cbor).unwrap() -} - -macro_rules! assert_golden { - ($message:expr) => {{ - let message = $message; - let hex = to_hex(&message); - insta::assert_snapshot!(hex.clone()); - - let decoded = from_hex(&hex); - assert_eq!(message, decoded, "roundtrip decode failed"); - }}; -} - -#[test] -fn golden_exchange_rate() { - assert_golden!(QuantumLinkMessage::ExchangeRate(ExchangeRate { - currency_code: "USD".to_string(), - rate: 42_000.5, - timestamp: 1700000000, - })); -} - -#[test] -fn golden_exchange_rate_history() { - assert_golden!(QuantumLinkMessage::ExchangeRateHistory( - ExchangeRateHistory { - history: vec![ - PricePoint { - rate: 41000.0, - timestamp: 1699999900, - }, - PricePoint { - rate: 42000.0, - timestamp: 1700000000, - }, - ], - currency_code: "EUR".to_string(), - } - )); -} - -#[test] -fn golden_firmware_update_check_request() { - assert_golden!(QuantumLinkMessage::FirmwareUpdateCheckRequest( - FirmwareUpdateCheckRequest { - current_version: "2.4.0".to_string(), - }, - )); -} - -#[test] -fn golden_firmware_update_check_response_available() { - assert_golden!(QuantumLinkMessage::FirmwareUpdateCheckResponse( - FirmwareUpdateCheckResponse::Available(FirmwareUpdateAvailable { - version: "2.5.0".to_string(), - changelog: "Bug fixes".to_string(), - timestamp: 1700000000, - total_size: 1024000, - patch_count: 3, - }), - )); -} - -#[test] -fn golden_firmware_update_check_response_not_available() { - assert_golden!(QuantumLinkMessage::FirmwareUpdateCheckResponse( - FirmwareUpdateCheckResponse::NotAvailable, - )); -} - -#[test] -fn golden_firmware_fetch_request() { - assert_golden!(QuantumLinkMessage::FirmwareFetchRequest( - FirmwareFetchRequest { - current_version: "2.4.0".to_string(), - chunk_offset: None - }, - )); -} - -#[test] -fn golden_firmware_fetch_event_not_available() { - assert_golden!(QuantumLinkMessage::FirmwareFetchEvent( - FirmwareFetchEvent::UpdateNotAvailable, - )); -} - -#[test] -fn golden_firmware_fetch_event_starting() { - assert_golden!(QuantumLinkMessage::FirmwareFetchEvent( - FirmwareFetchEvent::Starting(FirmwareUpdateAvailable { - version: "2.5.0".to_string(), - changelog: "New features".to_string(), - timestamp: 1700000000, - total_size: 2048000, - patch_count: 5, - }), - )); -} - -#[test] -fn golden_firmware_fetch_event_downloading() { - assert_golden!(QuantumLinkMessage::FirmwareFetchEvent( - FirmwareFetchEvent::Downloading, - )); -} - -#[test] -fn golden_firmware_fetch_event_chunk() { - assert_golden!(QuantumLinkMessage::FirmwareFetchEvent( - FirmwareFetchEvent::Chunk(FirmwareChunk { - patch_index: 0, - total_patches: 3, - chunk_index: 5, - total_chunks: 100, - data: vec![0xde, 0xad, 0xbe, 0xef], - }), - )); -} - -#[test] -fn golden_firmware_fetch_event_error() { - assert_golden!(QuantumLinkMessage::FirmwareFetchEvent( - FirmwareFetchEvent::Error { - error: "Download failed".to_string(), - }, - )); -} - -#[test] -fn golden_firmware_update_result_update_verified() { - assert_golden!(QuantumLinkMessage::FirmwareInstallEvent( - FirmwareInstallEvent::UpdateVerified, - )); -} - -#[test] -fn golden_firmware_update_result_installing() { - assert_golden!(QuantumLinkMessage::FirmwareInstallEvent( - FirmwareInstallEvent::Installing, - )); -} - -#[test] -fn golden_firmware_update_result_rebooting() { - assert_golden!(QuantumLinkMessage::FirmwareInstallEvent( - FirmwareInstallEvent::Rebooting, - )); -} - -#[test] -fn golden_firmware_update_result_success() { - assert_golden!(QuantumLinkMessage::FirmwareInstallEvent( - FirmwareInstallEvent::Success { - installed_version: "2.5.0".to_string(), - }, - )); -} - -#[test] -fn golden_firmware_update_result_error_verify() { - assert_golden!(QuantumLinkMessage::FirmwareInstallEvent( - FirmwareInstallEvent::Error { - error: "Signature verification failed".to_string(), - stage: InstallErrorStage::Verify, - }, - )); -} - -#[test] -fn golden_firmware_update_result_error_install() { - assert_golden!(QuantumLinkMessage::FirmwareInstallEvent( - FirmwareInstallEvent::Error { - error: "Installation failed".to_string(), - stage: InstallErrorStage::Install, - }, - )); -} - -#[test] -fn golden_device_status() { - assert_golden!(QuantumLinkMessage::DeviceStatus(DeviceStatus { - battery_level: 85, - version: "2.4.0".to_string(), - })); -} - -#[test] -fn golden_device_status_updating() { - assert_golden!(QuantumLinkMessage::DeviceStatus(DeviceStatus { - battery_level: 90, - version: "2.4.0".to_string(), - })); -} - -#[test] -fn golden_envoy_status() { - assert_golden!(QuantumLinkMessage::EnvoyStatus(EnvoyStatus { - version: "1.0.0".to_string(), - })); -} - -#[test] -fn golden_pairing_request() { - assert_golden!(QuantumLinkMessage::PairingRequest(PairingRequest { - xid_document: vec![0x01, 0x02, 0x03, 0x04], - device_name: "My iPhone".to_string(), - })); -} - -#[test] -fn golden_pairing_response() { - assert_golden!(QuantumLinkMessage::PairingResponse(PairingResponse { - passport_model: PassportModel::Prime, - passport_firmware_version: PassportFirmwareVersion("2.4.0".to_string()), - passport_serial: PassportSerial("ABC123".to_string()), - passport_color: PassportColor::Dark, - onboarding_complete: true, - device_name: Some("Passport Prime".to_string()), - })); -} - -#[test] -fn golden_onboarding_state_firmware_update_screen() { - assert_golden!(QuantumLinkMessage::OnboardingState( - OnboardingState::FirmwareUpdateScreen, - )); -} - -#[test] -fn golden_onboarding_state_completed() { - assert_golden!(QuantumLinkMessage::OnboardingState( - OnboardingState::Completed, - )); -} - -#[test] -fn golden_sign_psbt() { - assert_golden!(QuantumLinkMessage::SignPsbt(SignPsbt { - account_id: "account-1".to_string(), - psbt: vec![0x70, 0x73, 0x62, 0x74, 0xff], - })); -} - -#[test] -fn golden_broadcast_transaction() { - assert_golden!(QuantumLinkMessage::BroadcastTransaction( - BroadcastTransaction { - account_id: "account-1".to_string(), - psbt: vec![0x70, 0x73, 0x62, 0x74, 0xff], - }, - )); -} - -#[test] -fn golden_account_update() { - assert_golden!(QuantumLinkMessage::AccountUpdate(AccountUpdate { - account_id: "account-1".to_string(), - update: vec![0x01, 0x02, 0x03], - })); -} - -#[test] -fn golden_apply_passphrase_some() { - assert_golden!(QuantumLinkMessage::ApplyPassphrase(ApplyPassphrase { - fingerprint: Some("abc123".to_string()), - })); -} - -#[test] -fn golden_apply_passphrase_none() { - assert_golden!(QuantumLinkMessage::ApplyPassphrase(ApplyPassphrase { - fingerprint: None, - })); -} - -#[test] -fn golden_security_check_challenge_request() { - assert_golden!(QuantumLinkMessage::SecurityCheck( - SecurityCheck::ChallengeRequest(ChallengeRequest { - data: vec![0xca, 0xfe, 0xba, 0xbe], - }), - )); -} - -#[test] -fn golden_security_check_challenge_response_success() { - assert_golden!(QuantumLinkMessage::SecurityCheck( - SecurityCheck::ChallengeResponse(ChallengeResponseResult::Success { - data: vec![0xde, 0xad, 0xbe, 0xef], - }), - )); -} - -#[test] -fn golden_security_check_challenge_response_error() { - assert_golden!(QuantumLinkMessage::SecurityCheck( - SecurityCheck::ChallengeResponse(ChallengeResponseResult::Error { - error: "Invalid signature".to_string(), - }), - )); -} - -#[test] -fn golden_security_check_verification_success() { - assert_golden!(QuantumLinkMessage::SecurityCheck( - SecurityCheck::VerificationResult(VerificationResult::Success), - )); -} - -#[test] -fn golden_security_check_verification_error() { - assert_golden!(QuantumLinkMessage::SecurityCheck( - SecurityCheck::VerificationResult(VerificationResult::Error { - error: "Verification failed".to_string(), - }), - )); -} - -#[test] -fn golden_envoy_magic_backup_enabled_request() { - assert_golden!(QuantumLinkMessage::EnvoyMagicBackupEnabledRequest( - EnvoyMagicBackupEnabledRequest {}, - )); -} - -#[test] -fn golden_envoy_magic_backup_enabled_response() { - assert_golden!(QuantumLinkMessage::EnvoyMagicBackupEnabledResponse( - EnvoyMagicBackupEnabledResponse { enabled: true }, - )); -} - -#[test] -fn golden_prime_magic_backup_enabled() { - assert_golden!(QuantumLinkMessage::PrimeMagicBackupEnabled( - PrimeMagicBackupEnabled { - enabled: true, - seed_fingerprint: SeedFingerprint([0x42; 32]), - }, - )); -} - -#[test] -fn golden_prime_magic_backup_status_request() { - assert_golden!(QuantumLinkMessage::PrimeMagicBackupStatusRequest( - PrimeMagicBackupStatusRequest { - seed_fingerprint: SeedFingerprint([0xab; 32]), - timestamp: None, - }, - )); -} - -#[test] -fn golden_prime_magic_backup_status_response() { - assert_golden!(QuantumLinkMessage::PrimeMagicBackupStatusResponse( - PrimeMagicBackupStatusResponse { - shard_backup_found: true, - }, - )); -} - -#[test] -fn golden_backup_shard_request() { - assert_golden!(QuantumLinkMessage::BackupShardRequest(BackupShardRequest { - shard: Shard(vec![0x01, 0x02, 0x03, 0x04, 0x05]), - })); -} - -#[test] -fn golden_backup_shard_response_success() { - assert_golden!(QuantumLinkMessage::BackupShardResponse( - BackupShardResponse::Success, - )); -} - -#[test] -fn golden_backup_shard_response_error() { - assert_golden!(QuantumLinkMessage::BackupShardResponse( - BackupShardResponse::Error { - error: "Storage full".to_string(), - }, - )); -} - -#[test] -fn golden_restore_shard_request() { - assert_golden!(QuantumLinkMessage::RestoreShardRequest( - RestoreShardRequest { - seed_fingerprint: SeedFingerprint([0xcd; 32]), - timestamp: None, - }, - )); -} - -#[test] -fn golden_restore_shard_response_success() { - assert_golden!(QuantumLinkMessage::RestoreShardResponse( - RestoreShardResponse::Success { - shard: Shard(vec![0x0a, 0x0b, 0x0c]), - }, - )); -} - -#[test] -fn golden_restore_shard_response_error() { - assert_golden!(QuantumLinkMessage::RestoreShardResponse( - RestoreShardResponse::Error { - error: "Not found".to_string(), - }, - )); -} - -#[test] -fn golden_restore_shard_response_not_found() { - assert_golden!(QuantumLinkMessage::RestoreShardResponse( - RestoreShardResponse::NotFound, - )); -} - -#[test] -fn golden_create_magic_backup_event_start() { - assert_golden!(QuantumLinkMessage::CreateMagicBackupEvent( - CreateMagicBackupEvent::Start(StartMagicBackup { - seed_fingerprint: SeedFingerprint([0xef; 32]), - total_chunks: 100, - hash: [0xaa; 32], - }), - )); -} - -#[test] -fn golden_create_magic_backup_event_chunk() { - assert_golden!(QuantumLinkMessage::CreateMagicBackupEvent( - CreateMagicBackupEvent::Chunk(BackupChunk { - chunk_index: 5, - total_chunks: 100, - data: vec![0x11, 0x22, 0x33], - }), - )); -} - -#[test] -fn golden_create_magic_backup_result_success() { - assert_golden!(QuantumLinkMessage::CreateMagicBackupResult( - CreateMagicBackupResult::Success, - )); -} - -#[test] -fn golden_create_magic_backup_result_error() { - assert_golden!(QuantumLinkMessage::CreateMagicBackupResult( - CreateMagicBackupResult::Error { - error: "Upload failed".to_string(), - }, - )); -} - -#[test] -fn golden_restore_magic_backup_request() { - assert_golden!(QuantumLinkMessage::RestoreMagicBackupRequest( - RestoreMagicBackupRequest { - seed_fingerprint: SeedFingerprint([0xbb; 32]), - resume_from_chunk: 50, - }, - )); -} - -#[test] -fn golden_restore_magic_backup_event_no_backup() { - assert_golden!(QuantumLinkMessage::RestoreMagicBackupEvent( - RestoreMagicBackupEvent::NotFound, - )); -} - -#[test] -fn golden_restore_magic_backup_event_starting() { - assert_golden!(QuantumLinkMessage::RestoreMagicBackupEvent( - RestoreMagicBackupEvent::Starting(BackupMetadata { total_chunks: 200 }), - )); -} - -#[test] -fn golden_restore_magic_backup_event_chunk() { - assert_golden!(QuantumLinkMessage::RestoreMagicBackupEvent( - RestoreMagicBackupEvent::Chunk(BackupChunk { - chunk_index: 10, - total_chunks: 50, - data: vec![0xaa, 0xbb, 0xcc, 0xdd], - }), - )); -} - -#[test] -fn golden_restore_magic_backup_event_error() { - assert_golden!(QuantumLinkMessage::RestoreMagicBackupEvent( - RestoreMagicBackupEvent::Error { - error: "Network error".to_string(), - }, - )); -} - -#[test] -fn golden_restore_magic_backup_result_success() { - assert_golden!(QuantumLinkMessage::RestoreMagicBackupResult( - RestoreMagicBackupResult::Success, - )); -} - -#[test] -fn golden_restore_magic_backup_result_error() { - assert_golden!(QuantumLinkMessage::RestoreMagicBackupResult( - RestoreMagicBackupResult::Error { - error: "Checksum mismatch".to_string(), - }, - )); -} - -#[test] -fn golden_heartbeat() { - assert_golden!(QuantumLinkMessage::Heartbeat(Heartbeat {})) -} diff --git a/api/tests/snapshots/golden_tests__golden_account_update.snap b/api/tests/snapshots/golden_tests__golden_account_update.snap deleted file mode 100644 index c4158ee..0000000 --- a/api/tests/snapshots/golden_tests__golden_account_update.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -820fa200696163636f756e742d310143010203 diff --git a/api/tests/snapshots/golden_tests__golden_apply_passphrase_none.snap b/api/tests/snapshots/golden_tests__golden_apply_passphrase_none.snap deleted file mode 100644 index d18af85..0000000 --- a/api/tests/snapshots/golden_tests__golden_apply_passphrase_none.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8210a0 diff --git a/api/tests/snapshots/golden_tests__golden_apply_passphrase_some.snap b/api/tests/snapshots/golden_tests__golden_apply_passphrase_some.snap deleted file mode 100644 index fe18e39..0000000 --- a/api/tests/snapshots/golden_tests__golden_apply_passphrase_some.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8210a10066616263313233 diff --git a/api/tests/snapshots/golden_tests__golden_backup_shard_request.snap b/api/tests/snapshots/golden_tests__golden_backup_shard_request.snap deleted file mode 100644 index f77f942..0000000 --- a/api/tests/snapshots/golden_tests__golden_backup_shard_request.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8216a100450102030405 diff --git a/api/tests/snapshots/golden_tests__golden_backup_shard_response_error.snap b/api/tests/snapshots/golden_tests__golden_backup_shard_response_error.snap deleted file mode 100644 index 736377b..0000000 --- a/api/tests/snapshots/golden_tests__golden_backup_shard_response_error.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82178201a1006c53746f726167652066756c6c diff --git a/api/tests/snapshots/golden_tests__golden_backup_shard_response_success.snap b/api/tests/snapshots/golden_tests__golden_backup_shard_response_success.snap deleted file mode 100644 index 18b44be..0000000 --- a/api/tests/snapshots/golden_tests__golden_backup_shard_response_success.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82178100 diff --git a/api/tests/snapshots/golden_tests__golden_broadcast_transaction.snap b/api/tests/snapshots/golden_tests__golden_broadcast_transaction.snap deleted file mode 100644 index f240d2d..0000000 --- a/api/tests/snapshots/golden_tests__golden_broadcast_transaction.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -820ea200696163636f756e742d31014570736274ff diff --git a/api/tests/snapshots/golden_tests__golden_create_magic_backup_event_chunk.snap b/api/tests/snapshots/golden_tests__golden_create_magic_backup_event_chunk.snap deleted file mode 100644 index bfe41e5..0000000 --- a/api/tests/snapshots/golden_tests__golden_create_magic_backup_event_chunk.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82181a8201a300050118640243112233 diff --git a/api/tests/snapshots/golden_tests__golden_create_magic_backup_event_start.snap b/api/tests/snapshots/golden_tests__golden_create_magic_backup_event_start.snap deleted file mode 100644 index adc08fb..0000000 --- a/api/tests/snapshots/golden_tests__golden_create_magic_backup_event_start.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82181a8200a3005820efefefefefefefefefefefefefefefefefefefefefefefefefefefefefefefef011864025820aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa diff --git a/api/tests/snapshots/golden_tests__golden_create_magic_backup_result_error.snap b/api/tests/snapshots/golden_tests__golden_create_magic_backup_result_error.snap deleted file mode 100644 index 7aebdbc..0000000 --- a/api/tests/snapshots/golden_tests__golden_create_magic_backup_result_error.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82181b8201a1006d55706c6f6164206661696c6564 diff --git a/api/tests/snapshots/golden_tests__golden_create_magic_backup_result_success.snap b/api/tests/snapshots/golden_tests__golden_create_magic_backup_result_success.snap deleted file mode 100644 index 0c21e2e..0000000 --- a/api/tests/snapshots/golden_tests__golden_create_magic_backup_result_success.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82181b8100 diff --git a/api/tests/snapshots/golden_tests__golden_device_status.snap b/api/tests/snapshots/golden_tests__golden_device_status.snap deleted file mode 100644 index 733cceb..0000000 --- a/api/tests/snapshots/golden_tests__golden_device_status.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8207a20065322e342e30011855 diff --git a/api/tests/snapshots/golden_tests__golden_device_status_updating.snap b/api/tests/snapshots/golden_tests__golden_device_status_updating.snap deleted file mode 100644 index 0e69f1d..0000000 --- a/api/tests/snapshots/golden_tests__golden_device_status_updating.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8207a20065322e342e3001185a diff --git a/api/tests/snapshots/golden_tests__golden_envoy_magic_backup_enabled_request.snap b/api/tests/snapshots/golden_tests__golden_envoy_magic_backup_enabled_request.snap deleted file mode 100644 index 153c331..0000000 --- a/api/tests/snapshots/golden_tests__golden_envoy_magic_backup_enabled_request.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8211a0 diff --git a/api/tests/snapshots/golden_tests__golden_envoy_magic_backup_enabled_response.snap b/api/tests/snapshots/golden_tests__golden_envoy_magic_backup_enabled_response.snap deleted file mode 100644 index 2d62aef..0000000 --- a/api/tests/snapshots/golden_tests__golden_envoy_magic_backup_enabled_response.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8212a100f5 diff --git a/api/tests/snapshots/golden_tests__golden_envoy_status.snap b/api/tests/snapshots/golden_tests__golden_envoy_status.snap deleted file mode 100644 index df6e905..0000000 --- a/api/tests/snapshots/golden_tests__golden_envoy_status.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8208a10065312e302e30 diff --git a/api/tests/snapshots/golden_tests__golden_exchange_rate.snap b/api/tests/snapshots/golden_tests__golden_exchange_rate.snap deleted file mode 100644 index 40d1065..0000000 --- a/api/tests/snapshots/golden_tests__golden_exchange_rate.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8200a3006355534401fa47241080021a6553f100 diff --git a/api/tests/snapshots/golden_tests__golden_exchange_rate_history.snap b/api/tests/snapshots/golden_tests__golden_exchange_rate_history.snap deleted file mode 100644 index c727066..0000000 --- a/api/tests/snapshots/golden_tests__golden_exchange_rate_history.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8201a20082a20019a028011a6553f09ca20019a410011a6553f1000163455552 diff --git a/api/tests/snapshots/golden_tests__golden_firmware_fetch_event_chunk.snap b/api/tests/snapshots/golden_tests__golden_firmware_fetch_event_chunk.snap deleted file mode 100644 index 3c50383..0000000 --- a/api/tests/snapshots/golden_tests__golden_firmware_fetch_event_chunk.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82058203a50000010302050318640444deadbeef diff --git a/api/tests/snapshots/golden_tests__golden_firmware_fetch_event_downloading.snap b/api/tests/snapshots/golden_tests__golden_firmware_fetch_event_downloading.snap deleted file mode 100644 index 8759321..0000000 --- a/api/tests/snapshots/golden_tests__golden_firmware_fetch_event_downloading.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82058102 diff --git a/api/tests/snapshots/golden_tests__golden_firmware_fetch_event_error.snap b/api/tests/snapshots/golden_tests__golden_firmware_fetch_event_error.snap deleted file mode 100644 index 12c4912..0000000 --- a/api/tests/snapshots/golden_tests__golden_firmware_fetch_event_error.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82058205a1006f446f776e6c6f6164206661696c6564 diff --git a/api/tests/snapshots/golden_tests__golden_firmware_fetch_event_not_available.snap b/api/tests/snapshots/golden_tests__golden_firmware_fetch_event_not_available.snap deleted file mode 100644 index 1035adb..0000000 --- a/api/tests/snapshots/golden_tests__golden_firmware_fetch_event_not_available.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82058100 diff --git a/api/tests/snapshots/golden_tests__golden_firmware_fetch_event_starting.snap b/api/tests/snapshots/golden_tests__golden_firmware_fetch_event_starting.snap deleted file mode 100644 index 7538c4f..0000000 --- a/api/tests/snapshots/golden_tests__golden_firmware_fetch_event_starting.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82058201a50065322e352e30016c4e6577206665617475726573021a6553f100031a001f40000405 diff --git a/api/tests/snapshots/golden_tests__golden_firmware_fetch_request.snap b/api/tests/snapshots/golden_tests__golden_firmware_fetch_request.snap deleted file mode 100644 index 0cddae6..0000000 --- a/api/tests/snapshots/golden_tests__golden_firmware_fetch_request.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8204a10065322e342e30 diff --git a/api/tests/snapshots/golden_tests__golden_firmware_update_check_request.snap b/api/tests/snapshots/golden_tests__golden_firmware_update_check_request.snap deleted file mode 100644 index 6d75041..0000000 --- a/api/tests/snapshots/golden_tests__golden_firmware_update_check_request.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8202a10065322e342e30 diff --git a/api/tests/snapshots/golden_tests__golden_firmware_update_check_response_available.snap b/api/tests/snapshots/golden_tests__golden_firmware_update_check_response_available.snap deleted file mode 100644 index b6c888f..0000000 --- a/api/tests/snapshots/golden_tests__golden_firmware_update_check_response_available.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82038200a50065322e352e300169427567206669786573021a6553f100031a000fa0000403 diff --git a/api/tests/snapshots/golden_tests__golden_firmware_update_check_response_not_available.snap b/api/tests/snapshots/golden_tests__golden_firmware_update_check_response_not_available.snap deleted file mode 100644 index 34404d7..0000000 --- a/api/tests/snapshots/golden_tests__golden_firmware_update_check_response_not_available.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82038101 diff --git a/api/tests/snapshots/golden_tests__golden_firmware_update_result_error.snap b/api/tests/snapshots/golden_tests__golden_firmware_update_result_error.snap deleted file mode 100644 index c9c7137..0000000 --- a/api/tests/snapshots/golden_tests__golden_firmware_update_result_error.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82068201a10073496e7374616c6c6174696f6e206661696c6564 diff --git a/api/tests/snapshots/golden_tests__golden_firmware_update_result_error_install.snap b/api/tests/snapshots/golden_tests__golden_firmware_update_result_error_install.snap deleted file mode 100644 index a2accd3..0000000 --- a/api/tests/snapshots/golden_tests__golden_firmware_update_result_error_install.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82068204a20073496e7374616c6c6174696f6e206661696c6564018102 diff --git a/api/tests/snapshots/golden_tests__golden_firmware_update_result_error_verify.snap b/api/tests/snapshots/golden_tests__golden_firmware_update_result_error_verify.snap deleted file mode 100644 index 5a9de85..0000000 --- a/api/tests/snapshots/golden_tests__golden_firmware_update_result_error_verify.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82068204a200781d5369676e617475726520766572696669636174696f6e206661696c6564018101 diff --git a/api/tests/snapshots/golden_tests__golden_firmware_update_result_installing.snap b/api/tests/snapshots/golden_tests__golden_firmware_update_result_installing.snap deleted file mode 100644 index 2465e2d..0000000 --- a/api/tests/snapshots/golden_tests__golden_firmware_update_result_installing.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82068101 diff --git a/api/tests/snapshots/golden_tests__golden_firmware_update_result_rebooting.snap b/api/tests/snapshots/golden_tests__golden_firmware_update_result_rebooting.snap deleted file mode 100644 index efe79a6..0000000 --- a/api/tests/snapshots/golden_tests__golden_firmware_update_result_rebooting.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82068102 diff --git a/api/tests/snapshots/golden_tests__golden_firmware_update_result_success.snap b/api/tests/snapshots/golden_tests__golden_firmware_update_result_success.snap deleted file mode 100644 index 8343235..0000000 --- a/api/tests/snapshots/golden_tests__golden_firmware_update_result_success.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82068203a10065322e352e30 diff --git a/api/tests/snapshots/golden_tests__golden_firmware_update_result_update_verified.snap b/api/tests/snapshots/golden_tests__golden_firmware_update_result_update_verified.snap deleted file mode 100644 index 0d16e03..0000000 --- a/api/tests/snapshots/golden_tests__golden_firmware_update_result_update_verified.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82068100 diff --git a/api/tests/snapshots/golden_tests__golden_heartbeat.snap b/api/tests/snapshots/golden_tests__golden_heartbeat.snap deleted file mode 100644 index d4da48e..0000000 --- a/api/tests/snapshots/golden_tests__golden_heartbeat.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82181fa0 diff --git a/api/tests/snapshots/golden_tests__golden_onboarding_state_completed.snap b/api/tests/snapshots/golden_tests__golden_onboarding_state_completed.snap deleted file mode 100644 index 0ded2ec..0000000 --- a/api/tests/snapshots/golden_tests__golden_onboarding_state_completed.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -820c8110 diff --git a/api/tests/snapshots/golden_tests__golden_onboarding_state_firmware_update_screen.snap b/api/tests/snapshots/golden_tests__golden_onboarding_state_firmware_update_screen.snap deleted file mode 100644 index fce20b1..0000000 --- a/api/tests/snapshots/golden_tests__golden_onboarding_state_firmware_update_screen.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -820c8102 diff --git a/api/tests/snapshots/golden_tests__golden_pairing_request.snap b/api/tests/snapshots/golden_tests__golden_pairing_request.snap deleted file mode 100644 index b1773d3..0000000 --- a/api/tests/snapshots/golden_tests__golden_pairing_request.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8209a200440102030401694d79206950686f6e65 diff --git a/api/tests/snapshots/golden_tests__golden_pairing_response.snap b/api/tests/snapshots/golden_tests__golden_pairing_response.snap deleted file mode 100644 index 7bd3089..0000000 --- a/api/tests/snapshots/golden_tests__golden_pairing_response.snap +++ /dev/null @@ -1,6 +0,0 @@ ---- -source: api/tests/golden_tests.rs -assertion_line: 241 -expression: hex.clone() ---- -820aa60081020165322e342e30026641424331323303810104f5056e50617373706f7274205072696d65 diff --git a/api/tests/snapshots/golden_tests__golden_prime_magic_backup_enabled.snap b/api/tests/snapshots/golden_tests__golden_prime_magic_backup_enabled.snap deleted file mode 100644 index fdd857e..0000000 --- a/api/tests/snapshots/golden_tests__golden_prime_magic_backup_enabled.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8213a200f50158204242424242424242424242424242424242424242424242424242424242424242 diff --git a/api/tests/snapshots/golden_tests__golden_prime_magic_backup_status_request.snap b/api/tests/snapshots/golden_tests__golden_prime_magic_backup_status_request.snap deleted file mode 100644 index 4085bf0..0000000 --- a/api/tests/snapshots/golden_tests__golden_prime_magic_backup_status_request.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8214a1005820abababababababababababababababababababababababababababababababab diff --git a/api/tests/snapshots/golden_tests__golden_prime_magic_backup_status_response.snap b/api/tests/snapshots/golden_tests__golden_prime_magic_backup_status_response.snap deleted file mode 100644 index 55f5557..0000000 --- a/api/tests/snapshots/golden_tests__golden_prime_magic_backup_status_response.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8215a100f5 diff --git a/api/tests/snapshots/golden_tests__golden_raw_data.snap b/api/tests/snapshots/golden_tests__golden_raw_data.snap deleted file mode 100644 index 8c311b8..0000000 --- a/api/tests/snapshots/golden_tests__golden_raw_data.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -821864a10044feedface diff --git a/api/tests/snapshots/golden_tests__golden_restore_magic_backup_event_chunk.snap b/api/tests/snapshots/golden_tests__golden_restore_magic_backup_event_chunk.snap deleted file mode 100644 index c71f6b9..0000000 --- a/api/tests/snapshots/golden_tests__golden_restore_magic_backup_event_chunk.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82181d8202a3000a0118320244aabbccdd diff --git a/api/tests/snapshots/golden_tests__golden_restore_magic_backup_event_error.snap b/api/tests/snapshots/golden_tests__golden_restore_magic_backup_event_error.snap deleted file mode 100644 index 2e846ec..0000000 --- a/api/tests/snapshots/golden_tests__golden_restore_magic_backup_event_error.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82181d8203a1006d4e6574776f726b206572726f72 diff --git a/api/tests/snapshots/golden_tests__golden_restore_magic_backup_event_no_backup.snap b/api/tests/snapshots/golden_tests__golden_restore_magic_backup_event_no_backup.snap deleted file mode 100644 index 945badd..0000000 --- a/api/tests/snapshots/golden_tests__golden_restore_magic_backup_event_no_backup.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82181d8100 diff --git a/api/tests/snapshots/golden_tests__golden_restore_magic_backup_event_starting.snap b/api/tests/snapshots/golden_tests__golden_restore_magic_backup_event_starting.snap deleted file mode 100644 index e645a69..0000000 --- a/api/tests/snapshots/golden_tests__golden_restore_magic_backup_event_starting.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82181d8201a10018c8 diff --git a/api/tests/snapshots/golden_tests__golden_restore_magic_backup_request.snap b/api/tests/snapshots/golden_tests__golden_restore_magic_backup_request.snap deleted file mode 100644 index abc49fb..0000000 --- a/api/tests/snapshots/golden_tests__golden_restore_magic_backup_request.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82181ca2005820bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb011832 diff --git a/api/tests/snapshots/golden_tests__golden_restore_magic_backup_result_error.snap b/api/tests/snapshots/golden_tests__golden_restore_magic_backup_result_error.snap deleted file mode 100644 index 1e47f70..0000000 --- a/api/tests/snapshots/golden_tests__golden_restore_magic_backup_result_error.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82181e8201a10071436865636b73756d206d69736d61746368 diff --git a/api/tests/snapshots/golden_tests__golden_restore_magic_backup_result_success.snap b/api/tests/snapshots/golden_tests__golden_restore_magic_backup_result_success.snap deleted file mode 100644 index a2580d8..0000000 --- a/api/tests/snapshots/golden_tests__golden_restore_magic_backup_result_success.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -82181e8100 diff --git a/api/tests/snapshots/golden_tests__golden_restore_shard_request.snap b/api/tests/snapshots/golden_tests__golden_restore_shard_request.snap deleted file mode 100644 index 566178e..0000000 --- a/api/tests/snapshots/golden_tests__golden_restore_shard_request.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -821818a1005820cdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcd diff --git a/api/tests/snapshots/golden_tests__golden_restore_shard_response_error.snap b/api/tests/snapshots/golden_tests__golden_restore_shard_response_error.snap deleted file mode 100644 index b9595b4..0000000 --- a/api/tests/snapshots/golden_tests__golden_restore_shard_response_error.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8218198201a100694e6f7420666f756e64 diff --git a/api/tests/snapshots/golden_tests__golden_restore_shard_response_not_found.snap b/api/tests/snapshots/golden_tests__golden_restore_shard_response_not_found.snap deleted file mode 100644 index 04619ea..0000000 --- a/api/tests/snapshots/golden_tests__golden_restore_shard_response_not_found.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8218198102 diff --git a/api/tests/snapshots/golden_tests__golden_restore_shard_response_success.snap b/api/tests/snapshots/golden_tests__golden_restore_shard_response_success.snap deleted file mode 100644 index af0ffb0..0000000 --- a/api/tests/snapshots/golden_tests__golden_restore_shard_response_success.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -8218198200a100430a0b0c diff --git a/api/tests/snapshots/golden_tests__golden_security_check_challenge_request.snap b/api/tests/snapshots/golden_tests__golden_security_check_challenge_request.snap deleted file mode 100644 index d7b98fd..0000000 --- a/api/tests/snapshots/golden_tests__golden_security_check_challenge_request.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -820b8200a10044cafebabe diff --git a/api/tests/snapshots/golden_tests__golden_security_check_challenge_response_error.snap b/api/tests/snapshots/golden_tests__golden_security_check_challenge_response_error.snap deleted file mode 100644 index 5945e5e..0000000 --- a/api/tests/snapshots/golden_tests__golden_security_check_challenge_response_error.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -820b82018201a10071496e76616c6964207369676e6174757265 diff --git a/api/tests/snapshots/golden_tests__golden_security_check_challenge_response_success.snap b/api/tests/snapshots/golden_tests__golden_security_check_challenge_response_success.snap deleted file mode 100644 index 1df2e59..0000000 --- a/api/tests/snapshots/golden_tests__golden_security_check_challenge_response_success.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -820b82018200a10044deadbeef diff --git a/api/tests/snapshots/golden_tests__golden_security_check_verification_error.snap b/api/tests/snapshots/golden_tests__golden_security_check_verification_error.snap deleted file mode 100644 index 83b577a..0000000 --- a/api/tests/snapshots/golden_tests__golden_security_check_verification_error.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -820b82028201a10073566572696669636174696f6e206661696c6564 diff --git a/api/tests/snapshots/golden_tests__golden_security_check_verification_success.snap b/api/tests/snapshots/golden_tests__golden_security_check_verification_success.snap deleted file mode 100644 index 98d43c1..0000000 --- a/api/tests/snapshots/golden_tests__golden_security_check_verification_success.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -820b82028100 diff --git a/api/tests/snapshots/golden_tests__golden_sign_psbt.snap b/api/tests/snapshots/golden_tests__golden_sign_psbt.snap deleted file mode 100644 index 5379861..0000000 --- a/api/tests/snapshots/golden_tests__golden_sign_psbt.snap +++ /dev/null @@ -1,5 +0,0 @@ ---- -source: api/tests/golden_tests.rs -expression: hex.clone() ---- -820da200696163636f756e742d31014570736274ff diff --git a/quantum-link-macros/Cargo.toml b/quantum-link-macros/Cargo.toml deleted file mode 100644 index 6debf69..0000000 --- a/quantum-link-macros/Cargo.toml +++ /dev/null @@ -1,14 +0,0 @@ -[package] -name = "quantum-link-macros" -version = "0.1.0" -edition = "2021" -homepage.workspace = true - -[dependencies] -#foundation-api = { workspace = true } -quote = "^1" -syn = { version = "^2.0.5", features = ["full", "extra-traits"] } -proc-macro2 = "1" - -[lib] -proc-macro = true \ No newline at end of file diff --git a/quantum-link-macros/src/lib.rs b/quantum-link-macros/src/lib.rs deleted file mode 100644 index 1552432..0000000 --- a/quantum-link-macros/src/lib.rs +++ /dev/null @@ -1,632 +0,0 @@ -use proc_macro::TokenStream; -use proc_macro2::TokenStream as TokenStream2; -use quote::quote; -use syn::{ - parse_macro_input, spanned::Spanned, Attribute, Data, DataStruct, DeriveInput, Fields, Lit, - Meta, Type, Visibility, -}; - -#[proc_macro_attribute] -pub fn quantum_link(_metadata: TokenStream, input: TokenStream) -> TokenStream { - let input: DeriveInput = syn::parse(input).unwrap(); - - if let Err(e) = validate_visibility(&input) { - return e.to_compile_error().into(); - } - - let expanded = quote! { - #[derive(Clone, Debug, PartialEq, quantum_link_macros::Cbor)] - #[cfg_attr(feature = "keyos", derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize))] - #[cfg_attr(feature = "envoy", flutter_rust_bridge::frb(non_opaque))] - #input - }; - - TokenStream::from(expanded) -} - -fn validate_visibility(input: &DeriveInput) -> syn::Result<()> { - if !matches!(input.vis, Visibility::Public(_)) { - return Err(syn::Error::new( - input.ident.span(), - "quantum link types must be public", - )); - } - if let Data::Struct(data_struct) = &input.data { - match &data_struct.fields { - Fields::Named(fields) => { - for field in &fields.named { - if !matches!(field.vis, Visibility::Public(_)) { - return Err(syn::Error::new( - field.ident.as_ref().unwrap().span(), - "fields in quantum link structs must be public", - )); - } - } - } - Fields::Unnamed(fields) => { - for field in fields.unnamed.iter() { - if !matches!(field.vis, Visibility::Public(_)) { - return Err(syn::Error::new( - field.span(), - "fields in quantum link structs must be public", - )); - } - } - } - Fields::Unit => {} - } - } - - Ok(()) -} - -/// derive macro generates -/// - From for CBOR -/// - TryFrom for T -#[proc_macro_derive(Cbor, attributes(n))] -pub fn derive_cbor(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as DeriveInput); - - derive_cbor_impl(input) - .unwrap_or_else(|e| e.to_compile_error()) - .into() -} - -fn derive_cbor_impl(input: DeriveInput) -> syn::Result { - let name = &input.ident; - let generics = &input.generics; - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - - let (into_impl, try_from_impl) = match &input.data { - Data::Struct(data_struct) => generate_struct_impls(&data_struct.fields, name)?, - Data::Enum(data_enum) => { - let into_body = generate_enum_into_cbor(name, &data_enum.variants)?; - let try_from_body = generate_enum_try_from_cbor(&data_enum.variants)?; - (into_body, try_from_body) - } - Data::Union(data_union) => { - return Err(syn::Error::new( - data_union.union_token.span(), - "unions not supported", - )); - } - }; - - // only auto-impl for non-tuple structs - let cbor_marker_impl = match &input.data { - Data::Struct(DataStruct { - fields: Fields::Unnamed(_), - .. - }) => quote! {}, - _ => { - quote! { - impl #impl_generics crate::CborMarker for #name #ty_generics #where_clause {} - } - } - }; - - Ok(quote! { - impl #impl_generics From<#name #ty_generics> for dcbor::CBOR #where_clause { - fn from(value: #name #ty_generics) -> dcbor::CBOR { - #into_impl - } - } - - impl #impl_generics TryFrom for #name #ty_generics #where_clause { - type Error = dcbor::Error; - - fn try_from(cbor: dcbor::CBOR) -> dcbor::Result { - #try_from_impl - } - } - - #cbor_marker_impl - }) -} - -// -// struct -// - -fn generate_struct_impls( - fields: &Fields, - name: &syn::Ident, -) -> syn::Result<(TokenStream2, TokenStream2)> { - match fields { - Fields::Named(fields) => { - let into_body = generate_named_struct_into_cbor(&fields.named)?; - let try_from_body = generate_named_struct_try_from_cbor(&fields.named)?; - Ok((into_body, try_from_body)) - } - Fields::Unnamed(fields) => { - if fields.unnamed.len() != 1 { - return Err(syn::Error::new( - fields.span(), - "only single-field tuple structs (newtypes) are supported", - )); - } - let (into_body, try_from_body) = - generate_newtype_struct_impls(fields.unnamed.first().unwrap())?; - Ok((into_body, try_from_body)) - } - Fields::Unit => Err(syn::Error::new(name.span(), "unit structs not supported")), - } -} - -fn generate_named_struct_into_cbor( - fields: &syn::punctuated::Punctuated, -) -> syn::Result { - check_duplicate_indices(fields)?; - - let mut field_insertions = Vec::new(); - - for field in fields { - let field_name = field.ident.as_ref().unwrap(); - let field_type = &field.ty; - let index = get_field_index(&field.attrs) - .ok_or_else(|| syn::Error::new(field.span(), "missing #[n(x)] attribute"))?; - - if let Some(inner) = get_option_inner(field_type) { - let cbor_value = gen_to_cbor(&inner, quote! { val }); - field_insertions.push(quote! { - if let Some(val) = value.#field_name { - map.insert(dcbor::CBOR::from(#index), #cbor_value); - } - }); - } else { - let insertion = gen_map_insert(index, field_type, quote! { value.#field_name }); - field_insertions.push(insertion); - } - } - - Ok(quote! { - let mut map = dcbor::Map::new(); - #(#field_insertions)* - dcbor::CBOR::from(map) - }) -} - -fn generate_named_struct_try_from_cbor( - fields: &syn::punctuated::Punctuated, -) -> syn::Result { - let mut field_extractions = Vec::new(); - let mut field_names = Vec::new(); - - for field in fields { - let field_name = field.ident.as_ref().unwrap(); - let field_type = &field.ty; - let index = get_field_index(&field.attrs) - .ok_or_else(|| syn::Error::new(field.span(), "missing #[n(x)] attribute"))?; - - let extraction = if let Some(inner) = get_option_inner(field_type) { - let value = gen_map_get_optional(index, &inner, quote! { map }); - quote! { let #field_name: #field_type = #value; } - } else { - let value = gen_map_get_required(index, field_type, quote! { map }); - quote! { let #field_name: #field_type = #value; } - }; - - field_extractions.push(extraction); - field_names.push(field_name); - } - - Ok(quote! { - let case = cbor.into_case(); - let dcbor::CBORCase::Map(map) = case else { - return Err(dcbor::Error::WrongType); - }; - - #(#field_extractions)* - - Ok(Self { - #(#field_names),* - }) - }) -} - -fn generate_newtype_struct_impls(field: &syn::Field) -> syn::Result<(TokenStream2, TokenStream2)> { - let field_type = &field.ty; - - if get_field_index(&field.attrs).is_some() { - return Err(syn::Error::new( - field.span(), - "newtype structs cannot have #[n(x)] attribute; use a named struct instead", - )); - } - - let into_body = gen_to_cbor(field_type, quote! { value.0 }); - let from_value = gen_from_cbor(field_type, quote! { cbor }); - let try_from_body = quote! { Ok(Self(#from_value)) }; - - Ok((into_body, try_from_body)) -} - -// -// enum -// - -fn generate_enum_into_cbor( - enum_name: &syn::Ident, - variants: &syn::punctuated::Punctuated, -) -> syn::Result { - check_duplicate_indices(variants)?; - - let mut variant_arms = Vec::new(); - - for variant in variants { - let variant_name = &variant.ident; - let variant_index = get_field_index(&variant.attrs) - .ok_or_else(|| syn::Error::new(variant.span(), "missing #[n(x)] attribute"))?; - - let arm = match &variant.fields { - Fields::Unit => { - quote! { - #enum_name::#variant_name => { - dcbor::CBOR::from(vec![dcbor::CBOR::from(#variant_index)]) - } - } - } - Fields::Unnamed(fields) => { - generate_tuple_variant_into_cbor(enum_name, variant_name, variant_index, fields)? - } - Fields::Named(fields) => { - generate_struct_variant_into_cbor(enum_name, variant_name, variant_index, fields)? - } - }; - - variant_arms.push(arm); - } - - Ok(quote! { - match value { - #(#variant_arms)* - } - }) -} - -fn generate_tuple_variant_into_cbor( - enum_name: &syn::Ident, - variant_name: &syn::Ident, - variant_index: u64, - fields: &syn::FieldsUnnamed, -) -> syn::Result { - if fields.unnamed.len() != 1 { - return Err(syn::Error::new( - fields.span(), - "tuple variants must have exactly one field", - )); - } - - let field = fields.unnamed.first().unwrap(); - let field_type = &field.ty; - - if get_field_index(&field.attrs).is_some() { - return Err(syn::Error::new( - field.span(), - "tuple variant fields cannot have #[n(x)] attribute; use a struct variant instead", - )); - } - - Ok(quote! { - #enum_name::#variant_name(inner) => { - const _: fn() = || { - fn assert_cbor_marker() {} - assert_cbor_marker::<#field_type>(); - }; - dcbor::CBOR::from(vec![ - dcbor::CBOR::from(#variant_index), - dcbor::CBOR::from(inner), - ]) - } - }) -} - -fn generate_struct_variant_into_cbor( - enum_name: &syn::Ident, - variant_name: &syn::Ident, - variant_index: u64, - fields: &syn::FieldsNamed, -) -> syn::Result { - check_duplicate_indices(&fields.named)?; - - let mut field_names = Vec::new(); - let mut field_insertions = Vec::new(); - - for field in &fields.named { - let field_name = field.ident.as_ref().unwrap(); - let field_type = &field.ty; - let field_index = get_field_index(&field.attrs) - .ok_or_else(|| syn::Error::new(field.span(), "missing #[n(x)] attribute"))?; - - field_names.push(field_name); - - let cbor_value = gen_to_cbor(field_type, quote! { #field_name }); - field_insertions.push(quote! { - inner_map.insert(dcbor::CBOR::from(#field_index), #cbor_value); - }); - } - - Ok(quote! { - #enum_name::#variant_name { #(#field_names),* } => { - let mut inner_map = dcbor::Map::new(); - #(#field_insertions)* - - dcbor::CBOR::from(vec![ - dcbor::CBOR::from(#variant_index), - dcbor::CBOR::from(inner_map), - ]) - } - }) -} - -fn generate_enum_try_from_cbor( - variants: &syn::punctuated::Punctuated, -) -> syn::Result { - let mut variant_arms = Vec::new(); - - for variant in variants { - let variant_name = &variant.ident; - let variant_index = get_field_index(&variant.attrs) - .ok_or_else(|| syn::Error::new(variant.span(), "missing #[n(x)] attribute"))?; - - let arm = match &variant.fields { - Fields::Unit => { - quote! { #variant_index => Ok(Self::#variant_name), } - } - Fields::Unnamed(fields) => { - generate_tuple_variant_try_from_cbor(variant_name, variant_index, fields)? - } - Fields::Named(fields) => { - generate_struct_variant_try_from_cbor(variant_name, variant_index, fields)? - } - }; - - variant_arms.push(arm); - } - - Ok(quote! { - let case = cbor.into_case(); - let dcbor::CBORCase::Array(arr) = case else { - return Err(dcbor::Error::WrongType); - }; - - let variant_index: u64 = >::try_from( - arr.get(0).ok_or(dcbor::Error::WrongType)?.clone() - )?; - - match variant_index { - #(#variant_arms)* - _ => Err(dcbor::Error::WrongType), - } - }) -} - -fn generate_tuple_variant_try_from_cbor( - variant_name: &syn::Ident, - variant_index: u64, - fields: &syn::FieldsUnnamed, -) -> syn::Result { - if fields.unnamed.len() != 1 { - return Err(syn::Error::new( - fields.span(), - "tuple variants must have exactly one field", - )); - } - - let field = fields.unnamed.first().unwrap(); - let field_type = &field.ty; - - if get_field_index(&field.attrs).is_some() { - return Err(syn::Error::new( - field.span(), - "tuple variant fields cannot have #[n(x)] attribute; use a struct variant instead", - )); - } - - Ok(quote! { - #variant_index => { - let variant_data = arr.get(1).ok_or(dcbor::Error::WrongType)?; - let inner: #field_type = variant_data.clone().try_into()?; - Ok(Self::#variant_name(inner)) - } - }) -} - -fn generate_struct_variant_try_from_cbor( - variant_name: &syn::Ident, - variant_index: u64, - fields: &syn::FieldsNamed, -) -> syn::Result { - let mut field_extractions = Vec::new(); - let mut field_names = Vec::new(); - - for field in &fields.named { - let field_name = field.ident.as_ref().unwrap(); - let field_type = &field.ty; - let field_index = get_field_index(&field.attrs) - .ok_or_else(|| syn::Error::new(field.span(), "missing #[n(x)] attribute"))?; - - let extraction = if let Some(inner) = get_option_inner(field_type) { - let value = gen_map_get_optional(field_index, &inner, quote! { inner_map }); - quote! { let #field_name: #field_type = #value; } - } else { - let value = gen_map_get_required(field_index, field_type, quote! { inner_map }); - quote! { let #field_name: #field_type = #value; } - }; - - field_extractions.push(extraction); - field_names.push(field_name); - } - - Ok(quote! { - #variant_index => { - let variant_data = arr.get(1).ok_or(dcbor::Error::WrongType)?; - let inner_case = variant_data.clone().into_case(); - let dcbor::CBORCase::Map(inner_map) = inner_case else { - return Err(dcbor::Error::WrongType); - }; - - #(#field_extractions)* - - Ok(Self::#variant_name { - #(#field_names),* - }) - } - }) -} - -// -// helpers -// - -fn gen_to_cbor(field_type: &Type, value: TokenStream2) -> TokenStream2 { - if is_vec_u8(field_type) || is_u8_array(field_type) { - quote! { dcbor::CBOR::to_byte_string(#value) } - } else { - quote! { dcbor::CBOR::from(#value) } - } -} - -fn gen_from_cbor(field_type: &Type, cbor: TokenStream2) -> TokenStream2 { - if is_vec_u8(field_type) { - quote! { #cbor.try_into_byte_string()?.to_vec() } - } else if is_u8_array(field_type) { - gen_byte_array_from_cbor(field_type, cbor) - } else { - quote! { #cbor.try_into()? } - } -} - -fn gen_byte_array_from_cbor(field_type: &Type, cbor: TokenStream2) -> TokenStream2 { - quote! {{ - let bytes = #cbor.try_into_byte_string()?; - <#field_type>::try_from(bytes.as_ref()) - .map_err(|_| dcbor::Error::OutOfRange)? - }} -} - -fn gen_map_insert(index: u64, field_type: &Type, value: TokenStream2) -> TokenStream2 { - let cbor_value = gen_to_cbor(field_type, value); - quote! { - map.insert(dcbor::CBOR::from(#index), #cbor_value); - } -} - -fn gen_map_get_required(index: u64, field_type: &Type, map: TokenStream2) -> TokenStream2 { - let cbor_expr = quote! { - #map.get::(#index) - .ok_or(dcbor::Error::MissingMapKey)? - }; - gen_from_cbor(field_type, cbor_expr) -} - -fn gen_map_get_optional(index: u64, inner_type: &Type, map: TokenStream2) -> TokenStream2 { - let value_expr = if is_vec_u8(inner_type) { - quote! { field_cbor.try_into_byte_string()?.to_vec() } - } else if is_u8_array(inner_type) { - gen_byte_array_from_cbor(inner_type, quote! { field_cbor }) - } else { - quote! { field_cbor.try_into()? } - }; - - quote! { - match #map.get::(#index) { - Some(field_cbor) => Some(#value_expr), - None => None, - } - } -} - -fn get_field_index(attrs: &[Attribute]) -> Option { - for attr in attrs { - if attr.path().is_ident("n") { - if let Meta::List(meta_list) = &attr.meta { - let tokens = meta_list.tokens.clone(); - if let Ok(Lit::Int(lit_int)) = syn::parse2::(tokens) { - return lit_int.base10_parse().ok(); - } - } - } - } - None -} - -trait Indexed: Spanned { - fn index_attrs(&self) -> &[Attribute]; -} - -impl Indexed for syn::Field { - fn index_attrs(&self) -> &[Attribute] { - &self.attrs - } -} - -impl Indexed for syn::Variant { - fn index_attrs(&self) -> &[Attribute] { - &self.attrs - } -} - -fn check_duplicate_indices<'a>( - items: impl IntoIterator, -) -> syn::Result<()> { - let mut seen: std::collections::HashMap = - std::collections::HashMap::new(); - - for item in items { - if let Some(index) = get_field_index(item.index_attrs()) { - if let Some(&prev_span) = seen.get(&index) { - let mut err = - syn::Error::new(item.span(), format!("duplicate #[n({index})] attribute")); - err.combine(syn::Error::new(prev_span, "first use of this index")); - return Err(err); - } - seen.insert(index, item.span()); - } - } - - Ok(()) -} - -fn get_option_inner(ty: &Type) -> Option { - let Type::Path(p) = ty else { return None }; - let seg = p.path.segments.last().filter(|s| s.ident == "Option")?; - let syn::PathArguments::AngleBracketed(args) = &seg.arguments else { - return None; - }; - let syn::GenericArgument::Type(inner) = args.args.first()? else { - return None; - }; - Some(inner.clone()) -} - -fn is_vec_u8(ty: &Type) -> bool { - let Type::Path(p) = ty else { return false }; - let Some(seg) = p.path.segments.last().filter(|s| s.ident == "Vec") else { - return false; - }; - let syn::PathArguments::AngleBracketed(args) = &seg.arguments else { - return false; - }; - let Some(syn::GenericArgument::Type(Type::Path(inner))) = args.args.first() else { - return false; - }; - inner - .path - .segments - .last() - .map(|s| s.ident == "u8") - .unwrap_or(false) -} - -fn is_u8_array(ty: &Type) -> bool { - let Type::Array(array) = ty else { return false }; - let Type::Path(p) = &*array.elem else { - return false; - }; - p.path - .segments - .last() - .map(|s| s.ident == "u8") - .unwrap_or(false) -}