I've found that one of the best ways to understand a new concept is to start from the very beginning. Start from a place where it doesn't exist yet and recreate it yourself, learning in the process not just how it works, but why it was designed the way it was.
This isn't a practical guide to async, but hopefully some of the background knowledge it covers will help you think about asynchronous problems, or at least fulfill your curiosity, without boring you with too many details.
..It's still really long.
Table Of Contents
A Simple Web Server
We'll start our journey into async programming with the simplest of web servers, using nothing but the standard networking types in std::net
. Our server just needs to accept HTTP requests and reply with a basic response. We'll ignore most of the HTTP specification, or write any useful application code, for the entirety of this post, focusing instead on the basic flow of the server.
HTTP is a text-based protocol built on top of TCP, so to start we have to accept TCP connections. We can do that by creating a TcpListener
.
use std::net::TcpListener;
fn main() {
let listener = TcpListener::bind("localhost:3000").unwrap();
}
And listening for incoming connections, handling them one by one.
use std::net::{TcpListener, TcpStream};
use std::io;
fn main() {
// ...
loop {
let (connection, _) = listener.accept().unwrap();
if let Err(e) = handle_connection(connection) {
println!("failed to handle connection: {e}")
}
}
}
fn handle_connection(connection: TcpStream) -> io::Result<()> {
// ...
}
TCP connections are represented by the TcpStream
type, a bidirectional stream of data between us and the client. It implements the Read
and Write
traits, abstracting away the internal details of TCP and allowing us to read or write plain old bytes.
As a server, we need to receive the HTTP request. We'll initialize a small buffer to hold the request.
fn handle_connection(connection: TcpStream) -> io::Result<()> {
let mut request = [0u8; 1024];
// ...
Ok(())
}
And then call read
on the connection, reading the request bytes into the buffer. read
will fill the buffer with an arbitrary number of bytes, not necessarily the entire request at once. So we have to keep track of the total number of bytes we've read and call it in a loop, reading the rest of the request into the unfilled part of the buffer.
use std::io::Read;
fn handle_connection(mut connection: TcpStream) -> io::Result<()> {
let mut read = 0;
let mut request = [0u8; 1024];
loop {
// try reading from the stream
let num_bytes = connection.read(&mut request[read..])?;
// keep track of how many bytes we've read
read += num_bytes;
}
Ok(())
}
Finally, we have to check for the byte sequence \r\n\r\n
, which indicates the end of the request.
fn handle_connection(mut connection: TcpStream) -> io::Result<()> {
let mut read = 0;
let mut request = [0u8; 1024];
loop {
// try reading from the stream
let num_bytes = connection.read(&mut request[read..])?;
// keep track of how many bytes we've read
read += num_bytes;
// have we reached the end of the request?
if request.get(read - 4..read) == Some(b"\r\n\r\n") {
break;
}
}
Ok(())
}
It's also possible for read
to return zero bytes, which can happen when the client disconnects. If the client disconnects without sending an entire request we can simply return and move on to the next connection.
Again, don't worry too much about sticking to the HTTP specification, we're just trying to get something working.
fn handle_connection(mut connection: TcpStream) -> io::Result<()> {
let mut read = 0;
let mut request = [0u8; 1024];
loop {
// try reading from the stream
let num_bytes = connection.read(&mut request[read..])?;
// the client disconnected
if num_bytes == 0 { // 👈
println!("client disconnected unexpectedly");
return Ok(());
}
// keep track of how many bytes we've read
read += num_bytes;
// have we reached the end of the request?
if request.get(read - 4..read) == Some(b"\r\n\r\n") {
break;
}
}
Ok(())
}
Once we've read in the entire request, we can convert it to a string and log it to the console.
fn handle_connection(stream: TcpStream) -> io::Result<()> {
let mut read = 0;
let mut request = [0u8; 1024];
loop {
// ...
}
let request = String::from_utf8_lossy(&request[..read]);
println!("{request}");
Ok(())
}
Now we have to write our response.
Just like read
, a call to write
may not write the entire buffer at once. We need a second loop to ensure the entire response is written to the client, with each call to write
continuing from where the previous left off.
use std::io::Write;
fn handle_connection(mut connection: TcpStream) -> io::Result<()> {
// ...
// "Hello World!" in HTTP
let response = concat!(
"HTTP/1.1 200 OK\r\n",
"Content-Length: 12\n",
"Connection: close\r\n\r\n",
"Hello world!"
);
let mut written = 0;
loop {
// write the remaining response bytes
let num_bytes = connection.write(response[written..].as_bytes())?;
// the client disconnected
if num_bytes == 0 {
println!("client disconnected unexpectedly");
return Ok(());
}
written += num_bytes;
// have we written the whole response yet?
if written == response.len() {
break;
}
}
}
And finally, we'll call flush
to ensure that the response is written to the client 1.
fn handle_connection(mut connection: TcpStream) -> io::Result<()> {
let mut written = 0;
loop {
// ...
}
// flush the response
connection.flush()
}
There you go, a working HTTP server!
$ curl localhost:3000
# => Hello world!
A Multithreaded Server
Alright, so our server works. But there's a problem.
Take a look at our accept loop.
let listener = TcpListener::bind("localhost:3000").unwrap();
loop {
let (connection, _) = listener.accept().unwrap();
if let Err(e) = handle_connection(connection) {
// ...
}
}
See the problem?
Our server can only serve a single request at a time.
Reading and writing from/to a network connection isn't instantaneous, there's a lot of infrastructure between us and the user. What would happen if two users made a request to our server at the same time, or ten, or ten thousand? Obviously this isn't going to scale, so what do we do?
We have a couple of options, but by far the simplest one is to spawn some threads. Spawn a thread for each request and our server becomes infinitely faster, right?
fn main() {
// ...
loop {
let (connection, _) = listener.accept().unwrap();
// spawn a thread to handle each connection
std::thread::spawn(|| {
if let Err(e) = handle_connection(connection) {
// ...
}
});
}
}
In fact, it probably does! Maybe not infinitely, but with each request being handled in a separate thread, the potential throughput of our server increases significantly.
How exactly does that work?
On Linux, as well as most other modern operation systems, every program is run as a separate process. While it seems like every active program is run at the same time, it's only physically possible for a single CPU core to execute a single task at a time, or possibly two with hyperthreading. To allow all programs to make progress, the kernel constantly switches between them, pausing one program to give another a chance to run. These context switches happen on the order of milliseconds, providing the illusion of paralellism.
The kernel scheduler can take advantage of multiple cores by distributing its workload across them. Each core manages a subset of processes 2, meaning that some programs do in fact get to run in parallel.
cpu1 cpu2 cpu3 cpu4
|----|----|----|----|
| p1 | p3 | p5 | p7 |
| |____| |____|
| | |____| |
|____| p4 | | p8 |
| | | p6 |____|
| p2 |____| | |
| | p3 | | p7 |
| | | | |
This type of scheduling is known as preemptive multitasking. The kernel decides when one process has been running for too long and preempts it, switching to someone else.
Processes work well for distinct programs because the kernel ensures that separate processes aren't allowed to access each other's memory. However, this makes context switching more expensive because the kernel has to flush certain parts of memory before performing a context switch to ensure that memory is properly isolated 3.
Threads are similar to processes 4 except they are allowed to share memory with other threads in the parent process, which is what allows us to share state between threads within the same program. Scheduling works much the same way.
The key insight regarding thread-per-request is that our server is I/O bound. Most of the time inside handle_connection
is not spent doing compute work, it's spent waiting to send or receive some data across the network. Functions like read
, write
, and flush
perform blocking I/O. We submit an I/O request, yielding control to the kernel, and it returns control to us when the operation completes. In the meantime, the kernel can execute other runnable threads, which is exactly what we want!
In general, most of the time it takes to serve a web request is spent waiting for other tasks to complete, like database queries or HTTP requests. Multithreading works great because we can utilize that time to handle other requests.
A Non-Blocking Server
It seems like threads do exactly what we need, and they're easy to use, so why not stop here?
You may have heard that threads are too heavyweight and context switching is very expensive. Nowadays, that's not really true. Modern servers can manage tens of thousands of threads without breaking a sweat.
The issue is that blocking I/O yields complete control of our program to the kernel until the requested operation completes. We have no say in when we get to run again. This is problematic because it makes it very difficult to model two operations: cancellation, and selection.
Imagine we wanted to implement graceful shutdown for our server. When someone hits ctrl+c, instead of killing the program abruptly, we should stop accepting new connections but still wait for any active requests to complete. Any requests that take more than thirty seconds to finish are killed as the server exits.
This poses a problem when dealing with blocking I/O. Our accept loop blocks until the next connection comes in. We can check for the ctrl+c signal before or after a new connection comes in, but if the signal is triggered during a call to accept
, we have no choice but to wait until the next connection is accepted. The kernel has complete control over the execution of our program.
loop {
// check before we call `accept`
if got_ctrl_c() {
break;
}
// **what if ctrl+c happens here?**
let (connection, _) = listener.accept().unwrap();
// this won't be checked until *after* we accept a new connection
if got_ctrl_c() {
break;
}
std::thread::spawn(|| /* ... */);
}
What we want is to listen for both the incoming connection and the ctrl+c signal, at the same time. Like a match
statement, but for I/O operations.
loop {
// if only...
match {
ctrl_c() => {
break;
},
Ok((connection, _)) = listener.accept() => {
std::thread::spawn(|| ...);
}
}
}
And what about timing out long running requests after thirty seconds? We could set a flag that tells threads to stop, but how often would they check it? We again run into the problem that we lose control of our program during I/O and have no choice but to wait until it completes. There's really no good way of force cancelling a thread.
Problems like these are where threads and blocking I/O fall apart. Expressing event-based logic becomes very difficult when the kernel holds so much control over the execution of our program.
There are ways to accomplish this using platform specific interfaces such as Unix signal handlers. While this approach is simple and can work well, signal handlers often become quite cumbersome to work with outside of simple use cases. By the end of the post you'll see another method of expressing complex control flow and decide what is better for your use case.
But what if there was a way to perform I/O without yielding to the kernel?
It turns out there is a second way to perform I/O, known as non-blocking I/O. As the name suggests, a non-blocking operation will never block the calling thread. Instead it returns immediately, returning an error if the given resource was not available.
We can switch to using non-blocking I/O by putting our TCP listener and streams into non-blocking mode.
let listener = TcpListener::bind("localhost:3000").unwrap();
listener.set_nonblocking(true).unwrap();
loop {
let (connection, _) = listener.accept().unwrap();
connection.set_nonblocking(true).unwrap();
// ...
}
Non-blocking I/O works a little differently. If the I/O request cannot be fulfilled immediately, instead of blocking, the kernel simply returns the WouldBlock
error code. Despite being represented as an error, WouldBlock
isn't really an error condition. It just means the operation could not be performed immediately, giving us the chance to decide what to do instead of blocking.
use std::io;
// ...
listener.set_nonblocking(true).unwrap();
loop {
let connection = match listener.accept() {
Ok((connection, _)) => connection,
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
// the operation was not performed
// ...
}
Err(e) => panic!("{e}"),
};
connection.set_nonblocking(true).unwrap();
// ...
}
Imagine we called accept
when there were no incoming connections. With blocking I/O, the accept
would have blocked until a new connection came in. Now instead of yielding control to the kernel, WouldBlock
puts the control back in our hands.
Our I/O doesn't block, great! But what do we actually do when something isn't ready?
WouldBlock
is a temporary state, meaning at some point in the future the socket should become ready to read or write from. So technically, we could just spin until the socket becomes ready.
loop {
let connection = match listener.accept() {
Ok((connection, _)) => connection,
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
continue; // 👈
}
Err(e) => panic!("{e}"),
};
}
But spinning is really just worse than blocking directly. When we block for I/O, the OS gives other threads a chance to run. So what we really need is to build some sort of a scheduler for all of our tasks, doing what the operating system used to handle for us.
Let's walk things through from the beginning again.
We create a TCP listener.
let listener = TcpListener::bind("localhost:3000").unwrap();
Set it to non-blocking mode.
listener.set_nonblocking(true).unwrap();
And then start our main loop. The first thing we'll do is try accepting a new TCP connection.
// ...
loop {
match listener.accept() {
Ok((connection, _)) => {
connection.set_nonblocking(true).unwrap();
// ...
},
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => panic!("{e}"),
}
}
Now, we can't just continue serving that connection directly and forget about everyone else. Instead, we have to keep track of all our active connections.
// ...
let mut connections = Vec::new(); // 👈
loop {
match listener.accept() {
Ok((connection, _)) => {
connection.set_nonblocking(true).unwrap();
connections.push(connection); // 👈
},
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => panic!("{e}"),
}
}
But we can't keep accepting connections forever. We don't have the luxury of OS scheduling anymore, so we need to handle running a little bit of everything in every iteration of the main loop. After trying to accept
once, we need to deal with the active connections.
For each connection, we have to perform whatever operation is needed to move the processing of the request forward, whether that means reading the request, or writing the response.
// ...
loop {
// try accepting a new connection
match listener.accept() {
// ...
}
// attempt to make progress on active connections
for connection in connections.iter_mut() {
// 🤔
}
}
Uhh...
If you remember the handle_connection
function from before:
fn handle_connection(mut connection: TcpStream) -> io::Result<()> {
let mut request = [0u8; 1024];
let mut read = 0;
loop {
let num_bytes = connection.read(&mut request[read..])?; // 👈
// ...
}
let request = String::from_utf8_lossy(&request[..read]);
println!("{request}");
let response = /* ... */;
let mut written = 0;
loop {
let num_bytes = connection.write(&response[written..])?; // 👈
// ...
}
connection.flush().unwrap(); // 👈
}
We perform three different I/O operations, read
, write
, and flush
. With blocking I/O we could write our code sequentially, but now we have to deal with the fact that at any point when performing I/O, we could face WouldBlock
and won't be able to make progress.
We can't simply drop everything and move on to the next active connection, we need to keep track of its current state in order to resume from the correct point when we come back.
We can represent the three possible states of handle_connection
in an enum.
enum ConnectionState {
Read,
Write,
Flush
}
Remember, we don't need separate states for things like converting the request to a string, we only need states for places where we might encounter WouldBlock
.
The Read
and Write
states also need to hold on to some local state for the request/response buffers and the number of bytes that have already been read/written. These used to be local variables in our function, but now we need them to persist across iterations of our main loop.
enum ConnectionState {
Read {
request: [u8; 1024],
read: usize
},
Write {
response: &'static [u8],
written: usize,
},
Flush,
}
Connections start in the Read
state with an empty buffer and zero bytes read, the same variables we used to initialize at the very start of handle_connection
.
// ...
let mut connections = Vec::new();
loop {
match listener.accept() {
Ok((connection, _)) => {
connection.set_nonblocking(true).unwrap();
let state = ConnectionState::Read { // 👈
request: [0u8; 1024],
read: 0,
};
connections.push((connection, state));
},
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => panic!("{e}"),
}
}
Now we can try to drive each connection forward from its current state.
// ...
loop {
match listener.accept() {
// ...
}
for (connection, state) in connections.iter_mut() {
if let ConnectionState::Read { request, read } = state {
// ...
}
if let ConnectionState::Write { response, written } = state {
// ...
}
if let ConnectionState::Flush = state {
// ...
}
}
}
If the connection is still in the read state, we can continue reading the request same as we did before. The only difference is that when we receive WouldBlock
, we have to move on to the next connection.
// ...
'next: for (connection, state) in connections.iter_mut() {
if let ConnectionState::Read { request, read } = state {
loop {
// try reading from the stream
match connection.read(&mut request[*read..]) {
Ok(n) => {
// keep track of how many bytes we've read
*read += n
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
// not ready yet, move on to the next connection
continue 'next; // 👈
}
Err(e) => panic!("{e}"),
}
// did we reach the end of the request?
if request.get(*read - 4..*read) == Some(b"\r\n\r\n") {
break;
}
}
// we're done, print the request
let request = String::from_utf8_lossy(&request[..*read]);
println!("{request}");
}
// ...
}
We also have to deal with the case where we read zero bytes. Before we could simply return from the connection handler and the state would be cleaned up for us, but now we have to remove the connection ourselves. Because we're currently iterating through the connections list, we'll store a separate list of indices to remove after we finish.
let mut completed = Vec::new(); // 👈
'next: for (i, (connection, state)) in connections.iter_mut().enumerate() {
if let ConnectionState::Read { request, read } = state {
loop {
// try reading from the stream
match connection.read(&mut request[*read..]) {
Ok(0) => {
println!("client disconnected unexpectedly");
completed.push(i); // 👈
continue 'next;
}
Ok(n) => *read += n,
// not ready yet, move on to the next connection
Err(e) if e.kind() == io::ErrorKind::WouldBlock => continue 'next,
Err(e) => panic!("{e}"),
}
// ...
}
// ...
}
}
// iterate in reverse order to preserve the indices
for i in completed.into_iter().rev() {
connections.remove(i); // 👈
}
Once we finish reading the request, we have to transition into the Write
state and attempt to write the response. The control flow around writing the response is very similar to reading, transitioning to the Flush
state once we finish.
if let ConnectionState::Read { request, read } = state {
// ...
// move into the write state
let response = concat!(
"HTTP/1.1 200 OK\r\n",
"Content-Length: 12\n",
"Connection: close\r\n\r\n",
"Hello world!"
);
*state = ConnectionState::Write { // 👈
response: response.as_bytes(),
written: 0,
};
}
if let ConnectionState::Write { response, written } = state {
loop {
match connection.write(&response[*written..]) {
Ok(0) => {
println!("client disconnected unexpectedly");
completed.push(i);
continue 'next;
}
Ok(n) => {
*written += n;
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
// not ready yet, move on to the next connection
continue 'next;
}
Err(e) => panic!("{e}"),
}
// did we write the whole response yet?
if *written == response.len() {
break;
}
}
// successfully wrote the response, try flushing next
*state = ConnectionState::Flush;
}
And after we successfully flush the response, we can mark the connection as completed and have it removed from the list.
if let ConnectionState::Flush = state {
match connection.flush() {
Ok(_) => {
completed.push(i); // 👈
},
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
// not ready yet, move on to the next connection
continue 'next;
}
Err(e) => panic!("{e}"),
}
}
That's it! Here's the new high-level flow of the server:
// bind the listener
let listener = TcpListener::bind("localhost:3000").unwrap();
listener.set_nonblocking(true).unwrap();
let mut connections = Vec::new();
loop {
// try accepting a connection
match listener.accept() {
Ok((connection, _)) => {
connection.set_nonblocking(true).unwrap();
// keep track of connection state
let state = ConnectionState::Read {
request: Vec::new(),
read: 0,
};
connections.push((connection, state));
},
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => panic!("{e}"),
}
let mut completed = Vec::new();
// try to drive connnections forward
'next: for (i, (connection, state)) in connections.iter_mut().enumerate() {
if let ConnectionState::Read { request, read } = state {
// ...
*state = ConnectionState::Write { response, written };
}
if let ConnectionState::Write { response, written } = state {
// ...
*state = ConnectionState::Flush;
}
if let ConnectionState::Flush = state {
// ...
}
}
// remove any connections that completed, iterating in reverse order
// to preserve the indices
for i in completed.into_iter().rev() {
connections.remove(i);
}
}
Now that we have to manage scheduling ourselves.. things are a lot more complicated.
And now for the moment of truth...
$ curl localhost:3000
# => Hello world!
It works!
A Multiplexed Server
Our server can now handle running multiple requests concurrently on a single thread. Nothing ever blocks. If some operation would have blocked, it remembers the current state and moves on to run something else, much like the kernel scheduler was doing for us. However, our new design introduces two new problems.
The first problem is that everything runs on the main thread, utilizing only a single CPU core. We're doing the best we can to use that one core efficiently, but we're still only running a single thing at a time. With threads spread across multiple cores, we could be doing much more.
There's a bigger problem though.
Our main loop isn't actually very efficient.
We're making an I/O request to the kernel for every single active connection, every single iteration of the loop, to check if it's ready. A call to read
or write
, even if it returns WouldBlock
and doesn't actually perform any I/O, is still a syscall. Syscalls aren't cheap. We might have 10k active connections but only 500 of them are ready. Calling read
or write
10k times when only 500 of them will actually do anything is a massive waste of CPU cycles.
As the number of connections scales, our loop becomes less and less efficient, wasting more time doing useless work.
How do we fix this? With blocking I/O the kernel was able to schedule things efficiently because it knows when resources become ready. With non-blocking I/O, we don't know without checking. But checking is expensive.
What we need is an efficient way to keep track of all of our active connections, and somehow get notified when they become ready.
It turns out, we aren't the first to run into this problem. Every operating system provides a solution for exactly this. On Linux, it's called epoll
.
epoll(7)
- I/O event notification facilityThe epoll API performs a similar task to
poll(2)
: monitoring multiple file descriptors to see if I/O is possible on any of them. The epoll API can be used either as an edge-triggered or a level-triggered interface and scales well to large numbers of watched file descriptors.
Sounds perfect! Let's try using it.
epoll
is a family of Linux system calls that let us work with a set of non-blocking sockets. It isn't terribly ergonomic to use directly, so we'll be using the epoll
crate, a thin wrapper around the C interface.
To start, we'll initialize an epoll instance using the create
function.
// ```toml
// [dependencies]
// epoll = "4.3"
// ```
fn main() {
let epoll = epoll::create(false).unwrap(); // 👈
}
epoll::create
returns a file descriptor that represents the newly created epoll instance. You can think of it as a set of file descriptors that we can add or remove from.
In Linux/Unix, everything is considered a file. Actual files on the file system, TCP sockets, and external devices are all files that you can
read
/write
to. A file descriptor is an integer that represents an open "file" in the system. We'll be working with file descriptors a lot throughout the rest of the article.
The first file descriptor we have to add is the TCP listener. We can modify the epoll set with the epoll::ctl
command. To add to it, we'll use the EPOLL_CTL_ADD
flag.
use epoll::{Event, Events, ControlOptions::*};
use std::os::fd::AsRawFd;
fn main() {
let listener = TcpListener::bind("localhost:3000").unwrap();
listener.set_nonblocking(true).unwrap();
// add the listener to epoll
let event = Event::new(Events::EPOLLIN, listener.as_raw_fd() as _);
epoll::ctl(epoll, EPOLL_CTL_ADD, listener.as_raw_fd(), event).unwrap(); // 👈
}
We pass in the file descriptor of the resource we are registering, the TCP listener, along with an Event
. An event has two parts, the interest flag, and the data field. The interest flag gives us a way to tell epoll which I/O events we are interested in. In the case of the TCP listener, we want to be notified when new connections come in, so we pass the EPOLLIN
flag.
The data field lets us store an ID that will uniquely identify each resource. Remember, a file descriptor is a unique integer for a given file, so we can just use that. You'll see why this is important in the next step.
Now for the main loop. This time, no spinning. Instead we can call epoll::wait
.
epoll_wait(2)
- wait for an I/O event on an epoll file descriptorThe
epoll_wait()
system call waits for events on theepoll(7)
instance referred to by the file descriptor epfd. The buffer pointed to by events is used to return information from the ready list about file descriptors in the interest list that have some events available.A call to
epoll_wait()
will block until either:
- a file descriptor delivers an event;
- the call is interrupted by a signal handler; or
- the timeout expires.
epoll::wait
is the magical part of epoll. It lets us block until any of the events we registered become ready, and tells us which ones did. Right now that's just until a new connection comes in, but soon we'll use this same call to block for read, write, and flush events that we were previously spinning for.
The fact that epoll::wait
is "blocking" might put you off, but remember, it only blocks if there is nothing else to do, where previously we would have been spinning and making pointless syscalls. This idea of blocking on multiple operations simultaneously is known as I/O multiplexing.
epoll::wait
accepts a list of events that it will populate with information about the file descriptors that became ready. It then returns the number of events that were added.
// ...
loop {
let mut events = [Event::new(Events::empty(), 0); 1024];
let timeout = -1; // block forever, until something happens
let num_events = epoll::wait(epoll, timeout, &mut events).unwrap(); // 👈
for event in &events[..num_events] {
// ...
}
}
Each event contains the data field associated with the resource that became ready.
for event in &events[..num_events] {
let fd = event.data as i32;
// ...
}
Remember when we used the file descriptor for the data field? We can use it to check whether the event is for the TCP listener, which means there's an incoming connection ready to accept:
for event in &events[..num_events] {
let fd = event.data as i32;
// is the listener ready?
if fd == listener.as_raw_fd() {
// try accepting a connection
match listener.accept() {
Ok((connection, _)) => {
connection.set_nonblocking(true).unwrap();
// ...
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => panic!("{e}"),
}
}
}
If the call still returns WouldBlock
for whatever reason, we can just move one and wait for the next event.
Now we have to register the new connection in epoll, just like we did the listener.
for event in &events[..num_events] {
let fd = event.data as i32;
// is the listener ready?
if fd == listener.as_raw_fd() {
// try accepting a connection
match listener.accept() {
Ok((connection, _)) => {
connection.set_nonblocking(true).unwrap();
let fd = connection.as_raw_fd();
// register the connection with epoll
let event = Event::new(Events::EPOLLIN | Events::EPOLLOUT, fd as _);
epoll::ctl(epoll, EPOLL_CTL_ADD, fd, event).unwrap(); // 👈
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => panic!("{e}"),
}
}
}
This time we set both EPOLLIN
and EPOLLOUT
, because we are interested in both read and write events, depending on the state of the connection.
Now that we register connections, we'll get events for both the TCP listener and individual connections. We need to store connections and their states in a way that we can look up by file descriptor.
Instead of a list, we can use a HashMap
.
let mut connections = HashMap::new();
loop {
// ...
'next: for event in &events[..num_events] {
let fd = event.data as i32;
// is the listener ready?
if fd == listener.as_raw_fd() {
match listener.accept() {
Ok((connection, _)) => {
// ...
let state = ConnectionState::Read {
request: [0u8; 1024],
read: 0,
};
connections.insert(fd, (connection, state)); // 👈
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => panic!("{e}"),
}
continue 'next;
}
// otherwise, a connection must be ready
let (connection, state) = connections.get_mut(&fd).unwrap(); // 👈
}
}
Once we have the ready connection and it's state, we can try to make progress the exact same way we did last time. Nothing changes about the way we read or write from the stream, the only difference is that we only ever do I/O when epoll tells us to.
Before we had to check every single connection to see if something became ready, but now epoll
handles that for us, so we avoid any useless syscalls.
// ...
// epoll told us this connection is ready
let (connection, state) = connections.get_mut(&fd).unwrap();
if let ConnectionState::Read { request, read } = state {
// connection.read...
*state = ConnectionState::Write { response, written };
}
if let ConnectionState::Write { response, written } = state {
// connection.write...
*state = ConnectionState::Flush;
}
if let ConnectionState::Flush = state {
// connection.flush...
}
Once we've finished reading, writing, and flushing the response, we remove the connection from our map and drop it, which automatically unregisters it from epoll.
for fd in completed {
let (connection, _state) = connections.remove(&fd).unwrap();
// unregister from epoll
drop(connection);
}
And that's it! Here's the new high-level flow of our server:
// create epoll
let epoll = epoll::create(false).unwrap();
// bind the listener
let listener = /* ... */.
// add the listener to epoll
let event = Event::new(Events::EPOLLIN, listener.as_raw_fd() as _);
epoll::ctl(epoll, EPOLL_CTL_ADD, listener.as_raw_fd(), event).unwrap();
let mut connections = HashMap::new();
loop {
let mut events = [Event::new(Events::empty(), 0); 1024];
// block until epoll wakes us up
let num_events = epoll::wait(epoll, 0, &mut events).unwrap();
let mut completed = Vec::new();
'next: for event in &events[..num_events] {
let fd = event.data as i32;
// is the listener ready?
if fd == listener.as_raw_fd() {
match listener.accept() {
Ok((connection, _)) => {
// ...
// add the connection to epoll
let event = Event::new(Events::EPOLLIN | Events::EPOLLOUT, fd as _);
epoll::ctl(epoll, EPOLL_CTL_ADD, fd, event).unwrap();
// keep track of connection state
let state = ConnectionState::Read {
request: [0u8; 1024],
read: 0,
};
connections.insert(fd, (connection, state));
}
continue 'next;
}
// otherwise, a connection is ready
let (connection, state) = connections.get_mut(&fd).unwrap();
// try to drive it forward based on its state
if let ConnectionState::Read { request, read } = state {
// connection.read...
*state = ConnectionState::Write {
response: response.as_bytes(),
written: 0,
};
}
if let ConnectionState::Write { response, written } = state {
// connection.write...
*state = ConnectionState::Flush;
}
if let ConnectionState::Flush = state {
// connection.flush...
}
}
for fd in completed {
let (connection, _state) = connections.remove(&fd).unwrap();
// unregister from epoll
drop(connection);
}
}
And...
$ curl localhost:3000
# => Hello world!
It works!
Futures
Alright, our server can now process multiple requests concurrently on a single thread. And thanks to epoll, it's pretty efficient at doing so. But there's still a problem.
We got so caught up in gaining control over the execution of our tasks, and then scheduling them efficiently ourselves, that in the process the complexity of our code has increased dramatically.
What went from a simple, sequential accept loop has become a massive event loop managing multiple state machines.
And it's not pretty.
Making our original server multi-threaded was as simple as adding a single line of code in thread::spawn
. If you think about it, our server is still a set of concurrent tasks, we just manage them all messily in a giant loop.
This doesn't seem very scalable. The more features we add to our program, the more complex the loop becomes, because everything is so tightly coupled together.
What if we could write an abstraction like thread::spawn
that let us write our tasks as individual units, and handle the scheduling and event handling for all tasks in a single place, regaining some of that sequential control flow?
This idea is generally referred to as asynchronous programming.
Let's take a look at the signature of thread::spawn
.
pub fn spawn<F, T>(f: F) -> JoinHandle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static;
thread::spawn
takes a closure. Our version of thread::spawn
however, could not take a closure, because we aren't the operating system and can't arbitrarily preempt code at will. We need to somehow represent a non-blocking, resumable task.
// fn spawn<T: Task>(task: T);
trait Task {}
Handling a request is a task. Reading or writing from/to the connection is also a task. A task is really just a piece of code that needs to be run, representing a value that will resolve sometime in the future.
Future
, that's a nice name isn't it.
trait Future {
type Output;
fn run(self) -> Self::Output;
}
Hmm.. that signature doesn't really work. Having run
return the output directly means it must be blocking, which is what we're trying so hard to avoid. We instead want a way to attempt to drive the future forward without blocking, like we've been doing with all our state machines in the event loop.
What we're really doing when we try to run a future is asking it if the value is ready yet, polling it, and giving it a chance to make progress.
trait Future {
type Output;
fn poll(self) -> Option<Self::Output>;
}
That looks more like it.
Except wait, poll
can't take self
if we want to call it multiple times, it should probably take a reference. A mutable reference, if we want to mutate the internal state of the task as it makes progress, like ConnectionState
.
trait Future {
type Output;
fn poll(&mut self) -> Option<Self::Output>;
}
Alright, imagine a scheduler that runs these new futures.
impl Scheduler {
fn run(&self) {
loop {
for future in &self.tasks {
future.poll();
}
}
}
}
That doesn't look right.
After initiating the future, the scheduler should only try to call poll
when the given future is able to make progress, like when epoll returns an event. But how do we know when that happens?
If the future represents an I/O operation, we know it's able to make progress when epoll tells us it is. The problem is the scheduler won't know which epoll event corresponds to which future, because the future handles everything internally in poll
.
What we need is for the scheduler to pass each future an ID, so that the future can register any I/O resources with epoll using the same ID, instead of their file descriptors. That way the scheduler has a way of mapping epoll events to runnable futures.
impl Scheduler {
fn spawn<T>(&self, mut future: T) {
let id = rand();
// poll the future once to get it started, passing in it's ID
future.poll(event.id);
// store the future
self.tasks.insert(id, future);
}
fn run(self) {
// ...
for event in epoll_events {
// poll the future associated with this event
let future = self.tasks.get(&event.id).unwrap();
future.poll(event.id);
}
}
}
You know, it would be nice if there was a more generic way to tell the scheduler about progress than tying every future to epoll. We might have different types of futures that make progress in other ways, like a timer running on a background thread, or a channel that needs to notify tasks that a message is available.
What if we gave the futures themselves more control? Instead of just an ID, what if we give every future a way to wake itself up, notifying the scheduler that it's ready to make progress?
A simple callback should do the trick.
#[derive(Clone)]
struct Waker(Arc<dyn Fn() + Send + Sync>);
impl Waker {
fn wake(&self) {
(self.0)()
}
}
trait Future {
type Output;
fn poll(&mut self, waker: Waker) -> Option<Self::Output>;
}
The scheduler can provide each future a callback, that when called, updates the scheduler's state for that future, marking it as ready. That way our scheduler is completely disconnected from epoll, or any other individual notification system.
Waker
is thread-safe, allowing us to use background threads to wake futures. Right now all of our tasks are connected to epoll anyways, but this will come in handy later.
A Reactor
Consider a future that reads from a TCP connection. It receives a Waker
that needs to be called when epoll returns the relevant EPOLLIN
event, but the future won't be running when that happens, it will be idle in the scheduler's queue. Obviously, the future can't wake itself up, someone else has to.
All I/O futures need a way to give their wakers to epoll. In fact, they need more than that, they need some sort of background service that drives epoll, so we can register wakers with it.
This service is commonly known as a reactor.
The reactor is a simple object holding the epoll descriptor and a map of tasks keyed by file descriptor, just like we had before. The difference is that instead of the map holding the TCP connections themselves, it holds the wakers.
thread_local! {
static REACTOR: Reactor = Reactor::new();
}
struct Reactor {
epoll: RawFd,
tasks: RefCell<HashMap<RawFd, Waker>>,
}
impl Reactor {
pub fn new() -> Reactor {
Reactor {
epoll: epoll::create(false).unwrap(),
tasks: RefCell::new(HashMap::new()),
}
}
}
To keep things simple, the reactor is a thread-local object, mutated through a RefCell
. This is important because the reactor will be modified and accessed by different tasks throughout the program.
The reactor needs to support a couple simple operations.
Adding a task:
impl Reactor {
// Add a file descriptor with read and write interest.
//
// `waker` will be called when the descriptor becomes ready.
pub fn add(&self, fd: RawFd, waker: Waker) {
let event = epoll::Event::new(Events::EPOLLIN | Events::EPOLLOUT, fd as u64);
epoll::ctl(self.epoll, EPOLL_CTL_ADD, fd, event).unwrap();
self.tasks.borrow_mut().insert(fd, waker);
}
}
Removing a task:
impl Reactor {
// Remove the given descriptor from epoll.
//
// It will no longer receive any notifications.
pub fn remove(&self, fd: RawFd) {
self.tasks.borrow_mut().remove(&fd);
}
}
And driving epoll.
We'll be running the reactor in a loop, just like we were running epoll in a loop before. It works exactly the same way, except all the reactor has to do is wake the associated future for every event. Remember, this will trigger the scheduler to run the future later, and continue the cycle.
impl Reactor {
// Drive tasks forward, blocking forever until an event arrives.
pub fn wait(&self) {
let mut events = [Event::new(Events::empty(), 0); 1024];
let timeout = -1; // forever
let num_events = epoll::wait(self.epoll, timeout, &mut events).unwrap();
for event in &events[..num_events] {
let fd = event.data as i32;
// wake the task
if let Some(waker) = self.tasks.borrow().get(&fd) {
waker.wake();
}
}
}
}
Great, now we have a simple reactor interface.
But all of this is still a little abstract. What does it really mean to call wake
?
Scheduling Tasks
We have a reactor, now we need a scheduler to run our tasks.
One thing to keep in mind is that the scheduler must be global and thread-safe because wakers are Send
, meaning wake
may be called concurrently from other threads.
static SCHEDULER: Scheduler = Scheduler { /* ... */ };
#[derive(Default)]
struct Scheduler {
// ...
}
We want to be able to spawn tasks onto our scheduler, just like we could spawn threads. For now, we'll only handle spawning tasks that don't return anything, to avoid having to implement a version of JoinHandle
.
To start, we'll probably need some sort of list of tasks to run, guarded by a Mutex
to be thread-safe.
struct Scheduler {
tasks: Mutex<Vec<Box<dyn Future + Send>>>
}
impl Scheduler {
pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static) {
self.tasks.lock().unwrap().push(Box::new(task));
}
pub fn run(&self) {
for task in tasks.lock().unwrap().borrow_mut().iter_mut() {
// ...
}
}
}
Remember, futures are only polled when they are able to make progress. They should always be able to make progress at the start, but after that we don't touch them until someone calls wake
.
There are a couple of ways we could go about this. We could just store a HashMap
of tasks with a status flag that indicates whether or not the task was woken, but that means we would have to iterate through the entire map to find out which tasks are runnable. While this isn't incredibly expensive, there is a better way.
Instead of storing every spawned task in the map, we'll only store runnable tasks in a queue.
use std::collections::VecDeque;
type SharedTask = Arc<Mutex<dyn Future<Output = ()> + Send>>;
#[derive(Default)]
struct Scheduler {
runnable: Mutex<VecDeque<SharedTask>>,
}
The types will make sense soon.
When a task is spawned, it's pushed onto the back of the queue:
impl Scheduler {
pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static) {
self.runnable.lock().unwrap().push_back(Arc::new(Mutex::new(task)));
}
}
The scheduler pops tasks off the queue one by one and calls poll
:
impl Scheduler {
fn run(&self) {
loop {
// pop a runnable task off the queue
let task = self.runnable.lock().unwrap().pop_front();
if let Some(task) = task {
// and poll it
task.try_lock().unwrap().poll(waker);
}
}
}
Notice that we don't even really need a
Mutex
around the task because it's only going to be accessed by the main thread, but removing it would meanunsafe
. We'll settle withtry_lock().unwrap()
for now.
Now for the important bit, the waker. The beautiful part of our run queue is that when a task is woken, it's simply pushed back onto the queue.
impl Scheduler {
fn run(&self) {
loop {
// pop a runnable task off the queue
let task = self.runnable.lock().unwrap().pop_front();
if let Some(task) = task {
let t2 = task.clone();
// create a waker that pushes the task back on
let wake = Arc::new(move || {
SCHEDULER.runnable.lock().unwrap().push_back(t2.clone());
});
// poll the task
task.try_lock().unwrap().poll(Waker(wake));
}
}
}
}
This is why the task needed to be reference counted — it's not owned by the scheduler, it's referenced by the queue, as well as wherever the waker is being stored. In fact the same task might be on the queue multiple times at once, and the waker might be cloned all over the place.
Once we've dealt with all runnable tasks, we need to block on the reactor until another task becomes ready 5. Once a task becomes ready, the reactor will call wake
and push the future back onto our queue for us to run it again, continuing the cycle.
pub fn run(&self) {
loop {
loop {
// pop a runnable task off the queue
let Some(task) = self.runnable.lock().unwrap().pop_front() else { break };
let t2 = task.clone();
// create a waker that pushes the task back on
let wake = Arc::new(move || {
SCHEDULER.runnable.lock().unwrap().push_back(t2.clone());
});
// poll the task
task.lock().unwrap().poll(Waker(wake));
}
// if there are no runnable tasks, block on epoll until something becomes ready
REACTOR.with(|reactor| reactor.wait()); // 👈
}
}
Perfect.
...ignoring the Arc<Mutex<T>>
clutter.
Alright! Together, the scheduler and reactor form a runtime for our futures. The scheduler keeps tracks of which tasks are runnable and polls them, and the reactor marks tasks as runnable when epoll tells us something they are interested in becomes ready.
trait Future {
type Output;
fn poll(&mut self, waker: Waker) -> Option<Self::Output>;
}
static SCHEDULER: Scheduler = Scheduler { /* ... */ };
// The scheduler.
#[derive(Default)]
struct Scheduler {
runnable: Mutex<VecDeque<SharedTask>>,
}
type SharedTask = Arc<Mutex<dyn Future<Output = ()> + Send>>;
impl Scheduler {
pub fn spawn(&self, task: impl Future<Output = ()> + Send + 'static);
pub fn run(&self);
}
thread_local! {
static REACTOR: Reactor = Reactor::new();
}
// The reactor.
struct Reactor {
epoll: RawFd,
tasks: RefCell<HashMap<RawFd, Waker>>,
}
impl Reactor {
pub fn new() -> Reactor;
pub fn add(&self, fd: RawFd, waker: Waker);
pub fn remove(&self, fd: RawFd);
pub fn wait(&self);
}
We've written the runtime, now let's try to use it!
An Async Server
It's time to actually write the tasks that our scheduler is going to run. Like before, we'll use enums as state machines to manage the different states of our program. The difference is that this time, each task will manage it's own state independent from other tasks, instead of having the entire program revolve around a messy event loop.
To start everything off, we need to write the main task. This task will be pushed on and off the scheduler's run queue for the entirety of our program.
fn main() {
SCHEDULER.spawn(Main::Start);
SCHEDULER.run();
}
enum Main {
Start,
}
impl Future for Main {
type Output = ();
fn poll(&mut self, waker: Waker) -> Option<()> {
// ...
}
}
Our task starts off just like before, creating the TCP listener and setting it to non-blocking mode.
// impl Future for Main {
fn poll(&mut self, waker: Waker) -> Option<()> {
if let Main::Start = self {
let listener = TcpListener::bind("localhost:3000").unwrap();
listener.set_nonblocking(true).unwrap();
}
None
}
Now we need to register the listener with epoll. We can do that using our new Reactor
.
// impl Future for Main {
fn poll(&mut self, waker: Waker) -> Option<()> {
if let Main::Start = self {
// ...
REACTOR.with(|reactor| {
reactor.add(listener.as_raw_fd(), waker);
});
}
}
Notice how we give the reactor the waker provided to us by the scheduler. When a connection comes, epoll will return an event and the Reactor
will wake the task, causing the scheduler to push our task back onto the queue and poll
us again. The waker keeps everything connected.
We now need a second state for the next time we're run, Accept
. The main task will stay in the Accept
state for the rest of the program, attempting to accept new connections.
enum Main {
Start,
Accept { listener: TcpListener }, // 👈
}
fn poll(&mut self, waker: Waker) -> Option<()> {
if let Main::Start = self {
// ...
*self = Main::Accept { listener };
}
if let Main::Accept { listener } = self {
match listener.accept() {
Ok((connection, _)) => {
// ...
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
return None;
}
Err(e) => panic!("{e}"),
}
}
None
}
If the listener is not ready, we can simply return None
. Remember, this tells the scheduler the future is not yet ready, and it will be rescheduled once the reactor wakes us.
If we do accept a new connection, we need to again set it to non-blocking mode.
fn poll(&mut self, waker: Waker) -> Option<()> {
if let Main::Start = self {
// ...
}
if let Main::Accept { listener } = self {
match listener.accept() {
Ok((connection, _)) => {
connection.set_nonblocking(true).unwrap(); // 👈
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => return None,
Err(e) => panic!("{e}"),
}
}
None
}
And now we need to spawn a new task to handle the request.
fn poll(&mut self, waker: Waker) -> Option<()> {
if let Main::Start = self {
// ...
}
if let Main::Accept { listener } = self {
match listener.accept() {
Ok((connection, _)) => {
connection.set_nonblocking(true).unwrap();
SCHEDULER.spawn(Handler { // 👈
connection,
state: HandlerState::Start,
});
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => return None,
Err(e) => panic!("{e}"),
}
}
}
The handler task looks similar to before, but now it manages the connection itself along with its current state, which is identical to ConnectionState
from earlier.
struct Handler {
connection: TcpStream,
state: HandlerState,
}
enum HandlerState {
Start,
Read {
request: [u8; 1024],
read: usize,
},
Write {
response: &'static [u8],
written: usize,
},
Flush,
}
The handler task starts by registering its connection with the reactor to be notified when the connection is ready to read/write to. Again, it passes the waker so that the scheduler knows when to run it again.
impl Future for Handler {
type Output = ();
fn poll(&mut self, waker: Waker) -> Option<Self::Output> {
if let HandlerState::Start = self.state {
// start by registering our connection for notifications
REACTOR.with(|reactor| {
reactor.add(self.connection.as_raw_fd(), waker);
});
self.state = HandlerState::Read {
request: [0u8; 1024],
read: 0,
};
}
// ...
}
}
The Read
, Write
, and Flush
states work exactly the same as before, but now when we encounter WouldBlock
, we can simply return None
, knowing that we'll be run again once our future is woken.
// impl Future for Handler {
fn poll(&mut self, waker: Waker) -> Option<Self::Output> {
if let HandlerState::Start = self.state {
// ...
}
// read the request
if let HandlerState::Read { request, read } = &mut self.state {
loop {
match self.connection.read(&mut request[*read..]) {
Ok(0) => {
println!("client disconnected unexpectedly");
return Some(());
}
Ok(n) => *read += n,
Err(e) if e.kind() == io::ErrorKind::WouldBlock => return None, // 👈
Err(e) => panic!("{e}"),
}
// did we reach the end of the request?
let read = *read;
if read >= 4 && &request[read - 4..read] == b"\r\n\r\n" {
break;
}
}
// we're done, print the request
let request = String::from_utf8_lossy(&request[..*read]);
println!("{}", request);
// and move into the write state
let response = /* ... */;
self.state = HandlerState::Write {
response: response.as_bytes(),
written: 0,
};
}
// write the response
if let HandlerState::Write { response, written } = &mut self.state {
// self.connection.write...
// successfully wrote the response, try flushing next
self.state = HandlerState::Flush;
}
// flush the response
if let HandlerState::Flush = self.state {
match self.connection.flush() {
Ok(_) => {}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => return None, // 👈
Err(e) => panic!("{e}"),
}
}
}
Notice how much nicer things are when tasks are independent, encapsulated objects?
At the end of the task's lifecycle, it removes its connection from the reactor and returns Some
. It will never be run again after that point.
fn poll(&mut self, waker: Waker) -> Option<Self::Output> {
// ...
REACTOR.with(|reactor| {
reactor.remove(self.connection.as_raw_fd());
});
Some(())
}
Perfect! Our new server looks a lot nicer. Individual tasks are completely independent of each other, and we can spawn new tasks just like threads.
fn main() {
SCHEDULER.spawn(Main::Start);
SCHEDULER.run();
}
// main task: accept loop
enum Main {
Start,
Accept { listener: TcpListener },
}
impl Future for Main {
type Output = ();
fn poll(&mut self, waker: Waker) -> Option<()> {
if let Main::Start = self {
// ...
REACTOR.with(|reactor| {
reactor.add(listener.as_raw_fd(), waker);
});
*self = Main::Accept { listener };
}
if let Main::Accept { listener } = self {
match listener.accept() {
Ok((connection, _)) => {
// ...
SCHEDULER.spawn(Handler {
connection,
state: HandlerState::Start,
});
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => return None,
Err(e) => panic!("{e}"),
}
}
None
}
}
// handler task: handles every connection
struct Handler {
connection: TcpStream,
state: HandlerState,
}
enum HandlerState {
Start,
Read {
request: [u8; 1024],
read: usize,
},
Write {
response: &'static [u8],
written: usize,
},
Flush,
}
impl Future for Handler {
type Output = ();
fn poll(&mut self, waker: Waker) -> Option<Self::Output> {
if let HandlerState::Start = self.state {
REACTOR.with(|reactor| {
reactor.add(self.connection.as_raw_fd(), waker);
});
self.state = HandlerState::Read { /* .. */ };
}
if let HandlerState::Read { request, read } = &mut self.state {
// ...
self.state = HandlerState::Write { /* .. */ };
}
if let HandlerState::Write { response, written } = &mut self.state {
// ...
self.state = HandlerState::Flush;
}
if let HandlerState::Flush = self.state {
// ...
}
REACTOR.with(|reactor| {
reactor.remove(self.connection.as_raw_fd());
});
Some(())
}
}
Andd....
$ curl localhost:3000
# => Hello world!
It works!
A Functional Server
With this new future abstraction, our server is much nicer than before. Futures get to manage their state independently, the scheduler gets to run tasks without worrying about epoll, and tasks can be spawned and woken without worrying about any of the lower level details of the scheduler. It really is a much nicer programming model.
It is nice that tasks are encapsulated, but we still have to write everything in a state-machine like way. Granted, Rust makes this pretty easy to do with enums, but could we do better?
Looking at the two futures we've written, they have a lot in common. Each future has a number of states. At each state, some code is run. If that code completes successfully, we transition into the next state. If it encounters WouldBlock
, we return None
, indicating that the future is not yet ready.
This seems like something we can abstract over.
What we need is a way to create a future from some block of code, and a way to combine two futures, chaining them together.
Given a block of code, we need to be able to construct a future... sound like a job for a closure?
fn future_fn(f: F) -> impl Future
where
F: Fn(),
{
// ...
}
The closure probably also needs to mutate local state.
fn future_fn(f: F) -> impl Future
where
F: FnMut(),
{
// ...
}
And it also needs access to the waker.
fn future_fn(f: F) -> impl Future
where
F: FnMut(Waker),
{
// ...
}
And.. it needs to return a value. An optional value, in case it's not ready yet. In fact, we can just copy the signature of poll
, because that's really what this closure is.
fn poll_fn<F, T>(f: F) -> impl Future<Output = T>
where
F: FnMut(Waker) -> Option<T>,
{
// ...
}
Implementing poll_fn
doesn't seem too hard, we just need a wrapper struct that implements Future
and delegates poll
to the closure.
fn poll_fn<F, T>(f: F) -> impl Future<Output = T>
where
F: FnMut(Waker) -> Option<T>,
{
struct FromFn<F>(F);
impl<F, T> Future for FromFn<F>
where
F: FnMut(Waker) -> Option<T>,
{
type Output = T;
fn poll(&mut self, waker: Waker) -> Option<Self::Output> {
(self.0)(waker)
}
}
FromFn(f)
}
Alright. Let's try reworking the main task to use the new poll_fn
helper. We can easily stick the code of the Main::Start
state into a closure.
fn main() {
SCHEDULER.spawn(listen());
SCHEDULER.run();
}
fn listen() -> impl Future<Output = ()> {
let start = poll_fn(|waker| {
let listener = TcpListener::bind("localhost:3000").unwrap();
listener.set_nonblocking(true).unwrap();
REACTOR.with(|reactor| {
reactor.add(listener.as_raw_fd(), waker);
});
Some(listener)
});
// ...
}
Remember, Main::Start
never waits on any I/O, so it's immediately ready with the listener.
We can also use poll_fn
to write the Main::Accept
future.
fn listen() -> impl Future<Output = ()> {
let start = poll_fn(|waker| {
// ...
Some(listener)
});
let accept = poll_fn(|_| match listener.accept() {
Ok((connection, _)) => {
connection.set_nonblocking(true).unwrap();
SCHEDULER.spawn(Handler {
connection,
state: HandlerState::Start,
});
None
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => None,
Err(e) => panic!("{e}"),
});
}
On the other hand, accept
always returns None
because we want it to be called every time a new connection comes in. It runs for the entirety of our program.
We have our two task states, now we need a way to chain them together.
fn chain<F1, F2>(future1: F1, future2: F2) -> impl Future<Output = F2::Output>
where
F1: Future,
F2: Future
{
// ...
}
Hm, that doesn't really work.
The second future will need to access the output of the first, the TCP listener.
Instead of chaining the second future directly, we have to chain a closure over the first future's output. That way the closure can use the output of the first future to construct the second.
fn chain<T1, F, T2>(future1: T1, chain: F) -> impl Future<Output = T2::Output>
where
T1: Future,
F: FnOnce(T1::Output) -> T2,
T2: Future
{
// ...
}
That seems better.
We might as well be fancy and have chain
be a method on the Future
trait. That way we can call .chain
as a postfix method on any future.
trait Future {
// ...
fn chain<F, T>(self, chain: F) -> Chain<Self, F, T>
where
F: FnOnce(Self::Output) -> T,
T: Future,
Self: Sized,
{
// ...
}
}
enum Chain<T1, F, T2> {
// ...
}
That looks right, let's try implementing it!
The Chain
future is a generalization of our state machines, so it itself is a mini state machine. It starts off by polling the first future, holding onto the transition closure for when it finishes.
enum Chain<T1, F, T2> {
First { future1: T1, transition: F },
}
impl<T1, F, T2> Future for Chain<T1, F, T2>
where
T1: Future,
F: FnOnce(T1::Output) -> T2,
T2: Future,
{
type Output = T2::Output;
fn poll(&mut self, waker: Waker) -> Option<Self::Output> {
if let Chain::First { future1, transition } = self {
// poll the first future
match future1.poll(waker) {
// ...
}
}
}
}
Once the first future is finished, it constructs the second future using the transition
closure, and starts polling it:
enum Chain<T1, F, T2> {
First { future1: T1, transition: F },
Second { future2: T2 },
}
impl<T1, F, T2> Future for Chain<T1, F, T2>
where
T1: Future,
F: FnOnce(T1::Output) -> T2,
T2: Future,
{
type Output = T2::Output;
fn poll(&mut self, waker: Waker) -> Option<Self::Output> {
if let Chain::First { future1, transition } = self {
// poll the first future
match future1.poll(waker.clone()) {
Some(value) => {
// first future is done, transition into the second
let future2 = (transition)(value); // 👈
*self = Chain::Second { future2 };
}
// first future is not ready, return
None => return None,
}
}
if let Chain::Second { future2 } = self {
// first future is already done, poll the second
return future2.poll(waker); // 👈
}
None
}
}
Notice how the same waker
is used to poll both futures. This means that notifications for both futures will be propagated to the Chain
parent future, depending on its state.
Hm... that doesn't actually seem to work:
error[E0507]: cannot move out of `*transition` which is behind a mutable reference
--> src/main.rs:182:33
|
182 | let future2 = (transition)(value);
| ^^^^^^^^^^^^ move occurs because `*transition` has type `F`,
which does not implement the `Copy` trait
Oh right, transition
is an FnOnce
closure, meaning it is consumed the first time it is called. We only ever call it once based on our state machine, but the compiler doesn't know that.
We can wrap it in an Option
and use take
to call it, replacing it with None
and allowing us to get an owned value. This is a common pattern when working with state machines.
enum Chain<T1, F, T2> {
First { future1: T1, transition: Option<F> }, // 👈
Second { future2: T2 },
}
// ...
fn poll(&mut self, waker: Waker) -> Option<Self::Output> {
if let Chain::First { future1, transition } = self {
match future1.poll(waker.clone()) {
Some(value) => {
let future2 = (transition.take().unwrap())(value); // 👈
*self = Chain::future2 { future2 };
}
None => return None,
}
}
// ...
}
Perfect. Now the chain
method can simply construct our Chain
future in it's starting state.
trait Future {
// ...
fn chain<F, T>(self, transition: F) -> Chain<Self, F, T>
where
F: FnOnce(Self::Output) -> T,
T: Future,
Self: Sized,
{
Chain::First {
future1: self,
transition: Some(transition),
}
}
}
Alright. Where were we... oh right, the main future!
fn listen() -> impl Future<Output = ()> {
let start = poll_fn(|waker| {
// ...
Some(listener)
});
let accept = poll_fn(|_| match listener.accept() {
// ...
});
}
We can combine the two futures using our new chain
method:
fn listen() -> impl Future<Output = ()> {
poll_fn(|waker| {
// ...
Some(listener)
})
.chain(|listener| { // 👈
poll_fn(move |_| match listener.accept() {
// ...
})
})
}
Huh, that seems really nice! Gone is our manual state machine, our listen method can now be expressed in terms of simple closures!
fn main() {
SCHEDULER.spawn(listen());
SCHEDULER.run();
}
fn listen() -> impl Future<Output = ()> {
poll_fn(|waker| {
let listener = TcpListener::bind("localhost:3000").unwrap();
// ...
REACTOR.with(|reactor| {
reactor.add(listener.as_raw_fd(), waker);
});
Some(listener)
})
.chain(|listener| {
poll_fn(move |_| match listener.accept() {
Ok((connection, _)) => {
// ...
SCHEDULER.spawn(Handler {
connection,
state: HandlerState::Start,
});
None
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => None,
Err(e) => panic!("{e}"),
})
})
}
It's too good to be true!
We can convert the connection handler to a closure-based future just like we did with the main one. To start we'll separate it out into a function that returns a Future
.
fn listen() -> impl Future<Output = ()> {
poll_fn(|waker| {
// ...
})
.chain(|listener| {
poll_fn(move |_| match listener.accept() {
Ok((connection, _)) => {
// ...
SCHEDULER.spawn(handle(connection)); // 👈
None
}
// ...
})
})
}
fn handle(connection: TcpStream) -> impl Future<Output = ()> {
// ...
}
The first state, HandlerState::Start
, is a simple poll_fn
closure that registers the connection with the reactor and immediately returns.
fn handle(connection: TcpStream) -> impl Future<Output = ()> {
poll_fn(move |waker| {
REACTOR.with(|reactor| {
reactor.add(connection.as_raw_fd(), waker);
});
Some(())
})
}
The second state, HandlerState::Read
, can be chained on quite easily. It initializes its local request state on the stack and moves it into the future, allowing the future to own its state.
fn handle(mut connection: TcpStream) -> impl Future<Output = ()> {
poll_fn(move |waker| {
// ...
})
.chain(move |_| {
let mut read = 0;
let mut request = [0u8; 1024]; // 👈
poll_fn(move |_| {
loop {
// try reading from the stream
match connection.read(&mut request[read..]) {
Ok(0) => {
println!("client disconnected unexpectedly");
return Some(());
}
Ok(n) => read += n,
Err(e) if e.kind() == io::ErrorKind::WouldBlock => return None,
Err(e) => panic!("{e}"),
}
// did we reach the end of the request?
let read = read;
if read >= 4 && &request[read - 4..read] == b"\r\n\r\n" {
break;
}
}
// we're done, print the request
let request = String::from_utf8_lossy(&request[..read]);
println!("{request}");
Some(())
})
})
}
HandlerState::Write
and HandlerState::Flush
can be chained on the same way.
fn handle(connection: TcpStream) -> impl Future<Output = ()> {
poll_fn(move |waker| {
// REACTOR.register...
})
.chain(move |_| {
// connection.read...
})
.chain(move |_| {
let response = /* ... */;
let mut written = 0;
poll_fn(move |_| {
loop {
match connection.write(response[written..].as_bytes()) {
Ok(0) => {
println!("client disconnected unexpectedly");
return Some(());
}
Ok(n) => written += n,
Err(e) if e.kind() == io::ErrorKind::WouldBlock => return None,
Err(e) => panic!("{e}"),
}
// did we write the whole response yet?
if written == response.len() {
break;
}
}
Some(())
})
})
.chain(move |_| {
poll_fn(move |_| {
match connection.flush() {
Ok(_) => {}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
return None;
}
Err(e) => panic!("{e}"),
};
REACTOR.with(|reactor| reactor.remove(connection.as_raw_fd());
Some(())
})
})
}
It's perfect.
Uhhhh...
error[E0382]: use of moved value: `connection`
--> src/main.rs:59:12
|
51 | fn handle(mut connection: TcpStream) -> impl Future<Output = ()> {
| -------------- move occurs because `connection` has type `TcpStream`, which does not implement the `Copy` trait
52 | poll_fn(move |waker| {
| ------------- value moved into closure here
53 | REACTOR.with(|reactor| {
54 | reactor.add(connection.as_raw_fd(), waker);
| ---------- variable moved due to use in closure
...
59 | .chain(move |_| {
| ^^^^^^^^ value used here after move
...
66 | match connection.read(&mut request[read..]) {
| ---------- use occurs due to use in closure
error[E0382]: use of moved value: `connection`
// ...
Hmm....
All of our futures use move
closures, meaning they take ownership of the connection. There can only be one owner of the connection though. Guess they shouldn't be move
closures?
error[E0373]: closure may outlive the current function, but it borrows `connection`, which is owned by the current function
--> src/main.rs:52:13
|
52 | poll_fn(|waker| {
| ^^^^^^^^ may outlive borrowed value `connection`
53 | REACTOR.with(|reactor| {
54 | reactor.add(connection.as_raw_fd(), waker);
| ---------- `connection` is borrowed here
|
note: closure is returned here
--> src/main.rs:52:5
|
52 | / poll_fn(|waker| {
53 | | REACTOR.with(|reactor| {
54 | | reactor.add(connection.as_raw_fd(), waker);
55 | | });
... |
128 | | })
129 | | })
| |______^
help: to force the closure to take ownership of `connection` (and any other referenced variables), use the `move` keyword
|
52 | poll_fn(move |waker| {
| ++++
That doesn't seem to work either. The connection
needs to live somewhere. What if we only move it into the first future, and have the rest of the futures borrow it?
error[E0382]: use of moved value: `connection`
--> src/main.rs:59:12
|
51 | fn handle(mut connection: TcpStream) -> impl Future<Output = ()> {
| -------------- move occurs because `connection` has type `TcpStream`, which does not implement the `Copy` trait
52 | poll_fn(move |waker| {
| ------------- value moved into closure here
53 | REACTOR.with(|reactor| {
54 | reactor.add(connection.as_raw_fd(), waker);
| ---------- variable moved due to use in closure
...
59 | .chain(|_| {
| ^^^ value used here after move
...
66 | match connection.read(&mut request[read..]) {
| ---------- use occurs due to use in closure
Nope, that doesn't work either.
Under the hood, our chained futures look something like this. The first future owns the connection, and the rest borrow from it.
enum Handle {
Start {
connection: TcpStream,
}
Read {
connection: &'??? TcpStream
}
}
Which of course, doesn't make much sense. Once the state transitions into Read
, the connection from Start
is dropped, and we have nothing to reference.
So how did this work when we were writing futures manually?
struct Handler {
connection: TcpStream,
state: HandlerState,
}
enum HandlerState { /* ... */ }
Right, the connection lived in the outer struct. Maybe we can write another one of those future helpers that allows us to reference some data stored in an outer future?
Something like this:
struct WithData<D, F> {
data: D,
future: F,
}
Seems simple enough. We should be able to construct the future such that it can capture a reference to the data. We can use a closure, just like we did with chain
:
impl<D, F> WithData<D, F> {
pub fn new(data: D, construct: impl Fn(&D) -> F) -> WithData<D, F> {
let future = construct(&data);
WithData { data, future }
}
}
WithData
can implement Future
by simply delegating to the inner future:
impl<D, F> Future for WithData<D, F>
where
F: Future,
{
type Output = F::Output;
fn poll(&mut self, waker: Waker) -> Option<Self::Output> {
self.future.poll(waker)
}
}
Now we should be able to wrap our future in WithData
, giving connection
a place to live even after it is returned.. and everything should work!
fn handle(connection: TcpStream) -> impl Future<Output = ()> {
WithData::new(connection, |connection| {
from_fn(move |waker| {
// ...
})
.chain(move |_| {
let mut request = [0u8; 1024];
let mut read = 0;
from_fn(move |_| {
// ...
})
})
.chain(move |_| {
let response = /* ... */;
let mut written = 0;
from_fn(move |_| {
// ...
})
})
.chain(move |_| {
from_fn(move |_| {
// ...
})
})
})
}
A little funky, but if it works...
error: lifetime may not live long enough
--> src/main.rs:53:9
|
52 | WithData::new(connection, |connection| {
| ----------- return type of closure `Chain<Chain<Chain<impl Future<Output = ()>, [closure@src/bin/play.rs:60:16: 86:10], impl Future<Output = ()>>, [closure@src/bin/play.rs:87:16: 113:10], impl Future<Output = ()>>, [closure@src/bin/play.rs:114:16: 130:10], impl Future<Output = ()>>` contains a lifetime `'2`
| |
| has type `&'1 TcpStream`
53 | / from_fn(move |waker| {
54 | | REACTOR.with(|reactor| {
55 | | reactor.add(connection.as_raw_fd(), waker);
56 | | });
... |
129 | | })
130 | | })
| |__________^ returning this value requires that `'1` must outlive `'2`
It couldn't be that easy could it.
What a weird error message too.
return type of closure `Chain<Chain<Chain<impl Future<Output = ()>, [closure @src/bin/play.rs:60:16: 86:10], impl Future<Output = ()>>, [closure@src/bin/ play.rs:87:16: 113: 10], impl Future<Output = ()>>, [closure@src/bin/play.rs :114:16: 130:10], impl Future<Output = ()>>` contains a lifetime `'2`
Alright, so, the giant chained future we pass to WithData
contains a reference to the connection.
That's what we want, for the future to borrow the connection, right?
`connection` has type `&'1 TcpStream` ... returning this value requires that `'1` must outlive `'2`
Hmmm, nowhere in our WithData
struct did we actually specify that the future borrows from the data. It seems Rust can't figure out the lifetimes without that. So.. we should probably add a lifetime to WithData
right?
struct WithData<'data, D, F> {
data: D,
future: F,
}
error[E0392]: parameter `'data` is never used
--> src/bin/play.rs:160:17
|
160 | struct WithData<'data, D, F> {
| ^^^^^ unused parameter
|
= help: consider removing `'data`, referring to it in a field, or using a marker such as `PhantomData`
Adding PhantomData
seems like an easy fix.
struct WithData<'data, D, F> {
data: D,
future: F,
_data: PhantomData<&'data D>,
}
The future does reference &'data D
, so that sort of makes sense. Now in the constructor, we should say that the future borrows from the data:
impl<'data, D, F> WithData<'data, D, F>
where
F: Future + 'data, // 👈
{
pub fn new(
data: D,
construct: impl Fn(&'data D) -> F, // 👈
) -> WithData<'data, D, F> {
let future = construct(&data);
WithData {
data,
future,
_data: PhantomData,
}
}
}
And that should work, right? All the lifetimes are written out and make sense:
error[E0597]: `data` does not live long enough
--> src/bin/play.rs:172:30
|
167 | impl<'data, D, F> WithData<'data, D, F>
| ----- lifetime `'data` defined here
...
172 | let future = construct(&data);
| ----------^^^^^-
| | |
| | borrowed value does not live long enough
| argument requires that `data` is borrowed for `'data`
...
178 | }
| - `data` dropped here while still borrowed
error[E0505]: cannot move out of `data` because it is borrowed
--> src/bin/play.rs:174:13
|
167 | impl<'data, D, F> WithData<'data, D, F>
| ----- lifetime `'data` defined here
...
172 | let future = construct(&data);
| ----------------
| | |
| | borrow of `data` occurs here
| argument requires that `data` is borrowed for `'data`
173 | WithData {
174 | data,
| ^^^^ move out of `data` occurs here
Or... not.
Why doesn't this work?
`data` dropped here while still borrowed
Wait a minute, that's the same error message we got when we removed move
from our future closures?! But the data does have a place to live now... doesn't it?
Hmm.. actually, the second error is telling us that moving data
is wrong too:
cannot move out of `data` because it is borrowed
That... actually makes sense. The future we construct borrows the data
that lives on the stack. Once we move it, it's no longer in the same place on the stack. Its address changes, so the future's reference to the data is actually invalidated.
## Before Moving
┌─────────────┐
│0101001010010│
data: │001... │ ◄──────── &future.data
│ │
│ │
└─────────────┘
## After Moving
??? ◄──────── &future.data
┌─────────────┐
│0101001010010│
data: │001... │
│ │
│ │
└─────────────┘
We gave the data a place to live, but we didn't give it a stable place to live. It turns out, this is a well-known problem in Rust. What we're trying to create is called a self-referential struct, and it's not possible to do safely.
Back when our entire future state was in the Handler
struct, there was no self-referencing going on. Everything just worked off the Handler
. But now that we're trying to split our futures up into subtasks, we need a way for them to access the data independently.
So is it not possible?
Well...
We could allocate the data l using Rc
and clone the pointer into each of the futures. That way the futures get a stable pointer to the data on the heap, and it's only deallocated after all the futures complete.
The code is about to get pretty ugly.
fn handle(connection: TcpStream) -> impl Future<Output = ()> {
let connection = Rc::new(connection); // 👈
let read_connection_ref = connection.clone();
let write_connection_ref = connection.clone();
let flush_connection_ref = connection.clone();
poll_fn(move |waker| {
// ...
})
.chain(move |_| {
// ...
poll_fn(move |_| {
let connection = &*read_connection_ref;
loop {
match (&mut connection).read(&mut request) {
// ...
}
}
// ...
})
})
.chain(move |_| {
// ...
poll_fn(move |_| {
let connection = &*write_connection_ref;
// ...
})
})
.chain(move |_| {
poll_fn(move |_| {
let connection = &*flush_connection_ref;
// ...
})
})
}
Oh no..
error[E0277]: `Rc<TcpStream>` cannot be sent between threads safely
--> src/main.rs:90:37
|
90 | SCHEDULER.spawn(handle(connection));
| ----- ^^^^^^^^^^^^^^^^^^ `Rc<TcpStream>` cannot be sent between threads safely
| |
| required by a bound introduced by this call
...
100 | fn handle(mut connection: TcpStream) -> impl Future<Output = ()> {
| ------------------------ within this `impl Future<Output = ()>`
Using Rc
in our handler makes it !Send
. Even though the connection is only used internally within the future and futures are only ever run by the main thread, we need to use an Arc
to satisfy the compiler.
fn handle(connection: TcpStream) -> impl Future<Output = ()> {
let connection = Arc::new(connection); // 👈
// ...
}
A little sad, but at least it compiles.
Our server has no more manual state machines and is looking pretty clean.
A lot cleaner than when we started with epoll manually, even with the messy Arc
business.
fn main() {
SCHEDULER.spawn(listen());
SCHEDULER.run();
}
fn listen() -> impl Future<Output = ()> {
poll_fn(|waker| {
let listener = TcpListener::bind("localhost:3000").unwrap();
// ...
REACTOR.with(|reactor| {
reactor.add(listener.as_raw_fd(), waker);
});
Some(listener)
})
.chain(|listener| {
poll_fn(move |_| match listener.accept() {
Ok((connection, _)) => {
// ...
SCHEDULER.spawn(handle(connection));
None
}
// ...
})
})
}
fn handle(connection: TcpStream) -> impl Future<Output = ()> {
let connection = Arc::new(connection);
let read_connection_ref = connection.clone();
let write_connection_ref = connection.clone();
let flush_connection_ref = connection.clone();
poll_fn(move |waker| {
REACTOR.with(|reactor| {
reactor.add(connection.as_raw_fd(), waker);
});
Some(())
})
.chain(move |_| {
let mut request = [0u8; 1024];
let mut read = 0;
poll_fn(move |_| {
// ...
})
})
.chain(move |_| {
let response = /* ... */;
let mut written = 0;
poll_fn(move |_| {
// ...
})
})
.chain(move |_| {
poll_fn(move |_| {
// ...
REACTOR.with(|reactor| {
reactor.remove(flush_connection_ref.as_raw_fd());
});
Some(())
})
})
}
And of course...
$ curl localhost:3000
# => Hello world!
It works!
A Graceful Server
Whew, that was a lot.
One last thing before we finish. To put our task model to the test, we can finally implement the graceful shutdown mechanism we discussed earlier.
Imagine we wanted to implement graceful shutdown for our server. When someone hits the keys ctrl+c, instead of killing the program abruptly, we should stop accepting new connections and wait for any active requests to complete. Any requests that take more than 30 seconds to handle are killed as the server exits.
There are a couple things we have to do to set this up. Firstly, we have to actually detect the signal. On Linux, ctrl+c triggers the SIGINT
signal, so we can use the signal_hook
crate to wait until the signal is received.
use signal_hook::consts::signal::SIGINT;
use signal_hook::iterator::Signals;
fn ctrl_c() {
let mut signal = Signals::new(&[SIGINT]).unwrap();
let _ctrl_c = signal.forever().next().unwrap();
}
There's a problem though. forever().next()
blocks the thread until the signal is received. Now that our server is async, that means calling ctrl_c()
on the main thread will block the entire program.
Instead, we need to represent the ctrl+c signal as a future that resolves when it is received. Something like this.
fn ctrl_c() -> impl Future<Output = ()> {
poll_fn(move |waker| {
// ...
})
}
So how do we listen for the signal asynchronously?
We could register a signal handler with epoll, but we could also use this as an opportunity to learn about handling blocking tasks in an async program. There will be times when the only way to get what you want is through a blocking API, but you can't simply call it on the main thread. Instead, you can run the blocking work on a separate thread, and notify the main thread when it completes.
fn spawn_blocking(blocking_work: impl FnOnce()) -> impl Future<Output = ()> {
// run the blocking work on a separate thread
std::thread::spawn(move || {
blocking_work();
});
poll_fn(|waker| {
// ???
}))
}
The question is, how do we know when the work is done?
The blocking work is run on a separate thread, outside of the future. It needs access to the waker so it can notify the future when it completes. We only get access to the waker when the future is first polled, so the state needs to start out as None
.
We also need a flag that tells the future that the work has completed, in case the work completes before the future is even polled.
These two pieces of state can be stored inside a Mutex
.
let state: Arc<Mutex<(bool, Option<Waker>)>> = Arc::default();
Once the thread completes the work, it must set the flag to true and call wake
if a waker has been stored. It's fine if a waker hasn't been stored yet, the future will see the flag when it's first polled and return immediately.
fn spawn_blocking(blocking_work: impl FnOnce() + Send + 'static) -> impl Future<Output = ()> {
let state: Arc<Mutex<(bool, Option<Waker>)>> = Arc::default();
let state_handle = state.clone();
// run the blocking work on a separate thread
std::thread::spawn(move || {
// run the work
blocking_work();
// mark the task as done
let (done, waker) = &mut *state_handle.lock().unwrap();
*done = true;
// wake the waker
if let Some(waker) = waker.take() {
waker.wake();
}
});
poll_fn(|waker| {
// ...
}))
}
Now the future needs to access the state and check if the work has completed yet. If not, it stores its waker and returns None
, to be woken later when the work does complete.
fn spawn_blocking(blocking_work: impl FnOnce() + Send + 'static) -> impl Future<Output = ()> {
let state: Arc<Mutex<(bool, Option<Waker>)>> = Arc::default();
let state_handle = state.clone();
// run the blocking work on a separate thread
std::thread::spawn(move || {
// ...
});
poll_fn(move |waker| match &mut *state.lock().unwrap() {
// work is not completed, store our waker and come back later
(false, state) => {
*state = Some(waker);
None
}
// the work is completed
(true, _) => Some(()),
})
}
The future returned by spawn_blocking
serves as an asynchronous version of JoinHandle
. We can wait asynchronously on the main thread while the blocking work is run on a separate thread.
fn ctrl_c() -> impl Future<Output = ()> {
spawn_blocking(|| {
let mut signal = Signals::new(&[SIGINT]).unwrap();
let _ctrl_c = signal.forever().next().unwrap();
})
}
spawn_blocking
is an extremely convenient abstraction for dealing with blocking APIs in an async program.
Alright, we now have a future that waits for the ctrl+c signal!
If you remember from back when our server used blocking I/O, we wondered how to watch for the signal in a way that aborts the connection listener loop immediately after the signal arrives. We realized we needed some way to listen for both incoming connections, and the ctrl+c signal, at the same time.
Because accept
was blocking, it wasn't that simple. But with futures, it's actually possible!
We can implement this as another future wrapper. Given two futures, we should be able to create a wrapper future that selects between either of the future's outputs, depending on which future completed first.
fn select<L, R>(left: L, right: R) -> Select<L, R> {
Select { left, right }
}
struct Select<L, R> {
left: L,
right: R
}
enum Either<L, R> {
Left(L),
Right(R)
}
impl<L, R> Future for Select<L, R> {
type Output = Either<L, R>;
fn poll(&mut self, waker: Waker) -> Self::Output {
// ...
}
}
It turns out that the implementation of the select future is really simple. We just attempt to poll both futures and return when the first one resolves.
impl<L, R> Future for Select<L, R> {
type Output = Either<L::Output, R::Output>;
fn poll(&mut self, waker: Waker) -> Option<Self::Output> {
if let Some(output) = self.left.poll(waker.clone()) {
return Some(Either::Left(output));
}
if let Some(output) = self.right.poll(waker) {
return Some(Either::Right(output));
}
None
}
}
Because we pass the same waker to both futures, any progress in either future will notify us, and we can check if either of them completed.
It really is that simple.
Now back to our main program.
fn listen() -> impl Future<Output = ()> {
poll_fn(|waker| {
let listener = TcpListener::bind("localhost:3000").unwrap();
// ...
})
.chain(|listener| {
poll_fn(move |_| match listener.accept() {
Ok((connection, _)) => {
// ...
SCHEDULER.spawn(handle(connection));
None
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => None,
Err(e) => panic!("{e}"),
})
})
}
We can combine the TCP listener task with the ctrl+c listener using our new select
combinator. This way we can listen for both at the same time:
fn listen() -> impl Future<Output = ()> {
poll_fn(|waker| {
// ...
})
.chain(|listener| {
let listen = poll_fn(move |_| match listener.accept() {
Ok((connection, _)) => {
// ...
SCHEDULER.spawn(handle(connection));
None
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => None,
Err(e) => panic!("{e}"),
});
select(listen, ctrl_c())
})
}
The TCP listener task never resolves — remember, it represents the loop from our synchronous server. ctrl_c()
can resolve though, so we need to chain on another task to handle the shutdown signal.
fn listen() -> impl Future<Output = ()> {
poll_fn(|waker| {
// ...
})
.chain(|listener| {
let listen = /* ... */;
select(listen, ctrl_c())
})
.chain(|_ctrl_c| graceful_shutdown())
}
fn graceful_shutdown() -> impl Future<Output = ()> {
// ...
}
Now we need to implement the rest of our shutdown logic. Once the shutdown signal is received, we wait at most thirty seconds for any active requests to complete before shutting down.
This sounds like another use case for select
! Either thirty seconds elapse, or all active requests complete.
fn graceful_shutdown() -> impl Future<Output = ()> {
let timer = /* ... */;
let request_counter = /* ... */;
select(timer, request_counter).chain(|_| {
poll_fn(|waker| {
// graceful shutdown process complete, now we actually exit
println!("Graceful shutdown complete");
std::process::exit(0)
})
})
}
All we need to do now is create the two futures for our shutdown conditions.
First we need a timer. Of course, we can't simply call thread::sleep
because it's a blocking function. But we could run it through spawn_blocking
, and use the handle to represent our timer future.
Note that there are ways to build async timers around epoll, but that's out of scope for this article.
use std::thread;
use std::time::Duration;
fn graceful_shutdown() -> impl Future<Output = ()> {
let timer = spawn_blocking(|| thread::sleep(Duration::from_secs(30)));
let request_counter = /* ... */;
select(timer, request_counter).chain(|_| {
poll_fn(|waker| {
// graceful shutdown process complete, now we actually exit
println!("Graceful shutdown complete");
std::process::exit(0)
})
})
}
That was simple enough.
Now for the main shutdown condition. For us to know when all active requests are completed, we'll need a counter for active requests.
We can keep the counter local to our listen
future, and increment/decrement it whenever tasks are spawned, or complete.
fn listen() -> impl Future<Output = ()> {
let tasks = Arc::new(Mutex::new(0));
poll_fn(|waker| {
// ...
})
.chain(move |listener| {
let listen = poll_fn(move |_| match listener.accept() {
Ok((connection, _)) => {
// increment the counter
*tasks.lock().unwrap() += 1; // 👈
let handle_connection = handle(connection).chain(|_| {
poll_fn(|_| {
// decrement the counter
*tasks.lock().unwrap() -= 1; // 👈
Some(())
})
});
SCHEDULER.spawn(handle_connection);
None
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => None,
Err(e) => panic!("{e}"),
});
select(listen, ctrl_c())
})
.chain(|_ctrl_c| graceful_shutdown())
}
Notice how the decrement of the counter is chained onto the connection handler, so the decrement actually happens from within the spawned task, after it completes.
A task counter is great, but we need a little more than that. We can't just check for tasks == 0
in a loop, we need the shutdown handler to be notified when the last task completes.
And for that, the connection handler that completes needs access to the shutdown handler's waker.
What we need is similar to the spawn_blocking
solution we created earlier, except instead of a boolean flag, we need a counter. We can wrap up all this state into a small struct.
#[derive(Default)]
struct Counter {
state: Mutex<(usize, Option<Waker>)>,
}
impl Counter {
fn increment(&self) {
let (count, _) = &mut *self.state.lock().unwrap();
*count += 1;
}
fn decrement(&self) {
let (count, waker) = &mut *self.state.lock().unwrap();
*count -= 1;
// we were the last task
if *count == 0 {
// wake the waiting task
if let Some(waker) = waker.take() {
waker.wake();
}
}
}
fn wait_for_zero(self: Arc<Self>) -> impl Future<Output = ()> {
poll_fn(move |waker| {
match &mut *self.state.lock().unwrap() {
// work is completed
(0, _) => Some(()),
// work is not completed, store our waker and come back later
(_, state) => {
*state = Some(waker);
None
}
}
})
}
}
When wait_for_zero
is first called it stores its waker in the counter state before returning. Now the task that calls decrement
and sees that it was the last active task can simply call wake
, notifying the caller of wait_for_zero
.
When the shutdown handler is woken, it will see that the counter is at zero and shut down the program.
Now we can replace our manual counter with the Counter
object.
fn listen() -> impl Future<Output = ()> {
let tasks = Arc::new(Counter::default()); // 👈
let tasks_ref = tasks.clone();
poll_fn(|waker| {
// ...
})
.chain(move |listener| {
let listen = poll_fn(move |_| match listener.accept() {
Ok((connection, _)) => {
connection.set_nonblocking(true).unwrap();
// increment the counter
tasks.increment(); // 👈
let tasks = tasks.clone();
let handle_connection = handle(connection).chain(|_| {
poll_fn(move |_| {
// decrement the counter
tasks.decrement(); // 👈
Some(())
})
});
SCHEDULER.spawn(handle_connection);
None
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => None::<()>,
Err(e) => panic!("{e}"),
});
select(listen, ctrl_c())
})
.chain(|_ctrl_c| graceful_shutdown(tasks_ref)) // 👈
}
And our graceful shutdown handler can use wait_for_zero
to wait until all the active tasks are complete. Once they are, or the timer elapses, the graceful shutdown is met and the program will exit.
fn graceful_shutdown(tasks: Arc<Counter>) -> impl Future<Output = ()> {
poll_fn(|waker| {
let timer = spawn_blocking(|| thread::sleep(Duration::from_secs(30)));
let request_counter = tasks.wait_for_zero(); // 👈
select(timer, request_counter)
}).chain(|_| {
// graceful shutdown process complete, now we actually exit
println!("Graceful shutdown complete");
std::process::exit(0)
})
}
And that's it!
Now if you start the server and hit ctrl+c, it will exit immediately, without blocking for another connection.
$ cargo run
^C
# => Graceful shutdown complete
$ |
Looking Back
Well, that was quite the journey.
Our server is looking pretty good now. From threads, to an epoll event loop, to futures and closure combinators, we've come a long way. There is some manual work that we could abstract over even further, but overall our program is relatively clean.
Compared to our original multithreaded program, our code is still clearly more complex. However, it's also a lot more powerful. Composing futures is trivial, and we were able to express complex control flow that would have been very difficult to do with threads. We can even still call out to blocking functions without interrupting our async runtime.
There must be a price to pay for all this power, right?
Back To Reality
Now that we've thoroughly explored concurrency and async ourselves, let's see how it works in the real world.
The standard library defines a trait like Future
, which looks remarkably similar to the trait we designed.
pub trait Future {
type Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>;
}
However, there are a few noticeable differences.
The first is that instead of a Waker
argument, poll
takes a &mut Context
. It turns out this isn't much of a difference at all, because Context
is simply a wrapper around a Waker
.
impl Context<'_> {
pub fn from_waker(waker: &'a Waker) -> Context<'a> { /* ... */ }
pub fn waker(&self) -> &'a Waker { /* ... */ }
}
And Waker
, along with a few other utility methods, has the familiar wake
method.
impl Waker {
pub fn wake(self) { /* ... */ }
}
Constructing a Waker
is a little more complicated, but it's essentially a manual trait object just like the Arc<dyn Fn()>
we used for our version. All of that happens through the RawWaker
type, which you can check out yourself.
The second difference is that instead of returning an Option
, poll
returns a new type called Poll
.. which is really just a rebranding of Option
.
pub enum Poll<T> {
Ready(T),
Pending,
}
The final difference is a little more complicated.
Pinning
Instead of poll
taking a mutable reference to self
, it takes a pinned mutable reference to self — Pin<&mut Self>
.
What is Pin
, you ask?
#[derive(Copy, Clone)]
pub struct Pin<P> {
pointer: P,
}
Huh. That doesn't seem very useful.
It turns out, what makes Pin
special is how you create one:
impl<P: Deref> Pin<P> {
pub fn new(pointer: P) -> Pin<P> where P::Target: Unpin { /* ... */ }
pub unsafe fn new_unchecked(pointer: P) -> Pin<P> { /* ... */ }
}
impl<P: Deref> Deref for Pin<P> {
type Target = P::Target;
}
impl<P: Deref> DerefMut for Pin<P>
where
P::Target: Unpin
{
type Target = P::Target;
}
So you can only create a Pin<&mut T>
safely if T
is Unpin
... what's Unpin
?
pub auto trait Unpin {}
/// A marker type which does not implement `Unpin`.
pub struct PhantomPinned;
impl !Unpin for PhantomPinned {}
Unpin
seems to be automatically implemented for all types except PhantomPinned
. So creating a Pin
is safe, except for types that contain PhantomPinned
? And Pin
just dereferences to T
normally? All of this seems a little useless.
There is a point to it all though, and it goes back to a problem we ran into earlier. Remember when we tried creating a self-referential struct to hold our task state but it wouldn't work, so we ended up having to allocate our task state with an Arc
? It was a bit unfortunate, and it turns out that you actually can create self-referential structs with a little bit of unsafe code, and avoid that Arc
allocation.
The problem is that you can't just go handing out a self-referential struct in general, because as we realized, moving a self-referential struct breaks its internal references and is unsound.
struct SelfReferential {
counter: u8, // (X)
state: FutureState
}
enum FutureState {
First { counter_ptr: *mut u8 } // self-referentially points to `counter` (X)
// ...
}
let mut future1 = SelfReferential::new();
future1.poll(/* ... */);
let mut moved = future1; // move it
// unsound! `counter_ptr` still point to the old stack location of `counter`
moved.poll(/* ... */);
This is where Pin
comes in. You can only create a Pin<&mut T>
if you guarantee that the T
will stay in a stable location until it is dropped, meaning that any self-references will remain valid.
For most types, Pin
doesn't mean anything, which is why Unpin
exists. Unpin
essentially tells Pin
that a type is not self-referential, so pinning it is completely safe and always valid. Pin
will even hand out mutable references to Unpin
types and let you use mem::swap
or mem::replace
to move them around. Because you can't safely create a self-referential struct, Unpin
is the default and implemented by types automatically.
If you did want to create a self-referential future though, you can use the PhantomPinned
marker struct to make it !Unpin
. Pinning a !Unpin
type requires unsafe
, so because poll
requires Pin<&mut Self>
, it cannot be called safely on a self-referential future.
let mut future = SelfReferential::new();
// SAFETY: we never move `future`
let pinned = unsafe { Pin::new_unchecked(&mut future) };
pinned.poll(/* ... /*);
Notice that you can move around the future all you want before pinning it because the self-references are only created after you first call poll
. Once you do pin it though, you must uphold the Pin
safety contract.
There are a couple safe ways of creating a pin though, even for !Unpin
types.
The first way is with Box::pin
.
let mut future1: Pin<Box<SelfReferential>> = Box::pin(SelfReferential::new());
future1.as_mut().poll(/* ... */);
let mut moved = future1;
moved.as_mut().poll(/* ... */);
At first glance this may seem unsound, but remember, Box
is an allocation. Once the future is allocated it has a stable location on the heap, so you can move around the Box
pointer all you want, the internal references will remain stable.
The second way you can safely create a pin is with the pin!
macro.
use std::pin::pin;
let mut future1: Pin<&mut SelfReferential> = pin!(SelfReferential::new());
future1.as_mut().poll(/* ... */);
With pin!
, you can safely pin a struct without even allocating it! The trick is that pin!
takes ownership of the future, making it impossible to access except through the Pin<&mut T>
, which remember, will never give you a mutable reference if T
isn't Unpin
. The T
is completely hidden and thus safe from being tampered with.
Pin
is a common point of confusion around futures, but once you understand why it exists, the solution is pretty ingenious.
async/await
Alright, that's the standard Future
trait.
pub trait Future {
type Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>;
}
So how do we use it?
The futures
crate is where all the useful helpers live. Functions like poll_fn
that we wrote before, and combinators like map
and and_then
, which we called chain
.
use futures::{FutureExt, future::{ready, poll_fn}};
let future = poll_fn(|_| Poll::Ready(1))
.and_then(|x| poll_fn(|_| Poll::Ready(x + 1)))
.and_then(|x| poll_fn(|_| println!("{x}")));
But even with these helpers, as we found out, it's a little cumbersome to write async code. It's still a shift from the simple synchronous code we're used to. Maybe not as drastic as a manual epoll
event loop, but still a big change.
It turns out there's actually another way to write futures in Rust, with the async/await syntax.
Instead of using poll_fn
to create futures, you can attach the async
keyword to functions:
async fn foo() -> usize {
1
}
An async function is really just a function that returns an async block:
fn foo() -> impl Future<Output = usize> {
async { 1 }
}
Which is really just a function that returns a poll_fn
future:
fn foo() -> impl Future<Output = usize> {
poll_fn(|| Poll::Ready(1))
}
The magic comes with the await
keyword. await
waits for the completion of another future, propagating Poll::Pending
until the future is resolved.
async fn foo() {
let one = one().await;
let two = two().await;
assert_eq!(one + 1, two);
}
async fn two() -> usize {
one().await + 1
}
async fn one() -> usize {
1
}
Under the hood, the compiler transforms this into manual state machines, similar to the ones we created with those combinators:
fn foo() -> impl Future<Output = ()> {
one()
.and_then(|one| two().and_then(move |two| poll_fn(move |_| Poll::Ready((one, two)))))
.and_then(|(one, two)| poll_fn(move |_| Poll::Ready(assert_eq!(one, two + 1))))
}
fn two() -> impl Future<Output = usize> {
one().and_then(|one| poll_fn(move |_| Poll::Ready(one + 1)))
}
fn one() -> impl Future<Output = usize> {
poll_fn(|_| Poll::Ready(1))
}
Which, as we know all too well, translates into a huge manual state machine that looks something like this:
enum FooFuture {
One(OneFuture),
Two(usize, TwoFuture),
}
impl Future for FooFuture {
type Output = ();
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let FooFuture::One(f) = self {
match f.poll(cx) {
Poll::Ready(one) => *self = Self::Two(one, TwoFuture(OneFuture)),
Poll::Pending => return Poll::Pending,
}
}
if let FooFuture::Two(one, f) = self {
match f.poll(cx) {
Poll::Ready(two) => {
assert_eq!(*one + 1, two);
return Poll::Ready(());
}
Poll::Pending => return Poll::Pending,
}
}
None
}
}
struct TwoFuture(OneFuture);
impl Future for TwoFuture {
type Output = usize;
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.0.poll(waker) {
Poll::Ready(one) => Poll::Ready(one + 1),
Poll::Pending => Poll::Pending,
}
}
}
struct OneFuture;
impl Future for OneFuture {
type Output = usize;
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Self::Output> {
Poll::Ready(1)
}
}
... but async/await removes all of that headache. Of course, we aren't actually doing any I/O here so the futures are mostly useless, but you can imagine how helpful this would be for our web server.
In fact, it's even better than the combinators. With async
functions, you can hold onto local state across await
points!
async fn foo() {
let x = vec![1, 2, 3];
bar(&x).await;
baz(&x).await;
println!("{x:?}");
}
After going through implementing futures ourselves, we can really appreciate the convenience of this. Under the hood the compiler has to generate a self-referential future to give bar
and baz
access to the state.
struct FooFuture {
x: Vec<i32>, // (X)
state: FooFutureState,
}
enum FooFutureState {
Bar(BarFuture),
Baz(BazFuture),
}
struct BarFuture { x: *mut Vec<i32> /* pointer to (X)! */ }
struct BazFuture { x: *mut Vec<i32> /* pointer to (X)! */ }
The compiler takes care of all the unsafe code involved in this, allowing us to work with local state just like we would in a regular function. For this reason, the futures generated by async
blocks or functions are !Unpin
.
async/await removes any complexity that remained with writing futures compared to synchronous code. After implementing futures manually, it almost feels like magic!
A Tokio Server
So far we've only been looking at how Future
works, we haven't discussed how to actually run one, or do any I/O. The thing is, the standard library doesn't provide any of that, it only provides the bare essential types and traits to get started.
If you want to actually write an async application, you have to use an external runtime. The most popular general purpose runtime is tokio
. tokio
provides a task scheduler, a reactor, and a pool to run blocking tasks, just like we wrote earlier, but it also provides timers, async channels, and various other useful types and utilities for async code. On top of that, tokio
is multi-threaded, distributing async tasks to take advantage of all your CPU cores. The core ideas behind tokio
are very similar to the async runtime we wrote ourselves, but you can read more about it's design in this excellent blog post.
It's time to write our final web server, this time using the standard Future
trait and tokio.
Tokio applications begin with the #[tokio::main]
macro. Under the hood, this macro spins up the runtime and runs the async code in main
.
#[tokio::main]
async fn main() {
// ...
}
Tokio does it's best to mirror the standard library for most of it's types. For example, tokio::net::TcpListener
works exactly like std::net::TcpListener
, except with async
methods. Any interactions with epoll and the reactor are hidden under the hood.
use tokio::net::{TcpListener, TcpStream};
#[tokio::main]
async fn main() {
let listener = TcpListener::bind("localhost:3000").await.unwrap();
loop {
let (connection, _) = listener.accept().await.unwrap();
if let Err(e) = handle_connection(connection).await {
println!("failed to handle connection: {e}")
}
}
}
async fn handle_connection(mut connection: TcpStream) -> io::Result<()> {
// ...
}
That isn't exactly right though, we need to spawn the connection handler. We can do so with the tokio::spawn
function, which takes a future.
#[tokio::main]
async fn main() {
let listener = TcpListener::bind("localhost:3000").await.unwrap();
loop {
let (connection, _) = listener.accept().await.unwrap();
tokio::spawn(async move { // 👈
if let Err(e) = handle_connection(connection).await {
println!("failed to handle connection: {e}")
}
});
}
}
Now for the connection handler. With the AsyncReadExt
trait and the await
keyword, we can read from the TCP stream almost exactly like we did before.
use tokio::io::AsyncRead;
async fn handle_connection(mut connection: TcpStream) -> io::Result<()> {
let mut read = 0;
let mut request = [0u8; 1024];
loop {
// try reading from the stream
let num_bytes = connection.read(&mut request[read..]).await?; // 👈
// the client disconnected
if num_bytes == 0 {
println!("client disconnected unexpectedly");
return Ok(());
}
// keep track of how many bytes we've read
read += num_bytes;
// have we reached the end of the request?
if request.get(read - 4..read) == Some(b"\r\n\r\n") {
break;
}
}
let request = String::from_utf8_lossy(&request[..read]);
println!("{request}");
// ...
}
Writing the response works the same as well.
async fn handle_connection(mut connection: TcpStream) -> io::Result<()> {
// ...
// "Hello World!" in HTTP
let response = /* ... */;
let mut written = 0;
loop {
// write the remaining response bytes
let num_bytes = connection.write(response[written..].as_bytes()).await?; // 👈
// the client disconnected
if num_bytes == 0 {
println!("client disconnected unexpectedly");
return Ok(());
}
written += num_bytes;
// have we written the whole response yet?
if written == response.len() {
break;
}
}
connection.flush().await
}
Well that was easy.
If you notice, our program is exactly the same as our original server, with the exception of a couple uses of the async
and await
keywords. With async/await, we really can have our cake and eat it too.
Now to implement graceful shutdown.
The first step is to identify the ctrl+c signal. With tokio, this is as simple as using the tokio::signal::ctrl_c
function, an async function that returns once the ctrl+c signal is received. We can also use tokio's select!
macro, a more powerful version of the select
combinator we implemented earlier.
pub async fn main() {
let listener = TcpListener::bind("localhost:3000").await.unwrap();
let state = Arc::new((AtomicUsize::new(0), Notify::new()));
loop {
select! {
// new incoming connection
result = listener.accept() => {
let (connection, _) = result.unwrap();
tokio::spawn(async move {
// ..
});
}
// ctrl+c signal
shutdown = ctrl_c() => {
let timer = /* ... */;
let request_counter = /* .. */;
select! {
_ = timer => {}
_ = request_counter => {}
}
println!("Gracefully shutting down.");
return;
}
}
}
}
select!
runs the branch for whichever future completes first and cancels the other branches, allowing us to select between incoming connections and the ctrl+c signal, and run the appropriate code for each.
Now we need to create the graceful shutdown condition.
For the timer, we can use tokio's asynchronous sleep
function. Under the hood this hooks into a custom timer system, a much more efficient version of our spawn_blocking
timers. You can read more about how that works in this other excellent post.
select! {
// new incoming connection
result = listener.accept() => {
// ...
}
// ctrl+c signal
shutdown = ctrl_c() => {
let timer = tokio::time::sleep(Duration::from_secs(30));
let request_counter = /* .. */;
select! {
_ = timer => {}
_ = request_counter => {}
}
println!("Gracefully shutting down.");
return;
}
}
Now for the active request counter. Instead of managing wakers manually, we can use a simple counter and take advantage of tokio's Notify
type, which allows tasks to notify each other, or wait to be notified.
use tokio::sync::Notify;
let state = Arc::new((AtomicUsize::new(0), Notify::new()));
When a request comes in, we increment the counter, and when it completes, we decrement it. If the counter reaches zero, the last active task calls notify_one
, which will wake up the main thread, letting it know that all active tasks have completed.
select! {
result = listener.accept() => {
let (connection, _) = result.unwrap();
let state = state.clone();
// increment the counter
state.0.fetch_add(1, Ordering::Relaxed);
tokio::spawn(async move {
if let Err(e) = handle_connection(connection).await {
// ...
}
// decrement the counter
let count = state.0.fetch_sub(1, Ordering::Relaxed);
if count == 1 {
// we were the last active task
state.1.notify_one();
}
});
}
shutdown = ctrl_c() => {
// ...
}
}
The shutdown handler can now simply select between the timer and Notify::notified
, which will resolve when someone calls notify_one
, indicating that the last active request has completed.
select! {
result = listener.accept() => {
// ...
}
shutdown = ctrl_c() => {
// a 30 second timer
let timer = tokio::time::sleep(Duration::from_secs(30));
// notified by the last active task
let notification = state.1.notified();
// if the count isn't zero, we have to wait
if state.0.load(Ordering::Relaxed) != 0 {
// wait for either the timer or notification to resolve
select! {
_ = timer => {}
_ = notification => {}
}
}
println!("Gracefully shutting down.");
return;
}
}
Beautiful, isn't it?
With tokio
and async/await we don't even have to think about wakers, reactors, or anything else that goes on under the hood. All the building blocks are provided for us, we just have to put them together.
Afterword
Whew, that was quite the journey!
We started from the simplest web server, tried multithreading, and then worked our way up to a custom asynchronous runtime built on epoll. All to implement graceful shutdown.
Then we circled back and implemented graceful shutdown with tokio
in just a few extra lines of code.
Hopefully this article has helped you appreciate the power of async Rust, as well as taught you a little more about how it works under the hood. All the code for this repository is available on GitHub.
1. There's really no way to force flush a network socket so flush
is actually a no-op on TcpStream
, but we'll call it anyways to be true to io::Write
.
2. The real scheduler is of course more complicated than this. You can read about it here.
3. For example, flushing the TLB.
4. In fact on Linux, threads and processes are just "tasks" with different configuration settings.
5. Technically, our scheduler is unfair. A fair scheduler would check epoll every once in a while even if there are tasks that are runnable, to avoid starving certain tasks.