Skip to content

Commit

Permalink
WIP: TLS sockets
Browse files Browse the repository at this point in the history
This fixes #329.

Changelog: added
  • Loading branch information
yorickpeterse committed Jul 11, 2024
1 parent 131c0e1 commit 81210b4
Show file tree
Hide file tree
Showing 21 changed files with 802 additions and 212 deletions.
422 changes: 328 additions & 94 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion compiler/src/linker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,13 @@ pub(crate) fn link(
cmd.arg("-lm");
cmd.arg("-lpthread");
}
_ => {}
OperatingSystem::Mac => {
// This is needed for TLS support.
for name in ["Security", "CoreFoundation"] {
cmd.arg("-framework");
cmd.arg(name);
}
}
}

let mut static_linking = state.config.static_linking;
Expand Down
13 changes: 13 additions & 0 deletions rt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ unicode-segmentation = "^1.10"
backtrace = "^0.3"
rustix = { version = "^0.38", features = ["fs", "mm", "param", "process", "net", "std", "time", "event"], default-features = false }

# The dependencies needed for TLS support.
#
# We use ring instead of the default aws-lc-sys because:
#
# 1. aws-lc-sys requires cmake to be installed when building on FreeBSD (and
# potentially other platforms), as aws-lc-sys only provides generated
# bindings for a limited set of platforms
# 2. aws-lc-sys increases compile times quite a bit
# 3. We don't care about FIPS compliance at the time of writing
rustls = { version = "^0.23", features = ["ring", "tls12", "std"], default-features = false }
rustls-platform-verifier = "^0.3"
rustls-pemfile = "^2.1"

[dependencies.socket2]
version = "^0.5"
features = ["all"]
11 changes: 11 additions & 0 deletions rt/src/network_poller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ pub(crate) type NetworkPoller = sys::Poller;
pub(crate) enum Interest {
Read,
Write,
ReadWrite,
}

impl Interest {
pub(crate) fn new(read: bool, write: bool) -> Interest {
match (read, write) {
(true, true) => Interest::ReadWrite,
(false, true) => Interest::Write,
_ => Interest::Read,
}
}
}

/// A thread that polls a poller and reschedules processes.
Expand Down
1 change: 1 addition & 0 deletions rt/src/network_poller/epoll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ fn flags_for(interest: Interest) -> EventFlags {
let flags = match interest {
Interest::Read => EventFlags::IN,
Interest::Write => EventFlags::OUT,
Interest::ReadWrite => EventFlags::IN | EventFlags::OUT,
};

flags | EventFlags::ET | EventFlags::ONESHOT
Expand Down
24 changes: 15 additions & 9 deletions rt/src/network_poller/kqueue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,24 @@ impl Poller {
source: impl AsFd,
interest: Interest,
) {
let fd = source.as_fd().as_raw_fd();
let (add, del) = match interest {
Interest::Read => (EventFilter::Read(fd), EventFilter::Write(fd)),
Interest::Write => (EventFilter::Write(fd), EventFilter::Read(fd)),
};
let id = process.identifier() as isize;
let fd = source.as_fd().as_raw_fd();
let flags =
EventFlags::CLEAR | EventFlags::ONESHOT | EventFlags::RECEIPT;
let events = [
Event::new(add, EventFlags::ADD | flags, id),
Event::new(del, EventFlags::DELETE, 0),
];
let events = match interest {
Interest::Read => [
Event::new(EventFilter::Read(fd), EventFlags::ADD | flags, id),
Event::new(EventFilter::Write(fd), EventFlags::DELETE, 0),
],
Interest::Write => [
Event::new(EventFilter::Write(fd), EventFlags::ADD | flags, id),
Event::new(EventFilter::Read(fd), EventFlags::DELETE, 0),
],
Interest::ReadWrite => [
Event::new(EventFilter::Write(fd), EventFlags::ADD | flags, id),
Event::new(EventFilter::Read(fd), EventFlags::ADD | flags, id),
],
};

self.apply(&events);
}
Expand Down
2 changes: 1 addition & 1 deletion rt/src/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl Result {
}

pub(crate) fn io_error(error: io::Error) -> Result {
Self::error({ error_to_int(error) } as _)
Self::error(error_to_int(error) as _)
}
}

Expand Down
6 changes: 6 additions & 0 deletions rt/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ pub unsafe extern "system" fn inko_runtime_new(
// does for us when compiling an executable.
signal_sched::block_all();

// Configure the TLS provider. This must be done once before we start the
// program.
rustls::crypto::ring::default_provider()
.install_default()
.expect("failed to set up the default TLS cryptography provider");

Box::into_raw(Box::new(Runtime::new(&*counts, args)))
}

Expand Down
20 changes: 0 additions & 20 deletions rt/src/runtime/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,6 @@ pub unsafe extern "system" fn inko_env_size(state: *const State) -> i64 {
(*state).environment.len() as _
}

#[no_mangle]
pub unsafe extern "system" fn inko_env_home_directory(
state: *const State,
) -> InkoResult {
let state = &*state;

// Rather than performing all sorts of magical incantations to get the home
// directory, we're just going to require that HOME is set.
//
// If the home is explicitly set to an empty string we still ignore it,
// because there's no scenario in which Some("") is useful.
state
.environment
.get("HOME")
.filter(|&path| !path.is_empty())
.cloned()
.map(|v| InkoResult::ok(InkoString::alloc(state.string_class, v) as _))
.unwrap_or_else(InkoResult::none)
}

#[no_mangle]
pub unsafe extern "system" fn inko_env_temp_directory(
state: *const State,
Expand Down
200 changes: 188 additions & 12 deletions rt/src/runtime/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@ use crate::network_poller::Interest;
use crate::process::ProcessPointer;
use crate::result::{error_to_int, Result};
use crate::scheduler::timeouts::Timeout;
use crate::socket::Socket;
use crate::socket::{read_from, Socket};
use crate::state::State;
use std::io::{self, Write};
use rustls::pki_types::ServerName;
use rustls::{ClientConfig, ClientConnection, RootCertStore, Stream};
use rustls_pemfile::certs;
use rustls_platform_verifier::tls_config;
use std::fs::File;
use std::io::{self, BufReader, Write};
use std::ptr::{drop_in_place, write};
use std::sync::Arc;

#[repr(C)]
pub struct RawAddress {
Expand All @@ -24,19 +30,13 @@ impl RawAddress {
}
}

fn blocking<T>(
fn poll(
state: &State,
mut process: ProcessPointer,
socket: &mut Socket,
interest: Interest,
deadline: i64,
mut func: impl FnMut(&mut Socket) -> io::Result<T>,
) -> io::Result<T> {
match func(socket) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
val => return val,
}

) -> io::Result<()> {
let poll_id = unsafe { process.thread() }.network_poller;

// We must keep the process' state lock open until everything is registered,
Expand Down Expand Up @@ -72,7 +72,24 @@ fn blocking<T>(
return Err(io::Error::from(io::ErrorKind::TimedOut));
}

func(socket)
Ok(())
}

fn blocking<T>(
state: &State,
process: ProcessPointer,
socket: &mut Socket,
interest: Interest,
deadline: i64,
mut func: impl FnMut(&mut Socket) -> io::Result<T>,
) -> io::Result<T> {
match func(socket) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
poll(state, process, socket, interest, deadline)
.and_then(|_| func(socket))
}
val => val,
}
}

#[no_mangle]
Expand Down Expand Up @@ -127,7 +144,7 @@ pub unsafe extern "system" fn inko_socket_read(
let state = &*state;

blocking(state, process, &mut *socket, Interest::Read, deadline, |sock| {
sock.read(&mut (*buffer).value, amount as usize)
read_from(sock, &mut (*buffer).value, amount as usize)
})
.map(|size| Result::ok(size as _))
.unwrap_or_else(Result::io_error)
Expand Down Expand Up @@ -349,3 +366,162 @@ pub unsafe extern "system" fn inko_socket_try_clone(
pub unsafe extern "system" fn inko_socket_drop(socket: *mut Socket) {
drop_in_place(socket);
}

#[no_mangle]
pub unsafe extern "system" fn inko_tls_client_config_new() -> Result {
Result::ok(Arc::into_raw(Arc::new(tls_config())) as *mut _)
}

#[no_mangle]
pub unsafe extern "system" fn inko_tls_client_config_with_certificate(
path: *const InkoString,
) -> Result {
let mut store = RootCertStore::empty();
let mut reader = match File::open(InkoString::read(path)) {
Ok(f) => BufReader::new(f),
Err(e) => return Result::io_error(e),
};

for res in certs(&mut reader) {
match res {
// We don't want to expose a bunch of error messages/cases for the
// different reasons for a certificate being invalid, as it's not
// clear users actually care about that. As such, at least for the
// time being we just use a single opaque error for invalid
// certificates.
Ok(cert) => {
if store.add(cert).is_err() {
return Result::none();
}
}
Err(e) => return Result::io_error(e),
}
}

let conf = Arc::new(
ClientConfig::builder()
.with_root_certificates(store)
.with_no_client_auth(),
);

Result::ok(Arc::into_raw(conf) as *mut _)
}

#[no_mangle]
pub unsafe extern "system" fn inko_tls_client_config_clone(
config: *const ClientConfig,
) -> *const ClientConfig {
Arc::increment_strong_count(config);
config
}

#[no_mangle]
pub unsafe extern "system" fn inko_tls_client_config_drop(
config: *const ClientConfig,
) {
drop(Arc::from_raw(config));
}

#[no_mangle]
pub unsafe extern "system" fn inko_tls_client_connection_new(
config: *const ClientConfig,
server: *const InkoString,
) -> Result {
let name = match ServerName::try_from(InkoString::read(server)) {
Ok(v) => v,
Err(_) => return Result::error(0 as _),
};

Arc::increment_strong_count(config);

// TODO: under what circumstance does this fail?
let con = match ClientConnection::new(Arc::from_raw(config), name) {
Ok(v) => v,
Err(_) => return Result::error(1 as _),
};

Result::ok_boxed(con)
}

#[no_mangle]
pub unsafe extern "system" fn inko_tls_client_connection_drop(
connection: *mut ClientConnection,
) {
drop(Box::from_raw(connection));
}

#[no_mangle]
pub unsafe extern "system" fn inko_tls_socket_write(
state: *const State,
process: ProcessPointer,
socket: *mut Socket,
connection: *mut ClientConnection,
data: *mut u8,
size: i64,
deadline: i64,
) -> Result {
let state = &*state;
let slice = std::slice::from_raw_parts(data, size as _);
let mut stream = Stream::new(&mut *connection, &mut *socket);

loop {
match stream.write(slice) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
let interest = Interest::new(
stream.conn.wants_read(),
stream.conn.wants_write(),
);

if let Err(e) =
poll(state, process, stream.sock, interest, deadline)
{
return Result::io_error(e);
}
}
val => {
return val
.map(|v| Result::ok(v as _))
.unwrap_or_else(Result::io_error);
}
}
}
}

#[no_mangle]
pub unsafe extern "system" fn inko_tls_socket_read(
state: *const State,
process: ProcessPointer,
socket: *mut Socket,
connection: *mut ClientConnection,
buffer: *mut ByteArray,
amount: i64,
deadline: i64,
) -> Result {
let state = &*state;
let buf = &mut (*buffer).value;
let mut stream = Stream::new(&mut *connection, &mut *socket);

loop {
match read_from(&mut stream, buf, amount as usize) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
let interest = Interest::new(
stream.conn.wants_read(),
stream.conn.wants_write(),
);

if let Err(e) =
poll(state, process, stream.sock, interest, deadline)
{
return Result::io_error(e);
}

continue;
}
val => {
return val
.map(|v| Result::ok(v as _))
.unwrap_or_else(Result::io_error);
}
};
}
}
Loading

0 comments on commit 81210b4

Please sign in to comment.