diff --git a/Cargo.toml b/Cargo.toml index 784c14d86..6c806de24 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,14 +45,14 @@ members = [ futures-core = { version = "0.3", default-features = false } futures-sink = { version = "0.3", default-features = false } futures-util = { version = "0.3", default-features = false } -tokio-util = { version = "0.3.1", features = ["codec"] } -tokio = { version = "0.2", features = ["io-util"] } +tokio-util = { version = "0.4.0", features = ["codec"] } +tokio = { version = "0.3", features = ["io-util"] } bytes = "0.5.2" http = "0.2" tracing = { version = "0.1.13", default-features = false, features = ["std", "log"] } tracing-futures = { version = "0.2", default-features = false, features = ["std-future"]} fnv = "1.0.5" -slab = "0.4.0" +slab = "0.4.2" indexmap = "1.0" [dev-dependencies] @@ -68,9 +68,9 @@ serde = "1.0.0" serde_json = "1.0.0" # Examples -tokio = { version = "0.2", features = ["dns", "macros", "rt-core", "sync", "tcp"] } +tokio = { version = "0.3", features = ["rt-multi-thread", "macros", "sync", "net"] } env_logger = { version = "0.5.3", default-features = false } -rustls = "0.16" -tokio-rustls = "0.12.0" +rustls = "0.18" +tokio-rustls = "0.20.0" webpki = "0.21" webpki-roots = "0.17" diff --git a/examples/server.rs b/examples/server.rs index 1753b7a2e..777f4ea14 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -8,7 +8,7 @@ use tokio::net::{TcpListener, TcpStream}; async fn main() -> Result<(), Box> { let _ = env_logger::try_init(); - let mut listener = TcpListener::bind("127.0.0.1:5928").await?; + let listener = TcpListener::bind("127.0.0.1:5928").await?; println!("listening on {:?}", listener.local_addr()); diff --git a/src/codec/framed_write.rs b/src/codec/framed_write.rs index 201bba26c..53032ce23 100644 --- a/src/codec/framed_write.rs +++ b/src/codec/framed_write.rs @@ -3,13 +3,10 @@ use crate::codec::UserError::*; use crate::frame::{self, Frame, FrameSize}; use crate::hpack; -use bytes::{ - buf::{BufExt, BufMutExt}, - Buf, BufMut, BytesMut, -}; +use bytes::{buf::BufMutExt, Buf, BufMut, BytesMut}; use std::pin::Pin; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use std::io::{self, Cursor}; @@ -193,12 +190,26 @@ where match self.next { Some(Next::Data(ref mut frame)) => { tracing::trace!(queued_data_frame = true); - let mut buf = (&mut self.buf).chain(frame.payload_mut()); - ready!(Pin::new(&mut self.inner).poll_write_buf(cx, &mut buf))?; + + if self.buf.has_remaining() { + let n = + ready!(Pin::new(&mut self.inner).poll_write(cx, self.buf.bytes()))?; + self.buf.advance(n); + } + + let buf = frame.payload_mut(); + + if !self.buf.has_remaining() && buf.has_remaining() { + let n = ready!(Pin::new(&mut self.inner).poll_write(cx, buf.bytes()))?; + buf.advance(n); + } } _ => { tracing::trace!(queued_data_frame = false); - ready!(Pin::new(&mut self.inner).poll_write_buf(cx, &mut self.buf))?; + let n = ready!( + Pin::new(&mut self.inner).poll_write(cx, &mut self.buf.bytes()) + )?; + self.buf.advance(n); } } } @@ -290,25 +301,13 @@ impl FramedWrite { } impl AsyncRead for FramedWrite { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) - } - fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf, + ) -> Poll> { Pin::new(&mut self.inner).poll_read(cx, buf) } - - fn poll_read_buf( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut Buf, - ) -> Poll> { - Pin::new(&mut self.inner).poll_read_buf(cx, buf) - } } // We never project the Pin to `B`. diff --git a/src/server.rs b/src/server.rs index 3c093f7ee..32433121a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -127,7 +127,7 @@ use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; use std::{convert, fmt, io, mem}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tracing_futures::{Instrument, Instrumented}; /// In progress HTTP/2.0 connection handshake future. @@ -1158,8 +1158,10 @@ where let mut rem = PREFACE.len() - self.pos; while rem > 0 { - let n = ready!(Pin::new(self.inner_mut()).poll_read(cx, &mut buf[..rem])) + let mut buf = ReadBuf::new(&mut buf[..rem]); + ready!(Pin::new(self.inner_mut()).poll_read(cx, &mut buf)) .map_err(crate::Error::from_io)?; + let n = buf.filled().len(); if n == 0 { return Poll::Ready(Err(crate::Error::from_io(io::Error::new( io::ErrorKind::UnexpectedEof, @@ -1167,7 +1169,7 @@ where )))); } - if PREFACE[self.pos..self.pos + n] != buf[..n] { + if &PREFACE[self.pos..self.pos + n] != buf.filled() { proto_err!(conn: "read_preface: invalid preface"); // TODO: Should this just write the GO_AWAY frame directly? return Poll::Ready(Err(Reason::PROTOCOL_ERROR.into())); diff --git a/tests/h2-fuzz/Cargo.toml b/tests/h2-fuzz/Cargo.toml index 8bb121959..40e985de6 100644 --- a/tests/h2-fuzz/Cargo.toml +++ b/tests/h2-fuzz/Cargo.toml @@ -12,4 +12,4 @@ env_logger = { version = "0.5.3", default-features = false } futures = { version = "0.3", default-features = false, features = ["std"] } honggfuzz = "0.5" http = "0.2" -tokio = { version = "0.2", features = [] } +tokio = { version = "0.3", features = [] } diff --git a/tests/h2-support/Cargo.toml b/tests/h2-support/Cargo.toml index c4e68b1ee..183013f14 100644 --- a/tests/h2-support/Cargo.toml +++ b/tests/h2-support/Cargo.toml @@ -12,5 +12,5 @@ tracing = "0.1" tracing-subscriber = { version = "0.2", default-features = false, features = ["fmt", "chrono", "ansi"] } futures = { version = "0.3", default-features = false } http = "0.2" -tokio = { version = "0.2", features = ["time"] } -tokio-test = "0.2" +tokio = { version = "0.3", features = ["time"] } +tokio-test = "0.3" diff --git a/tests/h2-support/src/mock.rs b/tests/h2-support/src/mock.rs index 08837fa56..ebfc094c1 100644 --- a/tests/h2-support/src/mock.rs +++ b/tests/h2-support/src/mock.rs @@ -6,7 +6,7 @@ use h2::{self, RecvError, SendError}; use futures::future::poll_fn; use futures::{ready, Stream, StreamExt}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use super::assert::assert_frame_eq; use std::pin::Pin; @@ -147,10 +147,11 @@ impl Handle { poll_fn(move |cx| { while buf.has_remaining() { let res = Pin::new(self.codec.get_mut()) - .poll_write_buf(cx, &mut buf) + .poll_write(cx, &mut buf.bytes()) .map_err(|e| panic!("write err={:?}", e)); - ready!(res).unwrap(); + let n = ready!(res).unwrap(); + buf.advance(n); } Poll::Ready(()) @@ -294,8 +295,8 @@ impl AsyncRead for Handle { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf, + ) -> Poll> { Pin::new(self.codec.get_mut()).poll_read(cx, buf) } } @@ -344,10 +345,10 @@ impl AsyncRead for Mock { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf, + ) -> Poll> { assert!( - buf.len() > 0, + buf.remaining() > 0, "attempted read with zero length buffer... wut?" ); @@ -355,18 +356,18 @@ impl AsyncRead for Mock { if me.rx.is_empty() { if me.closed { - return Poll::Ready(Ok(0)); + return Poll::Ready(Ok(())); } me.rx_task = Some(cx.waker().clone()); return Poll::Pending; } - let n = cmp::min(buf.len(), me.rx.len()); - buf[..n].copy_from_slice(&me.rx[..n]); + let n = cmp::min(buf.remaining(), me.rx.len()); + buf.put_slice(&me.rx[..n]); me.rx.drain(..n); - Poll::Ready(Ok(n)) + Poll::Ready(Ok(())) } } @@ -427,10 +428,10 @@ impl AsyncRead for Pipe { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { + buf: &mut ReadBuf, + ) -> Poll> { assert!( - buf.len() > 0, + buf.remaining() > 0, "attempted read with zero length buffer... wut?" ); @@ -438,18 +439,18 @@ impl AsyncRead for Pipe { if me.tx.is_empty() { if me.closed { - return Poll::Ready(Ok(0)); + return Poll::Ready(Ok(())); } me.tx_task = Some(cx.waker().clone()); return Poll::Pending; } - let n = cmp::min(buf.len(), me.tx.len()); - buf[..n].copy_from_slice(&me.tx[..n]); + let n = cmp::min(buf.remaining(), me.tx.len()); + buf.put_slice(&me.tx[..n]); me.tx.drain(..n); - Poll::Ready(Ok(n)) + Poll::Ready(Ok(())) } } @@ -479,5 +480,5 @@ impl AsyncWrite for Pipe { } pub async fn idle_ms(ms: u64) { - tokio::time::delay_for(Duration::from_millis(ms)).await + tokio::time::sleep(Duration::from_millis(ms)).await } diff --git a/tests/h2-tests/Cargo.toml b/tests/h2-tests/Cargo.toml index 4c711fe24..b5f3c6eeb 100644 --- a/tests/h2-tests/Cargo.toml +++ b/tests/h2-tests/Cargo.toml @@ -11,4 +11,4 @@ edition = "2018" h2-support = { path = "../h2-support" } tracing = "0.1.13" futures = { version = "0.3", default-features = false, features = ["alloc"] } -tokio = { version = "0.2", features = ["macros", "tcp"] } +tokio = { version = "0.3", features = ["macros", "net", "rt", "io-util"] } diff --git a/tests/h2-tests/tests/codec_read.rs b/tests/h2-tests/tests/codec_read.rs index 95e895ddd..fe3cfea97 100644 --- a/tests/h2-tests/tests/codec_read.rs +++ b/tests/h2-tests/tests/codec_read.rs @@ -190,6 +190,7 @@ async fn read_continuation_frames() { #[tokio::test] async fn update_max_frame_len_at_rest() { use futures::StreamExt; + use tokio::io::AsyncReadExt; h2_support::trace_init!(); // TODO: add test for updating max frame length in flight as well? @@ -211,6 +212,10 @@ async fn update_max_frame_len_at_rest() { codec.next().await.unwrap().unwrap_err().to_string(), "frame with invalid size" ); + + // drain codec buffer + let mut buf = Vec::new(); + codec.get_mut().read_to_end(&mut buf).await.unwrap(); } #[tokio::test] diff --git a/tests/h2-tests/tests/flow_control.rs b/tests/h2-tests/tests/flow_control.rs index 4b6fe7a85..08019bbae 100644 --- a/tests/h2-tests/tests/flow_control.rs +++ b/tests/h2-tests/tests/flow_control.rs @@ -972,7 +972,7 @@ async fn settings_lowered_capacity_returns_capacity_to_connection() { // // A timeout is used here to avoid blocking forever if there is a // failure - let result = select(rx2, tokio::time::delay_for(Duration::from_secs(5))).await; + let result = select(rx2, tokio::time::sleep(Duration::from_secs(5))).await; if let Either::Right((_, _)) = result { panic!("Timed out"); } @@ -1004,7 +1004,7 @@ async fn settings_lowered_capacity_returns_capacity_to_connection() { }); // Wait for server handshake to complete. - let result = select(rx1, tokio::time::delay_for(Duration::from_secs(5))).await; + let result = select(rx1, tokio::time::sleep(Duration::from_secs(5))).await; if let Either::Right((_, _)) = result { panic!("Timed out"); } diff --git a/tests/h2-tests/tests/hammer.rs b/tests/h2-tests/tests/hammer.rs index cf7051814..9a200537a 100644 --- a/tests/h2-tests/tests/hammer.rs +++ b/tests/h2-tests/tests/hammer.rs @@ -26,8 +26,8 @@ impl Server { { let mk_data = Arc::new(mk_data); - let mut rt = tokio::runtime::Runtime::new().unwrap(); - let mut listener = rt + let rt = tokio::runtime::Runtime::new().unwrap(); + let listener = rt .block_on(TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))) .unwrap(); let addr = listener.local_addr().unwrap(); @@ -140,7 +140,7 @@ fn hammer_client_concurrency() { }) }); - let mut rt = tokio::runtime::Runtime::new().unwrap(); + let rt = tokio::runtime::Runtime::new().unwrap(); rt.block_on(tcp); println!("...done"); }