openzeppelin_relayer/api/middleware/
timeout.rs1use actix_web::{
2 body::MessageBody,
3 dev::{Service, ServiceRequest, ServiceResponse, Transform},
4 Error,
5};
6use futures::future::{ready, LocalBoxFuture, Ready};
7use std::{
8 rc::Rc,
9 task::{Context, Poll},
10 time::Duration,
11};
12use tokio::time::timeout;
13
14use crate::metrics::TIMEOUT_COUNTER;
15
16pub struct TimeoutMiddleware {
18 duration: Duration,
19}
20
21impl TimeoutMiddleware {
22 pub fn new(seconds: u64) -> Self {
23 Self {
24 duration: Duration::from_secs(seconds),
25 }
26 }
27}
28
29impl<S, B> Transform<S, ServiceRequest> for TimeoutMiddleware
30where
31 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
32 S::Future: 'static,
33 B: MessageBody + 'static,
34{
35 type Response = ServiceResponse<B>;
36 type Error = Error;
37 type Transform = TimeoutMiddlewareService<S>;
38 type InitError = ();
39 type Future = Ready<Result<Self::Transform, Self::InitError>>;
40
41 fn new_transform(&self, service: S) -> Self::Future {
42 ready(Ok(TimeoutMiddlewareService {
43 service: Rc::new(service),
44 duration: self.duration,
45 }))
46 }
47}
48
49pub struct TimeoutMiddlewareService<S> {
50 service: Rc<S>,
51 duration: Duration,
52}
53
54impl<S, B> Service<ServiceRequest> for TimeoutMiddlewareService<S>
55where
56 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
57 S::Future: 'static,
58 B: MessageBody + 'static,
59{
60 type Response = ServiceResponse<B>;
61 type Error = Error;
62 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
63
64 fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
65 self.service.poll_ready(cx)
66 }
67
68 fn call(&self, req: ServiceRequest) -> Self::Future {
69 let service = Rc::clone(&self.service);
70 let duration = self.duration;
71
72 Box::pin(async move {
73 let path = req.path().to_string();
74 let method = req.method().to_string();
75
76 match timeout(duration, service.call(req)).await {
77 Ok(result) => result,
78 Err(_) => {
79 tracing::warn!(
81 "Request timeout: {} {} exceeded {}s",
82 method,
83 path,
84 duration.as_secs()
85 );
86
87 TIMEOUT_COUNTER
89 .with_label_values(&[path.as_str(), method.as_str(), "handler"])
90 .inc();
91
92 Err(actix_web::error::ErrorGatewayTimeout(
93 "Request handler timeout",
94 ))
95 }
96 }
97 })
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use actix_web::{test, web, App, HttpResponse};
105 use std::time::Duration;
106
107 #[actix_web::test]
108 async fn test_request_completes_before_timeout() {
109 let app = test::init_service(App::new().wrap(TimeoutMiddleware::new(5)).route(
110 "/fast",
111 web::get().to(|| async { HttpResponse::Ok().body("OK") }),
112 ))
113 .await;
114
115 let req = test::TestRequest::get().uri("/fast").to_request();
116 let resp = test::call_service(&app, req).await;
117
118 assert!(resp.status().is_success());
119 }
120
121 #[actix_web::test]
122 async fn test_request_exceeds_timeout_returns_error() {
123 let app = test::init_service(App::new().wrap(TimeoutMiddleware::new(1)).route(
124 "/slow",
125 web::get().to(|| async {
126 tokio::time::sleep(Duration::from_secs(3)).await;
127 HttpResponse::Ok().body("OK")
128 }),
129 ))
130 .await;
131
132 let req = test::TestRequest::get().uri("/slow").to_request();
133 let result = test::try_call_service(&app, req).await;
134
135 assert!(result.is_err());
136 let err = result.unwrap_err();
137 assert_eq!(
138 err.as_response_error().status_code(),
139 actix_web::http::StatusCode::GATEWAY_TIMEOUT
140 );
141 }
142
143 #[actix_web::test]
144 async fn test_request_just_under_timeout_succeeds() {
145 let app = test::init_service(App::new().wrap(TimeoutMiddleware::new(2)).route(
146 "/almost",
147 web::get().to(|| async {
148 tokio::time::sleep(Duration::from_millis(500)).await;
149 HttpResponse::Ok().body("OK")
150 }),
151 ))
152 .await;
153
154 let req = test::TestRequest::get().uri("/almost").to_request();
155 let resp = test::call_service(&app, req).await;
156
157 assert!(resp.status().is_success());
158 }
159
160 #[actix_web::test]
161 async fn test_timeout_middleware_new_sets_duration() {
162 let middleware = TimeoutMiddleware::new(10);
163 assert_eq!(middleware.duration, Duration::from_secs(10));
164 }
165
166 #[actix_web::test]
167 async fn test_post_request_timeout() {
168 let app = test::init_service(App::new().wrap(TimeoutMiddleware::new(1)).route(
169 "/slow-post",
170 web::post().to(|| async {
171 tokio::time::sleep(Duration::from_secs(3)).await;
172 HttpResponse::Ok().body("OK")
173 }),
174 ))
175 .await;
176
177 let req = test::TestRequest::post()
178 .uri("/slow-post")
179 .set_payload("test body")
180 .to_request();
181 let result = test::try_call_service(&app, req).await;
182
183 assert!(result.is_err());
184 let err = result.unwrap_err();
185 assert_eq!(
186 err.as_response_error().status_code(),
187 actix_web::http::StatusCode::GATEWAY_TIMEOUT
188 );
189 }
190
191 #[actix_web::test]
192 async fn test_multiple_requests_independent_timeouts() {
193 let app = test::init_service(
194 App::new()
195 .wrap(TimeoutMiddleware::new(2))
196 .route(
197 "/fast",
198 web::get().to(|| async { HttpResponse::Ok().body("OK") }),
199 )
200 .route(
201 "/slow",
202 web::get().to(|| async {
203 tokio::time::sleep(Duration::from_secs(5)).await;
204 HttpResponse::Ok().body("OK")
205 }),
206 ),
207 )
208 .await;
209
210 let req = test::TestRequest::get().uri("/fast").to_request();
212 let resp = test::call_service(&app, req).await;
213 assert!(resp.status().is_success());
214
215 let req = test::TestRequest::get().uri("/slow").to_request();
217 let result = test::try_call_service(&app, req).await;
218 assert!(result.is_err());
219 let err = result.unwrap_err();
220 assert_eq!(
221 err.as_response_error().status_code(),
222 actix_web::http::StatusCode::GATEWAY_TIMEOUT
223 );
224 }
225}