diff --git a/src/tools.rs b/src/tools.rs index 186e6a9e..e1909364 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -16,6 +16,7 @@ use serde_json::Value; use proxmox::tools::vec; pub mod acl; +pub mod async_io; pub mod async_mutex; pub mod borrow; pub mod daemon; diff --git a/src/tools/async_io.rs b/src/tools/async_io.rs new file mode 100644 index 00000000..2ce01a68 --- /dev/null +++ b/src/tools/async_io.rs @@ -0,0 +1,111 @@ +//! Generic AsyncRead/AsyncWrite utilities. + +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::io::{AsyncRead, AsyncWrite}; + +pub enum EitherStream { + Left(L), + Right(R), +} + +impl AsyncRead for EitherStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut [u8], + ) -> Poll> { + match unsafe { self.get_unchecked_mut() } { + EitherStream::Left(ref mut s) => { + unsafe { Pin::new_unchecked(s) }.poll_read(cx, buf) + } + EitherStream::Right(ref mut s) => { + unsafe { Pin::new_unchecked(s) }.poll_read(cx, buf) + } + } + } + + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + match *self { + EitherStream::Left(ref s) => s.prepare_uninitialized_buffer(buf), + EitherStream::Right(ref s) => s.prepare_uninitialized_buffer(buf), + } + } + + fn poll_read_buf( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut B, + ) -> Poll> + where + B: bytes::BufMut, + { + match unsafe { self.get_unchecked_mut() } { + EitherStream::Left(ref mut s) => { + unsafe { Pin::new_unchecked(s) }.poll_read_buf(cx, buf) + } + EitherStream::Right(ref mut s) => { + unsafe { Pin::new_unchecked(s) }.poll_read_buf(cx, buf) + } + } + } +} + +impl AsyncWrite for EitherStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + match unsafe { self.get_unchecked_mut() } { + EitherStream::Left(ref mut s) => { + unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf) + } + EitherStream::Right(ref mut s) => { + unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf) + } + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match unsafe { self.get_unchecked_mut() } { + EitherStream::Left(ref mut s) => { + unsafe { Pin::new_unchecked(s) }.poll_flush(cx) + } + EitherStream::Right(ref mut s) => { + unsafe { Pin::new_unchecked(s) }.poll_flush(cx) + } + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match unsafe { self.get_unchecked_mut() } { + EitherStream::Left(ref mut s) => { + unsafe { Pin::new_unchecked(s) }.poll_shutdown(cx) + } + EitherStream::Right(ref mut s) => { + unsafe { Pin::new_unchecked(s) }.poll_shutdown(cx) + } + } + } + + fn poll_write_buf( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut B, + ) -> Poll> + where + B: bytes::Buf, + { + match unsafe { self.get_unchecked_mut() } { + EitherStream::Left(ref mut s) => { + unsafe { Pin::new_unchecked(s) }.poll_write_buf(cx, buf) + } + EitherStream::Right(ref mut s) => { + unsafe { Pin::new_unchecked(s) }.poll_write_buf(cx, buf) + } + } + } +}