Skip to content

Commit

Permalink
WIP: implement RFC 8305 for TcpClient
Browse files Browse the repository at this point in the history
This fixes #795.

Changelog: added
  • Loading branch information
yorickpeterse committed Jan 2, 2025
1 parent 888832e commit 18bae54
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 42 deletions.
244 changes: 213 additions & 31 deletions std/src/std/net/socket.inko
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,22 @@
#
# For more information about timeouts versus deadlines, consider reading [this
# article](https://vorpus.org/blog/timeouts-and-cancellation-for-humans/).
import std.cmp (Equal)
import std.cmp (Equal, max, min)
import std.drop (Drop)
import std.fmt (Format, Formatter)
import std.fs.path (Path)
import std.io (Error, Read, Write, WriteInternal)
import std.iter (Stream)
import std.libc
import std.net.ip (IpAddress)
import std.string (ToString)
import std.sync (Channel, Future, Promise)
import std.sys.net
import std.sys.unix.net (self as sys) if unix
import std.time (Duration, ToInstant)
import std.time (Duration, Instant, ToInstant)

# TODO: remove
import std.stdio (Stdout)

# The maximum value valid for a listen() call.
#
Expand Down Expand Up @@ -101,12 +106,103 @@ fn pub try_ips[R, E](
Result.Error(last.or_panic('at least one IP address must be specified'))
}

# Returns an iterator that yields IP addresses in alternating order, starting
# with an IPv6 address.
fn interleave_ips(ips: ref Array[IpAddress]) -> Stream[IpAddress] {
let mut v6_idx = 0
let mut v4_idx = 0
let mut v6 = true

Stream.new(fn move {
if v6 := v6.false? {
loop {
match ips.opt(v6_idx := v6_idx + 1) {
case Some(V6(ip)) -> return Option.Some(IpAddress.V6(ip))
case Some(_) -> {}
case _ -> break
}
}
}

loop {
match ips.opt(v4_idx := v4_idx + 1) {
case Some(V4(ip)) -> return Option.Some(IpAddress.V4(ip))
case Some(_) -> {}
case _ -> return Option.None
}
}
})
}

trait RawSocketOperations {
fn mut raw_socket -> Pointer[net.RawSocket]

fn raw_deadline -> Int
}

type async Connector {
let @ip: IpAddress
let @port: Int
let @deadline: Instant
let @output: Channel[Result[TcpClient, Error]]
let @run: Bool

fn static new(
ip: IpAddress,
port: Int,
deadline: Instant,
output: uni Channel[Result[TcpClient, Error]],
) -> Connector {
Connector(ip: ip, port: port, deadline: deadline, output: output, run: true)
}

fn async mut cancel {
@run = false
}

fn async retry {
if @run.false? {
@output.send(recover Result.Error(Error.TimedOut))
return
}

match connect {
case Some(v) -> @output.send(v)
case _ -> retry
}
}

fn async initial(promise: uni Promise[Result[TcpClient, Error]]) {
match connect {
case Some(v) -> promise.set(v)
case _ -> retry
}
}

fn connect -> Option[uni Result[TcpClient, Error]] {
# It's not unlikely that `deadline` is set to something large, e.g. 60
# seconds into the future. This means that if one socket successfully
# connects, others would just sit around for up to 60 seconds.
#
# To prevent this from happening we enforce an internal deadline that's
# smaller than the one provided, rescheduling the current process if there's
# still time left.
let wait = min(Duration.from_secs(1).to_instant, @deadline)

let res = recover {
match TcpClient.connect(@ip, @port, wait) {
case Ok(v) -> Result.Ok(v)
case Error(TimedOut) if @deadline.remaining.to_nanos > 0 -> {
return Option.None
}
case Error(e) -> Result.Error(e)
}
}

Option.Some(res)
}
}

# An IPv4 or IPv6 socket address.
type pub copy SocketAddress {
# The IPv4/IPv6 address of this socket address.
Expand Down Expand Up @@ -811,24 +907,114 @@ type pub TcpClient {
Result.Ok(TcpClient(socket))
}

# Creates a new `TcpClient` that's connected to an IP address and port number.
fn static connect(
ip: IpAddress,
port: Int,
timeout_after: Instant,
) -> Result[TcpClient, Error] {
let socket = try Socket.stream(ip.v6?)

socket.timeout_after = timeout_after
try socket.connect(ip, port)
socket.reset_deadline
from(socket)
}

fn static connect_parallel(
ips: ref Array[IpAddress],
port: Int,
timeout_after: Instant,
) -> Result[TcpClient, Error] {
let size = ips.size

match size {
case 0 -> panic('at least one IP address must be provided')
case 1 -> return connect(ips.get(0), port, timeout_after)
case _ -> {}
}

# TODO: sort IPs.
# let ips = interleave_ips(ips)
let ips = ips.iter
let chan = Channel.new
let mut last_err = Error.ConnectionRefused
let procs = []
let out = Stdout.new

loop {
let fut = match ips.next {
case Some(ip) -> {
let con = Connector.new(ip, port, timeout_after, recover chan.clone)

match Future.new {
case (fut, prom) -> {
con.initial(prom)
procs.push(con)
fut
}
}
}
case _ -> break
}

# For the initial connection we use a Future instead of the Channel,
# ensuring that previously scheduled sockets producing a result doesn't
# mess with us waiting for the result of the current socket.
#
# TODO: it's possible to run into the following: socket A takes 260 msec
# to connect, so we move to socket B. Using the current approach, if all
# other sockets also fail we won't observe socket A until the very end.
#
# What we need to do is separate the OK and error streams, and in this
# loop we only care about the OK streams.
match fut.get_until(Duration.from_millis(250)) {
case Ok(Ok(v)) -> {
if procs.size > 1 { procs.into_iter.each(fn (c) { c.cancel }) }

return Result.Ok(v)
}
case Ok(Error(e)) -> last_err = e
case _ -> {}
}
}

loop {
out.print('waiting for remaining connection...')

match chan.receive_until(timeout_after) {
case Some(Ok(v)) -> {
if procs.size > 1 { procs.into_iter.each(fn (c) { c.cancel }) }

return Result.Ok(v)
}
case Some(Error(e)) -> {
out.print('received error: ${e}')
last_err = e
}
case _ -> {
procs.into_iter.each(fn (c) { c.cancel })
break
}
}
}

Result.Error(last_err)
}

# Creates a new `TcpClient` that's connected to an IP address and port number,
# using a default timeout.
#
# If multiple IP addresses are given, this method attempts to connect to them
# in order, returning upon the first successful connection. If no connection
# can be established, the error of the last attempt is returned.
# This method uses a default timeout of 60 seconds. If you wish to use a
# custom timeout/deadline, use `TcpClient.with_timeout` instead.
#
# This method doesn't enforce a deadline on establishing the connection. If
# you need to limit the amount of time spent waiting to establish the
# connection, use `TcpClient.with_timeout` instead.
# For more details, refer to the documentation of `TcpClient.with_timeout`.
#
# # Panics
#
# This method panics if `ips` is empty.
#
# # Examples
#
# Connecting a `TcpClient`:
#
# ```inko
# import std.net.socket (TcpClient)
# import std.net.ip (IpAddress)
Expand All @@ -839,25 +1025,28 @@ type pub TcpClient {
ips: ref Array[IpAddress],
port: Int,
) -> Result[TcpClient, Error] {
try_ips(ips, fn (ip) {
let socket = try Socket.stream(ip.v6?)

try socket.connect(ip, port)
from(socket)
})
with_timeout(ips, port, Duration.from_secs(60))
}

# Creates a new `TcpClient` but limits the amount of time spent waiting for
# the connection to be established.
#
# If multiple IP addresses are given, this method attempts to connect to them
# in order, returning upon the first successful connection. If no connection
# can be established, the error of the last attempt is returned.
#
# The `timeout_after` argument specifies the deadline after which the
# `connect()` times out. The deadline is cleared once connected.
# `connect()` system call times out. This deadline is _not_ inherited by the
# returned `TcpClient`.
#
# # Connecting to multiple IP addresses
#
# If multiple IP addresses are given, this method attempts to connect to them
# in accordance with [RFC 8305](https://datatracker.ietf.org/doc/html/rfc8305)
# (also known as "Happy Eyeballs version 2"), with the following differences:
#
# See `TcpClient.new` for more information.
# - DNS requests are performed separately and thus not subject to the Happy
# Eyeballs algorithm.
# - We always interleave IPv6 and IPv4 addresses, starting with an IPv6
# address (so `IPv6, IPv4, IPv6, IPv4, ...`).
# - There's no way to configure this behavior, nor is it planned to add the
# ability to do so.
#
# # Panics
#
Expand All @@ -883,14 +1072,7 @@ type pub TcpClient {
port: Int,
timeout_after: ref T,
) -> Result[TcpClient, Error] {
try_ips(ips, fn (ip) {
let socket = try Socket.stream(ip.v6?)

socket.timeout_after = timeout_after
try socket.connect(ip, port)
socket.reset_deadline
from(socket)
})
connect_parallel(ips, port, timeout_after.to_instant)
}

# Returns the local address of this socket.
Expand Down
28 changes: 17 additions & 11 deletions std/src/std/time.inko
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import std.fmt (Format, Formatter)
import std.int (ToInt)
import std.locale (Locale)
import std.locale.en (Locale as English)
import std.ops (Add, Multiply, Subtract)
import std.ops (Add, Divide, Multiply, Subtract)
import std.string (Bytes)
import std.sys.unix.time (self as sys) if unix

Expand Down Expand Up @@ -270,31 +270,37 @@ impl Clone[Duration] for Duration {
}

impl Add[Duration, Duration] for Duration {
fn pub inline +(other: ref Duration) -> Duration {
fn pub inline +(other: Duration) -> Duration {
Duration(@nanos + other.nanos)
}
}

impl Subtract[Duration, Duration] for Duration {
fn pub inline -(other: ref Duration) -> Duration {
fn pub inline -(other: Duration) -> Duration {
Duration(@nanos - other.nanos)
}
}

impl Multiply[Int, Duration] for Duration {
fn pub inline *(other: ref Int) -> Duration {
fn pub inline *(other: Int) -> Duration {
Duration(@nanos * other)
}
}

impl Divide[Int, Duration] for Duration {
fn pub inline /(other: Int) -> Duration {
Duration(@nanos / other)
}
}

impl Compare[Duration] for Duration {
fn pub inline cmp(other: ref Duration) -> Ordering {
fn pub inline cmp(other: Duration) -> Ordering {
@nanos.cmp(other.nanos)
}
}

impl Equal[ref Duration] for Duration {
fn pub inline ==(other: ref Duration) -> Bool {
impl Equal[Duration] for Duration {
fn pub inline ==(other: Duration) -> Bool {
@nanos == other.nanos
}
}
Expand Down Expand Up @@ -1325,7 +1331,7 @@ impl Add[Duration, DateTime] for DateTime {
#
# This method may panic if the result can't be expressed as a `DateTime` (e.g.
# the year is too great).
fn pub +(other: ref Duration) -> DateTime {
fn pub +(other: Duration) -> DateTime {
let timestamp = to_float + other.to_secs

DateTime.from_timestamp(timestamp, utc_offset: @utc_offset).get
Expand All @@ -1340,7 +1346,7 @@ impl Subtract[Duration, DateTime] for DateTime {
#
# This method may panic if the result can't be expressed as a `DateTime` (e.g.
# the year is too great).
fn pub -(other: ref Duration) -> DateTime {
fn pub -(other: Duration) -> DateTime {
let timestamp = to_float - other.to_secs

DateTime.from_timestamp(timestamp, utc_offset: @utc_offset).get
Expand Down Expand Up @@ -1459,7 +1465,7 @@ impl ToFloat for Instant {
}

impl Add[Duration, Instant] for Instant {
fn pub inline +(other: ref Duration) -> Instant {
fn pub inline +(other: Duration) -> Instant {
let nanos = @nanos + other.nanos

if nanos < 0 { negative_time_error(nanos) }
Expand All @@ -1469,7 +1475,7 @@ impl Add[Duration, Instant] for Instant {
}

impl Subtract[Duration, Instant] for Instant {
fn pub inline -(other: ref Duration) -> Instant {
fn pub inline -(other: Duration) -> Instant {
let nanos = @nanos - other.nanos

if nanos < 0 { negative_time_error(nanos) }
Expand Down

0 comments on commit 18bae54

Please sign in to comment.