diff --git a/src/sys/windows/named_pipe.rs b/src/sys/windows/named_pipe.rs index eb35d3797..39c42ef8c 100644 --- a/src/sys/windows/named_pipe.rs +++ b/src/sys/windows/named_pipe.rs @@ -7,14 +7,14 @@ use std::sync::{Arc, Mutex}; use std::{fmt, mem, slice}; use windows_sys::Win32::Foundation::{ - ERROR_BROKEN_PIPE, ERROR_IO_INCOMPLETE, ERROR_IO_PENDING, ERROR_NO_DATA, ERROR_PIPE_CONNECTED, - ERROR_PIPE_LISTENING, HANDLE, INVALID_HANDLE_VALUE, + ERROR_BROKEN_PIPE, ERROR_IO_INCOMPLETE, ERROR_IO_PENDING, ERROR_MORE_DATA, ERROR_NO_DATA, + ERROR_PIPE_CONNECTED, ERROR_PIPE_LISTENING, HANDLE, INVALID_HANDLE_VALUE, }; use windows_sys::Win32::Storage::FileSystem::{ ReadFile, WriteFile, FILE_FLAG_FIRST_PIPE_INSTANCE, FILE_FLAG_OVERLAPPED, PIPE_ACCESS_DUPLEX, }; use windows_sys::Win32::System::Pipes::{ - ConnectNamedPipe, CreateNamedPipeW, DisconnectNamedPipe, PIPE_TYPE_BYTE, + ConnectNamedPipe, CreateNamedPipeW, DisconnectNamedPipe, PeekNamedPipe, PIPE_TYPE_BYTE, PIPE_UNLIMITED_INSTANCES, }; use windows_sys::Win32::System::IO::{ @@ -27,6 +27,8 @@ use crate::sys::windows::{Event, Handle, Overlapped}; use crate::Registry; use crate::{Interest, Token}; +const MAX_BUFFER_SZ: usize = 65536; + /// Non-blocking windows named pipe. /// /// This structure internally contains a `HANDLE` which represents the named @@ -307,6 +309,25 @@ impl Inner { Ok(transferred as usize) } } + + /// Calls the `PeekNamedPipe` function to get the remaining size of message in NamedPipe + #[inline] + unsafe fn remaining_size(&self) -> io::Result { + let mut remaining = 0; + let r = PeekNamedPipe( + self.handle.raw(), + std::ptr::null_mut(), + 0, + std::ptr::null_mut(), + std::ptr::null_mut(), + &mut remaining, + ); + if r == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(remaining as usize) + } + } } #[test] @@ -349,6 +370,7 @@ enum State { Pending(Vec, usize), Ok(Vec, usize), Err(io::Error), + InsufficientBufferSize(Vec, usize), } // Odd tokens are for named pipes @@ -535,7 +557,7 @@ impl<'a> Read for &'a NamedPipe { } // We previously read something into `data`, try to copy out some - // data. If we copy out all the data schedule a new read and + // data. If we copy out all the data, schedule a new read // otherwise store the buffer to get read later. State::Ok(data, cur) => { let n = { @@ -552,6 +574,10 @@ impl<'a> Read for &'a NamedPipe { Ok(n) } + // We scheduled another read with a bigger buffer after the first read (see `read_done`) + // This is not possible in theory, just like `State::None` case, but return would block for now. + State::InsufficientBufferSize(..) => Err(would_block()), + // Looks like an in-flight read hit an error, return that here while // we schedule a new one. State::Err(e) => { @@ -703,19 +729,26 @@ impl Inner { /// scheduled. fn schedule_read(me: &Arc, io: &mut Io, events: Option<&mut Vec>) -> bool { // Check to see if a read is already scheduled/completed - match io.read { - State::None => {} - _ => return true, - } + + let mut buf = match mem::replace(&mut io.read, State::None) { + State::None => me.get_buffer(), + State::InsufficientBufferSize(mut buf, rem) => { + let sz_rem = std::cmp::min(rem, MAX_BUFFER_SZ); + buf.reserve_exact(sz_rem); + buf + } + e @ _ => { + io.read = e; + return true; + } + }; // Allocate a buffer and schedule the read. - let mut buf = me.get_buffer(); let e = unsafe { let overlapped = me.read.as_ptr() as *mut _; let slice = slice::from_raw_parts_mut(buf.as_mut_ptr(), buf.capacity()); - me.read_overlapped(slice, overlapped) + me.read_overlapped(&mut slice[buf.len()..], overlapped) }; - match e { // See `NamedPipe::connect` above for the rationale behind `forget` Ok(_) => { @@ -874,9 +907,29 @@ fn read_done(status: &OVERLAPPED_ENTRY, events: Option<&mut Vec>) { match me.result(status.overlapped()) { Ok(n) => { debug_assert_eq!(status.bytes_transferred() as usize, n); - buf.set_len(status.bytes_transferred() as usize); + // Extend the len depending on the initial len is necessary + // when we call `ReadFile` again after resizing + // our internal buffer + buf.set_len(buf.len() + status.bytes_transferred() as usize); io.read = State::Ok(buf, 0); } + Err(e) if e.raw_os_error() == Some(ERROR_MORE_DATA as i32) => { + buf.set_len(status.bytes_transferred() as usize); + match me.remaining_size() { + Ok(rem) if rem == 0 => { + io.read = State::Ok(buf, 0); + } + Ok(rem) => { + io.read = State::InsufficientBufferSize(buf, rem); + Inner::schedule_read(&me, &mut io, None); + return; + } + Err(_e) => { + // When `PeekNamedPipe` encountered an error, truncate and return whatever is recoverable from the bytes + io.read = State::Ok(buf, 0); + } + } + } Err(e) => { debug_assert_eq!(status.bytes_transferred(), 0); io.read = State::Err(e); diff --git a/tests/win_named_pipe.rs b/tests/win_named_pipe.rs index e79d2fba4..a2c07e829 100644 --- a/tests/win_named_pipe.rs +++ b/tests/win_named_pipe.rs @@ -1,15 +1,25 @@ #![cfg(all(windows, feature = "os-poll", feature = "os-ext"))] +use std::ffi::OsStr; use std::fs::OpenOptions; -use std::io::{self, Read, Write}; +use std::io::{self, ErrorKind, Read, Write}; +use std::iter; +use std::os::windows::ffi::OsStrExt; use std::os::windows::fs::OpenOptionsExt; -use std::os::windows::io::{FromRawHandle, IntoRawHandle}; +use std::os::windows::io::{FromRawHandle, IntoRawHandle, RawHandle}; use std::time::Duration; use mio::windows::NamedPipe; use mio::{Events, Interest, Poll, Token}; use rand::Rng; -use windows_sys::Win32::{Foundation::ERROR_NO_DATA, Storage::FileSystem::FILE_FLAG_OVERLAPPED}; +use windows_sys::Win32::Foundation::ERROR_NO_DATA; +use windows_sys::Win32::Storage::FileSystem::{ + CreateFileW, FILE_FLAG_FIRST_PIPE_INSTANCE, FILE_FLAG_OVERLAPPED, OPEN_EXISTING, + PIPE_ACCESS_DUPLEX, +}; +use windows_sys::Win32::System::Pipes::{ + CreateNamedPipeW, PIPE_READMODE_MESSAGE, PIPE_TYPE_MESSAGE, PIPE_UNLIMITED_INSTANCES, +}; fn _assert_kinds() { fn _assert_send() {} @@ -43,6 +53,38 @@ fn client(name: &str) -> NamedPipe { unsafe { NamedPipe::from_raw_handle(file.into_raw_handle()) } } +fn pipe_msg_mode() -> (NamedPipe, NamedPipe) { + let num: u64 = rand::thread_rng().gen(); + let name = format!(r"\\.\pipe\my-pipe-{}", num); + let name: Vec<_> = OsStr::new(&name).encode_wide().chain(Some(0)).collect(); + unsafe { + let h = CreateNamedPipeW( + name.as_ptr(), + PIPE_ACCESS_DUPLEX | FILE_FLAG_FIRST_PIPE_INSTANCE | FILE_FLAG_OVERLAPPED, + PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE, + PIPE_UNLIMITED_INSTANCES, + 65536, + 65536, + 0, + std::ptr::null_mut(), + ); + + let server = NamedPipe::from_raw_handle(h as RawHandle); + + let h = CreateFileW( + name.as_ptr(), + PIPE_ACCESS_DUPLEX, + 0, + std::ptr::null_mut(), + OPEN_EXISTING, + FILE_FLAG_OVERLAPPED, + 0, + ); + let client = NamedPipe::from_raw_handle(h as RawHandle); + (server, client) + } +} + fn pipe() -> (NamedPipe, NamedPipe) { let (pipe, name) = server(); (pipe, client(&name)) @@ -108,6 +150,91 @@ fn write_then_read() { assert_eq!(&buf[..4], b"1234"); } +#[test] +fn read_sz_greater_than_default_buf_size() { + let (mut server, mut client) = pipe_msg_mode(); + let mut poll = t!(Poll::new()); + t!(poll.registry().register( + &mut server, + Token(0), + Interest::READABLE | Interest::WRITABLE, + )); + t!(poll.registry().register( + &mut client, + Token(1), + Interest::READABLE | Interest::WRITABLE, + )); + + let mut events = Events::with_capacity(128); + let msg = (0..4106) + .map(|e| e.to_string()) + .collect::>() + .join(""); + + t!(poll.poll(&mut events, None)); + assert_eq!(t!(client.write(msg.as_bytes())), 15314); + + loop { + t!(poll.poll(&mut events, None)); + let events = events.iter().collect::>(); + if let Some(event) = events.iter().find(|e| e.token() == Token(0)) { + if event.is_readable() { + break; + } + } + } + + let mut buf = [0; 15314]; + assert_eq!(t!(server.read(&mut buf)), 15314); + assert_eq!(&buf[..15314], msg.as_bytes()); +} + +#[test] +fn multi_read_sz_greater_than_default_buf_size() { + let (mut server, mut client) = pipe_msg_mode(); + let mut poll = t!(Poll::new()); + t!(poll.registry().register( + &mut server, + Token(0), + Interest::READABLE | Interest::WRITABLE, + )); + + std::thread::spawn(move || { + let msgs = vec!["hello".repeat(10), "world".repeat(100), "mio".repeat(1000)]; + + let mut poll = t!(Poll::new()); + t!(poll.registry().register( + &mut client, + Token(1), + Interest::READABLE | Interest::WRITABLE, + )); + let mut events = Events::with_capacity(128); + for msg in msgs.iter() { + t!(poll.poll(&mut events, None)); + t!(client.write(msg.as_bytes())); + } + }); + + let mut events = Events::with_capacity(128); + let msgs = vec!["hello".repeat(10), "world".repeat(100), "mio".repeat(1000)]; + for m in msgs.into_iter() { + let m = m.as_bytes(); + loop { + t!(poll.poll(&mut events, None)); + let events = events.iter().collect::>(); + if let Some(event) = events.iter().find(|e| e.token() == Token(0)) { + let mut buf = [0; 3000]; + let Ok(read) = server.read(&mut buf) else { + continue; + }; + assert_eq!(read, m.len()); + assert_eq!(buf[..read], *m); + break; + } + } + } +} + #[test] fn connect_before_client() { let (mut server, name) = server();