From 1217d06e651f729db081453ed4cb99b3fc86f306 Mon Sep 17 00:00:00 2001 From: Yorick Peterse Date: Fri, 26 Jul 2024 22:50:39 +0200 Subject: [PATCH] Add support for TLS sockets This adds the module std.net.tls and refactors std.net.socket in various places, such that we can provide support for TLS 1.2 and TLS 1.3. The TLS stack is backed by Rustls (https://github.com/rustls/rustls). My original plan was to write the stack in Inko, but I deemed this far too time consuming and not beneficial for users (compared to using an existing mature stack). I also experimented with OpenSSL, but using OpenSSL is like walking through a minefield, and its API is a pain to use (in part due to its use of global and thread-local state). Rustls is compiled such that it uses the "ring" backend instead of aws-lc. This is done because aws-lc requires additional dependencies on FreeBSD, and increases compile times significantly (about 30 seconds or so). While performance of TLS 1.3 is less ideal when using ring compared to using aws-lc (https://github.com/rustls/rustls/issues/1751), it should still be good enough (and still be much faster compared to using OpenSSL). A downside of using Rustls is that the executable sizes increase by about 6 MiB (or 2 MiB when stripping them), due to the extra code introduced by Rustls and its dependencies. Sadly we can't avoid this unless we use OpenSSL, which introduces far more pressing issues. For certificate validation we use a patched version of the rustls-platform-verifier crate. The patched version strips the code we don't need (mostly so we don't get tons of "this code is unused" warnings and what not), and patches the macOS code to account for the system verification process being (potentially) slow by using the `Process::blocking` method. This fixes https://github.com/inko-lang/inko/issues/329. Changelog: added --- Cargo.lock | 309 ++++++--- compiler/src/linker.rs | 8 +- rt/Cargo.toml | 29 + rt/src/lib.rs | 29 +- rt/src/network_poller.rs | 1 + rt/src/network_poller/kqueue.rs | 20 +- rt/src/process.rs | 1 - rt/src/result.rs | 8 +- rt/src/runtime.rs | 7 + rt/src/runtime/env.rs | 20 - rt/src/runtime/helpers.rs | 51 ++ rt/src/runtime/socket.rs | 96 +-- rt/src/runtime/tls.rs | 309 +++++++++ rt/src/rustls_platform_verifier/LICENSE | 21 + rt/src/rustls_platform_verifier/mod.rs | 41 ++ .../verification/apple.rs | 238 +++++++ .../verification/mod.rs | 59 ++ .../verification/others.rs | 166 +++++ rt/src/scheduler/process.rs | 18 + rt/src/socket.rs | 113 ++-- std/fixtures/tls/README.md | 12 + std/fixtures/tls/empty.key | 0 std/fixtures/tls/empty.pem | 0 std/fixtures/tls/invalid.key | 24 + std/fixtures/tls/invalid.pem | 10 + std/fixtures/tls/test.cnf | 17 + std/fixtures/tls/test.key | 28 + std/fixtures/tls/test.pem | 22 + std/src/std/crypto/pem.inko | 254 +++++++ std/src/std/crypto/x509.inko | 38 ++ std/src/std/env.inko | 15 +- std/src/std/fs/path.inko | 46 +- std/src/std/io.inko | 67 +- std/src/std/net/socket.inko | 610 +++++++++-------- std/src/std/net/tls.inko | 639 ++++++++++++++++++ std/src/std/option.inko | 36 + std/src/std/string.inko | 26 +- std/test/compiler/test_diagnostics.inko | 4 +- std/test/std/fs/test_path.inko | 33 +- std/test/std/net/test_socket.inko | 401 ++++++----- std/test/std/net/test_tls.inko | 437 ++++++++++++ std/test/std/test_env.inko | 14 + std/test/std/test_io.inko | 32 +- std/test/std/test_option.inko | 13 + std/test/std/test_optparse.inko | 6 +- std/test/std/test_string.inko | 36 +- 46 files changed, 3569 insertions(+), 795 deletions(-) create mode 100644 rt/src/runtime/tls.rs create mode 100644 rt/src/rustls_platform_verifier/LICENSE create mode 100644 rt/src/rustls_platform_verifier/mod.rs create mode 100644 rt/src/rustls_platform_verifier/verification/apple.rs create mode 100644 rt/src/rustls_platform_verifier/verification/mod.rs create mode 100644 rt/src/rustls_platform_verifier/verification/others.rs create mode 100644 std/fixtures/tls/README.md create mode 100644 std/fixtures/tls/empty.key create mode 100644 std/fixtures/tls/empty.pem create mode 100644 std/fixtures/tls/invalid.key create mode 100644 std/fixtures/tls/invalid.pem create mode 100644 std/fixtures/tls/test.cnf create mode 100644 std/fixtures/tls/test.key create mode 100644 std/fixtures/tls/test.pem create mode 100644 std/src/std/crypto/pem.inko create mode 100644 std/src/std/crypto/x509.inko create mode 100644 std/src/std/net/tls.inko create mode 100644 std/test/std/net/test_tls.inko diff --git a/Cargo.lock b/Cargo.lock index c3b6f89fa..f6fe87dbd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "addr2line" -version = "0.21.0" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678" dependencies = [ "gimli", ] @@ -44,11 +44,17 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "autocfg" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + [[package]] name = "backtrace" -version = "0.3.71" +version = "0.3.73" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" +checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a" dependencies = [ "addr2line", "cc", @@ -61,9 +67,9 @@ dependencies = [ [[package]] name = "base64" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "bitflags" @@ -73,9 +79,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "blake3" @@ -103,9 +109,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.95" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d32a725bc159af97c3e629873bb9f88fb8cf8a4867175f76dc987815ea07c83b" +checksum = "eaff6f8ce506b9773fa786672d63fc7a191ffea1be33f72bbd4aeacefca9ffc8" [[package]] name = "cfg-if" @@ -146,11 +152,27 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + [[package]] name = "crc32fast" -version = "1.4.0" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" dependencies = [ "cfg-if", ] @@ -166,15 +188,15 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.19" +version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" [[package]] name = "either" -version = "1.11.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "encode_unicode" @@ -184,9 +206,9 @@ checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" [[package]] name = "errno" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" dependencies = [ "libc", "windows-sys", @@ -206,9 +228,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.29" +version = "1.0.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4556222738635b7a3417ae6130d8f52201e45a0c4d1a907f0826383adb5f85e7" +checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" dependencies = [ "crc32fast", "miniz_oxide", @@ -240,9 +262,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", "libc", @@ -251,9 +273,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.28.1" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "idna" @@ -304,21 +326,21 @@ dependencies = [ [[package]] name = "lazy_static" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.153" +version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" [[package]] name = "linux-raw-sys" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" [[package]] name = "llvm-sys" @@ -336,30 +358,58 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "memchr" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "miniz_oxide" -version = "0.7.2" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" +checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" dependencies = [ "adler", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "object" -version = "0.32.2" +version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +checksum = "081b846d1d56ddfc18fdf1a922e4f6e07a11768ea1b92dec44e42b72712ccfce" dependencies = [ "memchr", ] @@ -370,6 +420,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + [[package]] name = "percent-encoding" version = "2.3.1" @@ -384,9 +440,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.81" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" dependencies = [ "unicode-ident", ] @@ -471,20 +527,29 @@ name = "rt" version = "0.15.0" dependencies = [ "backtrace", + "core-foundation", + "core-foundation-sys", "crossbeam-queue", "crossbeam-utils", "libc", + "once_cell", "rand", "rustix", + "rustls", + "rustls-native-certs", + "rustls-pemfile", + "rustls-webpki", + "security-framework", + "security-framework-sys", "socket2", "unicode-segmentation", ] [[package]] name = "rustc-demangle" -version = "0.1.23" +version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" [[package]] name = "rustix" @@ -492,7 +557,7 @@ version = "0.38.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "errno", "libc", "linux-raw-sys", @@ -501,11 +566,12 @@ dependencies = [ [[package]] name = "rustls" -version = "0.22.4" +version = "0.23.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" +checksum = "4828ea528154ae444e5a642dbb7d5623354030dc9822b83fd9bb79683c7399d0" dependencies = [ "log", + "once_cell", "ring", "rustls-pki-types", "rustls-webpki", @@ -513,28 +579,84 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "rustls-pki-types", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" +dependencies = [ + "base64", + "rustls-pki-types", +] + [[package]] name = "rustls-pki-types" -version = "1.5.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "beb461507cee2c2ff151784c52762cf4d9ff6a61f3e80968600ed24fa837fa54" +checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" [[package]] name = "rustls-webpki" -version = "0.102.3" +version = "0.102.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3bce581c0dd41bce533ce695a1437fa16a7ab5ac3ccfa99fe1a620a7885eabf" +checksum = "f9a6fccd794a42c2c105b513a2f62bc3fd8f3ba57a4593677ceb0bd035164d78" dependencies = [ "ring", "rustls-pki-types", "untrusted", ] +[[package]] +name = "schannel" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "security-framework" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +dependencies = [ + "bitflags 2.6.0", + "core-foundation", + "core-foundation-sys", + "libc", + "num-bigint", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" -version = "1.0.22" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" [[package]] name = "similar" @@ -558,9 +680,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.6" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871" +checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" dependencies = [ "libc", "windows-sys", @@ -574,15 +696,15 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "subtle" -version = "2.5.0" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.60" +version = "2.0.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" +checksum = "2f0209b68b3613b093e0ec905354eccaedcfe83b8cb37cbdeae64026c3064c16" dependencies = [ "proc-macro2", "quote", @@ -591,9 +713,9 @@ dependencies = [ [[package]] name = "tar" -version = "0.4.40" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb" +checksum = "cb797dad5fb5b76fcf519e702f4a589483b5ef06567f160c392832c1f5e44909" dependencies = [ "filetime", "libc", @@ -601,18 +723,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.59" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa" +checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.59" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66" +checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", @@ -621,9 +743,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.6.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" dependencies = [ "tinyvec_macros", ] @@ -667,9 +789,9 @@ checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" [[package]] name = "unicode-width" -version = "0.1.12" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6" +checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" [[package]] name = "untrusted" @@ -679,25 +801,24 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "ureq" -version = "2.9.7" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd" +checksum = "72139d247e5f97a3eff96229a7ae85ead5328a39efe76f8bf5a06313d505b6ea" dependencies = [ "base64", "log", "once_cell", "rustls", "rustls-pki-types", - "rustls-webpki", "url", "webpki-roots", ] [[package]] name = "url" -version = "2.5.0" +version = "2.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" dependencies = [ "form_urlencoded", "idna", @@ -712,9 +833,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "webpki-roots" -version = "0.26.1" +version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009" +checksum = "bd7c23921eeb1713a4e851530e9b9756e4fb0e89978582942612524cf09f01cd" dependencies = [ "rustls-pki-types", ] @@ -730,9 +851,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -746,54 +867,54 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "zeroize" -version = "1.7.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" diff --git a/compiler/src/linker.rs b/compiler/src/linker.rs index 39d20ceca..6ee66e134 100644 --- a/compiler/src/linker.rs +++ b/compiler/src/linker.rs @@ -256,7 +256,13 @@ pub(crate) fn link( cmd.arg("-lm"); cmd.arg("-lpthread"); } - _ => {} + OperatingSystem::Mac => { + // This is needed for TLS support. + for name in ["Security", "CoreFoundation"] { + cmd.arg("-framework"); + cmd.arg(name); + } + } } let mut static_linking = state.config.static_linking; diff --git a/rt/Cargo.toml b/rt/Cargo.toml index 9853cb5aa..e5e72df94 100644 --- a/rt/Cargo.toml +++ b/rt/Cargo.toml @@ -23,6 +23,35 @@ unicode-segmentation = "^1.10" backtrace = "^0.3" rustix = { version = "^0.38", features = ["fs", "mm", "param", "process", "net", "std", "time", "event"], default-features = false } +# The dependencies needed for TLS support. +# +# We use ring instead of the default aws-lc-sys because: +# +# 1. aws-lc-sys requires cmake to be installed when building on FreeBSD (and +# potentially other platforms), as aws-lc-sys only provides generated +# bindings for a limited set of platforms +# 2. aws-lc-sys increases compile times quite a bit +# 3. We don't care about FIPS compliance at the time of writing +rustls = { version = "^0.23", features = ["ring", "tls12", "std"], default-features = false } +rustls-pemfile = "^2.1" + +# These dependencies are used by the customized version of +# rustls-platform-modifier. We include a custom version so we can deal with the +# platform verification process being potentially slow. See +# https://github.com/rustls/rustls/issues/850 and +# https://github.com/inko-lang/inko/issues/329 for more details. +once_cell = "1.9" + +[target.'cfg(all(unix, not(target_os = "macos"), not(target_os = "ios"), not(target_os = "tvos")))'.dependencies] +rustls-native-certs = "0.7" +webpki = { package = "rustls-webpki", version = "0.102", default-features = false } + +[target.'cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))'.dependencies] +core-foundation = "0.9" +core-foundation-sys = "0.8" +security-framework = { version = "2.10", features = ["OSX_10_14"] } +security-framework-sys = { version = "2.10", features = ["OSX_10_14"] } + [dependencies.socket2] version = "^0.5" features = ["all"] diff --git a/rt/src/lib.rs b/rt/src/lib.rs index 800f8c47c..ca3dd4c15 100644 --- a/rt/src/lib.rs +++ b/rt/src/lib.rs @@ -3,21 +3,22 @@ #![allow(clippy::missing_safety_doc)] #![allow(clippy::too_many_arguments)] -pub mod macros; +mod macros; -pub mod arc_without_weak; -pub mod config; -pub mod context; -pub mod mem; -pub mod memory_map; -pub mod network_poller; -pub mod process; -pub mod result; -pub mod runtime; -pub mod scheduler; -pub mod socket; -pub mod stack; -pub mod state; +mod arc_without_weak; +mod config; +mod context; +mod mem; +mod memory_map; +mod network_poller; +mod process; +mod result; +mod runtime; +mod rustls_platform_verifier; +mod scheduler; +mod socket; +mod stack; +mod state; #[cfg(test)] pub mod test; diff --git a/rt/src/network_poller.rs b/rt/src/network_poller.rs index e28c278d1..68c61aa7e 100644 --- a/rt/src/network_poller.rs +++ b/rt/src/network_poller.rs @@ -25,6 +25,7 @@ const CAPACITY: usize = 1024; pub(crate) type NetworkPoller = sys::Poller; /// The type of event a poller should wait for. +#[derive(Debug)] pub(crate) enum Interest { Read, Write, diff --git a/rt/src/network_poller/kqueue.rs b/rt/src/network_poller/kqueue.rs index 4d3790277..d71a08eb5 100644 --- a/rt/src/network_poller/kqueue.rs +++ b/rt/src/network_poller/kqueue.rs @@ -40,18 +40,20 @@ impl Poller { source: impl AsFd, interest: Interest, ) { - let fd = source.as_fd().as_raw_fd(); - let (add, del) = match interest { - Interest::Read => (EventFilter::Read(fd), EventFilter::Write(fd)), - Interest::Write => (EventFilter::Write(fd), EventFilter::Read(fd)), - }; let id = process.identifier() as isize; + let fd = source.as_fd().as_raw_fd(); let flags = EventFlags::CLEAR | EventFlags::ONESHOT | EventFlags::RECEIPT; - let events = [ - Event::new(add, EventFlags::ADD | flags, id), - Event::new(del, EventFlags::DELETE, 0), - ]; + let events = match interest { + Interest::Read => [ + Event::new(EventFilter::Read(fd), EventFlags::ADD | flags, id), + Event::new(EventFilter::Write(fd), EventFlags::DELETE, 0), + ], + Interest::Write => [ + Event::new(EventFilter::Write(fd), EventFlags::ADD | flags, id), + Event::new(EventFilter::Read(fd), EventFlags::DELETE, 0), + ], + }; self.apply(&events); } diff --git a/rt/src/process.rs b/rt/src/process.rs index 20d91659d..c6c8dd9ec 100644 --- a/rt/src/process.rs +++ b/rt/src/process.rs @@ -4,7 +4,6 @@ use crate::scheduler::process::Thread; use crate::scheduler::timeouts::Timeout; use crate::stack::Stack; use crate::state::State; -use backtrace; use std::alloc::{alloc, dealloc, handle_alloc_error, Layout}; use std::cell::UnsafeCell; use std::collections::VecDeque; diff --git a/rt/src/result.rs b/rt/src/result.rs index d72a8b71d..836b7d6a4 100644 --- a/rt/src/result.rs +++ b/rt/src/result.rs @@ -13,7 +13,11 @@ pub(crate) fn error_to_int(error: io::Error) -> i64 { // raw_os_error() above returns a None. Errno::TIMEDOUT.raw_os_error() } else { - -1 + match error.kind() { + io::ErrorKind::InvalidData => -2, + io::ErrorKind::UnexpectedEof => -3, + _ => -1, + } }; code as i64 @@ -60,7 +64,7 @@ impl Result { } pub(crate) fn io_error(error: io::Error) -> Result { - Self::error({ error_to_int(error) } as _) + Self::error(error_to_int(error) as _) } } diff --git a/rt/src/runtime.rs b/rt/src/runtime.rs index f027d53eb..1ae7b84fc 100644 --- a/rt/src/runtime.rs +++ b/rt/src/runtime.rs @@ -14,6 +14,7 @@ mod stdio; mod string; mod sys; mod time; +mod tls; use crate::config::Config; use crate::mem::ClassPointer; @@ -67,6 +68,12 @@ pub unsafe extern "system" fn inko_runtime_new( // does for us when compiling an executable. signal_sched::block_all(); + // Configure the TLS provider. This must be done once before we start the + // program. + rustls::crypto::ring::default_provider() + .install_default() + .expect("failed to set up the default TLS cryptography provider"); + Box::into_raw(Box::new(Runtime::new(&*counts, args))) } diff --git a/rt/src/runtime/env.rs b/rt/src/runtime/env.rs index dfb9dbd51..1f8bbb43f 100644 --- a/rt/src/runtime/env.rs +++ b/rt/src/runtime/env.rs @@ -40,26 +40,6 @@ pub unsafe extern "system" fn inko_env_size(state: *const State) -> i64 { (*state).environment.len() as _ } -#[no_mangle] -pub unsafe extern "system" fn inko_env_home_directory( - state: *const State, -) -> InkoResult { - let state = &*state; - - // Rather than performing all sorts of magical incantations to get the home - // directory, we're just going to require that HOME is set. - // - // If the home is explicitly set to an empty string we still ignore it, - // because there's no scenario in which Some("") is useful. - state - .environment - .get("HOME") - .filter(|&path| !path.is_empty()) - .cloned() - .map(|v| InkoResult::ok(InkoString::alloc(state.string_class, v) as _)) - .unwrap_or_else(InkoResult::none) -} - #[no_mangle] pub unsafe extern "system" fn inko_env_temp_directory( state: *const State, diff --git a/rt/src/runtime/helpers.rs b/rt/src/runtime/helpers.rs index 178fcbce5..38b4f924e 100644 --- a/rt/src/runtime/helpers.rs +++ b/rt/src/runtime/helpers.rs @@ -1,3 +1,9 @@ +use crate::context; +use crate::network_poller::Interest; +use crate::process::ProcessPointer; +use crate::scheduler::timeouts::Timeout; +use crate::socket::Socket; +use crate::state::State; use std::io::{self, Read}; /// Reads a number of bytes from a buffer into a Vec. @@ -14,3 +20,48 @@ pub(crate) fn read_into( Ok(read as i64) } + +pub(crate) fn poll( + state: &State, + mut process: ProcessPointer, + socket: &mut Socket, + interest: Interest, + deadline: i64, +) -> io::Result<()> { + let poll_id = unsafe { process.thread() }.network_poller; + + // We must keep the process' state lock open until everything is registered, + // otherwise a timeout thread may reschedule the process (i.e. the timeout + // is very short) before we finish registering the socket with a poller. + { + let mut proc_state = process.state(); + + // A deadline of -1 signals that we should wait indefinitely. + if deadline >= 0 { + let time = Timeout::until(deadline as u64); + + proc_state.waiting_for_io(Some(time.clone())); + state.timeout_worker.suspend(process, time); + } else { + proc_state.waiting_for_io(None); + } + + socket.register(state, process, poll_id, interest); + } + + // Safety: the current thread is holding on to the process' run lock, so if + // the process gets rescheduled onto a different thread, said thread won't + // be able to use it until we finish this context switch. + unsafe { context::switch(process) }; + + if process.timeout_expired() { + // The socket is still registered at this point, so we have to + // deregister first. If we don't and suspend for another IO operation, + // the poller could end up rescheduling the process multiple times (as + // there are multiple events still in flight for the process). + socket.deregister(state); + return Err(io::Error::from(io::ErrorKind::TimedOut)); + } + + Ok(()) +} diff --git a/rt/src/runtime/socket.rs b/rt/src/runtime/socket.rs index 496b51200..51ee49b41 100644 --- a/rt/src/runtime/socket.rs +++ b/rt/src/runtime/socket.rs @@ -1,10 +1,9 @@ -use crate::context; use crate::mem::{ByteArray, String as InkoString}; use crate::network_poller::Interest; use crate::process::ProcessPointer; use crate::result::{error_to_int, Result}; -use crate::scheduler::timeouts::Timeout; -use crate::socket::Socket; +use crate::runtime::helpers::poll; +use crate::socket::{read_from, Socket}; use crate::state::State; use std::io::{self, Write}; use std::ptr::{drop_in_place, write}; @@ -24,66 +23,33 @@ impl RawAddress { } } -fn blocking( +fn run( state: &State, - mut process: ProcessPointer, + process: ProcessPointer, socket: &mut Socket, interest: Interest, deadline: i64, mut func: impl FnMut(&mut Socket) -> io::Result, ) -> io::Result { match func(socket) { - Err(err) if err.kind() == io::ErrorKind::WouldBlock => {} - val => return val, - } - - let poll_id = unsafe { process.thread() }.network_poller; - - // We must keep the process' state lock open until everything is registered, - // otherwise a timeout thread may reschedule the process (i.e. the timeout - // is very short) before we finish registering the socket with a poller. - { - let mut proc_state = process.state(); - - // A deadline of -1 signals that we should wait indefinitely. - if deadline >= 0 { - let time = Timeout::until(deadline as u64); - - proc_state.waiting_for_io(Some(time.clone())); - state.timeout_worker.suspend(process, time); - } else { - proc_state.waiting_for_io(None); + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + poll(state, process, socket, interest, deadline) + .and_then(|_| func(socket)) } - - socket.register(state, process, poll_id, interest); + val => val, } - - // Safety: the current thread is holding on to the process' run lock, so if - // the process gets rescheduled onto a different thread, said thread won't - // be able to use it until we finish this context switch. - unsafe { context::switch(process) }; - - if process.timeout_expired() { - // The socket is still registered at this point, so we have to - // deregister first. If we don't and suspend for another IO operation, - // the poller could end up rescheduling the process multiple times (as - // there are multiple events still in flight for the process). - socket.deregister(state); - return Err(io::Error::from(io::ErrorKind::TimedOut)); - } - - func(socket) } #[no_mangle] pub(crate) unsafe extern "system" fn inko_socket_new( - proto: i64, + domain: i64, kind: i64, + proto: i64, out: *mut Socket, ) -> i64 { - let sock = match proto { - 0 => Socket::ipv4(kind), - 1 => Socket::ipv6(kind), + let sock = match domain { + 0 => Socket::ipv4(kind, proto), + 1 => Socket::ipv6(kind, proto), _ => Socket::unix(kind), }; @@ -108,7 +74,7 @@ pub(crate) unsafe extern "system" fn inko_socket_write( let state = &*state; let slice = std::slice::from_raw_parts(data, size as _); - blocking(state, process, &mut *socket, Interest::Write, deadline, |sock| { + run(state, process, &mut *socket, Interest::Write, deadline, |sock| { sock.write(slice) }) .map(|v| Result::ok(v as _)) @@ -126,8 +92,8 @@ pub unsafe extern "system" fn inko_socket_read( ) -> Result { let state = &*state; - blocking(state, process, &mut *socket, Interest::Read, deadline, |sock| { - sock.read(&mut (*buffer).value, amount as usize) + run(state, process, &mut *socket, Interest::Read, deadline, |sock| { + read_from(sock, &mut (*buffer).value, amount as usize) }) .map(|size| Result::ok(size as _)) .unwrap_or_else(Result::io_error) @@ -169,7 +135,7 @@ pub unsafe extern "system" fn inko_socket_connect( ) -> Result { let state = &*state; - blocking(state, process, &mut *socket, Interest::Write, deadline, |sock| { + run(state, process, &mut *socket, Interest::Write, deadline, |sock| { sock.connect(InkoString::read(address), port as u16) }) .map(|_| Result::none()) @@ -184,14 +150,10 @@ pub unsafe extern "system" fn inko_socket_accept( deadline: i64, out: *mut Socket, ) -> i64 { - let res = blocking( - &*state, - process, - &mut *socket, - Interest::Read, - deadline, - |sock| sock.accept(), - ); + let res = + run(&*state, process, &mut *socket, Interest::Read, deadline, |sock| { + sock.accept() + }); match res { Ok(val) => { @@ -213,14 +175,10 @@ pub unsafe extern "system" fn inko_socket_receive_from( out: *mut RawAddress, ) -> i64 { let state = &*state; - let res = blocking( - state, - process, - &mut *socket, - Interest::Read, - deadline, - |sock| sock.recv_from(&mut (*buffer).value, amount as _), - ); + let res = + run(state, process, &mut *socket, Interest::Read, deadline, |sock| { + sock.recv_from(&mut (*buffer).value, amount as _) + }); match res { Ok((addr, port)) => { @@ -244,7 +202,7 @@ pub unsafe extern "system" fn inko_socket_send_bytes_to( let state = &*state; let addr = InkoString::read(address); - blocking(state, process, &mut *socket, Interest::Write, deadline, |sock| { + run(state, process, &mut *socket, Interest::Write, deadline, |sock| { sock.send_to(&(*buffer).value, addr, port as _) }) .map(|size| Result::ok(size as _)) @@ -264,7 +222,7 @@ pub unsafe extern "system" fn inko_socket_send_string_to( let state = &*state; let addr = InkoString::read(address); - blocking(state, process, &mut *socket, Interest::Write, deadline, |sock| { + run(state, process, &mut *socket, Interest::Write, deadline, |sock| { sock.send_to(InkoString::read(buffer).as_bytes(), addr, port as _) }) .map(|size| Result::ok(size as _)) diff --git a/rt/src/runtime/tls.rs b/rt/src/runtime/tls.rs new file mode 100644 index 000000000..a6163db94 --- /dev/null +++ b/rt/src/runtime/tls.rs @@ -0,0 +1,309 @@ +use crate::mem::{ByteArray, String as InkoString}; +use crate::network_poller::Interest; +use crate::process::ProcessPointer; +use crate::result::Result; +use crate::runtime::helpers::poll; +use crate::rustls_platform_verifier::tls_config; +use crate::socket::{read_from, Socket}; +use crate::state::State; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName}; +use rustls::{ + ClientConfig, ClientConnection, Error as TlsError, RootCertStore, + ServerConfig, ServerConnection, SideData, Stream, +}; +use std::io::{self, Write}; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; + +/// The error code produced when a TLS certificate is invalid. +const INVALID_CERT: isize = -1; + +/// The error code produced when a TLS private key is invalid. +const INVALID_KEY: isize = -2; + +unsafe fn run< + C: Deref> + DerefMut, + R, + S: SideData, +>( + state: *const State, + process: ProcessPointer, + socket: *mut Socket, + con: *mut C, + deadline: i64, + mut func: impl FnMut(&mut Stream) -> io::Result, +) -> io::Result { + let state = &*state; + let mut stream = Stream::new(&mut *con, &mut *socket); + + loop { + match func(&mut stream) { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + let interest = if stream.conn.wants_write() { + Interest::Write + } else { + Interest::Read + }; + + poll(state, process, stream.sock, interest, deadline)?; + } + val => return val, + } + } +} + +unsafe fn tls_close< + C: Deref> + DerefMut, + S: SideData, +>( + state: *const State, + proc: ProcessPointer, + sock: *mut Socket, + con: *mut C, + deadline: i64, +) -> io::Result<()> { + (*con).send_close_notify(); + + while (*con).wants_write() { + run(state, proc, sock, con, deadline, |s| s.conn.write_tls(s.sock))?; + } + + Ok(()) +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_client_config_new() -> *mut ClientConfig +{ + Arc::into_raw(Arc::new(tls_config())) as *mut _ +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_client_config_with_certificate( + cert: *const ByteArray, +) -> Result { + let mut store = RootCertStore::empty(); + let cert = CertificateDer::from((*cert).value.clone()); + + if store.add(cert).is_err() { + return Result::error(INVALID_CERT as _); + } + + let conf = Arc::new( + ClientConfig::builder() + .with_root_certificates(store) + .with_no_client_auth(), + ); + + Result::ok(Arc::into_raw(conf) as *mut _) +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_client_config_clone( + config: *const ClientConfig, +) -> *const ClientConfig { + Arc::increment_strong_count(config); + config +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_client_config_drop( + config: *const ClientConfig, +) { + drop(Arc::from_raw(config)); +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_client_connection_new( + config: *const ClientConfig, + server: *const InkoString, +) -> Result { + let name = match ServerName::try_from(InkoString::read(server)) { + Ok(v) => v, + Err(_) => return Result::none(), + }; + + Arc::increment_strong_count(config); + + // ClientConnection::new() _can_ in theory fail, but based on the source + // code it seems this only happens when certain settings are adjusted, which + // we don't allow at this time. + let con = ClientConnection::new(Arc::from_raw(config), name) + .expect("failed to set up the TLS client connection"); + + Result::ok_boxed(con) +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_client_connection_drop( + state: *mut ClientConnection, +) { + drop(Box::from_raw(state)); +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_server_config_new( + cert: *const ByteArray, + key: *const ByteArray, +) -> Result { + // CertificateDer/PrivateKeyDer either borrow a value or take an owned + // value. We can't use borrows because we don't know if the Inko values + // outlive the configuration, so we have to clone the bytes here. + let chain = vec![CertificateDer::from((*cert).value.clone())]; + let Ok(key) = PrivateKeyDer::try_from((*key).value.clone()) else { + return Result::error(INVALID_KEY as _); + }; + let conf = match ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(chain, key) + { + Ok(v) => v, + Err( + TlsError::NoCertificatesPresented + | TlsError::InvalidCertificate(_) + | TlsError::UnsupportedNameType + | TlsError::InvalidCertRevocationList(_), + ) => return Result::error(INVALID_CERT as _), + // For private key errors (and potentially others), rustls produces a + // `Error::General`, and in the future possibly other errors. The "one + // error type to rule them all" approach of rustls makes handling + // specific cases painful, so we just treat all remaining errors as + // private key errors. Given we already handle invalid certificates + // above, this should be correct (enough). + Err(_) => return Result::error(INVALID_KEY as _), + }; + + Result::ok(Arc::into_raw(Arc::new(conf)) as *mut _) +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_server_config_clone( + config: *const ServerConfig, +) -> *const ServerConfig { + Arc::increment_strong_count(config); + config +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_server_config_drop( + config: *const ServerConfig, +) { + drop(Arc::from_raw(config)); +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_server_connection_new( + config: *const ServerConfig, +) -> *mut ServerConnection { + Arc::increment_strong_count(config); + + // ServerConnection::new() _can_ in theory fail, but based on the source + // code it seems this only happens when certain settings are adjusted, which + // we don't allow at this time. + let con = ServerConnection::new(Arc::from_raw(config)) + .expect("failed to set up the TLS server connection"); + + Box::into_raw(Box::new(con)) +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_server_connection_drop( + state: *mut ServerConnection, +) { + drop(Box::from_raw(state)); +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_client_write( + state: *const State, + proc: ProcessPointer, + sock: *mut Socket, + con: *mut ClientConnection, + data: *mut u8, + size: i64, + deadline: i64, +) -> Result { + let buf = std::slice::from_raw_parts(data, size as _); + + run(state, proc, sock, con, deadline, |s| s.write(buf)) + .map(|v| Result::ok(v as _)) + .unwrap_or_else(Result::io_error) +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_client_read( + state: *const State, + proc: ProcessPointer, + sock: *mut Socket, + con: *mut ClientConnection, + buffer: *mut ByteArray, + amount: i64, + deadline: i64, +) -> Result { + let buf = &mut (*buffer).value; + let len = amount as usize; + + run(state, proc, sock, con, deadline, |s| read_from(s, buf, len)) + .map(|v| Result::ok(v as _)) + .unwrap_or_else(Result::io_error) +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_client_close( + state: *const State, + proc: ProcessPointer, + sock: *mut Socket, + con: *mut ClientConnection, + deadline: i64, +) -> Result { + tls_close(state, proc, sock, con, deadline) + .map(|_| Result::none()) + .unwrap_or_else(Result::io_error) +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_server_write( + state: *const State, + proc: ProcessPointer, + sock: *mut Socket, + con: *mut ServerConnection, + data: *mut u8, + size: i64, + deadline: i64, +) -> Result { + let buf = std::slice::from_raw_parts(data, size as _); + + run(state, proc, sock, con, deadline, |s| s.write(buf)) + .map(|v| Result::ok(v as _)) + .unwrap_or_else(Result::io_error) +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_server_read( + state: *const State, + proc: ProcessPointer, + sock: *mut Socket, + con: *mut ServerConnection, + buffer: *mut ByteArray, + amount: i64, + deadline: i64, +) -> Result { + let buf = &mut (*buffer).value; + let len = amount as usize; + + run(state, proc, sock, con, deadline, |s| read_from(s, buf, len)) + .map(|v| Result::ok(v as _)) + .unwrap_or_else(Result::io_error) +} + +#[no_mangle] +pub unsafe extern "system" fn inko_tls_server_close( + state: *const State, + proc: ProcessPointer, + sock: *mut Socket, + con: *mut ServerConnection, + deadline: i64, +) -> Result { + tls_close(state, proc, sock, con, deadline) + .map(|_| Result::none()) + .unwrap_or_else(Result::io_error) +} diff --git a/rt/src/rustls_platform_verifier/LICENSE b/rt/src/rustls_platform_verifier/LICENSE new file mode 100644 index 000000000..996412264 --- /dev/null +++ b/rt/src/rustls_platform_verifier/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 1Password + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/rt/src/rustls_platform_verifier/mod.rs b/rt/src/rustls_platform_verifier/mod.rs new file mode 100644 index 000000000..dcd8f7bdb --- /dev/null +++ b/rt/src/rustls_platform_verifier/mod.rs @@ -0,0 +1,41 @@ +use rustls::ClientConfig; +use std::sync::Arc; + +mod verification; +pub use verification::Verifier; + +/// Creates and returns a `rustls` configuration that verifies TLS +/// certificates in the best way for the underlying OS platform, using +/// safe defaults for the `rustls` configuration. +/// +/// # Example +/// +/// This example shows how to use the custom verifier with the `reqwest` crate: +/// ```ignore +/// # use reqwest::ClientBuilder; +/// #[tokio::main] +/// async fn main() { +/// let client = ClientBuilder::new() +/// .use_preconfigured_tls(rustls_platform_verifier::tls_config()) +/// .build() +/// .expect("nothing should fail"); +/// +/// let _response = client.get("https://example.com").send().await; +/// } +/// ``` +/// +/// **Important:** You must ensure that your `reqwest` version is using the same Rustls +/// version as this crate or it will panic when downcasting the `&dyn Any` verifier. +/// +/// If you require more control over the rustls `ClientConfig`, you can +/// instantiate a [Verifier] with [Verifier::default] and then use it +/// with [`DangerousClientConfigBuilder::with_custom_certificate_verifier`][rustls::client::danger::DangerousClientConfigBuilder::with_custom_certificate_verifier]. +/// +/// Refer to the crate level documentation to see what platforms +/// are currently supported. +pub fn tls_config() -> ClientConfig { + ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(Verifier::new())) + .with_no_client_auth() +} diff --git a/rt/src/rustls_platform_verifier/verification/apple.rs b/rt/src/rustls_platform_verifier/verification/apple.rs new file mode 100644 index 000000000..933c0f590 --- /dev/null +++ b/rt/src/rustls_platform_verifier/verification/apple.rs @@ -0,0 +1,238 @@ +use super::log_server_cert; +use crate::rustls_platform_verifier::verification::invalid_certificate; +use core_foundation::date::CFDate; +use core_foundation_sys::date::kCFAbsoluteTimeIntervalSince1970; +use once_cell::sync::OnceCell; +use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerifier}; +use rustls::crypto::{ + verify_tls12_signature, verify_tls13_signature, CryptoProvider, +}; +use rustls::pki_types; +use rustls::{ + CertificateError, DigitallySignedStruct, Error as TlsError, OtherError, + SignatureScheme, +}; +use security_framework::{ + certificate::SecCertificate, policy::SecPolicy, + secure_transport::SslProtocolSide, trust::SecTrust, +}; +use std::sync::Arc; + +use crate::process::ProcessPointer; +use crate::scheduler::process::CURRENT_PROCESS; + +mod errors { + pub(super) use security_framework_sys::base::{ + errSecCertificateRevoked, errSecCreateChainFailed, + errSecHostNameMismatch, errSecInvalidExtendedKeyUsage, + }; +} + +#[allow(clippy::as_conversions)] +fn system_time_to_cfdate( + time: pki_types::UnixTime, +) -> Result { + // SAFETY: The interval is defined by macOS externally, but is always present and never modified at runtime + // since its a global variable. + // + // See https://developer.apple.com/documentation/corefoundation/kcfabsolutetimeintervalsince1970. + let unix_adjustment = unsafe { kCFAbsoluteTimeIntervalSince1970 as u64 }; + + // Convert a system timestamp based off the UNIX epoch into the + // Apple epoch used by all `CFAbsoluteTime` values. + // Subtracting Durations with sub() will panic on overflow + time.as_secs() + .checked_sub(unix_adjustment) + .ok_or(TlsError::FailedToGetCurrentTime) + .map(|epoch| CFDate::new(epoch as f64)) +} + +/// A TLS certificate verifier that utilizes the Apple platform certificate facilities. +#[derive(Debug)] +pub struct Verifier { + pub(super) crypto_provider: OnceCell>, +} + +impl Verifier { + /// Creates a new instance of a TLS certificate verifier that utilizes the macOS certificate + /// facilities. + /// + /// A [`CryptoProvider`] must be set with + /// [`set_provider`][Verifier::set_provider]/[`with_provider`][Verifier::with_provider] or + /// [`CryptoProvider::install_default`] before the verifier can be used. + pub fn new() -> Self { + Self { crypto_provider: OnceCell::new() } + } + + fn verify_certificate( + &self, + end_entity: &pki_types::CertificateDer<'_>, + intermediates: &[pki_types::CertificateDer<'_>], + server_name: &str, + ocsp_response: Option<&[u8]>, + now: pki_types::UnixTime, + ) -> Result<(), TlsError> { + let certificates: Vec = + std::iter::once(end_entity.as_ref()) + .chain(intermediates.iter().map(|cert| cert.as_ref())) + .map(|cert| { + SecCertificate::from_der(cert).map_err(|_| { + TlsError::InvalidCertificate( + CertificateError::BadEncoding, + ) + }) + }) + .collect::, _>>()?; + + // Create our verification policy suitable for verifying TLS chains. + // This uses the "default" verification engine and parameters, the same as Windows. + // + // The protocol side should be set to `server` for a client to verify server TLS + // certificates. + // + // The server name will be required to match what the end-entity certificate reports + // + // Ref: https://developer.apple.com/documentation/security/1392592-secpolicycreatessl + let policy = + SecPolicy::create_ssl(SslProtocolSide::SERVER, Some(server_name)); + + // Create our trust evaluation context/chain. + // + // Apple requires that the certificate to be verified is always first in the array, and we + // always place the end-entity certificate at the start. + // + // Ref: https://developer.apple.com/documentation/security/1401555-sectrustcreatewithcertificates + let mut trust_evaluation = + SecTrust::create_with_certificates(&certificates, &[policy]) + .map_err(|e| TlsError::General(e.to_string()))?; + + // Tell the system that we want to consider the certificates evaluation at the point + // in time that `rustls` provided. + let now = system_time_to_cfdate(now)?; + trust_evaluation + .set_trust_verify_date(&now) + .map_err(|e| invalid_certificate(e.to_string()))?; + + // If we have OCSP response data, make sure the system makes use of it. + if let Some(ocsp_response) = ocsp_response { + trust_evaluation + .set_trust_ocsp_response(std::iter::once(ocsp_response)) + .map_err(|e| invalid_certificate(e.to_string()))?; + } + + // Safety: well, technically none, but due to the way the runtime uses + // the verifier this should never misbehave. + let process = unsafe { ProcessPointer::new(CURRENT_PROCESS.get()) }; + let trust_error = + match process.blocking(|| trust_evaluation.evaluate_with_error()) { + Ok(()) => return Ok(()), + Err(e) => e, + }; + + let err_code = trust_error.code(); + + let err = err_code + .try_into() + .map_err(|_| ()) + .and_then(|code| { + // Only map the errors we need for tests. + match code { + errors::errSecHostNameMismatch => Ok(TlsError::InvalidCertificate( + CertificateError::NotValidForName, + )), + errors::errSecCreateChainFailed => Ok(TlsError::InvalidCertificate( + CertificateError::UnknownIssuer, + )), + errors::errSecInvalidExtendedKeyUsage => Ok(TlsError::InvalidCertificate( + CertificateError::Other(OtherError(std::sync::Arc::new(super::EkuError))), + )), + errors::errSecCertificateRevoked => { + Ok(TlsError::InvalidCertificate(CertificateError::Revoked)) + } + _ => Err(()), + } + }) + // Fallback to an error containing the description and specific error code so that + // the exact error cause can be looked up easily. + .unwrap_or_else(|_| invalid_certificate(format!("{}: {}", trust_error, err_code))); + + Err(err) + } +} + +impl ServerCertVerifier for Verifier { + fn verify_server_cert( + &self, + end_entity: &pki_types::CertificateDer<'_>, + intermediates: &[pki_types::CertificateDer<'_>], + server_name: &pki_types::ServerName, + ocsp_response: &[u8], + now: pki_types::UnixTime, + ) -> Result { + log_server_cert(end_entity); + + // Convert IP addresses to name strings to ensure match check on leaf certificate. + // Ref: https://developer.apple.com/documentation/security/1392592-secpolicycreatessl + let server = server_name.to_str(); + + let ocsp_data = + if !ocsp_response.is_empty() { Some(ocsp_response) } else { None }; + + match self.verify_certificate( + end_entity, + intermediates, + &server, + ocsp_data, + now, + ) { + Ok(()) => { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + Err(e) => { + // This error only tells us what the system errored with, so it doesn't leak anything + // sensitive. + Err(e) + } + } + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &pki_types::CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + verify_tls12_signature( + message, + cert, + dss, + &self.get_provider().signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &pki_types::CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + verify_tls13_signature( + message, + cert, + dss, + &self.get_provider().signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + self.get_provider() + .signature_verification_algorithms + .supported_schemes() + } +} + +impl Default for Verifier { + fn default() -> Self { + Self::new() + } +} diff --git a/rt/src/rustls_platform_verifier/verification/mod.rs b/rt/src/rustls_platform_verifier/verification/mod.rs new file mode 100644 index 000000000..4469a3453 --- /dev/null +++ b/rt/src/rustls_platform_verifier/verification/mod.rs @@ -0,0 +1,59 @@ +use rustls::crypto::CryptoProvider; +use std::sync::Arc; + +#[cfg(all( + not(target_os = "macos"), + not(target_os = "ios"), + not(target_os = "tvos") +))] +mod others; + +#[cfg(all( + not(target_os = "macos"), + not(target_os = "ios"), + not(target_os = "tvos") +))] +pub use others::Verifier; + +#[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))] +mod apple; + +#[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))] +pub use apple::Verifier; + +/// An EKU was invalid for the use case of verifying a server certificate. +/// +/// This error is used primarily for tests. +#[derive(Debug, PartialEq)] +pub(crate) struct EkuError; + +impl std::fmt::Display for EkuError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("certificate had invalid extensions") + } +} + +impl std::error::Error for EkuError {} + +// Log the certificate we are verifying so that we can try and find what may be wrong with it +// if we need to debug a user's situation. +fn log_server_cert(_end_entity: &rustls::pki_types::CertificateDer<'_>) {} + +// Unknown certificate error shorthand. Used when we need to construct an "Other" certificate +// error with a platform specific error message. +#[cfg(any(target_os = "macos", target_os = "ios", target_os = "tvos"))] +fn invalid_certificate(reason: impl Into) -> rustls::Error { + rustls::Error::InvalidCertificate(rustls::CertificateError::Other( + rustls::OtherError(Arc::from(Box::from(reason.into()))), + )) +} + +impl Verifier { + fn get_provider(&self) -> &Arc { + self.crypto_provider.get_or_init(|| { + CryptoProvider::get_default() + .expect("rustls default CryptoProvider not set") + .clone() + }) + } +} diff --git a/rt/src/rustls_platform_verifier/verification/others.rs b/rt/src/rustls_platform_verifier/verification/others.rs new file mode 100644 index 000000000..3b247d936 --- /dev/null +++ b/rt/src/rustls_platform_verifier/verification/others.rs @@ -0,0 +1,166 @@ +use super::log_server_cert; +use once_cell::sync::OnceCell; +use rustls::client::danger::{ + HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier, +}; +use rustls::client::WebPkiServerVerifier; +use rustls::pki_types; +use rustls::{ + crypto::CryptoProvider, CertificateError, DigitallySignedStruct, + Error as TlsError, OtherError, SignatureScheme, +}; +use std::fmt::Debug; +use std::sync::{Arc, Mutex}; + +/// A TLS certificate verifier that uses the system's root store and WebPKI. +#[derive(Debug)] +pub struct Verifier { + // We use a `OnceCell` so we only need + // to try loading native root certs once per verifier. + // + // We currently keep one set of certificates per-verifier so that + // locking and unlocking the application will pull fresh root + // certificates from disk, picking up on any changes + // that might have been made since. + inner: OnceCell>, + + // Extra trust anchors to add to the verifier above and beyond those provided by the + // platform via rustls-native-certs. + extra_roots: Mutex>>, + + pub(super) crypto_provider: OnceCell>, +} + +impl Verifier { + /// Creates a new verifier whose certificate validation is provided by + /// WebPKI, using root certificates provided by the platform. + /// + /// A [`CryptoProvider`] must be set with + /// [`set_provider`][Verifier::set_provider]/[`with_provider`][Verifier::with_provider] or + /// [`CryptoProvider::install_default`] before the verifier can be used. + pub fn new() -> Self { + Self { + inner: OnceCell::new(), + extra_roots: Vec::new().into(), + crypto_provider: OnceCell::new(), + } + } + + fn get_or_init_verifier( + &self, + ) -> Result<&Arc, TlsError> { + self.inner.get_or_try_init(|| self.init_verifier()) + } + + // Attempt to load CA root certificates present on system, fallback to WebPKI roots if error + fn init_verifier(&self) -> Result, TlsError> { + let mut root_store = rustls::RootCertStore::empty(); + + // Safety: There's no way for the mutex to be locked multiple times, so this is + // an infallible operation. + let mut extra_roots = self.extra_roots.try_lock().unwrap(); + if !extra_roots.is_empty() { + root_store.extend(extra_roots.drain(..)); + } + + #[cfg(all( + unix, + not(target_os = "macos"), + not(target_os = "ios"), + not(target_os = "tvos"), + ))] + match rustls_native_certs::load_native_certs() { + Ok(certs) => { + root_store.add_parsable_certificates(certs); + } + Err(err) => { + // This only contains a path to a system directory: + // https://github.com/rustls/rustls-native-certs/blob/bc13b9a6bfc2e1eec881597055ca49accddd972a/src/lib.rs#L91-L94 + const MSG: &str = "failed to load system root certificates: "; + + // Don't return an error if this fails when other roots have already been loaded via + // `new_with_extra_roots`. It leads to extra failure cases where connections would otherwise still work. + if root_store.is_empty() { + return Err(rustls::Error::General(format!("{MSG}{err}"))); + } + } + }; + + WebPkiServerVerifier::builder_with_provider( + root_store.into(), + Arc::clone(self.get_provider()), + ) + .build() + .map_err(|e| TlsError::Other(OtherError(Arc::new(e)))) + } +} + +impl ServerCertVerifier for Verifier { + fn verify_server_cert( + &self, + end_entity: &pki_types::CertificateDer<'_>, + intermediates: &[pki_types::CertificateDer<'_>], + server_name: &pki_types::ServerName, + ocsp_response: &[u8], + now: pki_types::UnixTime, + ) -> Result { + log_server_cert(end_entity); + + self.get_or_init_verifier()? + .verify_server_cert( + end_entity, + intermediates, + server_name, + ocsp_response, + now, + ) + .map_err(map_webpki_errors) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &pki_types::CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + self.get_or_init_verifier()?.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &pki_types::CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + self.get_or_init_verifier()?.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + match self.get_or_init_verifier() { + Ok(v) => v.supported_verify_schemes(), + Err(_) => Vec::default(), + } + } +} + +impl Default for Verifier { + fn default() -> Self { + Self::new() + } +} + +fn map_webpki_errors(err: TlsError) -> TlsError { + if let TlsError::InvalidCertificate(CertificateError::Other(other_err)) = + &err + { + if let Some(webpki::Error::RequiredEkuNotFound) = + other_err.0.downcast_ref::() + { + return TlsError::InvalidCertificate(CertificateError::Other( + OtherError(Arc::new(super::EkuError)), + )); + } + } + + err +} diff --git a/rt/src/scheduler/process.rs b/rt/src/scheduler/process.rs index 5a2cdc4ec..31394f386 100644 --- a/rt/src/scheduler/process.rs +++ b/rt/src/scheduler/process.rs @@ -10,10 +10,12 @@ use crossbeam_utils::atomic::AtomicCell; use crossbeam_utils::thread::scope; use rand::rngs::ThreadRng; use rand::thread_rng; +use std::cell::Cell; use std::cmp::min; use std::collections::VecDeque; use std::mem::{size_of, swap}; use std::ops::Drop; +use std::ptr::null_mut; use std::sync::atomic::{AtomicBool, AtomicU16, AtomicU64, Ordering}; use std::sync::{Condvar, Mutex}; use std::thread::sleep; @@ -86,6 +88,18 @@ const MONITOR_INTERVAL: u64 = 100; /// we perform a number of regular cycles before entering a deep sleep. const MAX_IDLE_CYCLES: u64 = 1_000_000 / MONITOR_INTERVAL; +thread_local! { + /// The process that's currently running. + /// + /// This threat-local should only be used when access to the current process + /// is needed, but the process can't be passed in as an argument. An example + /// is the patched version of rustls-platform-verifier: it needs access to + /// the current process, but the rustls API doesn't make this possible. + pub(crate) static CURRENT_PROCESS: Cell<*mut Process> = const { + Cell::new(null_mut()) + }; +} + pub(crate) fn epoch_loop(state: &State) { while state.scheduler.pool.is_alive() { sleep(Duration::from_millis(EPOCH_INTERVAL)); @@ -533,15 +547,19 @@ impl Thread { match process.next_task() { Task::Resume => { + CURRENT_PROCESS.set(process.as_ptr()); process.resume(state, self); unsafe { context::switch(process) } } Task::Start(func, args) => { + CURRENT_PROCESS.set(process.as_ptr()); process.resume(state, self); unsafe { context::start(process, func, args) } } Task::Wait => return, } + + CURRENT_PROCESS.set(null_mut()); } match self.action.take() { diff --git a/rt/src/socket.rs b/rt/src/socket.rs index 5f881cf23..801bfc7a9 100644 --- a/rt/src/socket.rs +++ b/rt/src/socket.rs @@ -5,7 +5,7 @@ use crate::process::ProcessPointer; use crate::socket::socket_address::SocketAddress; use crate::state::State; use rustix::io::Errno; -use socket2::{Domain, SockAddr, Socket as RawSocket, Type}; +use socket2::{Domain, Protocol, SockAddr, Socket as RawSocket, Type}; use std::io::{self, Read}; use std::mem::transmute; use std::net::Shutdown; @@ -62,42 +62,22 @@ fn encode_sockaddr( /// The slice has enough space to store up to `bytes` of data. fn socket_output_slice(buffer: &mut Vec, bytes: usize) -> &mut [u8] { let len = buffer.len(); - let available = buffer.capacity() - len; - - if bytes > available { - let to_reserve = bytes - available; - - if to_reserve > 0 { - // Only increasing capacity when needed is done for two reasons: - // - // 1. It saves us from increasing capacity when there is enough - // space. - // - // 2. Due to sockets being non-blocking, a socket operation may - // fail. This will result in this code being called multiple - // times. If we were to simply increase capacity every time we'd - // end up growing the buffer much more than necessary. - buffer.reserve_exact(to_reserve); - } - } + buffer.reserve_exact(bytes); unsafe { slice::from_raw_parts_mut(buffer.as_mut_ptr().add(len), bytes) } } -fn update_buffer_length_and_capacity(buffer: &mut Vec, read: usize) { +fn update_buffer_length(buffer: &mut Vec, read: usize) { unsafe { buffer.set_len(buffer.len() + read); } - - buffer.shrink_to_fit(); } fn socket_type(kind: i64) -> io::Result { match kind { 0 => Ok(Type::STREAM), 1 => Ok(Type::DGRAM), - 2 => Ok(Type::SEQPACKET), - 3 => Ok(Type::RAW), + 2 => Ok(Type::RAW), _ => Err(io::Error::new( io::ErrorKind::Other, format!("{} is not a valid socket type", kind), @@ -105,6 +85,37 @@ fn socket_type(kind: i64) -> io::Result { } } +fn socket_protocol(value: i64) -> Option { + if value == 0 { + None + } else { + Some(Protocol::from(value as i32)) + } +} + +pub(crate) fn read_from( + reader: &mut R, + into: &mut Vec, + amount: usize, +) -> io::Result { + if amount > 0 { + // We don't use take(), because that only terminates if: + // + // 1. We hit EOF, or + // 2. We have read the desired number of bytes + // + // For files this is fine, but for sockets EOF is not triggered + // until the socket is closed; which is almost always too late. + let slice = socket_output_slice(into, amount); + let read = reader.read(slice)?; + + update_buffer_length(into, read); + Ok(read) + } else { + Ok(reader.read_to_end(into)?) + } +} + /// A nonblocking socket that can be registered with a `NetworkPoller`. /// /// When changing the layout of this type, don't forget to also update its @@ -133,12 +144,12 @@ impl Socket { pub(crate) fn new( domain: Domain, kind: Type, + protocol: Option, unix: bool, ) -> io::Result { - let socket = RawSocket::new(domain, kind, None)?; + let socket = RawSocket::new(domain, kind, protocol)?; socket.set_nonblocking(true)?; - Ok(Socket { inner: socket, registered: AtomicI8::new(NOT_REGISTERED), @@ -146,17 +157,27 @@ impl Socket { }) } - pub(crate) fn ipv4(kind_int: i64) -> io::Result { - Self::new(Domain::IPV4, socket_type(kind_int)?, false) + pub(crate) fn ipv4(kind_int: i64, protocol: i64) -> io::Result { + Self::new( + Domain::IPV4, + socket_type(kind_int)?, + socket_protocol(protocol), + false, + ) } - pub(crate) fn ipv6(kind_int: i64) -> io::Result { - Self::new(Domain::IPV6, socket_type(kind_int)?, false) + pub(crate) fn ipv6(kind_int: i64, protocol: i64) -> io::Result { + Self::new( + Domain::IPV6, + socket_type(kind_int)?, + socket_protocol(protocol), + false, + ) } #[cfg(unix)] pub(crate) fn unix(kind_int: i64) -> io::Result { - Self::new(Domain::UNIX, socket_type(kind_int)?, true) + Self::new(Domain::UNIX, socket_type(kind_int)?, None, true) } #[cfg(not(unix))] @@ -260,29 +281,6 @@ impl Socket { }) } - pub(crate) fn read( - &self, - buffer: &mut Vec, - amount: usize, - ) -> io::Result { - if amount > 0 { - // We don't use take(), because that only terminates if: - // - // 1. We hit EOF, or - // 2. We have read the desired number of bytes - // - // For files this is fine, but for sockets EOF is not triggered - // until the socket is closed; which is almost always too late. - let slice = socket_output_slice(buffer, amount); - let read = self.inner.recv(unsafe { transmute(slice) })?; - - update_buffer_length_and_capacity(buffer, read); - Ok(read) - } else { - Ok((&self.inner).read_to_end(buffer)?) - } - } - pub(crate) fn recv_from( &self, buffer: &mut Vec, @@ -292,8 +290,7 @@ impl Socket { let (read, sockaddr) = self.inner.recv_from(unsafe { transmute(slice) })?; - update_buffer_length_and_capacity(buffer, read); - + update_buffer_length(buffer, read); decode_sockaddr(sockaddr, self.unix) .map_err(|err| io::Error::new(io::ErrorKind::Other, err)) } @@ -358,7 +355,7 @@ impl io::Write for Socket { impl io::Read for Socket { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.inner.recv(unsafe { transmute(buf) }) + self.inner.read(buf) } } @@ -369,7 +366,7 @@ mod tests { #[test] fn test_try_clone() { - let socket1 = Socket::ipv4(0).unwrap(); + let socket1 = Socket::ipv4(0, 0).unwrap(); socket1.registered.store(2, Ordering::Release); diff --git a/std/fixtures/tls/README.md b/std/fixtures/tls/README.md new file mode 100644 index 000000000..8164aff0d --- /dev/null +++ b/std/fixtures/tls/README.md @@ -0,0 +1,12 @@ +This directory contains an X.509 certificate and private key, and a +configuration file used to generate them. These files are used for running the +tests of std.net.tls. Under no circumstance should you use these certificates +anywhere else. I repeat: + +DO NOT USE THESE CERTIFICATES, THEY ARE FOR TESTING PURPOSES ONLY. + +The certificate and private key are generated as follows: + +```bash +openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout test.key -out test.pem -config test.cnf -sha25 +``` diff --git a/std/fixtures/tls/empty.key b/std/fixtures/tls/empty.key new file mode 100644 index 000000000..e69de29bb diff --git a/std/fixtures/tls/empty.pem b/std/fixtures/tls/empty.pem new file mode 100644 index 000000000..e69de29bb diff --git a/std/fixtures/tls/invalid.key b/std/fixtures/tls/invalid.key new file mode 100644 index 000000000..d45f833be --- /dev/null +++ b/std/fixtures/tls/invalid.key @@ -0,0 +1,24 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCyU87LWHj+r1Dj +dZQtA0EWy5TmBjf8Rr5tv4eiSolFYAD1kfyelutxj6mMzgk1cvePpC8npPzZRhwO +/8vGZshRw2D0Vdgq8XN8B6G35S4PhhdgHsvGMSX6s5EBk4Xj730YKR9NaSpyHSIa +LEYcmM3Ryv+iKEwnEPw+/oFHeW56J197dWiLIHEDM8oC82lX5MT/cYm2QG4CsC/7 +TCDDT1sEVhkdPB60Scks/Ln5ACsWYInH67AvXWDOaBSebBRR+wFGckgu2kldahIc +e0ejawKa5xgvOX00JmQyigDwqm3NMxjdGOg7y27aNWw8Vbo/oRDif775DvtNIn9C +1nJB8wNjAgMBAAECggEAFuAJo2u7yE6HQvOmvdLF4IgzQgHjNaEs4AQqgHGCRAbJ +fgwrinu+mQh+OI8yKYvlYM+FXaxcOPzgMDZflpmXBxICgmVD/6zjDQfSQWDh3zZA +EmGbmdayfK3YeIogKeSN40cHJRV2pJZtyktf9Ql5ls4CVnPyjNewxoiRidsfBlvc +IoCRjiTD6+MHOQjp4AzwVvbXH1Sr7OsngA4glQJjFlXllyVYQNXBr1sWTVl3TS2L +OQfqYzHWlRtty88z7ExK3D03Jz0PD7qWtTwgJkq1ON+PjgCf7rMvZbVmCoUgX759 +LLxY1NE6ogdoRlPyZG3fvSfmxo2PWOqawboZMvirAQKBgQDVwkYwlbY00XuHA4UF +bGwTyk5Yp0/DVSR4jZyAJN6J5xln0JUWYKeps8kxLfDlPgmn6qIaine9Ewf1D9qt +DGr19mzEHzP07OBRo0l34XE5WUMvV4ter3swwmI5w/ysgds7Mz7xA9a69ukF/7SW +C+2RiVW7hJs1pydQw7+NY9YIYwKBgQDVkRv4tvmITtfPpVjAkw0gPIQ2WLw4uIvk +PIX/A58dg952ga+C4MZ7OFtcKI7CF7anr0gCNdGQS6I3SA18YS49U/zycuzPh9v7 +lcIMV+R0Wvo2B6QIGJpt7FzfZBXGdv/ft5l+MII3jpoGqGu1K3Ifj/zUrDlUJDQq +ivrkH+CJAQKBgQDFcbCRugfWW9TlDhw1uUNPOGQLwWeMvr10WSHAv82KxZsS6Hh9 +dgQIXZeuRIgpx5b1smXPbC1TyRtlgiJ0C29VCCzJLyU3zAEbh18aS3PhDBFhzlRe +vmpkzHgccWqYEU5mLVyrFOeoRN9S+jFdE2F6N8en8MHI2kAXeugZeqk9jwKBgCmV +pMWsEzCIcZs8DekJeR/SyMewRY4h2RNq+YhrUxszJykaHWu1itBJa/io6QtABM/n +4HSVuCWJpJ9xBzc10QQeC33GBPhv8tStF2jB4HkLkfbdTAJLkB5hTMAuw9KuLyqH +-----END PRIVATE KEY----- diff --git a/std/fixtures/tls/invalid.pem b/std/fixtures/tls/invalid.pem new file mode 100644 index 000000000..208614514 --- /dev/null +++ b/std/fixtures/tls/invalid.pem @@ -0,0 +1,10 @@ +-----BEGIN CERTIFICATE----- +MIIDvDCCAqSgAwIBAgIUBwbzzsn/P9HmSv1o2tAoIH6ZYE0wDQYJKoZIhvcNAQEL +BQAwaTELMAkGA1UEBhMCTkwxEDAOBgNVBAgMB0V4YW1wbGUxEDAOBgNVBAcMB0V4 +YW1wbGUxEDAOBgNVBAoMB0V4YW1wbGUxEDAOBgNVBAsMB2V4YW1wbGUxEjAQBgNV +BAMMCWxvY2FsaG9zdDAeFw0yNDA3MTcxMzA3MDVaFw0yNTA3MTcxMzA3MDVaMGkx +CzAJBgNVBAYTAk5MMRAwDgYDVQQIDAdFeGFtcGxlMRAwDgYDVQQHDAdFeGFtcGxl +MRAwDgYDVQQKDAdFeGFtcGxlMRAwDgYDVQQLDAdleGFtcGxlMRIwEAYDVQQDDAls +b2NhbGhvc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCyU87LWHj+ +r1DjdZQtA0EWy5TmBjf8Rr5tv4eiSolFYAD1kfyelutxj6mMzgk1cvePpC8npPzZ +-----END CERTIFICATE----- diff --git a/std/fixtures/tls/test.cnf b/std/fixtures/tls/test.cnf new file mode 100644 index 000000000..554df61c3 --- /dev/null +++ b/std/fixtures/tls/test.cnf @@ -0,0 +1,17 @@ +[req] +distinguished_name = req_distinguished_name +x509_extensions = v3_req +prompt = no +[req_distinguished_name] +C = NL +ST = Example +L = Example +O = Example +OU = example +CN = localhost +[v3_req] +keyUsage = critical, digitalSignature, keyAgreement +extendedKeyUsage = serverAuth +subjectAltName = @alt_names +[alt_names] +DNS.1 = localhost diff --git a/std/fixtures/tls/test.key b/std/fixtures/tls/test.key new file mode 100644 index 000000000..4dfecb289 --- /dev/null +++ b/std/fixtures/tls/test.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCyU87LWHj+r1Dj +dZQtA0EWy5TmBjf8Rr5tv4eiSolFYAD1kfyelutxj6mMzgk1cvePpC8npPzZRhwO +/8vGZshRw2D0Vdgq8XN8B6G35S4PhhdgHsvGMSX6s5EBk4Xj730YKR9NaSpyHSIa +LEYcmM3Ryv+iKEwnEPw+/oFHeW56J197dWiLIHEDM8oC82lX5MT/cYm2QG4CsC/7 +TCDDT1sEVhkdPB60Scks/Ln5ACsWYInH67AvXWDOaBSebBRR+wFGckgu2kldahIc +e0ejawKa5xgvOX00JmQyigDwqm3NMxjdGOg7y27aNWw8Vbo/oRDif775DvtNIn9C +1nJB8wNjAgMBAAECggEAFuAJo2u7yE6HQvOmvdLF4IgzQgHjNaEs4AQqgHGCRAbJ +fgwrinu+mQh+OI8yKYvlYM+FXaxcOPzgMDZflpmXBxICgmVD/6zjDQfSQWDh3zZA +EmGbmdayfK3YeIogKeSN40cHJRV2pJZtyktf9Ql5ls4CVnPyjNewxoiRidsfBlvc +IoCRjiTD6+MHOQjp4AzwVvbXH1Sr7OsngA4glQJjFlXllyVYQNXBr1sWTVl3TS2L +OQfqYzHWlRtty88z7ExK3D03Jz0PD7qWtTwgJkq1ON+PjgCf7rMvZbVmCoUgX759 +LLxY1NE6ogdoRlPyZG3fvSfmxo2PWOqawboZMvirAQKBgQDVwkYwlbY00XuHA4UF +bGwTyk5Yp0/DVSR4jZyAJN6J5xln0JUWYKeps8kxLfDlPgmn6qIaine9Ewf1D9qt +DGr19mzEHzP07OBRo0l34XE5WUMvV4ter3swwmI5w/ysgds7Mz7xA9a69ukF/7SW +C+2RiVW7hJs1pydQw7+NY9YIYwKBgQDVkRv4tvmITtfPpVjAkw0gPIQ2WLw4uIvk +PIX/A58dg952ga+C4MZ7OFtcKI7CF7anr0gCNdGQS6I3SA18YS49U/zycuzPh9v7 +lcIMV+R0Wvo2B6QIGJpt7FzfZBXGdv/ft5l+MII3jpoGqGu1K3Ifj/zUrDlUJDQq +ivrkH+CJAQKBgQDFcbCRugfWW9TlDhw1uUNPOGQLwWeMvr10WSHAv82KxZsS6Hh9 +dgQIXZeuRIgpx5b1smXPbC1TyRtlgiJ0C29VCCzJLyU3zAEbh18aS3PhDBFhzlRe +vmpkzHgccWqYEU5mLVyrFOeoRN9S+jFdE2F6N8en8MHI2kAXeugZeqk9jwKBgCmV +pMWsEzCIcZs8DekJeR/SyMewRY4h2RNq+YhrUxszJykaHWu1itBJa/io6QtABM/n +4HSVuCWJpJ9xBzc10QQeC33GBPhv8tStF2jB4HkLkfbdTAJLkB5hTMAuw9KuLyqH +nHqmxWQ9/x3Ww4o2WHVu2wMqOct5dTLnduzejCEBAoGATsOUSv3+Gm5TdhyB64Y9 +eCk+GwSiZuZUsWKLs68wmF4fmKM53rgJK2qZzW8gEpl6hhehhr/XJbJc7jY1Hmmk +567RIPm2hyj7npLo5sCL2moo21j2XJfqvVikHaXPg782e3nqqdnNZmV7+D32tl6M +AwI2G5eWOxC5PQckr7blIpM= +-----END PRIVATE KEY----- diff --git a/std/fixtures/tls/test.pem b/std/fixtures/tls/test.pem new file mode 100644 index 000000000..b1a3a96fb --- /dev/null +++ b/std/fixtures/tls/test.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDvDCCAqSgAwIBAgIUBwbzzsn/P9HmSv1o2tAoIH6ZYE0wDQYJKoZIhvcNAQEL +BQAwaTELMAkGA1UEBhMCTkwxEDAOBgNVBAgMB0V4YW1wbGUxEDAOBgNVBAcMB0V4 +YW1wbGUxEDAOBgNVBAoMB0V4YW1wbGUxEDAOBgNVBAsMB2V4YW1wbGUxEjAQBgNV +BAMMCWxvY2FsaG9zdDAeFw0yNDA3MTcxMzA3MDVaFw0yNTA3MTcxMzA3MDVaMGkx +CzAJBgNVBAYTAk5MMRAwDgYDVQQIDAdFeGFtcGxlMRAwDgYDVQQHDAdFeGFtcGxl +MRAwDgYDVQQKDAdFeGFtcGxlMRAwDgYDVQQLDAdleGFtcGxlMRIwEAYDVQQDDAls +b2NhbGhvc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCyU87LWHj+ +r1DjdZQtA0EWy5TmBjf8Rr5tv4eiSolFYAD1kfyelutxj6mMzgk1cvePpC8npPzZ +RhwO/8vGZshRw2D0Vdgq8XN8B6G35S4PhhdgHsvGMSX6s5EBk4Xj730YKR9NaSpy +HSIaLEYcmM3Ryv+iKEwnEPw+/oFHeW56J197dWiLIHEDM8oC82lX5MT/cYm2QG4C +sC/7TCDDT1sEVhkdPB60Scks/Ln5ACsWYInH67AvXWDOaBSebBRR+wFGckgu2kld +ahIce0ejawKa5xgvOX00JmQyigDwqm3NMxjdGOg7y27aNWw8Vbo/oRDif775DvtN +In9C1nJB8wNjAgMBAAGjXDBaMA4GA1UdDwEB/wQEAwIDiDATBgNVHSUEDDAKBggr +BgEFBQcDATAUBgNVHREEDTALgglsb2NhbGhvc3QwHQYDVR0OBBYEFAMp7neOF3LS +LEM8S3tQfXZeaxbvMA0GCSqGSIb3DQEBCwUAA4IBAQAAlLAzA+qDahYG8TvC1GYY +GO61UWwom/PTNHkr7M0ByNjz5XlPEq5zIWJZpcvHbr29ayh70xiZ2lr0a3xQstQF +8lNo0QgV0rQDqWkTXujMW5qos0NkLHhz/wNa5CLeGLHOxez4Yb3lmHg5n071bqQU +F/AhrLg17klvLB+I9QRpJ5RuGMeml+pJtrdQOXKjttZ+eX6vfI8iHG+5dVn50wYB +dSDVknR4MhaqJfLeXAz1JS1a6OJbSD4J8JglGiXvEMQdeXsmkYMRwIZOoiGFiJpu +OzLLhs1TsfEwTGLiapfCFXO610FVPsVynY4Ylr6LRodiFCzIlx2k9O3p6GRAkKy7 +-----END CERTIFICATE----- diff --git a/std/src/std/crypto/pem.inko b/std/src/std/crypto/pem.inko new file mode 100644 index 000000000..bf7b74671 --- /dev/null +++ b/std/src/std/crypto/pem.inko @@ -0,0 +1,254 @@ +# Parsing of Privacy-Enhanced Main (PEM) files. +# +# This module provides types and methods for parsing PEM files as defined in +# [RFC 7468](https://www.rfc-editor.org/rfc/rfc7468). +# +# # Constant-time parsing +# +# The current implementation does _not_ make use of constant-time operations (in +# the context of cryptography) for parsing, including the base64 encoded data +# found in PEM files. It's not clear if this matters either, as through timing +# attacks one should (in the worst case) only be able to derive the size of the +# base64 encoded data, not the actual data itself. +import std.base64 (Decoder) +import std.crypto.x509 (Certificate, PrivateKey) +import std.fmt (Format, Formatter) +import std.io (BufferedReader, Error, Read) +import std.iter (Iter) + +let LF = 0xA +let DASH = 0x2D + +# A table indexed using bytes, with a value of `true` indicating the byte is a +# whitespace byte. +let WHITESPACE = [ + false, false, false, false, false, false, false, false, false, true, true, + true, true, true, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, true, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, + false, false, +] + +class pub enum Item { + case Certificate(Certificate) + case PrivateKey(PrivateKey) +} + +impl Format for Item { + fn pub fmt(formatter: mut Formatter) { + match self { + case Certificate(v) -> formatter.tuple('Certificate').field(v).finish + case PrivateKey(v) -> formatter.tuple('PrivateKey').field(v).finish + } + } +} + +class pub enum ParseError { + case Read(Error) + case InvalidSectionStart(Int) + case InvalidSectionEnd(Int) + case InvalidBase64(Int) +} + +impl Format for ParseError { + fn pub fmt(formatter: mut Formatter) { + match self { + case Read(e) -> formatter.tuple('Read').field(e).finish + case InvalidSectionStart(v) -> { + formatter.tuple('InvalidSectionStart').field(v).finish + } + case InvalidSectionEnd(v) -> { + formatter.tuple('InvalidSectionEnd').field(v).finish + } + case InvalidBase64(a) -> formatter.tuple('InvalidBase64').field(a).finish + } + } +} + +# A parser/iterator over the sections in a PEM file. +# +# # Examples +# +# ```inko +# import std.crypto.pem (Parser) +# import std.io (Buffer) +# +# let input = ' +# -----BEGIN PRIVATE KEY----- +# aGVsbG8= +# -----END PRIVATE KEY----- +# ' +# let parser = Parser.new(Buffer.new(input)) +# +# parser.next # => Option.Some(Result.Ok(Item::PrivateKey(...))) +# ``` +class pub Parser[I: mut + Read] { + let @input: BufferedReader[I] + let @line: Int + let @buffer: ByteArray + let @decoder: Decoder + + # Returns a new parser that parses data from `input`. + fn pub static new(input: I) -> Parser[I] { + Parser( + input: BufferedReader.new(input), + line: 0, + buffer: ByteArray.new, + decoder: Decoder.new, + ) + } + + fn mut parse -> Result[Option[Item], ParseError] { + # Ensure we always start with a clean buffer (e.g. after a previous call + # produced an error the buffer isn't reset). + @buffer.clear + + let mut start = true + + loop { + match @input.read_byte { + case Ok(Some(LF)) -> { + @line += 1 + start = true + } + case Ok(Some(v)) if WHITESPACE.get(v) -> start = false + case Ok(Some(DASH)) if start -> break + case Ok(Some(_)) -> throw ParseError.InvalidSectionStart(@line) + case Ok(_) -> return Result.Ok(Option.None) + case Error(e) -> throw ParseError.Read(e) + } + } + + try @input.read_exact(into: @buffer, size: 10).map_error(fn (e) { + ParseError.Read(e) + }) + + if @buffer.equals_string?('----BEGIN ').false? { + throw ParseError.InvalidSectionStart(@line) + } + + @buffer.clear + + match @input.read_until(byte: DASH, into: @buffer, inclusive: false) { + case Ok(0) -> throw ParseError.Read(Error.EndOfInput) + case Ok(_) -> {} + case Error(e) -> throw ParseError.Read(e) + } + + if @buffer.empty? { throw ParseError.InvalidSectionStart(@line) } + + let name = @buffer.drain_to_string + + match @input.read_line(into: @buffer, inclusive: false) { + case Ok(0) -> throw ParseError.Read(Error.EndOfInput) + case Ok(_) -> {} + case Error(e) -> throw ParseError.Read(e) + } + + if @buffer.equals_string?('----').false? { + throw ParseError.InvalidSectionStart(@line) + } + + @buffer.clear + + let res = match name { + case 'CERTIFICATE' -> { + Item.Certificate(Certificate.new(try read_base64(name))) + } + case 'PRIVATE KEY' or 'RSA PRIVATE KEY' or 'DSA PRIVATE KEY' -> { + Item.PrivateKey(PrivateKey.new(try read_base64(name))) + } + case _ -> throw ParseError.InvalidSectionStart(@line) + } + + Result.Ok(Option.Some(res)) + } + + fn mut read_base64(name: String) -> Result[ByteArray, ParseError] { + let decoded = ByteArray.new + + loop { + loop { + # Per the RFC, parsers should ignore leading whitespace and newlines on + # each base64 line. + match @input.read_byte { + case Ok(Some(LF)) -> @line += 1 + case Ok(Some(v)) if WHITESPACE.get(v) -> {} + case Ok(Some(v)) -> { + @buffer.push(v) + break + } + case Ok(_) -> throw ParseError.Read(Error.EndOfInput) + case Error(e) -> throw ParseError.Read(e) + } + } + + match @input.read_line(into: @buffer, inclusive: false) { + case Ok(0) -> throw ParseError.Read(Error.EndOfInput) + case Ok(_) -> @line += 1 + case Error(e) -> throw ParseError.Read(e) + } + + if @buffer.get(0) == DASH { + if + (@buffer.starts_with?('-----END ') and @buffer.ends_with?('-----')) + .false? + { + throw ParseError.InvalidSectionEnd(@line) + } + + # Chop off the trailing dashes, such that we can compare the remaining + # tail with the expected section name. + @buffer.resize(@buffer.size - 5, value: 0) + + if @buffer.size != (name.size + 9) or @buffer.ends_with?(name).false? { + throw ParseError.InvalidSectionEnd(@line) + } + + @buffer.clear + break + } + + # The decoder ensures the size is a multiple of 4. In case the line is + # wrapper per RFC 2045, the maximum size is 76 which is still a multiple + # of 4. This means we can decode in chunks without having to worry about + # padding being expected in the middle. + try @decoder.decode(@buffer, into: decoded).map_error(fn (_) { + ParseError.InvalidBase64(@line) + }) + + @buffer.clear + } + + Result.Ok(decoded) + } +} + +impl Iter[Result[Item, ParseError]] for Parser { + fn pub mut next -> Option[Result[Item, ParseError]] { + match parse { + case Ok(Some(v)) -> Option.Some(Result.Ok(v)) + case Ok(_) -> Option.None + case Error(e) -> Option.Some(Result.Error(e)) + } + } +} diff --git a/std/src/std/crypto/x509.inko b/std/src/std/crypto/x509.inko new file mode 100644 index 000000000..09c90c666 --- /dev/null +++ b/std/src/std/crypto/x509.inko @@ -0,0 +1,38 @@ +# X.509 private keys and certificates +import std.fmt (Format, Formatter) + +# An X.509 certificate. +# +# This is currently just an opaque wrapper around a `ByteArray`. +class pub Certificate { + let @bytes: ByteArray + + # Returns a new `Certificate` that wraps the given `ByteArray`. + fn pub static new(bytes: ByteArray) -> Certificate { + Certificate(bytes) + } +} + +impl Format for Certificate { + fn pub fmt(formatter: mut Formatter) { + formatter.write('Certificate(${@bytes.size} bytes)') + } +} + +# An X.509 private key. +# +# This is currently just an opaque wrapper around a `ByteArray`. +class pub PrivateKey { + let @bytes: ByteArray + + # Returns a new `PrivateKey` that wraps the given `ByteArray`. + fn pub static new(bytes: ByteArray) -> PrivateKey { + PrivateKey(bytes) + } +} + +impl Format for PrivateKey { + fn pub fmt(formatter: mut Formatter) { + formatter.write('PrivateKey(${@bytes.size} bytes)') + } +} diff --git a/std/src/std/env.inko b/std/src/std/env.inko index a6b85c756..e8d486783 100644 --- a/std/src/std/env.inko +++ b/std/src/std/env.inko @@ -37,8 +37,6 @@ fn extern inko_env_get(state: Pointer[UInt8], name: String) -> AnyResult fn extern inko_env_get_working_directory(state: Pointer[UInt8]) -> AnyResult -fn extern inko_env_home_directory(state: Pointer[UInt8]) -> AnyResult - fn extern inko_env_set_working_directory(path: String) -> AnyResult fn extern inko_env_temp_directory(state: Pointer[UInt8]) -> String @@ -119,10 +117,17 @@ fn pub variables -> Map[String, String] { # env.home_directory # => Option.Some('/home/alice') # ``` fn pub home_directory -> Option[Path] { - match inko_env_home_directory(_INKO.state) { - case { @tag = 0, @value = val } -> Option.Some(Path.new(val as String)) - case _ -> Option.None + # Rather than performing all sorts of magical incantations to get the home + # directory, we're just going to require that HOME is set. + # + # If the home is explicitly set to an empty string we still ignore it, because + # there's no scenario in which Some("") is useful. + let val = match inko_env_get(_INKO.state, 'HOME') { + case { @tag = 0, @value = val } -> val as String + case _ -> return Option.None } + + if val.size > 0 { Option.Some(Path.new(val)) } else { Option.None } } # Returns the path to the temporary directory. diff --git a/std/src/std/fs/path.inko b/std/src/std/fs/path.inko index 2ba9f9916..da38eca26 100644 --- a/std/src/std/fs/path.inko +++ b/std/src/std/fs/path.inko @@ -1,6 +1,7 @@ # Cross-platform path manipulation. import std.clone (Clone) import std.cmp (Equal) +import std.env (home_directory) import std.fmt (Format, Formatter) import std.fs (DirectoryEntry) import std.hash (Hash, Hasher) @@ -85,6 +86,13 @@ let pub SEPARATOR = '/' # The byte used to represent the path separator. let SEPARATOR_BYTE = 47 +# The character used to signal the user's home directory. +let HOME = '~' + +# The prefix of a path that indicates a path relative to the user's home +# directory. +let HOME_WITH_SEPARATOR = HOME + SEPARATOR + # Returns the number of bytes leading up to the last path separator. # # If no separator could be found, `-1` is returned. @@ -623,11 +631,20 @@ class pub Path { # Returns the canonical, absolute version of `self`. # + # # Resolving home directories + # + # If `self` is equal to `~`, this method returns the path to the user's home + # directory. If `self` starts with `~/`, this prefix is replaced with the path + # to the user's home directory (e.g. `~/foo` becomes `/var/home/alice/foo`). + # # # Errors # - # This method may return an `Error` for cases such as when `self` doesn't - # exist, or when a component that isn't the last component is _not_ a - # directory. + # This method may return an `Error` for cases such as: + # + # - `self` doesn't exist + # - a component that isn't the last component is _not_ a directory + # - `self` is equal to `~` or starts with `~/`, but the home directory can't + # be found (e.g. it doesn't exist) # # # Examples # @@ -635,9 +652,30 @@ class pub Path { # import std.fs.path (Path) # # Path.new('/foo/../bar').expand.get # => Path.new('/bar') + # Path.new('~').expand.get # => '/var/home/...' + # Path.new('~/').expand.get # => '/var/home/...' # ``` fn pub expand -> Result[Path, Error] { - match inko_path_expand(_INKO.state, @path) { + if @path == HOME { + return match home_directory { + case Some(v) -> Result.Ok(v) + case _ -> Result.Error(Error.NotFound) + } + } + + let mut target = @path + + match @path.strip_prefix(HOME_WITH_SEPARATOR) { + case Some(tail) -> { + target = match home_directory { + case Some(v) -> join_strings(v.path, tail) + case _ -> throw Error.NotFound + } + } + case _ -> {} + } + + match inko_path_expand(_INKO.state, target) { case { @tag = 0, @value = v } -> Result.Ok(Path.new(v as String)) case { @tag = _, @value = e } -> { Result.Error(Error.from_os_error(e as Int)) diff --git a/std/src/std/io.inko b/std/src/std/io.inko index ceb54caef..9d32000d2 100644 --- a/std/src/std/io.inko +++ b/std/src/std/io.inko @@ -19,6 +19,15 @@ let MAX_READ_ALL_SIZE = 1024 * 1024 # The default size of the buffer maintained by `BufferedReader`. let DEFAULT_BUFFER_SIZE = 8 * 1024 +# The error code used to signal invalid data. +# +# This error code isn't produced by libc, instead it's specific to the runtime +# library. +let INVALID_DATA = -2 + +# The error code used when encountering an unexpected end of the input. +let UNEXPECTED_EOF = -3 + fn extern inko_last_error -> Int32 # An error type for I/O operations. @@ -85,9 +94,6 @@ class pub enum Error { # The network is down. case NetworkDown - # The network is unreachable. - case NetworkUnreachable - # The resource isn't a directory. case NotADirectory @@ -121,6 +127,18 @@ class pub enum Error { # A memory address used (e.g. as an argument) is in an invalid range. case BadAddress + # The data provided for the operation is invalid, such as when using an + # invalid TLS certificate or when a TLS socket encountered invalid TLS data + # (e.g. an invalid handshake message). + case InvalidData + + # The operation encountered the end of the input stream, but more input is + # required. + # + # An example of where this error is encountered is when reading from a TLS + # socket that was closed without sending the `close_notify` message. + case EndOfInput + # An error not covered by the other variants. # # The wrapped `Int` is the raw error code. @@ -160,7 +178,7 @@ class pub enum Error { case errors.EADDRINUSE -> Error.AddressInUse case errors.EADDRNOTAVAIL -> Error.AddressUnavailable case errors.ENETDOWN -> Error.NetworkDown - case errors.ENETUNREACH -> Error.NetworkUnreachable + case errors.ENETUNREACH -> Error.NetworkDown case errors.ECONNABORTED -> Error.ConnectionAborted case errors.ECONNRESET -> Error.ConnectionReset case errors.EISCONN -> Error.AlreadyConnected @@ -170,6 +188,8 @@ class pub enum Error { case errors.EHOSTUNREACH -> Error.HostUnreachable case errors.EINPROGRESS -> Error.InProgress case errors.EFAULT -> Error.BadAddress + case INVALID_DATA -> Error.InvalidData + case UNEXPECTED_EOF -> Error.EndOfInput case val -> Error.Other(val) } } @@ -202,7 +222,6 @@ impl ToString for Error { case InvalidSeek -> 'the seek operation is invalid' case IsADirectory -> 'the resource is a directory' case NetworkDown -> 'the network is down' - case NetworkUnreachable -> 'the network is unreachable' case NotADirectory -> "the resource isn't a directory" case NotConnected -> "a connection isn't established" case NotFound -> "the resource isn't found" @@ -214,6 +233,10 @@ impl ToString for Error { case TimedOut -> 'the operation timed out' case WouldBlock -> 'the operation would block' case BadAddress -> 'a memory address is in an invalid range' + case InvalidData -> "the data provided isn't valid for the operation" + case EndOfInput -> { + 'the end of the input stream is reached, but more input is required' + } case Other(code) -> 'an other error with code ${code} occurred' } } @@ -241,7 +264,6 @@ impl Format for Error { case InvalidSeek -> 'InvalidSeek' case IsADirectory -> 'IsADirectory' case NetworkDown -> 'NetworkDown' - case NetworkUnreachable -> 'NetworkUnreachable' case NotADirectory -> 'NotADirectory' case NotConnected -> 'NotConnected' case NotFound -> 'NotFound' @@ -253,6 +275,8 @@ impl Format for Error { case TimedOut -> 'TimedOut' case WouldBlock -> 'WouldBlock' case BadAddress -> 'BadAddress' + case InvalidData -> 'InvalidData' + case EndOfInput -> 'EndOfInput' case Other(code) -> { formatter.tuple('Other').field(code).finish return @@ -285,7 +309,6 @@ impl Equal[ref Error] for Error { case (InvalidSeek, InvalidSeek) -> true case (IsADirectory, IsADirectory) -> true case (NetworkDown, NetworkDown) -> true - case (NetworkUnreachable, NetworkUnreachable) -> true case (NotADirectory, NotADirectory) -> true case (NotConnected, NotConnected) -> true case (NotFound, NotFound) -> true @@ -297,6 +320,8 @@ impl Equal[ref Error] for Error { case (TimedOut, TimedOut) -> true case (WouldBlock, WouldBlock) -> true case (Other(a), Other(b)) -> a == b + case (InvalidData, InvalidData) -> true + case (EndOfInput, EndOfInput) -> true case _ -> false } } @@ -324,6 +349,34 @@ trait pub Read { # available (yet). fn pub mut read(into: mut ByteArray, size: Int) -> Result[Int, Error] + # Reads exactly `size` bytes into `into`. + # + # Whereas `Read.read` might return early if fewer bytes are available in the + # input stream, `Read.read_exact` continues reading until the desired amount + # of bytes is read. + # + # # Errors + # + # If the end of the input stream is encountered before filling the buffer, an + # `Error.EndOfInput` error is returned. + # + # If an error is returned, no assumption can be made about the state of the + # `into` buffer, i.e. there's no guarantee data read so far is in the buffer + # in the event of an error. + fn pub mut read_exact(into: mut ByteArray, size: Int) -> Result[Nil, Error] { + let mut pending = size + + while pending > 0 { + match read(into, pending) { + case Ok(0) if pending > 0 -> throw Error.EndOfInput + case Ok(n) -> pending -= n + case Error(e) -> throw e + } + } + + Result.Ok(nil) + } + # Reads from `self` into the given `ByteArray`, returning when all input is # consumed. # diff --git a/std/src/std/net/socket.inko b/std/src/std/net/socket.inko index 42171d2c5..81807b0a3 100644 --- a/std/src/std/net/socket.inko +++ b/std/src/std/net/socket.inko @@ -71,31 +71,22 @@ import std.net.ip (IpAddress) import std.string (ToString) import std.time (Duration, ToInstant) -class extern Linger { - let @l_onoff: Int32 - let @l_linger: Int32 -} - -class extern RawSocket { - let @inner: Int32 - let @registered: UInt8 - let @unix: UInt8 -} - -class extern RawAddress { - let @address: String - let @port: Int -} - -class extern AnyResult { - let @tag: Int - let @value: UInt64 -} +# The maximum value valid for a listen() call. +# +# Linux and FreeBSD do not allow for values greater than this as they internally +# use an u16, so we'll limit the backlog to this value. We don't use SOMAXCONN +# because it might be hardcoded. This means that setting `net.core.somaxconn` on +# Linux (for example) would have no effect. +let MAXIMUM_LISTEN_BACKLOG = 65_535 -class extern IntResult { - let @tag: Int - let @value: Int -} +# A value that signals the lack of a socket deadline. +let NO_DEADLINE = -1 +let IPV4 = 0 +let IPV6 = 1 +let UNIX = 2 +let STREAM = 0 +let DGRAM = 1 +let RAW = 2 fn extern setsockopt( socket: Int32, @@ -114,8 +105,9 @@ fn extern getsockopt( ) -> Int32 fn extern inko_socket_new( - proto: Int, + domain: Int, kind: Int, + protocol: Int, out: Pointer[RawSocket], ) -> Int64 @@ -222,130 +214,146 @@ fn extern inko_socket_shutdown_read_write( socket: Pointer[RawSocket], ) -> IntResult -# The maximum value valid for a listen() call. -# -# Linux and FreeBSD do not allow for values greater than this as they internally -# use an u16, so we'll limit the backlog to this value. We don't use SOMAXCONN -# because it might be hardcoded. This means that setting `net.core.somaxconn` on -# Linux (for example) would have no effect. -let MAXIMUM_LISTEN_BACKLOG = 65_535 - -# A value that signals the lack of a socket deadline. -let NO_DEADLINE = -1 -let IPV4 = 0 -let IPV6 = 1 -let UNIX = 2 +class extern Linger { + let @l_onoff: Int32 + let @l_linger: Int32 +} -# The type of a socket. -class pub enum Type { - # The type corresponding to `SOCK_STREAM`. - case STREAM +class extern RawSocket { + let @inner: Int32 + let @registered: UInt8 + let @unix: UInt8 +} - # The type corresponding to `SOCK_DGRAM`. - case DGRAM +class extern RawAddress { + let @address: String + let @port: Int +} - # The type corresponding to `SOCK_SEQPACKET`. - case SEQPACKET +class extern IntResult { + let @tag: Int + let @value: Int +} - # The type corresponding to `SOCK_RAW`. - case RAW +trait RawSocketOperations { + fn mut raw_socket -> Pointer[RawSocket] - # Converts a `Type` into the underlying `SOCK_*` integer. - fn pub move into_int -> Int { - match self { - case STREAM -> 0 - case DGRAM -> 1 - case SEQPACKET -> 2 - case RAW -> 3 - } - } + fn raw_deadline -> Int } # An IPv4 or IPv6 socket address. class pub SocketAddress { # The IPv4/IPv6 address of this socket address. - # - # This is stored as a `String` so we don't need to parse the address every - # time a `SocketAddress` is created. - let pub @address: String + let pub @ip: IpAddress # The port number of this socket address. let pub @port: Int - fn pub static new(address: String, port: Int) -> SocketAddress { - SocketAddress(address: address, port: port) - } + fn static from_raw(raw: ref RawAddress) -> SocketAddress { + # The address passed to this method is one filled in by the system. Assuming + # our IP address parsing logic is complete (which it should be), + # encountering an address we can't parse is a bug and should terminate the + # program. + let ip = match IpAddress.parse(raw.address) { + case Some(v) -> v + case _ -> panic("IpAddress.parse doesn't support '${raw.address}'") + } - # Returns the IPv4/IPv6 address associated with `self`. - fn pub ip -> Option[IpAddress] { - IpAddress.parse(@address) + SocketAddress(ip, raw.port) } } impl Equal[ref SocketAddress] for SocketAddress { # Returns `true` if `self` and `other` are the same. fn pub ==(other: ref SocketAddress) -> Bool { - @address == other.address and @port == other.port + @ip == other.ip and @port == other.port } } impl Format for SocketAddress { fn pub fmt(formatter: mut Formatter) { - formatter.write('${@address}:${@port}') + formatter.write('${@ip}:${@port}') } } # A low-level, non-blocking IPv4 or IPv6 socket. class pub Socket { - let @raw: RawSocket + let @socket: RawSocket # A point in time after which socket operations time out. # # We use an `Int` to remove the need for using `Option[Instant]`. let @deadline: Int - # Creates a new IPv4 socket. - # - # # Examples - # - # ```inko - # import std.net.socket (Type, Socket) - # - # Socket.ipv4(Type.DGRAM).get - # ``` - fn pub static ipv4(type: Type) -> Result[Socket, Error] { + fn static new( + domain: Int, + type: Int, + protocol: Int, + ) -> Result[Socket, Error] { let sock = RawSocket( inner: 0 as Int32, registered: 0 as UInt8, unix: 0 as UInt8, ) - match inko_socket_new(IPV4, type.into_int, mut sock) as Int { - case 0 -> Result.Ok(Socket(raw: sock, deadline: NO_DEADLINE)) + match inko_socket_new(domain, type, protocol, mut sock) as Int { + case 0 -> Result.Ok(Socket(socket: sock, deadline: NO_DEADLINE)) case e -> Result.Error(Error.from_os_error(e)) } } - # Creates a new IPv6 socket. + # Returns a new `Socket` configured as a stream socket. + # + # The `ipv6` argument specifies if the socket is an IPv4 socket (`false`) or + # an IPv6 socket (`true`). # # # Examples # # ```inko - # import std.net.socket (Type, Socket) + # import std.net.socket (Socket) # - # Socket.ipv6(Type.DGRAM).get + # Socket.stream(ipv6: false) # ``` - fn pub static ipv6(type: Type) -> Result[Socket, Error] { - let sock = RawSocket( - inner: 0 as Int32, - registered: 0 as UInt8, - unix: 0 as UInt8, - ) + fn pub static stream(ipv6: Bool) -> Result[Socket, Error] { + Socket.new(ipv6.to_int, STREAM, protocol: 0) + } - match inko_socket_new(IPV6, type.into_int, mut sock) as Int { - case 0 -> Result.Ok(Socket(raw: sock, deadline: NO_DEADLINE)) - case e -> Result.Error(Error.from_os_error(e)) - } + # Returns a new `Socket` configured as a datagram socket. + # + # The `ipv6` argument specifies if the socket is an IPv4 socket (`false`) or + # an IPv6 socket (`true`). + # + # # Examples + # + # ```inko + # import std.net.socket (Socket) + # + # Socket.datagram(ipv6: false) + # ``` + fn pub static datagram(ipv6: Bool) -> Result[Socket, Error] { + Socket.new(ipv6.to_int, DGRAM, protocol: 0) + } + + # Returns a new `Socket` configured as a raw socket. + # + # The `ipv6` argument specifies if the socket is an IPv4 socket (`false`) or + # an IPv6 socket (`true`). + # + # The `protocol` argument must specify a valid IANA IP protocol as defined in + # RFC 1700. + # + # Note that on certain platforms (e.g. Linux, and probably most other Unix + # systems) you'll need root privileges in order to create a raw socket. + # + # # Examples + # + # ```inko + # import std.net.socket (Socket) + # + # Socket.raw(ipv6: false, protocol: 1) + # ``` + fn pub static raw(ipv6: Bool, protocol: Int) -> Result[Socket, Error] { + Socket.new(ipv6.to_int, RAW, protocol) } # Sets the point in time after which socket operations must time out, known as @@ -357,10 +365,10 @@ class pub Socket { # after which operations time out: # # ```inko - # import std.net.socket (Socket, Type) + # import std.net.socket (Socket) # import std.time (Duration) # - # let socket = Socket.ipv4(Type.DGRAM) + # let socket = Socket.datagram(ipv6: false) # # socket.timeout_after = Duration.from_secs(5) # ``` @@ -368,10 +376,10 @@ class pub Socket { # We can also use an `Instant`: # # ```inko - # import std.net.socket (Socket, Type) + # import std.net.socket (Socket) # import std.time (Duration, Instant) # - # let socket = Socket.ipv4(Type.DGRAM) + # let socket = Socket.datagram(ipv6: false) # # socket.timeout_after = Instant.new + Duration.from_secs(5) # ``` @@ -391,15 +399,15 @@ class pub Socket { # Binding a socket: # # ```inko - # import std.net.socket (Socket, Type) + # import std.net.socket (Socket) # import std.net.ip (IpAddress) # - # let socket = Socket.ipv4(Type.DGRAM).get + # let socket = Socket.datagram(ipv6: false).get # # socket.bind(ip: IpAddress.v4(0, 0, 0, 0), port: 9999).get # ``` fn pub mut bind(ip: ref IpAddress, port: Int) -> Result[Nil, Error] { - match inko_socket_bind(@raw, ip.to_string, port) { + match inko_socket_bind(@socket, ip.to_string, port) { case { @tag = 1, @value = _ } -> Result.Ok(nil) case { @tag = _, @value = e } -> Result.Error(Error.from_os_error(e)) } @@ -412,11 +420,11 @@ class pub Socket { # Connecting a socket: # # ```inko - # import std.net.socket (Socket, Type) + # import std.net.socket (Socket) # import std.net.ip (IpAddress) # - # let listener = Socket.ipv4(Type.STREAM).get - # let client = Socket.ipv4(Type.STREAM).get + # let listener = Socket.stream(ipv6: false).get + # let client = Socket.stream(ipv6: false).get # # socket.bind(ip: IpAddress.v4(0, 0, 0, 0), port: 9999).get # socket.listen.get @@ -427,7 +435,7 @@ class pub Socket { inko_socket_connect( _INKO.state, _INKO.process, - @raw, + @socket, ip.to_string, port, @deadline, @@ -446,16 +454,16 @@ class pub Socket { # Marking a socket as a listener: # # ```inko - # import std.net.socket (Socket, Type) + # import std.net.socket (Socket) # import std.net.ip (IpAddress) # - # let socket = Socket.ipv4(Type.STREAM).get + # let socket = Socket.stream(ipv6: false).get # # socket.bind(ip: IpAddress.v4(0, 0, 0, 0), port: 9999).get # socket.listen.get # ``` fn pub mut listen -> Result[Nil, Error] { - match inko_socket_listen(@raw, MAXIMUM_LISTEN_BACKLOG) { + match inko_socket_listen(@socket, MAXIMUM_LISTEN_BACKLOG) { case { @tag = 1, @value = _ } -> Result.Ok(nil) case { @tag = _, @value = e } -> Result.Error(Error.from_os_error(e)) } @@ -470,11 +478,11 @@ class pub Socket { # Accepting a connection and reading data from the connection: # # ```inko - # import std.net.socket (Socket, Type) + # import std.net.socket (Socket) # import std.net.ip (IpAddress) # - # let listener = Socket.ipv4(Type.STREAM).get - # let stream = Socket.ipv4(Type.STREAM).get + # let listener = Socket.stream(ipv6: false).get + # let stream = Socket.stream(ipv6: false).get # # listener.bind(ip: IpAddress.v4(0, 0, 0, 0), port: 9999).get # listener.listen.get @@ -497,10 +505,16 @@ class pub Socket { ) match - inko_socket_accept(_INKO.state, _INKO.process, @raw, @deadline, mut sock) + inko_socket_accept( + _INKO.state, + _INKO.process, + @socket, + @deadline, + mut sock, + ) as Int { - case 0 -> Result.Ok(Socket(raw: sock, deadline: NO_DEADLINE)) + case 0 -> Result.Ok(Socket(socket: sock, deadline: NO_DEADLINE)) case e -> Result.Error(Error.from_os_error(e)) } } @@ -512,10 +526,10 @@ class pub Socket { # # Examples # # ```inko - # import std.net.socket (Socket, Type) + # import std.net.socket (Socket) # import std.net.ip (IpAddress) # - # let socket = Socket.ipv4(Type.DGRAM).get + # let socket = Socket.datagram(ipv6: false).get # # socket.bind(ip: IpAddress.v4(0, 0, 0, 0), port: 9999).get # socket @@ -535,7 +549,7 @@ class pub Socket { inko_socket_send_string_to( _INKO.state, _INKO.process, - @raw, + @socket, string, ip.to_string, port, @@ -554,10 +568,10 @@ class pub Socket { # # Examples # # ```inko - # import std.net.socket (Socket, Type) + # import std.net.socket (Socket) # import std.net.ip (IpAddress) # - # let socket = Socket.ipv4(Type.DGRAM).get + # let socket = Socket.datagram(ipv6: false).get # let bytes = 'hello'.to_byte_array # # socket.bind(ip: IpAddress.v4(0, 0, 0, 0), port: 9999).get @@ -578,7 +592,7 @@ class pub Socket { inko_socket_send_bytes_to( _INKO.state, _INKO.process, - @raw, + @socket, bytes, ip.to_string, port, @@ -601,10 +615,10 @@ class pub Socket { # Sending a message to ourselves and receiving it: # # ```inko - # import std.net.socket (Socket, Type) + # import std.net.socket (Socket) # import std.net.ip (IpAddress) # - # let socket = Socket.ipv4(Type.DGRAM).get + # let socket = Socket.datagram(ipv6: false).get # let bytes = ByteArray.new # # socket @@ -631,7 +645,7 @@ class pub Socket { inko_socket_receive_from( _INKO.state, _INKO.process, - @raw, + @socket, bytes, size, @deadline, @@ -639,7 +653,7 @@ class pub Socket { ) as Int { - case 0 -> Result.Ok(SocketAddress.new(raw.address, raw.port)) + case 0 -> Result.Ok(SocketAddress.from_raw(raw)) case e -> Result.Error(Error.from_os_error(e)) } } @@ -648,8 +662,8 @@ class pub Socket { fn pub local_address -> Result[SocketAddress, Error] { let raw = RawAddress(address: '', port: 0) - match inko_socket_local_address(_INKO.state, @raw, mut raw) as Int { - case 0 -> Result.Ok(SocketAddress.new(raw.address, raw.port)) + match inko_socket_local_address(_INKO.state, @socket, mut raw) as Int { + case 0 -> Result.Ok(SocketAddress.from_raw(raw)) case e -> Result.Error(Error.from_os_error(e)) } } @@ -658,8 +672,8 @@ class pub Socket { fn pub peer_address -> Result[SocketAddress, Error] { let raw = RawAddress(address: '', port: 0) - match inko_socket_peer_address(_INKO.state, @raw, mut raw) as Int { - case 0 -> Result.Ok(SocketAddress.new(raw.address, raw.port)) + match inko_socket_peer_address(_INKO.state, @socket, mut raw) as Int { + case 0 -> Result.Ok(SocketAddress.from_raw(raw)) case e -> Result.Error(Error.from_os_error(e)) } } @@ -710,7 +724,7 @@ class pub Socket { } let res = setsockopt( - @raw.inner, + @socket.inner, const.SOL_SOCKET as Int32, const.SO_LINGER as Int32, (mut linger) as Pointer[UInt8], @@ -726,7 +740,7 @@ class pub Socket { let linger = Linger(l_onoff: 0 as Int32, l_linger: 0 as Int32) let size = 8 as Int32 let res = getsockopt( - @raw.inner, + @socket.inner, const.SOL_SOCKET as Int32, const.SO_LINGER as Int32, (mut linger) as Pointer[UInt8], @@ -775,7 +789,7 @@ class pub Socket { # Shuts down the reading half of this socket. fn pub mut shutdown_read -> Result[Nil, Error] { - match inko_socket_shutdown_read(@raw) { + match inko_socket_shutdown_read(@socket) { case { @tag = 1, @value = _ } -> Result.Ok(nil) case { @tag = _, @value = e } -> Result.Error(Error.from_os_error(e)) } @@ -783,7 +797,7 @@ class pub Socket { # Shuts down the writing half of this socket. fn pub mut shutdown_write -> Result[Nil, Error] { - match inko_socket_shutdown_write(@raw) { + match inko_socket_shutdown_write(@socket) { case { @tag = 1, @value = _ } -> Result.Ok(nil) case { @tag = _, @value = e } -> Result.Error(Error.from_os_error(e)) } @@ -791,7 +805,7 @@ class pub Socket { # Shuts down both the reading and writing half of this socket. fn pub mut shutdown -> Result[Nil, Error] { - match inko_socket_shutdown_read_write(@raw) { + match inko_socket_shutdown_read_write(@socket) { case { @tag = 1, @value = _ } -> Result.Ok(nil) case { @tag = _, @value = e } -> Result.Error(Error.from_os_error(e)) } @@ -808,8 +822,8 @@ class pub Socket { unix: 0 as UInt8, ) - match inko_socket_try_clone(@raw, mut sock) as Int { - case 0 -> Result.Ok(Socket(raw: sock, deadline: NO_DEADLINE)) + match inko_socket_try_clone(@socket, mut sock) as Int { + case 0 -> Result.Ok(Socket(socket: sock, deadline: NO_DEADLINE)) case e -> Result.Error(Error.from_os_error(e)) } } @@ -817,7 +831,7 @@ class pub Socket { fn mut set_option(level: Int, option: Int, value: Int) -> Result[Nil, Error] { let val = value as Int32 let res = setsockopt( - @raw.inner, + @socket.inner, level as Int32, option as Int32, (mut val) as Pointer[UInt8], @@ -832,7 +846,7 @@ class pub Socket { let size = 4 as Int32 let val = 0 as Int32 let res = getsockopt( - @raw.inner, + @socket.inner, level as Int32, option as Int32, (mut val) as Pointer[UInt8], @@ -852,16 +866,33 @@ class pub Socket { } } +impl RawSocketOperations for Socket { + fn mut raw_socket -> Pointer[RawSocket] { + @socket + } + + fn raw_deadline -> Int { + @deadline + } +} + impl Drop for Socket { fn mut drop { - inko_socket_drop(@raw) + inko_socket_drop(@socket) } } impl Read for Socket { fn pub mut read(into: mut ByteArray, size: Int) -> Result[Int, Error] { match - inko_socket_read(_INKO.state, _INKO.process, @raw, into, size, @deadline) + inko_socket_read( + _INKO.state, + _INKO.process, + @socket, + into, + size, + @deadline, + ) { case { @tag = 0, @value = v } -> Result.Ok(v) case { @tag = _, @value = e } -> Result.Error(Error.from_os_error(e)) @@ -874,7 +905,7 @@ impl WriteInternal for Socket { let state = _INKO.state let proc = _INKO.process - match inko_socket_write(state, proc, @raw, data, size, @deadline) { + match inko_socket_write(state, proc, @socket, data, size, @deadline) { case { @tag = 0, @value = n } -> Result.Ok(n) case { @value = e } -> Result.Error(Error.from_os_error(e)) } @@ -913,16 +944,10 @@ class pub UdpSocket { # import std.net.socket (UdpSocket) # import std.net.ip (IpAddress) # - # let ip = IpAddress.parse('0.0.0.0').get - # - # UdpSocket.new(ip, port: 0).get + # UdpSocket.new(IpAddress.v4(0, 0, 0, 0), port: 0).get # ``` fn pub static new(ip: ref IpAddress, port: Int) -> Result[UdpSocket, Error] { - let socket = if ip.v6? { - try Socket.ipv6(Type.DGRAM) - } else { - try Socket.ipv4(Type.DGRAM) - } + let socket = try Socket.datagram(ip.v6?) try socket.bind(ip, port) Result.Ok(UdpSocket(socket)) @@ -1039,6 +1064,16 @@ class pub UdpSocket { } } +impl RawSocketOperations for UdpSocket { + fn mut raw_socket -> Pointer[RawSocket] { + @socket.socket + } + + fn raw_deadline -> Int { + @socket.deadline + } +} + impl Read for UdpSocket { fn pub mut read(into: mut ByteArray, size: Int) -> Result[Int, Error] { @socket.read(into, size) @@ -1084,16 +1119,10 @@ class pub TcpClient { # import std.net.socket (TcpClient) # import std.net.ip (IpAddress) # - # let ip = IpAddress.parse('127.0.0.1').get - # - # TcpClient.new(ip, port: 40_000).get + # TcpClient.new(IpAddress.v4(127, 0, 0, 1), port: 40_000).get # ``` fn pub static new(ip: ref IpAddress, port: Int) -> Result[TcpClient, Error] { - let socket = if ip.v6? { - try Socket.ipv6(Type.STREAM) - } else { - try Socket.ipv4(Type.STREAM) - } + let socket = try Socket.stream(ip.v6?) try socket.connect(ip, port) from(socket) @@ -1125,13 +1154,9 @@ class pub TcpClient { fn pub static with_timeout[T: ToInstant]( ip: ref IpAddress, port: Int, - timeout_after: T, + timeout_after: ref T, ) -> Result[TcpClient, Error] { - let socket = if ip.v6? { - try Socket.ipv6(Type.STREAM) - } else { - try Socket.ipv4(Type.STREAM) - } + let socket = try Socket.stream(ip.v6?) socket.timeout_after = timeout_after try socket.connect(ip, port) @@ -1177,6 +1202,16 @@ class pub TcpClient { } } +impl RawSocketOperations for TcpClient { + fn mut raw_socket -> Pointer[RawSocket] { + @socket.socket + } + + fn raw_deadline -> Int { + @socket.deadline + } +} + impl Read for TcpClient { fn pub mut read(into: mut ByteArray, size: Int) -> Result[Int, Error] { @socket.read(into, size) @@ -1209,26 +1244,16 @@ class pub TcpServer { # rebinding of sockets. `SO_REUSEPORT` is only used on platforms that support # it. # - # The `only_ipv6` argument is ignored when binding to an IPv4 address. - # # # Examples # - # Creating a `TcpServer`: - # # ```inko - # import std.net.socket (TcpServer) # import std.net.ip (IpAddress) + # import std.net.socket (TcpServer) # - # let ip = IpAddress.parse('0.0.0.0').get - # - # TcpServer.new(ip, port: 40_000).get + # TcpServer.new(IpAddress.v4(0, 0, 0, 0), port: 40_000).get # ``` fn pub static new(ip: ref IpAddress, port: Int) -> Result[TcpServer, Error] { - let socket = if ip.v6? { - try Socket.ipv6(Type.STREAM) - } else { - try Socket.ipv4(Type.STREAM) - } + let socket = try Socket.stream(ip.v6?) try socket.no_delay = true try socket.reuse_address = true @@ -1284,6 +1309,16 @@ class pub TcpServer { } } +impl RawSocketOperations for TcpServer { + fn mut raw_socket -> Pointer[RawSocket] { + @socket.socket + } + + fn raw_deadline -> Int { + @socket.deadline + } +} + # A Unix domain socket address. class pub UnixAddress { # The path or name of the address. @@ -1292,37 +1327,6 @@ class pub UnixAddress { # and unnamed addresses. let pub @address: String - # Creates a new `UnixAddress` from the given path or name. - # - # # Examples - # - # Creating a `UnixAddress` that uses a path: - # - # ```inko - # import std.net.socket (UnixAddress) - # - # UnixAddress.new('/tmp/test.sock'.to_path) - # ``` - # - # Creating a `UnixAddress` that uses an unnamed address: - # - # ```inko - # import std.net.socket (UnixAddress) - # - # UnixAddress.new(''.to_path) - # ``` - # - # Creating a `UnixAddress` that uses an abstract address: - # - # ```inko - # import std.net.socket (UnixAddress) - # - # UnixAddress.new("\0example".to_path) - # ``` - fn pub static new(address: ref Path) -> UnixAddress { - UnixAddress(address.to_string) - } - # Returns the path of this address. # # If the address is unnamed or an abstract address, None is returned. @@ -1419,37 +1423,52 @@ impl ToString for UnixAddress { # A low-level, non-blocking Unix domain socket. class pub UnixSocket { - let @raw: RawSocket + let @socket: RawSocket # A point in time after which socket operations time out. # # We use an `Int` to remove the need for using `Option[Instant]`. let @deadline: Int - # Creates a new Unix domain socket. - # - # # Examples - # - # Creating a new socket: - # - # ```inko - # import std.net.socket (Type, UnixSocket) - # - # UnixSocket.new(Type.DGRAM).get - # ``` - fn pub static new(type: Type) -> Result[UnixSocket, Error] { + fn static new(type: Int) -> Result[UnixSocket, Error] { let sock = RawSocket( inner: 0 as Int32, registered: 0 as UInt8, unix: 1 as UInt8, ) - match inko_socket_new(UNIX, type.into_int, mut sock) as Int { - case 0 -> Result.Ok(UnixSocket(raw: sock, deadline: NO_DEADLINE)) + match inko_socket_new(UNIX, type, 0, mut sock) as Int { + case 0 -> Result.Ok(UnixSocket(socket: sock, deadline: NO_DEADLINE)) case e -> Result.Error(Error.from_os_error(e)) } } + # Returns a new `UnixSocket` configured as a stream socket. + # + # # Examples + # + # ```inko + # import std.net.socket (UnixSocket) + # + # UnixSocket.stream + # ``` + fn pub static stream -> Result[UnixSocket, Error] { + UnixSocket.new(STREAM) + } + + # Returns a new `UnixSocket` configured as a datagram socket. + # + # # Examples + # + # ```inko + # import std.net.socket (UnixSocket) + # + # UnixSocket.datagram + # ``` + fn pub static datagram -> Result[UnixSocket, Error] { + UnixSocket.new(DGRAM) + } + # Sets the point in time after which socket operations must time out, known as # a "deadline". # @@ -1459,10 +1478,10 @@ class pub UnixSocket { # after which operations time out: # # ```inko - # import std.net.socket (UnixSocket, Type) + # import std.net.socket (UnixSocket) # import std.time (Duration) # - # let socket = UnixSocket.new(Type.DGRAM) + # let socket = UnixSocket.datagram # # socket.timeout_after = Duration.from_secs(5) # ``` @@ -1470,10 +1489,10 @@ class pub UnixSocket { # We can also use an `Instant`: # # ```inko - # import std.net.socket (UnixSocket, Type) + # import std.net.socket (UnixSocket) # import std.time (Duration, Instant) # - # let socket = UnixSocket.new(Type.DGRAM) + # let socket = UnixSocket.datagram # # socket.timeout_after = Instant.new + Duration.from_secs(5) # ``` @@ -1493,14 +1512,14 @@ class pub UnixSocket { # Binding a Unix socket to a path: # # ```inko - # import std.net.socket (Type, UnixSocket) + # import std.net.socket (UnixSocket) # - # let socket = UnixSocket.new(Type.DGRAM).get + # let socket = UnixSocket.datagram.get # # socket.bind('/tmp/test.sock'.to_path).get # ``` fn pub mut bind(path: ref Path) -> Result[Nil, Error] { - match inko_socket_bind(@raw, path.to_string, 0) { + match inko_socket_bind(@socket, path.to_string, 0) { case { @tag = 1, @value = _ } -> Result.Ok(nil) case { @tag = _, @value = e } -> Result.Error(Error.from_os_error(e)) } @@ -1513,10 +1532,10 @@ class pub UnixSocket { # Connecting a Unix socket: # # ```inko - # import std.net.socket (Type, UnixSocket) + # import std.net.socket (UnixSocket) # - # let listener = UnixSocket.new(Type.STREAM).get - # let stream = UnixSocket.new(Type.STREAM).get + # let listener = UnixSocket.stream.get + # let stream = UnixSocket.stream.get # # listener.bind('/tmp/test.sock'.to_path).get # listener.listen.get @@ -1528,7 +1547,7 @@ class pub UnixSocket { inko_socket_connect( _INKO.state, _INKO.process, - @raw, + @socket, path.to_string, 0, @deadline, @@ -1547,15 +1566,15 @@ class pub UnixSocket { # Marking a socket as a listener: # # ```inko - # import std.net.socket (Type, UnixSocket) + # import std.net.socket (UnixSocket) # - # let socket = UnixSocket.new(Type.STREAM).get + # let socket = UnixSocket.stream.get # # socket.bind('/tmp/test.sock'.to_path).get # socket.listen.get # ``` fn pub mut listen -> Result[Nil, Error] { - match inko_socket_listen(@raw, MAXIMUM_LISTEN_BACKLOG) { + match inko_socket_listen(@socket, MAXIMUM_LISTEN_BACKLOG) { case { @tag = 1, @value = _ } -> Result.Ok(nil) case { @tag = _, @value = e } -> Result.Error(Error.from_os_error(e)) } @@ -1570,10 +1589,10 @@ class pub UnixSocket { # Accepting a connection and reading data from the connection: # # ```inko - # import std.net.socket (Type, UnixSocket) + # import std.net.socket (UnixSocket) # - # let listener = UnixSocket.new(Type.STREAM).get - # let stream = UnixSocket.new(Type.STREAM).get + # let listener = UnixSocket.stream.get + # let stream = UnixSocket.stream.get # # listener.bind('/tmp/test.sock'.to_path).get # listener.listen.get @@ -1596,10 +1615,16 @@ class pub UnixSocket { ) match - inko_socket_accept(_INKO.state, _INKO.process, @raw, @deadline, mut sock) + inko_socket_accept( + _INKO.state, + _INKO.process, + @socket, + @deadline, + mut sock, + ) as Int { - case 0 -> Result.Ok(UnixSocket(raw: sock, deadline: NO_DEADLINE)) + case 0 -> Result.Ok(UnixSocket(socket: sock, deadline: NO_DEADLINE)) case e -> Result.Error(Error.from_os_error(e)) } } @@ -1611,9 +1636,9 @@ class pub UnixSocket { # # Examples # # ```inko - # import std.net.socket (Type, UnixSocket) + # import std.net.socket (UnixSocket) # - # let socket = UnixSocket.new(Type.DGRAM).get + # let socket = UnixSocket.datagram.get # # socket.bind('/tmp/test.sock'.to_path).get # socket @@ -1630,7 +1655,7 @@ class pub UnixSocket { inko_socket_send_string_to( _INKO.state, _INKO.process, - @raw, + @socket, string, addr, 0, @@ -1649,9 +1674,9 @@ class pub UnixSocket { # # Examples # # ```inko - # import std.net.socket (Type, UnixSocket) + # import std.net.socket (UnixSocket) # - # let socket = UnixSocket.new(Type.DGRAM).get + # let socket = UnixSocket.datagram.get # let bytes = 'hello'.to_byte_array # # socket.bind('/tmp/test.sock'.to_path).get @@ -1669,7 +1694,7 @@ class pub UnixSocket { inko_socket_send_bytes_to( _INKO.state, _INKO.process, - @raw, + @socket, bytes, addr, 0, @@ -1692,9 +1717,9 @@ class pub UnixSocket { # Sending a message to ourselves and receiving it: # # ```inko - # import std.net.socket (Type, UnixSocket) + # import std.net.socket (UnixSocket) # - # let socket = UnixSocket.new(Type.DGRAM).get + # let socket = UnixSocket.datagram.get # let bytes = ByteArray.new # # socket.send_string_to('hello', address: '/tmp/test.sock'.to_path).get @@ -1714,7 +1739,7 @@ class pub UnixSocket { inko_socket_receive_from( _INKO.state, _INKO.process, - @raw, + @socket, bytes, size, @deadline, @@ -1722,7 +1747,7 @@ class pub UnixSocket { ) as Int { - case 0 -> Result.Ok(UnixAddress.new(raw.address.to_path)) + case 0 -> Result.Ok(UnixAddress(raw.address)) case e -> Result.Error(Error.from_os_error(e)) } } @@ -1731,8 +1756,8 @@ class pub UnixSocket { fn pub local_address -> Result[UnixAddress, Error] { let raw = RawAddress(address: '', port: 0) - match inko_socket_local_address(_INKO.state, @raw, mut raw) as Int { - case 0 -> Result.Ok(UnixAddress.new(raw.address.to_path)) + match inko_socket_local_address(_INKO.state, @socket, mut raw) as Int { + case 0 -> Result.Ok(UnixAddress(raw.address)) case e -> Result.Error(Error.from_os_error(e)) } } @@ -1741,8 +1766,8 @@ class pub UnixSocket { fn pub peer_address -> Result[UnixAddress, Error] { let raw = RawAddress(address: '', port: 0) - match inko_socket_peer_address(_INKO.state, @raw, mut raw) as Int { - case 0 -> Result.Ok(UnixAddress.new(raw.address.to_path)) + match inko_socket_peer_address(_INKO.state, @socket, mut raw) as Int { + case 0 -> Result.Ok(UnixAddress(raw.address)) case e -> Result.Error(Error.from_os_error(e)) } } @@ -1759,7 +1784,7 @@ class pub UnixSocket { # Shuts down the reading half of this socket. fn pub mut shutdown_read -> Result[Nil, Error] { - match inko_socket_shutdown_read(@raw) { + match inko_socket_shutdown_read(@socket) { case { @tag = 1, @value = _ } -> Result.Ok(nil) case { @tag = _, @value = e } -> Result.Error(Error.from_os_error(e)) } @@ -1767,7 +1792,7 @@ class pub UnixSocket { # Shuts down the writing half of this socket. fn pub mut shutdown_write -> Result[Nil, Error] { - match inko_socket_shutdown_write(@raw) { + match inko_socket_shutdown_write(@socket) { case { @tag = 1, @value = _ } -> Result.Ok(nil) case { @tag = _, @value = e } -> Result.Error(Error.from_os_error(e)) } @@ -1775,7 +1800,7 @@ class pub UnixSocket { # Shuts down both the reading and writing half of this socket. fn pub mut shutdown -> Result[Nil, Error] { - match inko_socket_shutdown_read_write(@raw) { + match inko_socket_shutdown_read_write(@socket) { case { @tag = 1, @value = _ } -> Result.Ok(nil) case { @tag = _, @value = e } -> Result.Error(Error.from_os_error(e)) } @@ -1792,8 +1817,8 @@ class pub UnixSocket { unix: 1 as UInt8, ) - match inko_socket_try_clone(@raw, mut sock) as Int { - case 0 -> Result.Ok(UnixSocket(raw: sock, deadline: NO_DEADLINE)) + match inko_socket_try_clone(@socket, mut sock) as Int { + case 0 -> Result.Ok(UnixSocket(socket: sock, deadline: NO_DEADLINE)) case e -> Result.Error(Error.from_os_error(e)) } } @@ -1801,7 +1826,7 @@ class pub UnixSocket { fn mut set_option(level: Int, option: Int, value: Int) -> Result[Nil, Error] { let val = value as Int32 let res = setsockopt( - @raw.inner, + @socket.inner, level as Int32, option as Int32, (mut val) as Pointer[UInt8], @@ -1813,16 +1838,33 @@ class pub UnixSocket { } } +impl RawSocketOperations for UnixSocket { + fn mut raw_socket -> Pointer[RawSocket] { + @socket + } + + fn raw_deadline -> Int { + @deadline + } +} + impl Drop for UnixSocket { fn mut drop { - inko_socket_drop(@raw) + inko_socket_drop(@socket) } } impl Read for UnixSocket { fn pub mut read(into: mut ByteArray, size: Int) -> Result[Int, Error] { match - inko_socket_read(_INKO.state, _INKO.process, @raw, into, size, @deadline) + inko_socket_read( + _INKO.state, + _INKO.process, + @socket, + into, + size, + @deadline, + ) { case { @tag = 0, @value = v } -> Result.Ok(v) case { @tag = _, @value = e } -> Result.Error(Error.from_os_error(e)) @@ -1835,7 +1877,7 @@ impl WriteInternal for UnixSocket { let state = _INKO.state let proc = _INKO.process - match inko_socket_write(state, proc, @raw, data, size, @deadline) { + match inko_socket_write(state, proc, @socket, data, size, @deadline) { case { @tag = 0, @value = n } -> Result.Ok(n) case { @value = e } -> Result.Error(Error.from_os_error(e)) } @@ -1873,7 +1915,7 @@ class pub UnixDatagram { # UnixDatagram.new('/tmp/test.sock'.to_path).get # ``` fn pub static new(address: ref Path) -> Result[UnixDatagram, Error] { - let socket = try UnixSocket.new(Type.DGRAM) + let socket = try UnixSocket.datagram try socket.bind(address) Result.Ok(UnixDatagram(socket)) @@ -1973,6 +2015,16 @@ class pub UnixDatagram { } } +impl RawSocketOperations for UnixDatagram { + fn mut raw_socket -> Pointer[RawSocket] { + @socket.socket + } + + fn raw_deadline -> Int { + @socket.deadline + } +} + impl Read for UnixDatagram { fn pub mut read(into: mut ByteArray, size: Int) -> Result[Int, Error] { @socket.read(into, size) @@ -2016,7 +2068,7 @@ class pub UnixClient { # UnixClient.new('/tmp/test.sock'.to_path).get # ``` fn pub static new(address: ref Path) -> Result[UnixClient, Error] { - let socket = try UnixSocket.new(Type.STREAM) + let socket = try UnixSocket.stream try socket.connect(address) Result.Ok(UnixClient(socket)) @@ -2043,11 +2095,11 @@ class pub UnixClient { # ) # .get # ``` - fn pub static with_timeout[I: ToInstant]( + fn pub static with_timeout[T: ToInstant]( address: ref Path, - timeout_after: I, + timeout_after: ref T, ) -> Result[UnixClient, Error] { - let socket = try UnixSocket.new(Type.STREAM) + let socket = try UnixSocket.stream socket.timeout_after = timeout_after try socket.connect(address) @@ -2093,6 +2145,16 @@ class pub UnixClient { } } +impl RawSocketOperations for UnixClient { + fn mut raw_socket -> Pointer[RawSocket] { + @socket.socket + } + + fn raw_deadline -> Int { + @socket.deadline + } +} + impl Read for UnixClient { fn pub mut read(into: mut ByteArray, size: Int) -> Result[Int, Error] { @socket.read(into, size) @@ -2132,7 +2194,7 @@ class pub UnixServer { # UnixServer.new('/tmp/test.sock'.to_path).get # ``` fn pub static new(address: ref Path) -> Result[UnixServer, Error] { - let socket = try UnixSocket.new(Type.STREAM) + let socket = try UnixSocket.stream try socket.bind(address) try socket.listen @@ -2181,3 +2243,13 @@ class pub UnixServer { @socket.try_clone.map(fn (sock) { UnixServer(sock) }) } } + +impl RawSocketOperations for UnixServer { + fn mut raw_socket -> Pointer[RawSocket] { + @socket.socket + } + + fn raw_deadline -> Int { + @socket.deadline + } +} diff --git a/std/src/std/net/tls.inko b/std/src/std/net/tls.inko new file mode 100644 index 000000000..fb9bf9a39 --- /dev/null +++ b/std/src/std/net/tls.inko @@ -0,0 +1,639 @@ +# TLS support for sockets. +# +# This module provides socket support for TLS 1.2 and TLS 1.3. +# +# The two main socket types are `Client` and `Server`, both acting as wrappers +# around existing socket types (e.g. `std.net.socket.TcpClient`) that +# transparently handle TLS encryption and decryption. +# +# For more details on how to set up a client and/or server socket, refer to the +# documentation of `Client.new` and `Server.new`. +# +# # Handling closing of connections +# +# The TLS specification states that clients _should_ send the `close_notify` +# message when they disconnect, but not every TLS implementation/user sends it. +# The `Client` and `Server` types provided by this module automatically send the +# `close_notify` message when they're dropped. +# +# When performing an IO operation on a socket closed without an explicit +# `close_notify` message being sent first, an `Error.EndOfInput` error is +# produced. +# +# When receiving a `close_notify` message during or after an IO operation (e.g. +# a write), a `Error.InvalidData` or `Error.BrokenPipe` error may be produced. +import std.clone (Clone) +import std.cmp (Equal) +import std.crypto.x509 (Certificate, PrivateKey) +import std.drop (Drop) +import std.fmt (Format, Formatter) +import std.io (Error, Read, Write, WriteInternal) +import std.net.socket (RawSocket, RawSocketOperations) +import std.string (ToString) + +# The error code produced when a TLS certificate is invalid. +let INVALID_CERT = -1 + +fn extern inko_tls_client_config_new -> Pointer[UInt8] + +fn extern inko_tls_client_config_clone(config: Pointer[UInt8]) -> Pointer[UInt8] + +fn extern inko_tls_client_config_drop(config: Pointer[UInt8]) + +fn extern inko_tls_client_config_with_certificate( + certificate: ref ByteArray, +) -> AnyResult + +fn extern inko_tls_client_connection_new( + config: Pointer[UInt8], + name: String, +) -> AnyResult + +fn extern inko_tls_client_connection_drop(connection: Pointer[UInt8]) + +fn extern inko_tls_server_config_new( + certificate: ref ByteArray, + key: ref ByteArray, +) -> AnyResult + +fn extern inko_tls_server_config_clone(config: Pointer[UInt8]) -> Pointer[UInt8] + +fn extern inko_tls_server_config_drop(config: Pointer[UInt8]) + +fn extern inko_tls_server_connection_new( + config: Pointer[UInt8], +) -> Pointer[UInt8] + +fn extern inko_tls_server_connection_drop(connection: Pointer[UInt8]) + +fn extern inko_tls_client_read( + state: Pointer[UInt8], + process: Pointer[UInt8], + socket: Pointer[RawSocket], + connection: Pointer[UInt8], + buffer: mut ByteArray, + amount: Int, + deadline: Int, +) -> IntResult + +fn extern inko_tls_client_write( + state: Pointer[UInt8], + process: Pointer[UInt8], + socket: Pointer[RawSocket], + connection: Pointer[UInt8], + data: Pointer[UInt8], + size: Int, + deadline: Int, +) -> IntResult + +fn extern inko_tls_client_flush( + state: Pointer[UInt8], + process: Pointer[UInt8], + socket: Pointer[RawSocket], + connection: Pointer[UInt8], +) -> IntResult + +fn extern inko_tls_client_close( + state: Pointer[UInt8], + process: Pointer[UInt8], + socket: Pointer[RawSocket], + connection: Pointer[UInt8], + deadline: Int, +) -> IntResult + +fn extern inko_tls_server_read( + state: Pointer[UInt8], + process: Pointer[UInt8], + socket: Pointer[RawSocket], + connection: Pointer[UInt8], + buffer: mut ByteArray, + amount: Int, + deadline: Int, +) -> IntResult + +fn extern inko_tls_server_write( + state: Pointer[UInt8], + process: Pointer[UInt8], + socket: Pointer[RawSocket], + connection: Pointer[UInt8], + data: Pointer[UInt8], + size: Int, + deadline: Int, +) -> IntResult + +fn extern inko_tls_server_flush( + state: Pointer[UInt8], + process: Pointer[UInt8], + socket: Pointer[RawSocket], + connection: Pointer[UInt8], +) -> IntResult + +fn extern inko_tls_server_close( + state: Pointer[UInt8], + process: Pointer[UInt8], + socket: Pointer[RawSocket], + connection: Pointer[UInt8], + deadline: Int, +) -> IntResult + +class extern AnyResult { + let @tag: Int + let @value: UInt64 +} + +class extern IntResult { + let @tag: Int + let @value: Int +} + +# An error produced when creating a `ServerConfig`. +class pub enum ServerConfigError { + # The certificate exists but is invalid, such as when it's revoked or not + # encoded correctly. + case InvalidCertificate + + # The private key exists but is invalid. + case InvalidPrivateKey +} + +impl Equal[ref ServerConfigError] for ServerConfigError { + fn pub ==(other: ref ServerConfigError) -> Bool { + match (self, other) { + case (InvalidCertificate, InvalidCertificate) -> true + case (InvalidPrivateKey, InvalidPrivateKey) -> true + case _ -> false + } + } +} + +impl ToString for ServerConfigError { + fn pub to_string -> String { + match self { + case InvalidCertificate -> 'the certificate is invalid' + case InvalidPrivateKey -> 'the private key is invalid' + } + } +} + +impl Format for ServerConfigError { + fn pub fmt(formatter: mut Formatter) { + match self { + case InvalidCertificate -> formatter.tuple('InvalidCertificate').finish + case InvalidPrivateKey -> formatter.tuple('InvalidPrivateKey').finish + } + } +} + +# A type storing the configuration details for TLS clients. +# +# To configure a `Server`, use `ServerConfig` instead. +# +# Creating a `ClientConfig` is potentially expensive, depending on the amount of +# certificates that need to be processed. As such, it's recommended to only +# create a `ClientConfig` once and use `ClientConfig.clone` to clone it whenever +# necessary (e.g. when sharing a `ClientConfig` between processes), as cloning a +# `ClientConfig` is cheap. +class pub ClientConfig { + let @raw: Pointer[UInt8] + + # Returns a new `ClientConfig` that uses the system's certificate store. + # + # # Examples + # + # ```inko + # import std.net.tls (ClientConfig) + # + # ClientConfig.new + # ``` + fn pub static new -> ClientConfig { + ClientConfig(inko_tls_client_config_new) + } + + # Returns a new `ClientConfig` using the specified PEM encoded X.509 + # certificate. + # + # # Errors + # + # If the certificate isn't valid, a `None` is returned. + # + # # Examples + # + # ```inko + # import std.net.tls (ClientConfig) + # import std.crypto.x509 (Certificate) + # + # # In a real program you'd load the certificate from a file or a database. + # let cert = Certificate.new(ByteArray.from_array[1, 2, 3, 4]) + # + # ClientConfig + # .with_certificate(cert) + # .or_panic('the client configuration is invalid') + # ``` + fn pub static with_certificate( + certificate: ref Certificate, + ) -> Option[ClientConfig] { + match inko_tls_client_config_with_certificate(certificate.bytes) { + case { @tag = 0, @value = v } -> { + Option.Some(ClientConfig(v as Pointer[UInt8])) + } + case _ -> Option.None + } + } +} + +impl Drop for ClientConfig { + fn mut drop { + inko_tls_client_config_drop(@raw) + } +} + +impl Clone[ClientConfig] for ClientConfig { + fn pub clone -> ClientConfig { + ClientConfig(inko_tls_client_config_clone(@raw)) + } +} + +# A type that acts as the client in a TLS session. +# +# `Client` values wrap existing sockets such as `std.net.socket.TcpClient` and +# apply TLS encryption/decryption to IO operations. +# +# # Closing TLS connections +# +# When a `Client` is dropped, the TLS connection is closed by sending the TLS +# `close_notify` message. +# +# # Examples +# +# ```inko +# import std.net.ip (IpAddress) +# import std.net.socket (TcpClient) +# import std.net.tls (Client, ClientConfig) +# +# let conf = ClientConfig.new +# let sock = TcpClient +# .new(ip: IpAddress.v4(127, 0, 0, 1), port: 9000) +# .or_panic('failed to connect to the server') +# let client = Client +# .new(socket, conf, name: 'localhost') +# .or_panic('the server name is invalid') +# +# client.write_string('ping').or_panic('failed to write the message') +# +# let response = ByteArray.new +# +# client.read_all(response).or_panic('failed to read the response') +# ``` +class pub Client[T: mut + RawSocketOperations] { + # The socket wrapped by this `Client`. + let pub @socket: T + + # The TLS connection state. + let @state: Pointer[UInt8] + + # Returns a `Client` acting as the client in a TLS session. + # + # The `socket` argument is the socket (e.g. `std.net.socket.TcpClient`) to + # wrap. This can be either an owned socket or a mutable borrow of a socket. + # + # The `name` argument is the DNS name to use for Server Name Indication (SNI). + # Setting this to an IP address disables the use of SNI. In most cases you'll + # want to set this to the DNS name of the server the socket is connecting to. + # + # The `config` argument is a `ClientConfig` instance to use for configuring + # the TLS connection. + # + # # Errors + # + # This method returns an `Option.None` if the `name` argument contains an + # invalid value. + # + # # Examples + # + # ```inko + # import std.net.ip (IpAddress) + # import std.net.socket (TcpClient) + # import std.net.tls (Client, ClientConfig) + # + # let conf = ClientConfig.new + # let sock = TcpClient + # .new(ip: IpAddress.v4(127, 0, 0, 1), port: 9000) + # .or_panic('failed to connect to the server') + # + # Client + # .new(sock, conf, name: 'localhost') + # .or_panic('the server name is invalid') + # ``` + fn pub static new( + socket: T, + config: ref ClientConfig, + name: String, + ) -> Option[Client[T]] { + let state = match inko_tls_client_connection_new(config.raw, name) { + case { @tag = 0, @value = v } -> v as Pointer[UInt8] + case _ -> return Option.None + } + + Option.Some(Client(socket, state)) + } + + # Sends the TLS `close_notify` message to the socket, informing the peer that + # the connection is being closed. + fn mut close -> Result[Nil, Error] { + match + inko_tls_client_close( + _INKO.state, + _INKO.process, + @socket.raw_socket, + @state, + @socket.raw_deadline, + ) + { + case { @tag = 1 } -> Result.Ok(nil) + case { @value = e } -> Result.Error(Error.from_os_error(e)) + } + } +} + +impl Drop for Client { + fn mut drop { + # Per the TLS specification, the connection _should_ be closed explicitly + # when discarding the socket. + let _ = close + + inko_tls_client_connection_drop(@state) + } +} + +impl Read for Client { + fn pub mut read(into: mut ByteArray, size: Int) -> Result[Int, Error] { + match + inko_tls_client_read( + _INKO.state, + _INKO.process, + @socket.raw_socket, + @state, + into, + size, + @socket.raw_deadline, + ) + { + case { @tag = 0, @value = v } -> Result.Ok(v) + case { @tag = _, @value = e } -> Result.Error(Error.from_os_error(e)) + } + } +} + +impl WriteInternal for Client { + fn mut write_internal(data: Pointer[UInt8], size: Int) -> Result[Int, Error] { + match + inko_tls_client_write( + _INKO.state, + _INKO.process, + @socket.raw_socket, + @state, + data, + size, + @socket.raw_deadline, + ) + { + case { @tag = 0, @value = v } -> Result.Ok(v) + case { @value = e } -> Result.Error(Error.from_os_error(e)) + } + } +} + +impl Write for Client { + fn pub mut write_bytes(bytes: ref ByteArray) -> Result[Nil, Error] { + write_all_internal(bytes.to_pointer, bytes.size) + } + + fn pub mut write_string(string: String) -> Result[Nil, Error] { + write_all_internal(string.to_pointer, string.size) + } + + fn pub mut flush -> Result[Nil, Never] { + Result.Ok(nil) + } +} + +# A type storing the configuration details for TLS servers. +# +# To configure a `Client`, use `ClientConfig` instead. +# +# Creating a `ServerConfig` is potentially expensive, depending on the +# certificate and private key that are used. As such, it's recommended to only +# create a `ServerConfig` once and use `ServerConfig.clone` to clone it whenever +# necessary, as cloning a `ServerConfig` is cheap. +class pub ServerConfig { + let @raw: Pointer[UInt8] + + # Returns a new `ClientConfig` using the specified PEM encoded X.509 + # certificate and private key. + # + # # Errors + # + # A `ServerConfigError` is returned if any of the following is true: + # + # - The certificate is invalid + # - The private key is invalid + # + # # Examples + # + # ```inko + # import std.net.tls (ServerConfig) + # import std.crypto.x509 (Certificate, PrivateKey) + # + # let cert = Certificate.new(ByteArray.from_array([1, 2, 3])) + # let key = PrivateKey.new(ByteArray.from_array([4, 5, 6])) + # + # ServerConfig.new(cert, key).or_panic('failed to create the configuration') + # ``` + fn pub static new( + certificate: ref Certificate, + key: ref PrivateKey, + ) -> Result[ServerConfig, ServerConfigError] { + match inko_tls_server_config_new(certificate.bytes, key.bytes) { + case { @tag = 0, @value = v } -> { + Result.Ok(ServerConfig(v as Pointer[UInt8])) + } + case { @value = e } if e as Int == INVALID_CERT -> { + Result.Error(ServerConfigError.InvalidCertificate) + } + case _ -> Result.Error(ServerConfigError.InvalidPrivateKey) + } + } +} + +impl Drop for ServerConfig { + fn mut drop { + inko_tls_server_config_drop(@raw) + } +} + +impl Clone[ServerConfig] for ServerConfig { + fn pub clone -> ServerConfig { + ServerConfig(inko_tls_server_config_clone(@raw)) + } +} + +# A type that acts as the server in a TLS session. +# +# `Server` values wrap existing sockets such as `std.net.socket.TcpClient` and +# apply TLS encryption/decryption to IO operations. +# +# # Closing TLS connections +# +# When a `Client` is dropped the TLS connection is closed by sending the TLS +# `close_notify` message. +# +# # Examples +# +# ```inko +# import std.crypto.x509 (Certificate, PrivateKey) +# import std.net.ip (IpAddress) +# import std.net.socket (TcpServer) +# import std.net.tls (Server, ServerConfig) +# +# let cert = Certificate.new(ByteArray.from_array([1, 2, 3])) +# let key = PrivateKey.new(ByteArray.from_array([4, 5, 6])) +# +# let conf = ServerConfig +# .new(cert, key) +# .or_panic('failed to create the server configuration') +# +# let server = TcpServer +# .new(ip: IpAddress.v4(0, 0, 0, 0), port: 9000) +# .or_panic('failed to start the server') +# +# let con = server +# .accept +# .map(fn (sock) { Server.new(sock, conf) }) +# .or_panic('failed to accept the new connection') +# +# let bytes = ByteArray.new +# +# con.read(into: bytes, size: 32).or_panic('failed to read the data') +# ``` +class pub Server[T: mut + RawSocketOperations] { + # The socket wrapped by this `Server`. + let pub @socket: T + + # The TLS connection state. + let @state: Pointer[UInt8] + + # Returns a `Server` acting as the server in a TLS session. + # + # The `socket` argument is the socket (e.g. `std.net.socket.TcpClient`) to + # wrap. This can be either an owned socket or a mutable borrow of a socket. + # + # The `config` argument is a `ServerConfig` instance to use for configuring + # the TLS connection. + # + # # Examples + # + # ```inko + # import std.crypto.x509 (Certificate, PrivateKey) + # import std.net.ip (IpAddress) + # import std.net.socket (TcpServer) + # import std.net.tls (Server, ServerConfig) + # + # let cert = Certificate.new(ByteArray.from_array([1, 2, 3])) + # let key = PrivateKey.new(ByteArray.from_array([4, 5, 6])) + # + # let conf = ServerConfig + # .new(cert, key) + # .or_panic('failed to create the server configuration') + # + # let server = TcpServer + # .new(ip: IpAddress.v4(0, 0, 0, 0), port: 9000) + # .or_panic('failed to start the server') + # + # server + # .accept + # .map(fn (sock) { Server.new(sock, conf) }) + # .or_panic('failed to accept the new connection') + # ``` + fn pub static new(socket: T, config: ref ServerConfig) -> Server[T] { + Server(socket, inko_tls_server_connection_new(config.raw)) + } + + # Sends the TLS `close_notify` message to the socket, informing the peer that + # the connection is being closed. + fn mut close -> Result[Nil, Error] { + match + inko_tls_server_close( + _INKO.state, + _INKO.process, + @socket.raw_socket, + @state, + @socket.raw_deadline, + ) + { + case { @tag = 1 } -> Result.Ok(nil) + case { @value = e } -> Result.Error(Error.from_os_error(e)) + } + } +} + +impl Drop for Server { + fn mut drop { + # Per the TLS specification, the connection _should_ be closed explicitly + # when discarding the socket. + let _ = close + + inko_tls_server_connection_drop(@state) + } +} + +impl Read for Server { + fn pub mut read(into: mut ByteArray, size: Int) -> Result[Int, Error] { + match + inko_tls_server_read( + _INKO.state, + _INKO.process, + @socket.raw_socket, + @state, + into, + size, + @socket.raw_deadline, + ) + { + case { @tag = 0, @value = v } -> Result.Ok(v) + case { @tag = _, @value = e } -> Result.Error(Error.from_os_error(e)) + } + } +} + +impl WriteInternal for Server { + fn mut write_internal(data: Pointer[UInt8], size: Int) -> Result[Int, Error] { + match + inko_tls_server_write( + _INKO.state, + _INKO.process, + @socket.raw_socket, + @state, + data, + size, + @socket.raw_deadline, + ) + { + case { @tag = 0, @value = v } -> Result.Ok(v) + case { @value = e } -> Result.Error(Error.from_os_error(e)) + } + } +} + +impl Write for Server { + fn pub mut write_bytes(bytes: ref ByteArray) -> Result[Nil, Error] { + write_all_internal(bytes.to_pointer, bytes.size) + } + + fn pub mut write_string(string: String) -> Result[Nil, Error] { + write_all_internal(string.to_pointer, string.size) + } + + fn pub mut flush -> Result[Nil, Never] { + Result.Ok(nil) + } +} diff --git a/std/src/std/option.inko b/std/src/std/option.inko index adea8615b..e78e96d87 100644 --- a/std/src/std/option.inko +++ b/std/src/std/option.inko @@ -191,6 +191,42 @@ class pub enum Option[T] { case None -> true } } + + # Transforms `self` into a `Result[T, E]`, mapping an `Option.Some(T)` to + # `Result.Ok(T)` and a `Option.None` to `Result.Error(E)`. + # + # The argument is eagerly evaluated. If this isn't desired, use + # `Option.ok_or_else` instead. + # + # # Examples + # + # ```inko + # Option.Some(10).ok_or('oops!') # => Result.Ok(10) + # Option.None.ok_or('oops!') # => Result.Error('oops!') + # ``` + fn pub move ok_or[E](error: E) -> Result[T, E] { + match self { + case Some(v) -> Result.Ok(v) + case _ -> Result.Error(error) + } + } + + # Transforms `self` into a `Result[T, E]`, mapping an `Option.Some(T)` to + # `Result.Ok(T)` and a `Option.None` to `Result.Error(E)` where `E` is the + # return value of the given closure. + # + # # Examples + # + # ```inko + # Option.Some(10).ok_or_else(fn { 'oops!' }) # => Result.Ok(10) + # Option.None.ok_or_else(fn { 'oops!' }) # => Result.Error('oops!') + # ``` + fn pub move ok_or_else[E](error: fn -> E) -> Result[T, E] { + match self { + case Some(v) -> Result.Ok(v) + case _ -> Result.Error(error.call) + } + } } impl Option if T: mut { diff --git a/std/src/std/string.inko b/std/src/std/string.inko index 60b5c63ed..11c45c738 100644 --- a/std/src/std/string.inko +++ b/std/src/std/string.inko @@ -480,28 +480,38 @@ class builtin String { # Returns a new `String` without the given prefix. # + # If `self` starts with the prefix, a `Option.Some` is returned containing the + # substring after the prefix. If `self` doesn't start with the prefix, an + # `Option.None` is returned. + # # # Examples # # ```inko - # 'xhellox'.strip_prefix('x') # => 'hellox' + # 'xhellox'.strip_prefix('x') # => Option.Some('hellox') + # 'xhellox'.strip_prefix('y') # => Option.None # ``` - fn pub strip_prefix(prefix: String) -> String { - if starts_with?(prefix).false? { return clone } + fn pub strip_prefix(prefix: String) -> Option[String] { + if starts_with?(prefix).false? { return Option.None } - slice(start: prefix.size, size: size - prefix.size).into_string + Option.Some(slice(start: prefix.size, size: size - prefix.size).into_string) } # Returns a new `String` without the given suffix. # + # If `self` ends with the suffix, a `Option.Some` is returned containing the + # substring before the prefix. If `self` doesn't end with the suffix, an + # `Option.None` is returned. + # # # Examples # # ```inko - # 'xhellox'.strip_suffix('x') # => 'xhello' + # 'xhellox'.strip_suffix('x') # => Option.Some('xhello') + # 'xhellox'.strip_suffix('y') # => Option.None # ``` - fn pub strip_suffix(suffix: String) -> String { - if ends_with?(suffix).false? { return clone } + fn pub strip_suffix(suffix: String) -> Option[String] { + if ends_with?(suffix).false? { return Option.None } - slice(start: 0, size: size - suffix.size).into_string + Option.Some(slice(start: 0, size: size - suffix.size).into_string) } # Returns a new `String` without any leading whitespace. diff --git a/std/test/compiler/test_diagnostics.inko b/std/test/compiler/test_diagnostics.inko index 2588ed4bd..1e43e4725 100644 --- a/std/test/compiler/test_diagnostics.inko +++ b/std/test/compiler/test_diagnostics.inko @@ -215,7 +215,7 @@ class Diagnostic { # We remove the directory leading up to the file, that way the diagnostic # lines in the test file don't need to specify the full file paths, and # debugging failing tests is a little less annoying due to noisy output. - let file = (try string(map, 'file')).strip_prefix('${directory}/') + let file = (try string(map, 'file')).strip_prefix('${directory}/').get let line = try location(map, 'lines') let column = try location(map, 'columns') let message = try string(map, 'message') @@ -277,7 +277,7 @@ fn pub tests(t: mut Tests) { case Error(e) -> panic('failed to read the diagnostics directory: ${e}') } - let name = test_file.tail.strip_suffix('.inko') + let name = test_file.tail.strip_suffix('.inko').get t.test('inko check ${name}', fn move (t) { let file = ReadOnlyFile.new(test_file.clone).or_panic( diff --git a/std/test/std/fs/test_path.inko b/std/test/std/fs/test_path.inko index bb0bcbdf0..6e31649e9 100644 --- a/std/test/std/fs/test_path.inko +++ b/std/test/std/fs/test_path.inko @@ -5,6 +5,7 @@ import std.fs (DirectoryEntry, FileType) import std.fs.file (self, ReadOnlyFile, WriteOnlyFile) import std.fs.path (self, Path) import std.io (Error) +import std.stdio (STDOUT) import std.sys import std.test (Tests) import std.time (DateTime) @@ -153,17 +154,37 @@ fn pub tests(t: mut Tests) { t.test('Path.fmt', fn (t) { t.equal(fmt(Path.new('foo')), '"foo"') }) t.test('Path.expand', fn (t) { - let temp = env.temporary_directory - let bar = temp.join('foo').join('bar') + with_directory(t.id, fn (temp) { + let bar = temp.join('foo').join('bar') - bar.create_directory_all.get + bar.create_directory_all.get - let expanded = bar.join('..').join('..').expand + let expanded = bar.join('..').join('..').expand - t.equal(expanded, Result.Ok(temp)) - bar.remove_directory_all + t.equal(expanded, Result.Ok(temp.clone)) + }) + + t.equal(Path.new('~').expand.ok, env.home_directory) + t.equal(Path.new('~/').expand.ok, env.home_directory) + t.true(Path.new('~foo').expand.error?) + t.true(Path.new('/~').expand.error?) + t.true(Path.new('~/this-directory-should-not-exist').expand.error?) }) + t.fork( + 'Path.expand with a missing home directory', + child: fn { + let out = STDOUT.new + let res = Path.new('~').expand.map(fn (v) { v.to_string }).or('ERROR') + + out.write_string(res) + }, + test: fn (test, proc) { + proc.variable('HOME', '') + test.equal(proc.spawn.stdout, 'ERROR') + }, + ) + t.test('Path.tail', fn (t) { t.equal(Path.new('foo').tail, 'foo') t.equal(Path.new('foo/').tail, 'foo') diff --git a/std/test/std/net/test_socket.inko b/std/test/std/net/test_socket.inko index d1c212940..e07e0f00e 100644 --- a/std/test/std/net/test_socket.inko +++ b/std/test/std/net/test_socket.inko @@ -6,7 +6,7 @@ import std.fs.path (Path) import std.io (Error) import std.net.ip (IpAddress, Ipv4Address, Ipv6Address) import std.net.socket ( - NO_DEADLINE, Socket, SocketAddress, TcpClient, TcpServer, Type, UdpSocket, + NO_DEADLINE, Socket, SocketAddress, TcpClient, TcpServer, UdpSocket, UnixAddress, UnixClient, UnixDatagram, UnixServer, UnixSocket, ) import std.string (ToString) @@ -47,15 +47,15 @@ impl Drop for SocketPath { fn pub tests(t: mut Tests) { t.test('SocketAddress.new', fn (t) { - let addr = SocketAddress.new(address: '127.0.0.1', port: 1234) + let addr = SocketAddress(ip: IpAddress.v4(127, 0, 0, 1), port: 1234) - t.equal(addr.ip, Option.Some(IpAddress.V4(Ipv4Address.new(127, 0, 0, 1)))) + t.equal(addr.ip, IpAddress.V4(Ipv4Address.new(127, 0, 0, 1))) t.equal(addr.port, 1234) }) t.test('SocketAddress.==', fn (t) { - let addr1 = SocketAddress.new(address: '127.0.0.1', port: 1234) - let addr2 = SocketAddress.new(address: '127.0.0.1', port: 4567) + let addr1 = SocketAddress(ip: IpAddress.v4(127, 0, 0, 1), port: 1234) + let addr2 = SocketAddress(ip: IpAddress.v4(127, 0, 0, 1), port: 4567) t.equal(addr1, addr1) t.not_equal(addr1, addr2) @@ -63,47 +63,47 @@ fn pub tests(t: mut Tests) { t.test('SocketAddress.fmt', fn (t) { t.equal( - fmt(SocketAddress.new(address: '127.0.0.1', port: 1234)), + fmt(SocketAddress(ip: IpAddress.v4(127, 0, 0, 1), port: 1234)), '127.0.0.1:1234', ) }) - t.test('Socket.ipv4', fn (t) { - t.true(Socket.ipv4(Type.STREAM).ok?) - t.true(Socket.ipv4(Type.DGRAM).ok?) + t.test('Socket.stream', fn (t) { + t.true(Socket.stream(ipv6: false).ok?) + t.true(Socket.stream(ipv6: true).ok?) }) - t.test('Socket.ipv6', fn (t) { - t.true(Socket.ipv6(Type.STREAM).ok?) - t.true(Socket.ipv6(Type.DGRAM).ok?) + t.test('Socket.datagram', fn (t) { + t.true(Socket.datagram(ipv6: false).ok?) + t.true(Socket.datagram(ipv6: true).ok?) }) t.test('Socket.bind', fn (t) { { - let sock = Socket.ipv4(Type.STREAM).get + let sock = Socket.stream(ipv6: false).get t.true(sock.bind(ip: IpAddress.v4(-1, -1, -1, -1), port: 0).error?) } { - let sock = Socket.ipv4(Type.STREAM).get + let sock = Socket.stream(ipv6: false).get t.true(sock.bind(ip: IpAddress.v4(0, 0, 0, 0), port: 0).ok?) } }) t.test('Socket.connect', fn (t) { - let listener = Socket.ipv4(Type.STREAM).get - let stream1 = Socket.ipv4(Type.STREAM).get + let listener = Socket.stream(ipv6: false).get + let stream1 = Socket.stream(ipv6: false).get listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get listener.listen.get let addr = listener.local_address.get - t.true(stream1.connect(addr.ip.get, addr.port).ok?) + t.true(stream1.connect(addr.ip, addr.port).ok?) - let stream2 = Socket.ipv4(Type.STREAM).get + let stream2 = Socket.stream(ipv6: false).get # connect() may not immediately raise a "connection refused" error, due to # connect() being non-blocking. In this case the "connection refused" error @@ -118,22 +118,22 @@ fn pub tests(t: mut Tests) { }) t.test('Socket.listen', fn (t) { - let socket = Socket.ipv4(Type.STREAM).get + let socket = Socket.stream(ipv6: false).get socket.bind(ip: IpAddress.v4(0, 0, 0, 0), port: 0).get t.true(socket.listen.ok?) }) t.test('Socket.accept', fn (t) { - let server = Socket.ipv4(Type.STREAM).get - let client = Socket.ipv4(Type.STREAM).get + let server = Socket.stream(ipv6: false).get + let client = Socket.stream(ipv6: false).get server.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get server.listen.get let addr = server.local_address.get - t.equal(client.connect(addr.ip.get, addr.port), Result.Ok(nil)) + t.equal(client.connect(addr.ip, addr.port), Result.Ok(nil)) let connection = server.accept.get @@ -141,7 +141,7 @@ fn pub tests(t: mut Tests) { }) t.test('Socket.send_string_to', fn (t) { - let socket = Socket.ipv4(Type.DGRAM).get + let socket = Socket.datagram(ipv6: false).get socket.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get @@ -149,7 +149,7 @@ fn pub tests(t: mut Tests) { let buffer = ByteArray.new t.equal( - socket.send_string_to('ping', send_to.ip.get, send_to.port), + socket.send_string_to('ping', send_to.ip, send_to.port), Result.Ok(4), ) t.equal(socket.read(into: buffer, size: 4), Result.Ok(4)) @@ -157,7 +157,7 @@ fn pub tests(t: mut Tests) { }) t.test('Socket.send_bytes_to', fn (t) { - let socket = Socket.ipv4(Type.DGRAM).get + let socket = Socket.datagram(ipv6: false).get socket.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get @@ -165,11 +165,7 @@ fn pub tests(t: mut Tests) { let buffer = ByteArray.new t.equal( - socket.send_bytes_to( - 'ping'.to_byte_array, - send_to.ip.get, - send_to.port.clone, - ), + socket.send_bytes_to('ping'.to_byte_array, send_to.ip, send_to.port.clone), Result.Ok(4), ) t.equal(socket.read(into: buffer, size: 4), Result.Ok(4)) @@ -177,8 +173,8 @@ fn pub tests(t: mut Tests) { }) t.test('Socket.receive_from', fn (t) { - let listener = Socket.ipv4(Type.DGRAM).get - let client = Socket.ipv4(Type.DGRAM).get + let listener = Socket.datagram(ipv6: false).get + let client = Socket.datagram(ipv6: false).get listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get client.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get @@ -186,7 +182,7 @@ fn pub tests(t: mut Tests) { let send_to = listener.local_address.get t.equal( - client.send_string_to('ping', send_to.ip.get, send_to.port.clone), + client.send_string_to('ping', send_to.ip, send_to.port.clone), Result.Ok(4), ) @@ -197,70 +193,70 @@ fn pub tests(t: mut Tests) { }) t.test('Socket.local_address with an unbound socket', fn (t) { - let socket = Socket.ipv4(Type.DGRAM).get + let socket = Socket.datagram(ipv6: false).get let address = socket.local_address.get - t.equal(address.address, '0.0.0.0') + t.equal(address.ip, IpAddress.v4(0, 0, 0, 0)) t.equal(address.port, 0) }) t.test('Socket.local_address with a bound socket', fn (t) { - let socket = Socket.ipv4(Type.DGRAM).get + let socket = Socket.datagram(ipv6: false).get socket.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get let local_address = socket.local_address.get - t.equal(local_address.address, '127.0.0.1') + t.equal(local_address.ip, IpAddress.v4(127, 0, 0, 1)) t.true(local_address.port > 0) }) t.test('Socket.peer_address with a disconnected socket', fn (t) { - let socket = Socket.ipv4(Type.DGRAM).get + let socket = Socket.datagram(ipv6: false).get t.true(socket.peer_address.error?) }) t.test('Socket.peer_address with a connected socket', fn (t) { - let listener = Socket.ipv4(Type.STREAM).get - let client = Socket.ipv4(Type.STREAM).get + let listener = Socket.stream(ipv6: false).get + let client = Socket.stream(ipv6: false).get listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get listener.listen.get let addr = listener.local_address.get - t.equal(client.connect(addr.ip.get, addr.port), Result.Ok(nil)) + t.equal(client.connect(addr.ip, addr.port), Result.Ok(nil)) t.equal(client.peer_address, Result.Ok(addr)) }) t.test('Socket.ttl=', fn (t) { - let socket = Socket.ipv4(Type.STREAM).get + let socket = Socket.stream(ipv6: false).get t.true((socket.ttl = 10).ok?) }) t.test('Socket.only_ipv6=', fn (t) { - let socket = Socket.ipv6(Type.STREAM).get + let socket = Socket.stream(ipv6: true).get t.true((socket.only_ipv6 = true).ok?) }) t.test('Socket.no_delay=', fn (t) { - let socket = Socket.ipv4(Type.STREAM).get + let socket = Socket.stream(ipv6: false).get t.true((socket.no_delay = true).ok?) t.true(socket.no_delay?) }) t.test('Socket.broadcast=', fn (t) { - let socket = Socket.ipv4(Type.DGRAM).get + let socket = Socket.datagram(ipv6: false).get t.true((socket.broadcast = true).ok?) }) t.test('Socket.linger=', fn (t) { - let socket = Socket.ipv4(Type.STREAM).get + let socket = Socket.stream(ipv6: false).get let duration = Duration.from_secs(5) t.true((socket.linger = Option.Some(duration)).ok?) @@ -271,45 +267,45 @@ fn pub tests(t: mut Tests) { }) t.test('Socket.receive_buffer_size=', fn (t) { - let socket = Socket.ipv4(Type.STREAM).get + let socket = Socket.stream(ipv6: false).get t.true((socket.receive_buffer_size = 256).ok?) }) t.test('Socket.send_buffer_size=', fn (t) { - let socket = Socket.ipv4(Type.STREAM).get + let socket = Socket.stream(ipv6: false).get t.true((socket.send_buffer_size = 256).ok?) }) t.test('Socket.keepalive=', fn (t) { - let socket = Socket.ipv4(Type.STREAM).get + let socket = Socket.stream(ipv6: false).get t.true((socket.keepalive = true).ok?) }) t.test('Socket.reuse_adress=', fn (t) { - let socket = Socket.ipv6(Type.DGRAM).get + let socket = Socket.datagram(ipv6: true).get t.true((socket.reuse_address = true).ok?) }) t.test('Socket.reuse_port=', fn (t) { - let socket = Socket.ipv6(Type.DGRAM).get + let socket = Socket.datagram(ipv6: true).get t.true((socket.reuse_port = true).ok?) }) t.test('Socket.shutdown_read', fn (t) { - let listener = Socket.ipv4(Type.STREAM).get - let stream = Socket.ipv4(Type.STREAM).get + let listener = Socket.stream(ipv6: false).get + let stream = Socket.stream(ipv6: false).get listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get listener.listen.get let addr = listener.local_address.get - t.equal(stream.connect(addr.ip.get, addr.port), Result.Ok(nil)) + t.equal(stream.connect(addr.ip, addr.port), Result.Ok(nil)) stream.shutdown_read.get let bytes = ByteArray.new @@ -319,45 +315,45 @@ fn pub tests(t: mut Tests) { }) t.test('Socket.shutdown_write', fn (t) { - let listener = Socket.ipv4(Type.STREAM).get - let stream = Socket.ipv4(Type.STREAM).get + let listener = Socket.stream(ipv6: false).get + let stream = Socket.stream(ipv6: false).get listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get listener.listen.get let addr = listener.local_address.get - t.equal(stream.connect(addr.ip.get, addr.port), Result.Ok(nil)) + t.equal(stream.connect(addr.ip, addr.port), Result.Ok(nil)) stream.shutdown_write.get t.true(stream.write_string('ping').error?) }) t.test('Socket.shutdown shuts down the writing half', fn (t) { - let listener = Socket.ipv4(Type.STREAM).get - let stream = Socket.ipv4(Type.STREAM).get + let listener = Socket.stream(ipv6: false).get + let stream = Socket.stream(ipv6: false).get listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get listener.listen.get let addr = listener.local_address.get - t.equal(stream.connect(addr.ip.get, addr.port), Result.Ok(nil)) + t.equal(stream.connect(addr.ip, addr.port), Result.Ok(nil)) stream.shutdown.get t.true(stream.write_string('ping').error?) }) t.test('Socket.shutdown shuts down the reading half', fn (t) { - let listener = Socket.ipv4(Type.STREAM).get - let stream = Socket.ipv4(Type.STREAM).get + let listener = Socket.stream(ipv6: false).get + let stream = Socket.stream(ipv6: false).get listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get listener.listen.get let addr = listener.local_address.get - t.equal(stream.connect(addr.ip.get, addr.port), Result.Ok(nil)) + t.equal(stream.connect(addr.ip, addr.port), Result.Ok(nil)) stream.shutdown.get let bytes = ByteArray.new @@ -367,34 +363,34 @@ fn pub tests(t: mut Tests) { }) t.test('Socket.try_clone', fn (t) { - let socket = Socket.ipv4(Type.STREAM).get + let socket = Socket.stream(ipv6: false).get t.true(socket.try_clone.ok?) }) t.test('Socket.read', fn (t) { - let socket = Socket.ipv4(Type.DGRAM).get + let socket = Socket.datagram(ipv6: false).get socket.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get let addr = socket.local_address.get let bytes = ByteArray.new - t.equal(socket.send_string_to('ping', addr.ip.get, addr.port), Result.Ok(4)) + t.equal(socket.send_string_to('ping', addr.ip, addr.port), Result.Ok(4)) t.equal(socket.read(into: bytes, size: 4), Result.Ok(4)) t.equal(bytes.into_string, 'ping') }) t.test('Socket.write_bytes', fn (t) { - let listener = Socket.ipv4(Type.STREAM).get - let stream = Socket.ipv4(Type.STREAM).get + let listener = Socket.stream(ipv6: false).get + let stream = Socket.stream(ipv6: false).get listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get listener.listen.get let addr = listener.local_address.get - t.equal(stream.connect(addr.ip.get, addr.port), Result.Ok(nil)) + t.equal(stream.connect(addr.ip, addr.port), Result.Ok(nil)) stream.write_bytes('ping'.to_byte_array).get let connection = listener.accept.get @@ -405,15 +401,15 @@ fn pub tests(t: mut Tests) { }) t.test('Socket.write_string', fn (t) { - let listener = Socket.ipv4(Type.STREAM).get - let stream = Socket.ipv4(Type.STREAM).get + let listener = Socket.stream(ipv6: false).get + let stream = Socket.stream(ipv6: false).get listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get listener.listen.get let addr = listener.local_address.get - t.equal(stream.connect(addr.ip.get, addr.port), Result.Ok(nil)) + t.equal(stream.connect(addr.ip, addr.port), Result.Ok(nil)) stream.write_string('ping').get let connection = listener.accept.get @@ -424,13 +420,13 @@ fn pub tests(t: mut Tests) { }) t.test('Socket.flush', fn (t) { - let socket = Socket.ipv4(Type.STREAM).get + let socket = Socket.stream(ipv6: false).get t.equal(socket.flush, Result.Ok(nil)) }) t.test('Socket.timeout_after=', fn (t) { - let server = Socket.ipv4(Type.STREAM).get + let server = Socket.stream(ipv6: false).get server.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get server.listen.get @@ -440,7 +436,7 @@ fn pub tests(t: mut Tests) { }) t.test('Socket.timeout_after= with a read after a timeout', fn (t) { - let server = Socket.ipv4(Type.STREAM).get + let server = Socket.stream(ipv6: false).get let bytes = ByteArray.new server.timeout_after = Duration.from_secs(0) @@ -450,7 +446,7 @@ fn pub tests(t: mut Tests) { }) t.test('Socket.reset_deadline', fn (t) { - let socket = Socket.ipv4(Type.STREAM).get + let socket = Socket.stream(ipv6: false).get socket.timeout_after = Duration.from_secs(10) t.true(socket.deadline > 0) @@ -477,7 +473,7 @@ fn pub tests(t: mut Tests) { let socket2 = UdpSocket.new(ip: ip, port: 0).get let addr = socket2.local_address.get - t.true(socket1.connect(addr.ip.get, addr.port).ok?) + t.true(socket1.connect(addr.ip, addr.port).ok?) }) t.test('UdpSocket.send_string_to', fn (t) { @@ -487,7 +483,7 @@ fn pub tests(t: mut Tests) { let addr = socket.local_address.get - t.equal(socket.send_string_to('ping', addr.ip.get, addr.port), Result.Ok(4)) + t.equal(socket.send_string_to('ping', addr.ip, addr.port), Result.Ok(4)) let bytes = ByteArray.new @@ -500,7 +496,7 @@ fn pub tests(t: mut Tests) { let socket = UdpSocket.new(ip: ip, port: 0).get let addr = socket.local_address.get - socket.send_bytes_to('ping'.to_byte_array, addr.ip.get, addr.port).get + socket.send_bytes_to('ping'.to_byte_array, addr.ip, addr.port).get let bytes = ByteArray.new @@ -514,7 +510,7 @@ fn pub tests(t: mut Tests) { let client = UdpSocket.new(ip: ip, port: 0).get let addr = listener.local_address.get - client.send_string_to('ping', addr.ip.get, addr.port).get + client.send_string_to('ping', addr.ip, addr.port).get let bytes = ByteArray.new @@ -527,7 +523,7 @@ fn pub tests(t: mut Tests) { let socket = UdpSocket.new(ip: ip, port: 0).get let local_address = socket.local_address.get - t.equal(local_address.address, '127.0.0.1') + t.equal(local_address.ip, IpAddress.v4(127, 0, 0, 1)) t.true(local_address.port > 0) }) @@ -543,7 +539,7 @@ fn pub tests(t: mut Tests) { let socket = UdpSocket.new(ip: ip, port: 0).get let addr = socket.local_address.get - t.equal(socket.send_string_to('ping', addr.ip.get, addr.port), Result.Ok(4)) + t.equal(socket.send_string_to('ping', addr.ip, addr.port), Result.Ok(4)) let bytes = ByteArray.new @@ -557,7 +553,7 @@ fn pub tests(t: mut Tests) { let client_socket = UdpSocket.new(ip: ip, port: 0).get let addr = server_socket.local_address.get - t.equal(client_socket.connect(addr.ip.get, addr.port), Result.Ok(nil)) + t.equal(client_socket.connect(addr.ip, addr.port), Result.Ok(nil)) client_socket.write_bytes('ping'.to_byte_array).get let bytes = ByteArray.new @@ -574,20 +570,20 @@ fn pub tests(t: mut Tests) { }) t.ok('TcpClient.new', fn (t) { - let listener = Socket.ipv4(Type.STREAM).get + let listener = Socket.stream(ipv6: false).get listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get listener.listen.get let addr = listener.local_address.get - let client = try TcpClient.new(addr.ip.get, addr.port) + let client = try TcpClient.new(addr.ip, addr.port) t.true(client.socket.no_delay?) Result.Ok(nil) }) t.ok('TcpClient.with_timeout', fn (t) { - let listener = Socket.ipv4(Type.STREAM).get + let listener = Socket.stream(ipv6: false).get listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0).get listener.listen.get @@ -596,7 +592,7 @@ fn pub tests(t: mut Tests) { { let client = try TcpClient.with_timeout( - addr.ip.get, + addr.ip, addr.port, timeout_after: Duration.from_secs(2), ) @@ -611,48 +607,56 @@ fn pub tests(t: mut Tests) { timeout_after: Duration.from_micros(500), ) - t.equal(timed_out.error, Option.Some(Error.TimedOut)) + # If no internet connection is available the error is NetworkDown + # instead, so we have to account for that. + t.true( + match timed_out { + case Error(TimedOut or NetworkDown) -> true + case _ -> false + }, + ) + Result.Ok(nil) }) t.ok('TcpClient.local_address', fn (t) { - let listener = try Socket.ipv4(Type.STREAM) + let listener = try Socket.stream(ipv6: false) try listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0) try listener.listen let addr = try listener.local_address - let stream = try TcpClient.new(addr.ip.get, addr.port) + let stream = try TcpClient.new(addr.ip, addr.port) let local_addr = try stream.local_address - t.equal(local_addr.address, '127.0.0.1') + t.equal(local_addr.ip, IpAddress.v4(127, 0, 0, 1)) t.true(local_addr.port > 0) Result.Ok(nil) }) t.ok('TcpClient.peer_address', fn (t) { - let listener = try Socket.ipv4(Type.STREAM) + let listener = try Socket.stream(ipv6: false) try listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0) try listener.listen let addr = try listener.local_address - let stream = try TcpClient.new(addr.ip.get, addr.port) + let stream = try TcpClient.new(addr.ip, addr.port) let peer_addr = try stream.peer_address - t.equal(peer_addr.address, addr.address) + t.equal(peer_addr.ip, addr.ip) t.equal(peer_addr.port, addr.port) Result.Ok(nil) }) t.ok('TcpClient.read', fn (t) { - let listener = try Socket.ipv4(Type.STREAM) + let listener = try Socket.stream(ipv6: false) try listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0) try listener.listen let addr = try listener.local_address - let stream = try TcpClient.new(addr.ip.get, addr.port) + let stream = try TcpClient.new(addr.ip, addr.port) let bytes = ByteArray.new let client = try listener.accept @@ -664,13 +668,13 @@ fn pub tests(t: mut Tests) { }) t.ok('TcpClient.write_bytes', fn (t) { - let listener = try Socket.ipv4(Type.STREAM) + let listener = try Socket.stream(ipv6: false) try listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0) try listener.listen let addr = try listener.local_address - let stream = try TcpClient.new(addr.ip.get, addr.port) + let stream = try TcpClient.new(addr.ip, addr.port) let connection = try listener.accept let bytes = ByteArray.new @@ -681,13 +685,13 @@ fn pub tests(t: mut Tests) { }) t.ok('TcpClient.write_string', fn (t) { - let listener = try Socket.ipv4(Type.STREAM) + let listener = try Socket.stream(ipv6: false) try listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0) try listener.listen let addr = try listener.local_address - let stream = try TcpClient.new(addr.ip.get, addr.port) + let stream = try TcpClient.new(addr.ip, addr.port) let connection = try listener.accept let bytes = ByteArray.new @@ -698,25 +702,25 @@ fn pub tests(t: mut Tests) { }) t.ok('TcpClient.flush', fn (t) { - let listener = try Socket.ipv4(Type.STREAM) + let listener = try Socket.stream(ipv6: false) try listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0) try listener.listen let addr = try listener.local_address - let stream = try TcpClient.new(addr.ip.get, addr.port) + let stream = try TcpClient.new(addr.ip, addr.port) stream.flush }) t.ok('TcpClient.shutdown_read', fn (t) { - let listener = try Socket.ipv4(Type.STREAM) + let listener = try Socket.stream(ipv6: false) try listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0) try listener.listen let addr = try listener.local_address - let stream = try TcpClient.new(addr.ip.get, addr.port) + let stream = try TcpClient.new(addr.ip, addr.port) t.equal(stream.shutdown_read, Result.Ok(nil)) @@ -728,13 +732,13 @@ fn pub tests(t: mut Tests) { }) t.ok('TcpClient.shutdown_write', fn (t) { - let listener = try Socket.ipv4(Type.STREAM) + let listener = try Socket.stream(ipv6: false) try listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0) try listener.listen let addr = try listener.local_address - let stream = try TcpClient.new(addr.ip.get, addr.port) + let stream = try TcpClient.new(addr.ip, addr.port) t.equal(stream.shutdown_write, Result.Ok(nil)) t.true(stream.write_string('ping').error?) @@ -742,13 +746,13 @@ fn pub tests(t: mut Tests) { }) t.ok('TcpClient.shutdown shuts down the writing half', fn (t) { - let listener = try Socket.ipv4(Type.STREAM) + let listener = try Socket.stream(ipv6: false) try listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0) try listener.listen let addr = try listener.local_address - let stream = try TcpClient.new(addr.ip.get, addr.port) + let stream = try TcpClient.new(addr.ip, addr.port) t.equal(stream.shutdown, Result.Ok(nil)) t.true(stream.write_string('ping').error?) @@ -756,13 +760,13 @@ fn pub tests(t: mut Tests) { }) t.ok('TcpClient.shutdown shuts down the reading half', fn (t) { - let listener = try Socket.ipv4(Type.STREAM) + let listener = try Socket.stream(ipv6: false) try listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0) try listener.listen let addr = try listener.local_address - let stream = try TcpClient.new(addr.ip.get, addr.port) + let stream = try TcpClient.new(addr.ip, addr.port) t.equal(stream.shutdown, Result.Ok(nil)) @@ -774,13 +778,13 @@ fn pub tests(t: mut Tests) { }) t.ok('TcpClient.try_clone', fn (t) { - let listener = Socket.ipv4(Type.STREAM).get + let listener = Socket.stream(ipv6: false).get try listener.bind(ip: IpAddress.v4(127, 0, 0, 1), port: 0) try listener.listen let addr = try listener.local_address - let client = try TcpClient.new(addr.ip.get, addr.port) + let client = try TcpClient.new(addr.ip, addr.port) let clone = try client.try_clone t.true(clone.socket.no_delay?) @@ -795,14 +799,14 @@ fn pub tests(t: mut Tests) { let listener = TcpServer.new(ip: ip, port: 0).get let addr = listener.local_address.get - t.true(TcpServer.new(addr.ip.get, addr.port).ok?) + t.true(TcpServer.new(addr.ip, addr.port).ok?) }) t.ok('TcpServer.accept', fn (t) { let ip = IpAddress.V4(Ipv4Address.new(127, 0, 0, 1)) let listener = try TcpServer.new(ip: ip, port: 0) let addr = try listener.local_address - let stream = try TcpClient.new(addr.ip.get, addr.port) + let stream = try TcpClient.new(addr.ip, addr.port) let connection = try listener.accept t.equal(connection.local_address, stream.peer_address) @@ -815,7 +819,7 @@ fn pub tests(t: mut Tests) { let listener = TcpServer.new(ip: ip, port: 0).get let addr = listener.local_address.get - t.equal(addr.address, '127.0.0.1') + t.equal(addr.ip, IpAddress.v4(127, 0, 0, 1)) t.true(addr.port > 0) }) @@ -830,78 +834,69 @@ fn pub tests(t: mut Tests) { }) t.test('UnixAddress.to_path', fn (t) { - t.equal( - UnixAddress.new('foo.sock'.to_path).to_path, - Option.Some('foo.sock'.to_path), - ) - t.true(UnixAddress.new('\0foo'.to_path).to_path.none?) - t.true(UnixAddress.new(''.to_path).to_path.none?) + t.equal(UnixAddress('foo.sock').to_path, Option.Some('foo.sock'.to_path)) + t.true(UnixAddress('\0foo').to_path.none?) + t.true(UnixAddress('').to_path.none?) }) t.test('UnixAddress.to_string', fn (t) { - t.equal(UnixAddress.new('foo.sock'.to_path).to_string, 'foo.sock') - t.equal(UnixAddress.new('\0foo'.to_path).to_string, '\0foo') - t.equal(UnixAddress.new(''.to_path).to_string, '') + t.equal(UnixAddress('foo.sock').to_string, 'foo.sock') + t.equal(UnixAddress('\0foo').to_string, '\0foo') + t.equal(UnixAddress('').to_string, '') }) t.test('UnixAddress.abstract?', fn (t) { - t.false(UnixAddress.new(''.to_path).abstract?) - t.false(UnixAddress.new('foo.sock'.to_path).abstract?) - t.true(UnixAddress.new('\0foo'.to_path).abstract?) + t.false(UnixAddress('').abstract?) + t.false(UnixAddress('foo.sock').abstract?) + t.true(UnixAddress('\0foo').abstract?) }) t.test('UnixAddress.unnamed?', fn (t) { - t.false(UnixAddress.new('foo.sock'.to_path).unnamed?) - t.false(UnixAddress.new('\0foo'.to_path).unnamed?) - t.true(UnixAddress.new(''.to_path).unnamed?) + t.false(UnixAddress('foo.sock').unnamed?) + t.false(UnixAddress('\0foo').unnamed?) + t.true(UnixAddress('').unnamed?) }) t.test('UnixAddress.fmt', fn (t) { - t.equal(fmt(UnixAddress.new('foo.sock'.to_path)), 'foo.sock') - t.equal(fmt(UnixAddress.new('\0foo'.to_path)), '@foo') - t.equal(fmt(UnixAddress.new(''.to_path)), 'unnamed') + t.equal(fmt(UnixAddress('foo.sock')), 'foo.sock') + t.equal(fmt(UnixAddress('\0foo')), '@foo') + t.equal(fmt(UnixAddress('')), 'unnamed') }) t.test('UnixAddress.==', fn (t) { - t.equal( - UnixAddress.new('a.sock'.to_path), - UnixAddress.new('a.sock'.to_path), - ) - t.not_equal( - UnixAddress.new('a.sock'.to_path), - UnixAddress.new('b.sock'.to_path), - ) + t.equal(UnixAddress('a.sock'), UnixAddress('a.sock')) + t.not_equal(UnixAddress('a.sock'), UnixAddress('b.sock')) }) t.test('UnixAddress.to_string', fn (t) { - t.equal(UnixAddress.new('foo.sock'.to_path).to_string, 'foo.sock') - t.equal(UnixAddress.new('\0foo'.to_path).to_string, '\0foo') - t.equal(UnixAddress.new(''.to_path).to_string, '') + t.equal(UnixAddress('foo.sock').to_string, 'foo.sock') + t.equal(UnixAddress('\0foo').to_string, '\0foo') + t.equal(UnixAddress('').to_string, '') }) t.test('UnixSocket.new', fn (t) { - t.true(UnixSocket.new(Type.DGRAM).ok?) - t.true(UnixSocket.new(Type.STREAM).ok?) + t.true(UnixSocket.datagram.ok?) + t.true(UnixSocket.stream.ok?) }) t.test('UnixSocket.bind', fn (t) { - let socket1 = UnixSocket.new(Type.STREAM).get - let socket2 = UnixSocket.new(Type.STREAM).get + let socket1 = UnixSocket.stream.get + let socket2 = UnixSocket.stream.get let path = SocketPath.new(t.id) t.true(socket1.bind(path.path).ok?) t.true(socket2.bind(path.path).error?) if env.OS == 'linux' { - let socket = UnixSocket.new(Type.STREAM).get + let socket = UnixSocket.stream.get t.true(socket.bind('\0inko-test-${t.id}'.to_path).ok?) } }) t.test('UnixSocket.connect', fn (t) { - let listener = UnixSocket.new(Type.STREAM).get - let stream = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get + let stream = UnixSocket.stream.get let path = SocketPath.new(t.id) t.true(stream.connect(path.path).error?) @@ -912,8 +907,8 @@ fn pub tests(t: mut Tests) { t.true(stream.connect(path.path).ok?) if env.OS == 'linux' { - let listener = UnixSocket.new(Type.STREAM).get - let stream = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get + let stream = UnixSocket.stream.get let addr = '\0inko-test-${t.id}' listener.bind(addr.to_path).get @@ -925,7 +920,7 @@ fn pub tests(t: mut Tests) { t.test('UnixSocket.listen', fn (t) { let path = SocketPath.new(t.id) - let socket = UnixSocket.new(Type.STREAM).get + let socket = UnixSocket.stream.get socket.bind(path.path).get @@ -934,8 +929,8 @@ fn pub tests(t: mut Tests) { t.test('UnixSocket.accept', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get - let stream = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get + let stream = UnixSocket.stream.get listener.bind(path.path).get listener.listen.get @@ -948,7 +943,7 @@ fn pub tests(t: mut Tests) { t.test('UnixSocket.send_string_to', fn (t) { let path = SocketPath.new(t.id) - let socket = UnixSocket.new(Type.DGRAM).get + let socket = UnixSocket.datagram.get t.equal(socket.bind(path.path), Result.Ok(nil)) t.equal(socket.send_string_to('ping', path.path), Result.Ok(4)) @@ -961,7 +956,7 @@ fn pub tests(t: mut Tests) { t.test('UnixSocket.send_bytes_to', fn (t) { let path = SocketPath.new(t.id) - let socket = UnixSocket.new(Type.DGRAM).get + let socket = UnixSocket.datagram.get t.equal(socket.bind(path.path), Result.Ok(nil)) t.equal(socket.send_bytes_to('ping'.to_byte_array, path.path), Result.Ok(4)) @@ -974,8 +969,8 @@ fn pub tests(t: mut Tests) { t.test('UnixSocket.receive_from', fn (t) { let pair = SocketPath.pair(t.id) - let listener = UnixSocket.new(Type.DGRAM).get - let client = UnixSocket.new(Type.DGRAM).get + let listener = UnixSocket.datagram.get + let client = UnixSocket.datagram.get t.equal(listener.bind(pair.0.path), Result.Ok(nil)) t.equal(client.bind(pair.1.path), Result.Ok(nil)) @@ -988,31 +983,31 @@ fn pub tests(t: mut Tests) { }) t.test('UnixSocket.local_address with an unbound socket', fn (t) { - let socket = UnixSocket.new(Type.DGRAM).get + let socket = UnixSocket.datagram.get let address = socket.local_address.get - t.equal(address, UnixAddress.new(''.to_path)) + t.equal(address, UnixAddress('')) }) t.test('UnixSocket.local_address with a bound socket', fn (t) { let path = SocketPath.new(t.id) - let socket = UnixSocket.new(Type.DGRAM).get + let socket = UnixSocket.datagram.get socket.bind(path.path).get - t.equal(socket.local_address, Result.Ok(UnixAddress.new(path.path))) + t.equal(socket.local_address, Result.Ok(UnixAddress(path.to_string))) }) t.test('UnixSocket.peer_address with a disconnected socket', fn (t) { - let socket = UnixSocket.new(Type.DGRAM).get + let socket = UnixSocket.datagram.get t.true(socket.peer_address.error?) }) t.test('UnixSocket.peer_address with a connected socket', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get - let client = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get + let client = UnixSocket.stream.get listener.bind(path.path).get listener.listen.get @@ -1023,7 +1018,7 @@ fn pub tests(t: mut Tests) { t.test('UnixSocket.read', fn (t) { let path = SocketPath.new(t.id) - let socket = UnixSocket.new(Type.DGRAM).get + let socket = UnixSocket.datagram.get t.equal(socket.bind(path.path), Result.Ok(nil)) t.equal(socket.send_string_to('ping', path.path), Result.Ok(4)) @@ -1036,8 +1031,8 @@ fn pub tests(t: mut Tests) { t.test('UnixSocket.write_bytes', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get - let stream = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get + let stream = UnixSocket.stream.get listener.bind(path.path).get listener.listen.get @@ -1053,8 +1048,8 @@ fn pub tests(t: mut Tests) { t.test('UnixSocket.write_string', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get - let stream = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get + let stream = UnixSocket.stream.get listener.bind(path.path).get listener.listen.get @@ -1069,27 +1064,27 @@ fn pub tests(t: mut Tests) { }) t.test('UnixSocket.flush', fn (t) { - let socket = UnixSocket.new(Type.STREAM).get + let socket = UnixSocket.stream.get t.equal(socket.flush, Result.Ok(nil)) }) t.test('UnixSocket.receive_buffer_size', fn (t) { - let socket = UnixSocket.new(Type.STREAM).get + let socket = UnixSocket.stream.get t.true((socket.receive_buffer_size = 256).ok?) }) t.test('UnixSocket.send_buffer_size', fn (t) { - let socket = UnixSocket.new(Type.STREAM).get + let socket = UnixSocket.stream.get t.true((socket.send_buffer_size = 256).ok?) }) t.test('UnixSocket.shutdown_read', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get - let stream = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get + let stream = UnixSocket.stream.get listener.bind(path.path).get listener.listen.get @@ -1104,8 +1099,8 @@ fn pub tests(t: mut Tests) { t.test('UnixSocket.shutdown_write', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get - let stream = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get + let stream = UnixSocket.stream.get listener.bind(path.path).get listener.listen.get @@ -1117,8 +1112,8 @@ fn pub tests(t: mut Tests) { t.test('UnixSocket.shutdown shuts down the writing half', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get - let stream = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get + let stream = UnixSocket.stream.get listener.bind(path.path).get listener.listen.get @@ -1130,8 +1125,8 @@ fn pub tests(t: mut Tests) { t.test('UnixSocket.shutdown shuts down the reading half', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get - let stream = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get + let stream = UnixSocket.stream.get listener.bind(path.path).get listener.listen.get @@ -1145,14 +1140,14 @@ fn pub tests(t: mut Tests) { }) t.test('UnixSocket.try_clone', fn (t) { - let socket = UnixSocket.new(Type.STREAM).get + let socket = UnixSocket.stream.get t.true(socket.try_clone.ok?) }) t.test('UnixSocket.timeout_after=', fn (t) { let path = SocketPath.new(t.id) - let server = UnixSocket.new(Type.STREAM).get + let server = UnixSocket.stream.get server.bind(path.path).get server.listen @@ -1162,7 +1157,7 @@ fn pub tests(t: mut Tests) { }) t.test('UnixSocket.reset_deadline', fn (t) { - let socket = UnixSocket.new(Type.STREAM).get + let socket = UnixSocket.stream.get socket.timeout_after = Duration.from_secs(10) t.true(socket.deadline > 0) @@ -1215,7 +1210,7 @@ fn pub tests(t: mut Tests) { let path = SocketPath.new(t.id) let socket = UnixDatagram.new(path.path).get - t.equal(socket.local_address, Result.Ok(UnixAddress.new(path.path))) + t.equal(socket.local_address, Result.Ok(UnixAddress(path.to_string))) }) t.test('UnixDatagram.try_clone', fn (t) { @@ -1260,7 +1255,7 @@ fn pub tests(t: mut Tests) { t.test('UnixClient.new', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get listener.bind(path.path).get listener.listen.get @@ -1270,7 +1265,7 @@ fn pub tests(t: mut Tests) { t.ok('UnixClient.with_timeout', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get listener.bind(path.path).get listener.listen.get @@ -1286,33 +1281,33 @@ fn pub tests(t: mut Tests) { t.ok('UnixClient.local_address', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get try listener.bind(path.path) try listener.listen let stream = try UnixClient.new(path.path) - t.equal(stream.local_address, Result.Ok(UnixAddress.new(''.to_path))) + t.equal(stream.local_address, Result.Ok(UnixAddress(''))) Result.Ok(nil) }) t.ok('UnixClient.peer_address', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get try listener.bind(path.path) try listener.listen let stream = try UnixClient.new(path.path) - t.equal(stream.peer_address, Result.Ok(UnixAddress.new(path.path))) + t.equal(stream.peer_address, Result.Ok(UnixAddress(path.to_string))) Result.Ok(nil) }) t.ok('UnixClient.read', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get try listener.bind(path.path) try listener.listen @@ -1331,7 +1326,7 @@ fn pub tests(t: mut Tests) { t.ok('UnixClient.write_bytes', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get try listener.bind(path.path) try listener.listen @@ -1348,7 +1343,7 @@ fn pub tests(t: mut Tests) { t.ok('UnixClient.write_string', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get try listener.bind(path.path) try listener.listen @@ -1365,7 +1360,7 @@ fn pub tests(t: mut Tests) { t.ok('UnixClient.flush', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get try listener.bind(path.path) try listener.listen @@ -1375,7 +1370,7 @@ fn pub tests(t: mut Tests) { t.ok('UnixClient.shutdown_read', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get try listener.bind(path.path) try listener.listen @@ -1393,7 +1388,7 @@ fn pub tests(t: mut Tests) { t.ok('UnixClient.shutdown_write', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get try listener.bind(path.path) try listener.listen @@ -1407,7 +1402,7 @@ fn pub tests(t: mut Tests) { t.ok('UnixClient.shutdown shuts down the writing half', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get try listener.bind(path.path) try listener.listen @@ -1421,7 +1416,7 @@ fn pub tests(t: mut Tests) { t.ok('UnixClient.shutdown shuts down the reading half', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get try listener.bind(path.path) try listener.listen @@ -1439,7 +1434,7 @@ fn pub tests(t: mut Tests) { t.ok('UnixClient.try_clone', fn (t) { let path = SocketPath.new(t.id) - let listener = UnixSocket.new(Type.STREAM).get + let listener = UnixSocket.stream.get try listener.bind(path.path) try listener.listen @@ -1468,7 +1463,7 @@ fn pub tests(t: mut Tests) { let listener = UnixServer.new(path.path).get let addr = listener.local_address.get - t.equal(addr, UnixAddress.new(path.path)) + t.equal(addr, UnixAddress(path.to_string)) }) t.test('UnixServer.try_clone', fn (t) { diff --git a/std/test/std/net/test_tls.inko b/std/test/std/net/test_tls.inko new file mode 100644 index 000000000..fec0eb728 --- /dev/null +++ b/std/test/std/net/test_tls.inko @@ -0,0 +1,437 @@ +import std.crypto.pem +import std.crypto.x509 (Certificate, PrivateKey) +import std.env +import std.fmt (fmt) +import std.fs.file (ReadOnlyFile) +import std.fs.path (Path) +import std.io (Error) +import std.net.ip (IpAddress) +import std.net.socket (Socket, TcpClient, TcpServer) +import std.net.tls ( + Client, ClientConfig, Server, ServerConfig, ServerConfigError, +) +import std.test (Tests) + +class async DummyServer { + let @socket: TcpServer + let @config: ServerConfig + + fn static new -> Result[(DummyServer, Int), String] { + let conf = recover { + try ServerConfig + .new(certificate('test.pem'), private_key('test.key')) + .map_error(fn (e) { 'failed to create the server config: ${e}' }) + } + + # We bind to port zero so we get a random port number, such that we don't + # accidentally pick one that's already in use. + let sock = recover { + try TcpServer.new(ip: IpAddress.v4(0, 0, 0, 0), port: 0).map_error( + fn (e) { 'failed to start the server: ${e}' }, + ) + } + let addr = try sock.local_address.map_error(fn (e) { + 'failed to get the server address: ${e}' + }) + + Result.Ok((DummyServer(sock, conf), addr.port)) + } + + fn async hello { + # We may encounter an error such as when the client closes the connection. + # We want to ignore those so we don't terminate the test suite. + let _ = @socket.accept.then(fn (sock) { + Server.new(sock, @config).write_string('hello') + }) + } + + fn async reply { + # We may encounter an error such as when the client closes the connection. + # We want to ignore those so we don't terminate the test suite. + let _ = @socket.accept.then(fn (sock) { + let con = Server.new(sock, @config) + let bytes = ByteArray.new + + try con.read(into: bytes, size: 32) + con.write_bytes(bytes) + }) + } +} + +class async DummyClient { + fn async connect( + port: Int, + output: Channel[uni Result[Client[TcpClient], String]], + ) { + output.send(recover client(port)) + } + + fn async write( + port: Int, + output: Channel[uni Result[Client[TcpClient], String]], + ) { + let res = recover { + client(port).then(fn (s) { + s.write_string('hello').map_error(fn (e) { 'write failed: ${e}' }) + Result.Ok(s) + }) + } + + output.send(res) + } + + fn async reply(port: Int, output: Channel[uni Result[String, String]]) { + let res = recover { + client(port).then(fn (s) { + let buf = ByteArray.new + + try s.read(into: buf, size: 5).map_error(fn (e) { 'read failed: ${e}' }) + Result.Ok(buf.into_string) + }) + } + + output.send(res) + } + + fn client(port: Int) -> Result[Client[TcpClient], String] { + let conf = try client_config + let sock = try TcpClient.new(IpAddress.v4(127, 0, 0, 1), port).map_error( + fn (e) { 'failed to connect: ${e}' }, + ) + + Client.new(sock, conf, name: 'localhost').ok_or('invalid server name') + } +} + +fn connect(port: Int) -> Result[TcpClient, String] { + TcpClient.new(IpAddress.v4(127, 0, 0, 1), port).map_error(fn (e) { + 'failed to connect to the server: ${e}' + }) +} + +fn fixture(name: String) -> Path { + env + .working_directory + .or_else(fn (_) { '.'.to_path }) + .join('fixtures') + .join('tls') + .join(name) +} + +fn certificate(name: String) -> Certificate { + let path = fixture(name) + let reader = pem.Parser.new( + ReadOnlyFile.new(path.clone).or_panic("${path} doesn't exist"), + ) + + match reader.next { + case Some(Ok(Certificate(cert))) -> cert + case _ -> panic('expected ${path} to contain a certificate') + } +} + +fn private_key(name: String) -> PrivateKey { + let path = fixture(name) + let reader = pem.Parser.new( + ReadOnlyFile.new(path.clone).or_panic("${path} doesn't exist"), + ) + + match reader.next { + case Some(Ok(PrivateKey(cert))) -> cert + case _ -> panic('expected ${path} to contain a private key') + } +} + +fn client_config -> Result[ClientConfig, String] { + ClientConfig.with_certificate(certificate('test.pem')).ok_or( + 'failed to create the client configuration', + ) +} + +fn server_config -> Result[ServerConfig, String] { + ServerConfig.new(certificate('test.pem'), private_key('test.key')).map_error( + fn (e) { e.to_string }, + ) +} + +fn dummy_socket -> Result[Socket, String] { + Socket.stream(ipv6: false).map_error(fn (e) { e.to_string }) +} + +fn accept(socket: mut TcpServer) -> Result[Server[TcpClient], String] { + let conf = try server_config + let server = Server.new( + try socket.accept.map_error(fn (e) { 'accept failed: ${e}' }), + conf, + ) + + Result.Ok(server) +} + +fn listener -> Result[(TcpServer, Int), String] { + let sock = try TcpServer.new(ip: IpAddress.v4(0, 0, 0, 0), port: 0).map_error( + fn (e) { 'failed to start the server: ${e}' }, + ) + let addr = try sock.local_address.map_error(fn (e) { + 'failed to get the server address: ${e}' + }) + + Result.Ok((sock, addr.port)) +} + +fn pub tests(t: mut Tests) { + t.test('ServerConfigError.to_string', fn (t) { + t.false(ServerConfigError.InvalidCertificate.to_string.empty?) + t.false(ServerConfigError.InvalidPrivateKey.to_string.empty?) + }) + + t.test('ServerConfigError.fmt', fn (t) { + t.equal(fmt(ServerConfigError.InvalidCertificate), 'InvalidCertificate') + t.equal(fmt(ServerConfigError.InvalidPrivateKey), 'InvalidPrivateKey') + }) + + t.test('ServerConfigError.==', fn (t) { + t.equal( + ServerConfigError.InvalidCertificate, + ServerConfigError.InvalidCertificate, + ) + t.equal( + ServerConfigError.InvalidPrivateKey, + ServerConfigError.InvalidPrivateKey, + ) + }) + + t.test('ClientConfig.with_certificate with a valid certificate', fn (t) { + t.true(ClientConfig.with_certificate(certificate('test.pem')).some?) + }) + + t.test('ClientConfig.with_certificate with an invalid certificate', fn (t) { + t.true(ClientConfig.with_certificate(certificate('invalid.pem')).none?) + }) + + t.ok('ClientConfig.clone', fn (t) { + let a = try client_config + let b = a.clone + + t.equal(a.raw as Int, b.raw as Int) + Result.Ok(nil) + }) + + t.test('ServerConfig.new with a valid certificate and private key', fn (t) { + let cert = certificate('test.pem') + let key = private_key('test.key') + + t.true(ServerConfig.new(cert, key).ok?) + }) + + t.test('ServerConfig with an invalid certificate', fn (t) { + let key = private_key('test.key') + + t.equal( + ServerConfig.new(certificate('invalid.pem'), key).error, + Option.Some(ServerConfigError.InvalidCertificate), + ) + }) + + t.test('ServerConfig with an invalid private key', fn (t) { + let cert = certificate('test.pem') + + t.equal( + ServerConfig.new(cert, private_key('invalid.key')).error, + Option.Some(ServerConfigError.InvalidPrivateKey), + ) + }) + + t.ok('ServerConfig.clone', fn (t) { + let a = try server_config + let b = a.clone + + t.equal(b.raw as Int, a.raw as Int) + Result.Ok(nil) + }) + + t.ok('Client.new with a valid DNS name as the server name', fn (t) { + let conf = try client_config + let sock = try dummy_socket + + t.true(Client.new(sock, conf, name: 'localhost').some?) + Result.Ok(nil) + }) + + t.ok('Client.new with a valid IP address as the server name', fn (t) { + let conf = try client_config + let sock = try dummy_socket + + t.true(Client.new(sock, conf, name: '127.0.0.1').some?) + Result.Ok(nil) + }) + + t.ok('Client.new with an invalid DNS name as the server name', fn (t) { + let conf = try client_config + let sock = try dummy_socket + + t.true(Client.new(sock, conf, name: 'what?!').none?) + Result.Ok(nil) + }) + + t.ok('Client.new with an invalid IP address as the server name', fn (t) { + let conf = try client_config + let sock = try dummy_socket + + t.true(Client.new(sock, conf, name: '1.2.3.4.5').none?) + Result.Ok(nil) + }) + + t.ok('Client.read', fn (t) { + let port = match DummyServer.new { + case Ok((server, port)) -> { + server.hello + port + } + case Error(e) -> throw e + } + + let conf = try client_config + let sock = try connect(port) + let client = Client.new(sock, conf, name: 'localhost').get + let bytes = ByteArray.new + + t.equal(client.read(into: bytes, size: 5), Result.Ok(5)) + t.equal(bytes.into_string, 'hello') + Result.Ok(nil) + }) + + t.ok('Client.write_bytes', fn (t) { + let port = match DummyServer.new { + case Ok((server, port)) -> { + server.reply + port + } + case Error(e) -> throw e + } + + let conf = try client_config + let sock = try connect(port) + let client = Client.new(sock, conf, name: 'localhost').get + let bytes = ByteArray.new + + t.equal(client.write_bytes('ping'.to_byte_array), Result.Ok(nil)) + t.equal(client.read(into: bytes, size: 4), Result.Ok(4)) + t.equal(bytes.into_string, 'ping') + Result.Ok(nil) + }) + + t.ok('Client.write_string', fn (t) { + let port = match DummyServer.new { + case Ok((server, port)) -> { + server.reply + port + } + case Error(e) -> throw e + } + + let conf = try client_config + let sock = try connect(port) + let client = Client.new(sock, conf, name: 'localhost').get + let bytes = ByteArray.new + + t.equal(client.write_string('ping'), Result.Ok(nil)) + t.equal(client.read(into: bytes, size: 4), Result.Ok(4)) + t.equal(bytes.into_string, 'ping') + Result.Ok(nil) + }) + + t.ok('Client.close', fn (t) { + let port = match DummyServer.new { + case Ok((server, port)) -> { + server.reply + port + } + case Error(e) -> throw e + } + + let conf = try client_config + let sock = try connect(port) + let client = Client.new(sock, conf, name: 'localhost').get + + t.equal(client.close, Result.Ok(nil)) + Result.Ok(nil) + }) + + # Not much to test here, so this test mostly exists to make sure the + # underlying code doesn't blow up outright. + t.ok('Server.new', fn (t) { + let conf = try server_config + let sock = try dummy_socket + let server = Server.new(sock, conf) + + t.not_equal(server.state as Int, 0) + Result.Ok(nil) + }) + + t.ok('Server.close', fn (t) { + let out = Channel.new(size: 1) + let sock = match try listener { + case (sock, port) -> { + DummyClient().connect(port, out) + sock + } + } + let server = try accept(sock) + + # There's not really a sensible way we can test close_notify handling as + # it's timing/ordering sensitive, so we just test that the initial call + # doesn't fail outright. + t.equal(server.close, Result.Ok(nil)) + t.true(out.receive.ok?) + Result.Ok(nil) + }) + + t.ok('Server.read', fn (t) { + let out = Channel.new(size: 1) + let sock = match try listener { + case (sock, port) -> { + DummyClient().write(port, out) + sock + } + } + let server = try accept(sock) + let bytes = ByteArray.new + + t.equal(server.read(into: bytes, size: 5), Result.Ok(5)) + t.equal(bytes.into_string, 'hello') + t.true(out.receive.ok?) + + Result.Ok(nil) + }) + + t.ok('Server.write_string', fn (t) { + let out = Channel.new(size: 1) + let sock = match try listener { + case (sock, port) -> { + DummyClient().reply(port, out) + sock + } + } + let server = try accept(sock) + + t.equal(server.write_string('hello'), Result.Ok(nil)) + t.equal(recover out.receive, Result.Ok('hello')) + Result.Ok(nil) + }) + + t.ok('Server.write_bytes', fn (t) { + let out = Channel.new(size: 1) + let sock = match try listener { + case (sock, port) -> { + DummyClient().reply(port, out) + sock + } + } + let server = try accept(sock) + + t.equal(server.write_bytes('hello'.to_byte_array), Result.Ok(nil)) + t.equal(recover out.receive, Result.Ok('hello')) + Result.Ok(nil) + }) +} diff --git a/std/test/std/test_env.inko b/std/test/std/test_env.inko index f704802b0..0afadfb3e 100644 --- a/std/test/std/test_env.inko +++ b/std/test/std/test_env.inko @@ -44,6 +44,20 @@ fn pub tests(t: mut Tests) { } }) + t.fork( + 'env.home_directory with a missing home directory', + child: fn { + let out = STDOUT.new + let res = env.home_directory.map(fn (v) { v.to_string }).or('ERROR') + + out.write_string(res) + }, + test: fn (test, proc) { + proc.variable('HOME', '') + test.equal(proc.spawn.stdout, 'ERROR') + }, + ) + t.test('env.working_directory', fn (t) { let path = env.working_directory.get diff --git a/std/test/std/test_io.inko b/std/test/std/test_io.inko index ebcb77548..503f81709 100644 --- a/std/test/std/test_io.inko +++ b/std/test/std/test_io.inko @@ -1,5 +1,8 @@ import std.fmt (fmt) -import std.io (Buffer, BufferedReader, DEFAULT_BUFFER_SIZE, Error, Read, Write) +import std.io ( + Buffer, BufferedReader, DEFAULT_BUFFER_SIZE, Error, INVALID_DATA, Read, + UNEXPECTED_EOF, Write, +) import std.libc.bsd.errors if bsd import std.libc.linux.errors if linux import std.libc.mac.errors if mac @@ -114,7 +117,7 @@ fn pub tests(t: mut Tests) { t.equal(Error.from_os_error(errors.EADDRINUSE), Error.AddressInUse) t.equal(Error.from_os_error(errors.EADDRNOTAVAIL), Error.AddressUnavailable) t.equal(Error.from_os_error(errors.ENETDOWN), Error.NetworkDown) - t.equal(Error.from_os_error(errors.ENETUNREACH), Error.NetworkUnreachable) + t.equal(Error.from_os_error(errors.ENETUNREACH), Error.NetworkDown) t.equal(Error.from_os_error(errors.ECONNABORTED), Error.ConnectionAborted) t.equal(Error.from_os_error(errors.ECONNRESET), Error.ConnectionReset) t.equal(Error.from_os_error(errors.EISCONN), Error.AlreadyConnected) @@ -123,6 +126,8 @@ fn pub tests(t: mut Tests) { t.equal(Error.from_os_error(errors.ECONNREFUSED), Error.ConnectionRefused) t.equal(Error.from_os_error(errors.EHOSTUNREACH), Error.HostUnreachable) t.equal(Error.from_os_error(errors.EINPROGRESS), Error.InProgress) + t.equal(Error.from_os_error(INVALID_DATA), Error.InvalidData) + t.equal(Error.from_os_error(UNEXPECTED_EOF), Error.EndOfInput) t.equal(Error.from_os_error(999), Error.Other(999)) }) @@ -146,7 +151,6 @@ fn pub tests(t: mut Tests) { t.equal(fmt(Error.InvalidSeek), 'InvalidSeek') t.equal(fmt(Error.IsADirectory), 'IsADirectory') t.equal(fmt(Error.NetworkDown), 'NetworkDown') - t.equal(fmt(Error.NetworkUnreachable), 'NetworkUnreachable') t.equal(fmt(Error.NotADirectory), 'NotADirectory') t.equal(fmt(Error.NotConnected), 'NotConnected') t.equal(fmt(Error.NotFound), 'NotFound') @@ -157,6 +161,8 @@ fn pub tests(t: mut Tests) { t.equal(fmt(Error.StorageFull), 'StorageFull') t.equal(fmt(Error.TimedOut), 'TimedOut') t.equal(fmt(Error.WouldBlock), 'WouldBlock') + t.equal(fmt(Error.InvalidData), 'InvalidData') + t.equal(fmt(Error.EndOfInput), 'EndOfInput') t.equal(fmt(Error.Other(999)), 'Other(999)') }) @@ -169,6 +175,26 @@ fn pub tests(t: mut Tests) { t.equal(bytes, ByteArray.from_array([1, 2, 3])) }) + t.test('Read.read_exact', fn (t) { + let reader = Reader.new + let bytes = ByteArray.new + + t.equal(reader.read_exact(into: bytes, size: 3), Result.Ok(nil)) + t.equal(bytes, ByteArray.from_array([1, 2, 3])) + + t.equal( + reader.read_exact(into: bytes, size: 3), + Result.Error(Error.EndOfInput), + ) + t.equal(bytes, ByteArray.from_array([1, 2, 3])) + + reader.index = 0 + t.equal( + reader.read_exact(into: bytes, size: 6), + Result.Error(Error.EndOfInput), + ) + }) + t.test('Write.print', fn (t) { let writer = Writer.new diff --git a/std/test/std/test_option.inko b/std/test/std/test_option.inko index ddf09c4a7..c6374118f 100644 --- a/std/test/std/test_option.inko +++ b/std/test/std/test_option.inko @@ -97,4 +97,17 @@ fn pub tests(t: mut Tests) { t.equal(fmt(Option.Some(42)), 'Some(42)') t.equal(fmt(Option.None as Option[Int]), 'None') }) + + t.test('Option.ok_or', fn (t) { + t.equal(Option.Some(10).ok_or('oops'), Result.Ok(10)) + t.equal((Option.None as Option[Int]).ok_or('oops'), Result.Error('oops')) + }) + + t.test('Option.ok_or_else', fn (t) { + t.equal(Option.Some(10).ok_or_else(fn { 'oops' }), Result.Ok(10)) + t.equal( + (Option.None as Option[Int]).ok_or_else(fn { 'oops' }), + Result.Error('oops'), + ) + }) } diff --git a/std/test/std/test_optparse.inko b/std/test/std/test_optparse.inko index d918a35f5..2bdafd9f1 100644 --- a/std/test/std/test_optparse.inko +++ b/std/test/std/test_optparse.inko @@ -535,7 +535,8 @@ fn pub tests(t: mut Tests) { lorem ipsum --example -x Foo' - .strip_prefix('\n'), + .strip_prefix('\n') + .get, ) opts.single('o', 'option-with-a-much-longer-name', '', 'Example') @@ -552,7 +553,8 @@ fn pub tests(t: mut Tests) { --example -x Foo -o, --option-with-a-much-longer-name Example' - .strip_prefix('\n'), + .strip_prefix('\n') + .get, ) }) diff --git a/std/test/std/test_string.inko b/std/test/std/test_string.inko index ae852cc0e..a09095ca2 100644 --- a/std/test/std/test_string.inko +++ b/std/test/std/test_string.inko @@ -299,27 +299,27 @@ fn pub tests(t: mut Tests) { }) t.test('String.strip_prefix', fn (t) { - t.equal('hello'.strip_prefix('xxxxxxxxx'), 'hello') - t.equal('hello'.strip_prefix('x'), 'hello') - t.equal('hello'.strip_prefix(''), 'hello') - t.equal('XhelloX'.strip_prefix('x'), 'XhelloX') - t.equal('xhellox'.strip_prefix('xy'), 'xhellox') - t.equal('xhellox'.strip_prefix('y'), 'xhellox') - t.equal('xhellox'.strip_prefix('x'), 'hellox') - t.equal('xxhelloxx'.strip_prefix('xx'), 'helloxx') - t.equal('๐Ÿ˜ƒhello๐Ÿ˜ƒ'.strip_prefix('๐Ÿ˜ƒ'), 'hello๐Ÿ˜ƒ') + t.equal('hello'.strip_prefix('xxxxxxxxx'), Option.None) + t.equal('hello'.strip_prefix('x'), Option.None) + t.equal('hello'.strip_prefix(''), Option.None) + t.equal('XhelloX'.strip_prefix('x'), Option.None) + t.equal('xhellox'.strip_prefix('xy'), Option.None) + t.equal('xhellox'.strip_prefix('y'), Option.None) + t.equal('xhellox'.strip_prefix('x'), Option.Some('hellox')) + t.equal('xxhelloxx'.strip_prefix('xx'), Option.Some('helloxx')) + t.equal('๐Ÿ˜ƒhello๐Ÿ˜ƒ'.strip_prefix('๐Ÿ˜ƒ'), Option.Some('hello๐Ÿ˜ƒ')) }) t.test('String.strip_suffix', fn (t) { - t.equal('hello'.strip_suffix('xxxxxxxxx'), 'hello') - t.equal('hello'.strip_suffix('x'), 'hello') - t.equal('hello'.strip_suffix(''), 'hello') - t.equal('XhelloX'.strip_suffix('x'), 'XhelloX') - t.equal('xhellox'.strip_suffix('xy'), 'xhellox') - t.equal('xhellox'.strip_suffix('y'), 'xhellox') - t.equal('xhellox'.strip_suffix('x'), 'xhello') - t.equal('xxhelloxx'.strip_suffix('xx'), 'xxhello') - t.equal('๐Ÿ˜ƒhello๐Ÿ˜ƒ'.strip_suffix('๐Ÿ˜ƒ'), '๐Ÿ˜ƒhello') + t.equal('hello'.strip_suffix('xxxxxxxxx'), Option.None) + t.equal('hello'.strip_suffix('x'), Option.None) + t.equal('hello'.strip_suffix(''), Option.None) + t.equal('XhelloX'.strip_suffix('x'), Option.None) + t.equal('xhellox'.strip_suffix('xy'), Option.None) + t.equal('xhellox'.strip_suffix('y'), Option.None) + t.equal('xhellox'.strip_suffix('x'), Option.Some('xhello')) + t.equal('xxhelloxx'.strip_suffix('xx'), Option.Some('xxhello')) + t.equal('๐Ÿ˜ƒhello๐Ÿ˜ƒ'.strip_suffix('๐Ÿ˜ƒ'), Option.Some('๐Ÿ˜ƒhello')) }) t.test('String.trim_start', fn (t) {