<Ct: rust> [ 2023-05-07 ] Rust Advanced - Async Trait 에 관하여
Rust
TOKIO-RS 의 AsyncRead 부분을 구현하는 쪽을 이용하여 STABLE 채널 [ 현재 1.65 ] 에서
async trait 을 구현해보도록 한다.
소스코드 출처 : https://github.com/tokio-rs/tokio
이하 소스코드는 바로 빌드 & 테스트 가능하도록 짜집기 해두었다.
의존성으로,
pin-project-lite
를 필요로하다.
src/lib.rs
use std::{
io::{ self, IoSlice },
ops::DerefMut,
pin::Pin,
task::{ Context, Poll },
future::Future,
marker::{ PhantomPinned, Unpin },
};
use pin_project_lite::pin_project;
use tokio::io::ReadBuf;
// ? : MACRO Poll:Ready(t) => t
macro_rules! ready {
($e:expr $(,)?) => {
match $e {
std::task::Poll::Ready(t) => t,
std::task::Poll::Pending => return std::task::Poll::Pending,
}
};
}
// ? : 이친구가 read.await 이런식으로 쓰일 애인가봄
// LIFETIME 'a
// R : AsyncRead + Unpin + ?Sized
// Return : Read<'a, R>
pub(crate) fn read<'a, R>(reader: &'a mut R, buf: &'a mut [u8]) -> Read<'a, R>
where
R: AsyncRead + Unpin + ?Sized,
{
Read {
reader,
buf,
_pin: PhantomPinned,
}
}
// TODO : 여기 봐줘야됨~
pin_project! {
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Read<'a, R: ?Sized> {
reader: &'a mut R,
buf: &'a mut [u8],
#[pin]
_pin: PhantomPinned,
}
}
impl<R> Future for Read<'_, R>
where
R: AsyncRead + Unpin + ?Sized,
{
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
let me = self.project();
let mut buf = ReadBuf::new(me.buf);
ready!(Pin::new(me.reader).poll_read(cx, &mut buf))?;
Poll::Ready(Ok(buf.filled().len()))
}
}
pub trait AsyncRead {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>>;
}
macro_rules! deref_async_read {
() => {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut **self).poll_read(cx, buf)
}
};
}
impl<T: ?Sized + AsyncRead + Unpin> AsyncRead for Box<T>
{
deref_async_read!();
}
impl<T: ?Sized + AsyncRead + Unpin> AsyncRead for &mut T
{
deref_async_read!();
}
impl<P> AsyncRead for Pin<P>
where
P: DerefMut + Unpin,
P::Target: AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.get_mut().as_mut().poll_read(cx, buf)
}
}
impl AsyncRead for &[u8]
{
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let amt = std::cmp::min(self.len(), buf.remaining());
let (a, b) = self.split_at(amt);
buf.put_slice(a);
*self = b;
Poll::Ready(Ok(()))
}
}
impl<T: AsRef<[u8]> + Unpin> AsyncRead for io::Cursor<T>
{
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let pos = self.position();
let slice: &[u8] = (*self).get_ref().as_ref();
if pos > slice.len() as u64 {
return Poll::Ready(Ok(()));
}
let start = pos as usize;
let amt = std::cmp::min(slice.len() - start, buf.remaining());
let end = start + amt;
buf.put_slice(&slice[start..end]);
self.set_position(end as u64);
Poll::Ready(Ok(()))
}
}
#[cfg(test)]
mod tests {
#[tokio::test]
async fn test_fn() {
println!("ASDFASDFASDF");
let mut buf = [0; 1024];
let mut reader = std::io::Cursor::new(&b"hello world"[..]);
let n = super::read(&mut reader, &mut buf).await.unwrap();
assert_eq!(&buf[..n], b"hello world");
}
}
ㅂ..분석
결국 read 라는 함수 하나를 만들기 위해 이런 짓을 해야된다는것,
use std::{
io,
future::Future,
pin::Pin,
marker::{ PhantomPinned, Unpin },
ops::DerefMut,
task::{ Context, Poll },
};
use pin_project_lite::pin_project;
use tokio::io::ReadBuf;
use util::ready;
// NOTE : ASYNC 구현체 시작
// TYPE : async 함수 본체
pub(crate) fn read<'a, R>(reader: &'a mut R, buf: &'a mut [u8]) -> MyAsync<'a, R>
where
R: AsyncTrait + Unpin + ?Sized,
{
MyAsync {
reader,
buf,
_pin: PhantomPinned,
}
}
// TYPE : async trait 을 impl 당한 구조체
pin_project! {
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct MyAsync<'a, R: ?Sized> {
// TYPE : .project() 를 통해 이쪽으로 접근 가능
reader: &'a mut R,
buf: &'a mut [u8],
#[pin]
_pin: PhantomPinned,
}
}
// TYPE : async function을 가지고있는 trait
pub trait AsyncTrait {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>>;
}
// NOTE : 여기까지가 ASYNC 구현체
// ======================================================================================================
// TYPE : 여기는 .await 를 통해 호출되는 함수를 구현하는 곳인듯함
impl<R> Future for MyAsync<'_, R>
where
R: AsyncTrait + Unpin + ?Sized,
{
type Output = io::Result<usize>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
let me = self.project();
let mut buf = ReadBuf::new(me.buf);
ready!(Pin::new(me.reader).poll_read(cx, &mut buf))?;
Poll::Ready(Ok(buf.filled().len()))
}
}
// TYPE : 여기는 매크로를 이용해, async_trait 을 구현하는 곳인듯함
macro_rules! deref_async_read {
() => {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut **self).poll_read(cx, buf)
}
};
}
// TYPE : 여기는 Box<T> 타입을 다룰때에 대해 ?
impl<T: ?Sized + AsyncTrait + Unpin> AsyncTrait for Box<T>
{
deref_async_read!();
}
// TYPE : 여기는.. &mut T 타입을 다룰때에 ?
impl<T: ?Sized + AsyncTrait + Unpin> AsyncTrait for &mut T
{
deref_async_read!();
}
// TYPE : 여기는 Pin<P> 타입을 다룰때에 ?
impl<P> AsyncTrait for Pin<P>
where
P: DerefMut + Unpin,
P::Target: AsyncTrait,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.get_mut().as_mut().poll_read(cx, buf)
}
}
// TYPE : 여기는 반환 타입이 &[u8] 이고 Cursor<T> 의 input 일때에 대한 impl
impl<T: AsRef<[u8]> + Unpin> AsyncTrait for io::Cursor<T>
{
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let pos = self.position();
let slice: &[u8] = (*self).get_ref().as_ref();
if pos > slice.len() as u64 {
return Poll::Ready(Ok(()));
}
let start = pos as usize;
let amt = std::cmp::min(slice.len() - start, buf.remaining());
let end = start + amt;
buf.put_slice(&slice[start..end]);
self.set_position(end as u64);
Poll::Ready(Ok(()))
}
}
// NOTE : 아마 여기서 사용할 수 있게 해주는것같은데...?
#[cfg(test)]
mod tests {
#[tokio::test]
async fn test_fn() {
println!("ASDFASDFASDF");
let mut buf = [0; 1024];
let mut reader = std::io::Cursor::new(&b"hello world"[..]);
let n = super::read(&mut reader, &mut buf).await.unwrap();
assert_eq!(&buf[..n], b"hello world");
}
}