openzeppelin_relayer/api/middleware/
timeout.rs

1use actix_web::{
2    body::MessageBody,
3    dev::{Service, ServiceRequest, ServiceResponse, Transform},
4    Error,
5};
6use futures::future::{ready, LocalBoxFuture, Ready};
7use std::{
8    rc::Rc,
9    task::{Context, Poll},
10    time::Duration,
11};
12use tokio::time::timeout;
13
14use crate::metrics::TIMEOUT_COUNTER;
15
16/// Middleware that enforces a timeout on HTTP request handlers
17pub struct TimeoutMiddleware {
18    duration: Duration,
19}
20
21impl TimeoutMiddleware {
22    pub fn new(seconds: u64) -> Self {
23        Self {
24            duration: Duration::from_secs(seconds),
25        }
26    }
27}
28
29impl<S, B> Transform<S, ServiceRequest> for TimeoutMiddleware
30where
31    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
32    S::Future: 'static,
33    B: MessageBody + 'static,
34{
35    type Response = ServiceResponse<B>;
36    type Error = Error;
37    type Transform = TimeoutMiddlewareService<S>;
38    type InitError = ();
39    type Future = Ready<Result<Self::Transform, Self::InitError>>;
40
41    fn new_transform(&self, service: S) -> Self::Future {
42        ready(Ok(TimeoutMiddlewareService {
43            service: Rc::new(service),
44            duration: self.duration,
45        }))
46    }
47}
48
49pub struct TimeoutMiddlewareService<S> {
50    service: Rc<S>,
51    duration: Duration,
52}
53
54impl<S, B> Service<ServiceRequest> for TimeoutMiddlewareService<S>
55where
56    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
57    S::Future: 'static,
58    B: MessageBody + 'static,
59{
60    type Response = ServiceResponse<B>;
61    type Error = Error;
62    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
63
64    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
65        self.service.poll_ready(cx)
66    }
67
68    fn call(&self, req: ServiceRequest) -> Self::Future {
69        let service = Rc::clone(&self.service);
70        let duration = self.duration;
71
72        Box::pin(async move {
73            let path = req.path().to_string();
74            let method = req.method().to_string();
75
76            match timeout(duration, service.call(req)).await {
77                Ok(result) => result,
78                Err(_) => {
79                    // Timeout occurred
80                    tracing::warn!(
81                        "Request timeout: {} {} exceeded {}s",
82                        method,
83                        path,
84                        duration.as_secs()
85                    );
86
87                    // Record timeout metric
88                    TIMEOUT_COUNTER
89                        .with_label_values(&[path.as_str(), method.as_str(), "handler"])
90                        .inc();
91
92                    Err(actix_web::error::ErrorGatewayTimeout(
93                        "Request handler timeout",
94                    ))
95                }
96            }
97        })
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use actix_web::{test, web, App, HttpResponse};
105    use std::time::Duration;
106
107    #[actix_web::test]
108    async fn test_request_completes_before_timeout() {
109        let app = test::init_service(App::new().wrap(TimeoutMiddleware::new(5)).route(
110            "/fast",
111            web::get().to(|| async { HttpResponse::Ok().body("OK") }),
112        ))
113        .await;
114
115        let req = test::TestRequest::get().uri("/fast").to_request();
116        let resp = test::call_service(&app, req).await;
117
118        assert!(resp.status().is_success());
119    }
120
121    #[actix_web::test]
122    async fn test_request_exceeds_timeout_returns_error() {
123        let app = test::init_service(App::new().wrap(TimeoutMiddleware::new(1)).route(
124            "/slow",
125            web::get().to(|| async {
126                tokio::time::sleep(Duration::from_secs(3)).await;
127                HttpResponse::Ok().body("OK")
128            }),
129        ))
130        .await;
131
132        let req = test::TestRequest::get().uri("/slow").to_request();
133        let result = test::try_call_service(&app, req).await;
134
135        assert!(result.is_err());
136        let err = result.unwrap_err();
137        assert_eq!(
138            err.as_response_error().status_code(),
139            actix_web::http::StatusCode::GATEWAY_TIMEOUT
140        );
141    }
142
143    #[actix_web::test]
144    async fn test_request_just_under_timeout_succeeds() {
145        let app = test::init_service(App::new().wrap(TimeoutMiddleware::new(2)).route(
146            "/almost",
147            web::get().to(|| async {
148                tokio::time::sleep(Duration::from_millis(500)).await;
149                HttpResponse::Ok().body("OK")
150            }),
151        ))
152        .await;
153
154        let req = test::TestRequest::get().uri("/almost").to_request();
155        let resp = test::call_service(&app, req).await;
156
157        assert!(resp.status().is_success());
158    }
159
160    #[actix_web::test]
161    async fn test_timeout_middleware_new_sets_duration() {
162        let middleware = TimeoutMiddleware::new(10);
163        assert_eq!(middleware.duration, Duration::from_secs(10));
164    }
165
166    #[actix_web::test]
167    async fn test_post_request_timeout() {
168        let app = test::init_service(App::new().wrap(TimeoutMiddleware::new(1)).route(
169            "/slow-post",
170            web::post().to(|| async {
171                tokio::time::sleep(Duration::from_secs(3)).await;
172                HttpResponse::Ok().body("OK")
173            }),
174        ))
175        .await;
176
177        let req = test::TestRequest::post()
178            .uri("/slow-post")
179            .set_payload("test body")
180            .to_request();
181        let result = test::try_call_service(&app, req).await;
182
183        assert!(result.is_err());
184        let err = result.unwrap_err();
185        assert_eq!(
186            err.as_response_error().status_code(),
187            actix_web::http::StatusCode::GATEWAY_TIMEOUT
188        );
189    }
190
191    #[actix_web::test]
192    async fn test_multiple_requests_independent_timeouts() {
193        let app = test::init_service(
194            App::new()
195                .wrap(TimeoutMiddleware::new(2))
196                .route(
197                    "/fast",
198                    web::get().to(|| async { HttpResponse::Ok().body("OK") }),
199                )
200                .route(
201                    "/slow",
202                    web::get().to(|| async {
203                        tokio::time::sleep(Duration::from_secs(5)).await;
204                        HttpResponse::Ok().body("OK")
205                    }),
206                ),
207        )
208        .await;
209
210        // Fast request should succeed
211        let req = test::TestRequest::get().uri("/fast").to_request();
212        let resp = test::call_service(&app, req).await;
213        assert!(resp.status().is_success());
214
215        // Slow request should timeout
216        let req = test::TestRequest::get().uri("/slow").to_request();
217        let result = test::try_call_service(&app, req).await;
218        assert!(result.is_err());
219        let err = result.unwrap_err();
220        assert_eq!(
221            err.as_response_error().status_code(),
222            actix_web::http::StatusCode::GATEWAY_TIMEOUT
223        );
224    }
225}