openzeppelin_relayer/api/middleware/
concurrency.rs

1use actix_web::{
2    body::MessageBody,
3    dev::{Service, ServiceRequest, ServiceResponse, Transform},
4    Error,
5};
6use futures::future::{ready, LocalBoxFuture, Ready};
7use std::{
8    rc::Rc,
9    sync::Arc,
10    task::{Context, Poll},
11};
12use tokio::sync::Semaphore;
13
14use crate::metrics::IN_FLIGHT_REQUESTS;
15
16/// Middleware that limits concurrent requests using a semaphore
17pub struct ConcurrencyLimiter {
18    semaphore: Arc<Semaphore>,
19    max_permits: usize,
20}
21
22impl ConcurrencyLimiter {
23    pub fn new(max_concurrent: usize) -> Self {
24        Self {
25            semaphore: Arc::new(Semaphore::new(max_concurrent)),
26            max_permits: max_concurrent,
27        }
28    }
29}
30
31impl<S, B> Transform<S, ServiceRequest> for ConcurrencyLimiter
32where
33    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
34    S::Future: 'static,
35    B: MessageBody + 'static,
36{
37    type Response = ServiceResponse<B>;
38    type Error = Error;
39    type Transform = ConcurrencyLimiterService<S>;
40    type InitError = ();
41    type Future = Ready<Result<Self::Transform, Self::InitError>>;
42
43    fn new_transform(&self, service: S) -> Self::Future {
44        ready(Ok(ConcurrencyLimiterService {
45            service: Rc::new(service),
46            semaphore: Arc::clone(&self.semaphore),
47            max_permits: self.max_permits,
48        }))
49    }
50}
51
52pub struct ConcurrencyLimiterService<S> {
53    service: Rc<S>,
54    semaphore: Arc<Semaphore>,
55    max_permits: usize,
56}
57
58impl<S, B> Service<ServiceRequest> for ConcurrencyLimiterService<S>
59where
60    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
61    S::Future: 'static,
62    B: MessageBody + 'static,
63{
64    type Response = ServiceResponse<B>;
65    type Error = Error;
66    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
67
68    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
69        self.service.poll_ready(cx)
70    }
71
72    fn call(&self, req: ServiceRequest) -> Self::Future {
73        let service = Rc::clone(&self.service);
74        let semaphore = Arc::clone(&self.semaphore);
75        let max_permits = self.max_permits;
76
77        Box::pin(async move {
78            let endpoint = req.path().to_string();
79
80            // Try to acquire a permit
81            let permit = match semaphore.try_acquire() {
82                Ok(permit) => permit,
83                Err(_) => {
84                    // No permits available
85                    let current_in_flight = max_permits;
86                    tracing::warn!(
87                        "Concurrency limit reached for {}: {} requests in flight",
88                        endpoint,
89                        current_in_flight
90                    );
91
92                    return Err(actix_web::error::ErrorTooManyRequests(
93                        serde_json::json!({
94                            "error": "Too many concurrent requests",
95                            "max_concurrent": max_permits,
96                        })
97                        .to_string(),
98                    ));
99                }
100            };
101
102            // Increment in-flight counter
103            IN_FLIGHT_REQUESTS.with_label_values(&[&endpoint]).inc();
104
105            // Call the service
106            let result = service.call(req).await;
107
108            // Decrement in-flight counter when done
109            IN_FLIGHT_REQUESTS.with_label_values(&[&endpoint]).dec();
110
111            // Permit is automatically dropped here
112            drop(permit);
113
114            result
115        })
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use actix_web::{test, web, App, HttpResponse};
123    use std::rc::Rc;
124    use std::time::Duration;
125    use tokio::time::sleep;
126
127    #[actix_rt::test]
128    async fn test_concurrency_limiter_allows_requests_within_limit() {
129        let app = test::init_service(App::new().wrap(ConcurrencyLimiter::new(2)).route(
130            "/test",
131            web::get().to(|| async { HttpResponse::Ok().body("ok") }),
132        ))
133        .await;
134
135        // Make 2 sequential requests - both should succeed since limit is 2
136        let req1 = test::TestRequest::get().uri("/test").to_request();
137        let req2 = test::TestRequest::get().uri("/test").to_request();
138
139        let resp1 = test::call_service(&app, req1).await;
140        let resp2 = test::call_service(&app, req2).await;
141
142        assert!(resp1.status().is_success());
143        assert!(resp2.status().is_success());
144    }
145
146    #[actix_rt::test]
147    async fn test_concurrency_limiter_rejects_excess_requests() {
148        use tokio::sync::mpsc;
149        use tokio::task;
150
151        let (tx, mut rx) = mpsc::unbounded_channel();
152        let app = Rc::new(
153            test::init_service(App::new().wrap(ConcurrencyLimiter::new(1)).route(
154                "/slow",
155                web::get().to(move || {
156                    let tx = tx.clone();
157                    async move {
158                        // Signal that we've started processing
159                        let _ = tx.send(());
160                        sleep(Duration::from_millis(100)).await;
161                        HttpResponse::Ok().body("ok")
162                    }
163                }),
164            ))
165            .await,
166        );
167
168        // Start a slow request that will hold the permit
169        let app_clone = Rc::clone(&app);
170        let req1 = test::TestRequest::get().uri("/slow").to_request();
171        let resp1_handle =
172            task::spawn_local(async move { test::call_service(&*app_clone, req1).await });
173
174        // Wait for confirmation that the first request acquired the permit
175        rx.recv().await.unwrap();
176
177        // Try to make another request while the first is still processing
178        let req2 = test::TestRequest::get().uri("/slow").to_request();
179        let resp2_result = test::try_call_service(&*app, req2).await;
180
181        // Second request should be rejected with 429
182        assert!(resp2_result.is_err());
183        let err = resp2_result.unwrap_err();
184        assert_eq!(err.as_response_error().status_code().as_u16(), 429);
185
186        // Wait for first request to complete
187        let resp1 = resp1_handle.await.unwrap();
188        assert!(resp1.status().is_success());
189    }
190
191    #[actix_rt::test]
192    async fn test_concurrency_limiter_releases_permits_after_completion() {
193        let app = test::init_service(App::new().wrap(ConcurrencyLimiter::new(1)).route(
194            "/test",
195            web::get().to(|| async { HttpResponse::Ok().body("ok") }),
196        ))
197        .await;
198
199        // Make first request
200        let req1 = test::TestRequest::get().uri("/test").to_request();
201        let resp1 = test::call_service(&app, req1).await;
202        assert!(resp1.status().is_success());
203
204        // After first request completes, permit should be released
205        // Make second request - should succeed
206        let req2 = test::TestRequest::get().uri("/test").to_request();
207        let resp2 = test::call_service(&app, req2).await;
208        assert!(resp2.status().is_success());
209    }
210
211    #[actix_rt::test]
212    async fn test_concurrency_limiter_error_response_format() {
213        use tokio::sync::mpsc;
214        use tokio::task;
215
216        let (tx, mut rx) = mpsc::unbounded_channel();
217        let app = Rc::new(
218            test::init_service(App::new().wrap(ConcurrencyLimiter::new(1)).route(
219                "/slow",
220                web::get().to(move || {
221                    let tx = tx.clone();
222                    async move {
223                        let _ = tx.send(());
224                        sleep(Duration::from_millis(100)).await;
225                        HttpResponse::Ok().body("ok")
226                    }
227                }),
228            ))
229            .await,
230        );
231
232        // Start a slow request
233        let app_clone = Rc::clone(&app);
234        let req1 = test::TestRequest::get().uri("/slow").to_request();
235        let _resp1_handle =
236            task::spawn_local(async move { test::call_service(&*app_clone, req1).await });
237
238        // Wait for confirmation that the permit was acquired
239        rx.recv().await.unwrap();
240
241        // Try to make another request
242        let req2 = test::TestRequest::get().uri("/slow").to_request();
243        let resp2_result = test::try_call_service(&*app, req2).await;
244
245        assert!(resp2_result.is_err());
246        let err = resp2_result.unwrap_err();
247        assert_eq!(err.as_response_error().status_code().as_u16(), 429);
248
249        // Check error response body contains expected JSON
250        let error_str = format!("{err}");
251        assert!(error_str.contains("Too many concurrent requests"));
252        assert!(error_str.contains("max_concurrent"));
253
254        // Verify it's valid JSON
255        let json: serde_json::Value = serde_json::from_str(&error_str).unwrap();
256        assert_eq!(json["max_concurrent"], 1);
257    }
258
259    #[actix_rt::test]
260    async fn test_concurrency_limiter_handles_multiple_endpoints() {
261        use tokio::sync::mpsc;
262        use tokio::task;
263
264        let (tx, mut rx) = mpsc::unbounded_channel();
265        let app = Rc::new(
266            test::init_service(
267                App::new()
268                    .wrap(ConcurrencyLimiter::new(1))
269                    .route(
270                        "/endpoint1",
271                        web::get().to(move || {
272                            let tx = tx.clone();
273                            async move {
274                                let _ = tx.send(());
275                                sleep(Duration::from_millis(100)).await;
276                                HttpResponse::Ok().body("ok1")
277                            }
278                        }),
279                    )
280                    .route(
281                        "/endpoint2",
282                        web::get().to(|| async { HttpResponse::Ok().body("ok2") }),
283                    ),
284            )
285            .await,
286        );
287
288        // Start request to endpoint1 that holds the permit
289        let app_clone = Rc::clone(&app);
290        let req1 = test::TestRequest::get().uri("/endpoint1").to_request();
291        let resp1_handle =
292            task::spawn_local(async move { test::call_service(&*app_clone, req1).await });
293
294        // Wait for endpoint1 to acquire the permit
295        rx.recv().await.unwrap();
296
297        // Try to make concurrent request to endpoint2 - should be rejected (global limit)
298        let req2 = test::TestRequest::get().uri("/endpoint2").to_request();
299        let resp2_result = test::try_call_service(&*app, req2).await;
300        assert!(
301            resp2_result.is_err(),
302            "Global limit should apply across endpoints"
303        );
304        let err = resp2_result.unwrap_err();
305        assert_eq!(err.as_response_error().status_code().as_u16(), 429);
306
307        // Wait for first request to complete
308        let resp1 = resp1_handle.await.unwrap();
309        assert!(resp1.status().is_success());
310
311        // Now endpoint2 should succeed
312        let req3 = test::TestRequest::get().uri("/endpoint2").to_request();
313        let resp3 = test::call_service(&*app, req3).await;
314        assert!(resp3.status().is_success());
315    }
316
317    #[actix_rt::test]
318    async fn test_concurrency_limiter_with_zero_limit() {
319        let app = test::init_service(App::new().wrap(ConcurrencyLimiter::new(0)).route(
320            "/test",
321            web::get().to(|| async { HttpResponse::Ok().body("ok") }),
322        ))
323        .await;
324
325        // With limit of 0, all requests should be rejected
326        let req = test::TestRequest::get().uri("/test").to_request();
327        let resp_result = test::try_call_service(&app, req).await;
328
329        assert!(resp_result.is_err());
330        let err = resp_result.unwrap_err();
331        assert_eq!(err.as_response_error().status_code().as_u16(), 429);
332    }
333
334    #[actix_rt::test]
335    async fn test_concurrency_limiter_metrics_tracking() {
336        use tokio::sync::mpsc;
337        use tokio::sync::Barrier;
338        use tokio::task;
339
340        // Reset metrics before test
341        IN_FLIGHT_REQUESTS
342            .with_label_values(&["/metrics-test"])
343            .set(0.0);
344
345        let (tx, mut rx) = mpsc::unbounded_channel();
346        let barrier = Rc::new(Barrier::new(2));
347        let barrier_clone = Rc::clone(&barrier);
348
349        let app = Rc::new(
350            test::init_service(App::new().wrap(ConcurrencyLimiter::new(2)).route(
351                "/metrics-test",
352                web::get().to(move || {
353                    let tx = tx.clone();
354                    let barrier = barrier_clone.clone();
355                    async move {
356                        let _ = tx.send(());
357                        barrier.wait().await;
358                        HttpResponse::Ok().body("ok")
359                    }
360                }),
361            ))
362            .await,
363        );
364
365        let initial_value = IN_FLIGHT_REQUESTS
366            .with_label_values(&["/metrics-test"])
367            .get();
368        assert_eq!(initial_value, 0.0, "Should start at 0");
369
370        // Start a request that will wait
371        let app_clone = Rc::clone(&app);
372        let req1 = test::TestRequest::get().uri("/metrics-test").to_request();
373        let handle = task::spawn_local(async move { test::call_service(&*app_clone, req1).await });
374
375        // Wait until request is processing
376        rx.recv().await.unwrap();
377
378        // Metric should be incremented while request is in flight
379        let in_flight = IN_FLIGHT_REQUESTS
380            .with_label_values(&["/metrics-test"])
381            .get();
382        assert_eq!(in_flight, 1.0, "Should be 1 while request is in flight");
383
384        // Release the barrier to let request complete
385        barrier.wait().await;
386
387        // Wait for request to finish
388        let resp = handle.await.unwrap();
389        assert!(resp.status().is_success());
390
391        // After completion, should be back to 0
392        let final_value = IN_FLIGHT_REQUESTS
393            .with_label_values(&["/metrics-test"])
394            .get();
395        assert_eq!(final_value, 0.0, "Should be back to 0 after completion");
396    }
397
398    #[actix_rt::test]
399    async fn test_permits_released_on_handler_error() {
400        let app = Rc::new(
401            test::init_service(App::new().wrap(ConcurrencyLimiter::new(1)).route(
402                "/error",
403                web::get().to(|| async { HttpResponse::InternalServerError().body("error") }),
404            ))
405            .await,
406        );
407
408        // Make request that returns error
409        let req1 = test::TestRequest::get().uri("/error").to_request();
410        let resp1 = test::call_service(&*app, req1).await;
411        assert_eq!(resp1.status(), 500);
412
413        // Permit should be released, so next request should succeed
414        let req2 = test::TestRequest::get().uri("/error").to_request();
415        let resp2 = test::call_service(&*app, req2).await;
416        assert_eq!(resp2.status(), 500);
417    }
418
419    #[actix_rt::test]
420    async fn test_concurrent_requests_up_to_limit() {
421        use tokio::sync::mpsc;
422        use tokio::task;
423
424        let (start_tx, mut start_rx) = mpsc::unbounded_channel();
425
426        let app = Rc::new(
427            test::init_service(App::new().wrap(ConcurrencyLimiter::new(3)).route(
428                "/test",
429                web::get().to(move || {
430                    let tx = start_tx.clone();
431                    async move {
432                        let _ = tx.send(());
433                        HttpResponse::Ok().body("ok")
434                    }
435                }),
436            ))
437            .await,
438        );
439
440        // Spawn 3 concurrent requests (at the limit)
441        let app1 = Rc::clone(&app);
442        let app2 = Rc::clone(&app);
443        let app3 = Rc::clone(&app);
444
445        let h1 = task::spawn_local(async move {
446            test::call_service(&*app1, test::TestRequest::get().uri("/test").to_request()).await
447        });
448        let h2 = task::spawn_local(async move {
449            test::call_service(&*app2, test::TestRequest::get().uri("/test").to_request()).await
450        });
451
452        let h3 = task::spawn_local(async move {
453            test::call_service(&*app3, test::TestRequest::get().uri("/test").to_request()).await
454        });
455
456        // Wait for all 3 requests to start
457        let mut count = 0;
458        while count < 3 {
459            if start_rx.recv().await.is_some() {
460                count += 1;
461            }
462        }
463
464        let r1 = h1.await.unwrap();
465        let r2 = h2.await.unwrap();
466        let r3 = h3.await.unwrap();
467
468        assert!(r1.status().is_success());
469        assert!(r2.status().is_success());
470        assert!(r3.status().is_success());
471    }
472
473    #[actix_rt::test]
474    async fn test_metric_tracking_multiple_endpoints() {
475        use tokio::sync::mpsc;
476
477        // Reset metrics
478        IN_FLIGHT_REQUESTS
479            .with_label_values(&["/endpoint-a"])
480            .set(0.0);
481        IN_FLIGHT_REQUESTS
482            .with_label_values(&["/endpoint-b"])
483            .set(0.0);
484
485        let (tx_a, mut rx_a) = mpsc::unbounded_channel();
486        let (tx_b, mut rx_b) = mpsc::unbounded_channel();
487
488        let app = Rc::new(
489            test::init_service(
490                App::new()
491                    .wrap(ConcurrencyLimiter::new(2))
492                    .route(
493                        "/endpoint-a",
494                        web::get().to(move || {
495                            let tx = tx_a.clone();
496                            async move {
497                                let _ = tx.send(());
498                                sleep(Duration::from_millis(100)).await;
499                                HttpResponse::Ok().body("a")
500                            }
501                        }),
502                    )
503                    .route(
504                        "/endpoint-b",
505                        web::get().to(move || {
506                            let tx = tx_b.clone();
507                            async move {
508                                let _ = tx.send(());
509                                sleep(Duration::from_millis(100)).await;
510                                HttpResponse::Ok().body("b")
511                            }
512                        }),
513                    ),
514            )
515            .await,
516        );
517
518        // Start request to endpoint-a
519        let app_clone_a = Rc::clone(&app);
520        let _handle_a = tokio::task::spawn_local(async move {
521            test::call_service(
522                &*app_clone_a,
523                test::TestRequest::get().uri("/endpoint-a").to_request(),
524            )
525            .await
526        });
527        rx_a.recv().await.unwrap();
528
529        // Start request to endpoint-b
530        let app_clone_b = Rc::clone(&app);
531        let _handle_b = tokio::task::spawn_local(async move {
532            test::call_service(
533                &*app_clone_b,
534                test::TestRequest::get().uri("/endpoint-b").to_request(),
535            )
536            .await
537        });
538        rx_b.recv().await.unwrap();
539
540        // Both metrics should be tracked separately
541        let metric_a = IN_FLIGHT_REQUESTS.with_label_values(&["/endpoint-a"]).get();
542        let metric_b = IN_FLIGHT_REQUESTS.with_label_values(&["/endpoint-b"]).get();
543
544        assert_eq!(metric_a, 1.0, "Endpoint-a should have 1 in-flight request");
545        assert_eq!(metric_b, 1.0, "Endpoint-b should have 1 in-flight request");
546    }
547
548    #[actix_rt::test]
549    async fn test_429_response_includes_correct_limit() {
550        use tokio::sync::mpsc;
551        use tokio::task;
552
553        let (tx, mut rx) = mpsc::unbounded_channel();
554        let max_limit = 3;
555
556        let app = Rc::new(
557            test::init_service(App::new().wrap(ConcurrencyLimiter::new(max_limit)).route(
558                "/limited",
559                web::get().to(move || {
560                    let tx = tx.clone();
561                    async move {
562                        let _ = tx.send(());
563                        sleep(Duration::from_secs(10)).await;
564                        HttpResponse::Ok().body("ok")
565                    }
566                }),
567            ))
568            .await,
569        );
570
571        // Fill all slots
572        for _ in 0..max_limit {
573            let app_clone = Rc::clone(&app);
574            task::spawn_local(async move {
575                test::call_service(
576                    &*app_clone,
577                    test::TestRequest::get().uri("/limited").to_request(),
578                )
579                .await
580            });
581            rx.recv().await.unwrap();
582        }
583
584        // Try to exceed limit
585        let req = test::TestRequest::get().uri("/limited").to_request();
586        let result = test::try_call_service(&*app, req).await;
587
588        assert!(result.is_err());
589        let err = result.unwrap_err();
590        let error_str = format!("{err}");
591
592        // Verify the limit value is reported correctly
593        assert!(error_str.contains(&max_limit.to_string()));
594        let json: serde_json::Value = serde_json::from_str(&error_str).unwrap();
595        assert_eq!(json["max_concurrent"].as_u64(), Some(max_limit as u64));
596    }
597
598    #[actix_rt::test]
599    async fn test_permits_released_after_panic_handling() {
600        // Note: This tests that cleanup happens even in edge cases
601        // The middleware should always clean up the permit and decrement metrics
602        let app = test::init_service(App::new().wrap(ConcurrencyLimiter::new(1)).route(
603            "/ok",
604            web::get().to(|| async { HttpResponse::Ok().body("ok") }),
605        ))
606        .await;
607
608        // Make successful request
609        let req1 = test::TestRequest::get().uri("/ok").to_request();
610        let resp1 = test::call_service(&app, req1).await;
611        assert!(resp1.status().is_success());
612
613        // Metrics should be back to 0
614        let metric1 = IN_FLIGHT_REQUESTS.with_label_values(&["/ok"]).get();
615        assert_eq!(
616            metric1, 0.0,
617            "Metrics should be decremented after completion"
618        );
619
620        // Next request should succeed (permit was released)
621        let req2 = test::TestRequest::get().uri("/ok").to_request();
622        let resp2 = test::call_service(&app, req2).await;
623        assert!(resp2.status().is_success());
624
625        let metric2 = IN_FLIGHT_REQUESTS.with_label_values(&["/ok"]).get();
626        assert_eq!(metric2, 0.0, "Metrics should remain clean");
627    }
628}