Rust 使用 epoll 的例子

通过代码示例学习 epoll 调用机制。
use std::collections::HashMap;
use std::io;
use std::io::prelude::*;
use std::net::{TcpListener, TcpStream};
use std::os::unix::io::{AsRawFd, RawFd};
/*
从开源项目 mio (https://github.com/tokio-rs/mio) 中复制了一个宏方法,用于调用 libc 库中的系统方法。
这段代码大致是封装好的一段结构,用于调用一个系统调用然后返回一个 Result 结构。
*/
#[allow(unused_macros)]
macro_rules! syscall {
($fn: ident ( $($arg: expr),* $(,)* ) ) => {{
let res = unsafe { libc::$fn($($arg, )*) };
if res == -1 {
Err(std::io::Error::last_os_error())
} else {
Ok(res)
}
}};
}
#[derive(Debug)]
pub struct RequestContext {
pub stream: TcpStream,
pub content_length: usize,
pub buf: Vec<u8>,
}
impl RequestContext {
fn new(stream: TcpStream) -> Self {
Self {
stream,
buf: Vec::new(),
content_length: 0,
}
}
fn read_cb(&mut self, key: u64, epoll_fd: RawFd) -> io::Result<()> {
let mut buf = [0u8; 4096];
match self.stream.read(&mut buf) {
Ok(_) => {
if let Ok(data) = std::str::from_utf8(&buf) {
self.parse_and_set_content_length(data);
}
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => {
return Err(e);
}
};
self.buf.extend_from_slice(&buf);
if self.buf.len() >= self.content_length {
println!("got all data: {} bytes", self.buf.len());
modify_interest(epoll_fd, self.stream.as_raw_fd(), listener_write_event(key))?;
} else {
modify_interest(epoll_fd, self.stream.as_raw_fd(), listener_read_event(key))?;
}
Ok(())
}
fn parse_and_set_content_length(&mut self, data: &str) {
if data.contains("HTTP") {
if let Some(content_length) = data
.lines()
.find(|l| l.to_lowercase().starts_with("content-length: "))
{
if let Some(len) = content_length
.to_lowercase()
.strip_prefix("content-length: ")
{
self.content_length = len.parse::<usize>().expect("content-length is valid");
println!("set content length: {} bytes", self.content_length);
}
}
}
}
fn write_cb(&mut self, key: u64, epoll_fd: RawFd) -> io::Result<()> {
match self.stream.write(HTTP_RESP) {
Ok(_) => println!("answered from request {}", key),
Err(e) => eprintln!("could not answer to request {}, {}", key, e),
};
self.stream.shutdown(std::net::Shutdown::Both)?;
let fd = self.stream.as_raw_fd();
remove_interest(epoll_fd, fd)?;
close(fd);
Ok(())
}
}
const READ_FLAGS: i32 = libc::EPOLLONESHOT | libc::EPOLLIN;
const WRITE_FLAGS: i32 = libc::EPOLLONESHOT | libc::EPOLLOUT;
const HTTP_RESP: &[u8] = br#"HTTP/1.1 200 OK
content-type: text/html
content-length: 5
Hello"#;
fn main() -> io::Result<()> {
let mut request_contexts: HashMap<u64, RequestContext> = HashMap::new();
let mut events: Vec<libc::epoll_event> = Vec::with_capacity(1024);
let mut key = 100;
let listener = TcpListener::bind("127.0.0.1:8000")?;
listener.set_nonblocking(true)?;
let listener_fd = listener.as_raw_fd();
let epoll_fd = epoll_create().expect("can create epoll queue");
add_interest(epoll_fd, listener_fd, listener_read_event(key))?;
loop {
println!("requests in flight: {}", request_contexts.len());
events.clear();
/*
这里调用 epoll_wait 方法会阻塞并等待直到有事件发生。
当有事件发生时, epoll_wait 返回,并将所有事件放入 events 变量中。
epoll_wait 当发生以下情况时返回:
- 有事件发生
- 有超时发生
- 被信号中断
*/
let res = match syscall!(epoll_wait(
epoll_fd,
events.as_mut_ptr() as *mut libc::epoll_event,
1024,
1000 as libc::c_int,
)) {
Ok(v) => v,
Err(e) => panic!("error during epoll wait: {}", e),
};
// safe as long as the kernel does nothing wrong - copied from mio
unsafe { events.set_len(res as usize) };
for ev in &events {
match ev.u64 {
100 => {
match listener.accept() {
Ok((stream, addr)) => {
stream.set_nonblocking(true)?;
println!("new client: {}", addr);
key += 1;
add_interest(epoll_fd, stream.as_raw_fd(), listener_read_event(key))?;
request_contexts.insert(key, RequestContext::new(stream));
}
Err(e) => eprintln!("couldn't accept: {}", e),
};
modify_interest(epoll_fd, listener_fd, listener_read_event(100))?;
}
key => {
let mut to_delete = None;
if let Some(context) = request_contexts.get_mut(&key) {
let events: u32 = ev.events;
match events {
v if v as i32 & libc::EPOLLIN == libc::EPOLLIN => {
context.read_cb(key, epoll_fd)?;
}
v if v as i32 & libc::EPOLLOUT == libc::EPOLLOUT => {
context.write_cb(key, epoll_fd)?;
to_delete = Some(key);
}
v => println!("unexpected events: {}", v),
};
}
if let Some(key) = to_delete {
request_contexts.remove(&key);
}
}
}
}
}
}
/*
创建并返回一个 epoll 资源
*/
fn epoll_create() -> io::Result<RawFd> {
let fd = syscall!(epoll_create1(0))?;
if let Ok(flags) = syscall!(fcntl(fd, libc::F_GETFD)) {
let _ = syscall!(fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC));
}
Ok(fd)
}
fn listener_read_event(key: u64) -> libc::epoll_event {
libc::epoll_event {
events: READ_FLAGS as u32,
u64: key,
}
}
fn listener_write_event(key: u64) -> libc::epoll_event {
libc::epoll_event {
events: WRITE_FLAGS as u32,
u64: key,
}
}
fn close(fd: RawFd) {
let _ = syscall!(close(fd));
}
/*
通过使用 epoll_ctl 函数,你可以动态地添加、修改或删除事件,以满足应用程序随时间变化的需求。
*/
fn add_interest(epoll_fd: RawFd, fd: RawFd, mut event: libc::epoll_event) -> io::Result<()> {
syscall!(epoll_ctl(epoll_fd, libc::EPOLL_CTL_ADD, fd, &mut event))?;
Ok(())
}
fn modify_interest(epoll_fd: RawFd, fd: RawFd, mut event: libc::epoll_event) -> io::Result<()> {
syscall!(epoll_ctl(epoll_fd, libc::EPOLL_CTL_MOD, fd, &mut event))?;
Ok(())
}
fn remove_interest(epoll_fd: RawFd, fd: RawFd) -> io::Result<()> {
syscall!(epoll_ctl(
epoll_fd,
libc::EPOLL_CTL_DEL,
fd,
std::ptr::null_mut()
))?;
Ok(())
}
Publish on 2024-03-01,Update on 2025-02-10