openzeppelin_relayer/api/middleware/
concurrency.rs1use actix_web::{
2 body::MessageBody,
3 dev::{Service, ServiceRequest, ServiceResponse, Transform},
4 Error,
5};
6use futures::future::{ready, LocalBoxFuture, Ready};
7use std::{
8 rc::Rc,
9 sync::Arc,
10 task::{Context, Poll},
11};
12use tokio::sync::Semaphore;
13
14use crate::metrics::IN_FLIGHT_REQUESTS;
15
16pub struct ConcurrencyLimiter {
18 semaphore: Arc<Semaphore>,
19 max_permits: usize,
20}
21
22impl ConcurrencyLimiter {
23 pub fn new(max_concurrent: usize) -> Self {
24 Self {
25 semaphore: Arc::new(Semaphore::new(max_concurrent)),
26 max_permits: max_concurrent,
27 }
28 }
29}
30
31impl<S, B> Transform<S, ServiceRequest> for ConcurrencyLimiter
32where
33 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
34 S::Future: 'static,
35 B: MessageBody + 'static,
36{
37 type Response = ServiceResponse<B>;
38 type Error = Error;
39 type Transform = ConcurrencyLimiterService<S>;
40 type InitError = ();
41 type Future = Ready<Result<Self::Transform, Self::InitError>>;
42
43 fn new_transform(&self, service: S) -> Self::Future {
44 ready(Ok(ConcurrencyLimiterService {
45 service: Rc::new(service),
46 semaphore: Arc::clone(&self.semaphore),
47 max_permits: self.max_permits,
48 }))
49 }
50}
51
52pub struct ConcurrencyLimiterService<S> {
53 service: Rc<S>,
54 semaphore: Arc<Semaphore>,
55 max_permits: usize,
56}
57
58impl<S, B> Service<ServiceRequest> for ConcurrencyLimiterService<S>
59where
60 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
61 S::Future: 'static,
62 B: MessageBody + 'static,
63{
64 type Response = ServiceResponse<B>;
65 type Error = Error;
66 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
67
68 fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
69 self.service.poll_ready(cx)
70 }
71
72 fn call(&self, req: ServiceRequest) -> Self::Future {
73 let service = Rc::clone(&self.service);
74 let semaphore = Arc::clone(&self.semaphore);
75 let max_permits = self.max_permits;
76
77 Box::pin(async move {
78 let endpoint = req.path().to_string();
79
80 let permit = match semaphore.try_acquire() {
82 Ok(permit) => permit,
83 Err(_) => {
84 let current_in_flight = max_permits;
86 tracing::warn!(
87 "Concurrency limit reached for {}: {} requests in flight",
88 endpoint,
89 current_in_flight
90 );
91
92 return Err(actix_web::error::ErrorTooManyRequests(
93 serde_json::json!({
94 "error": "Too many concurrent requests",
95 "max_concurrent": max_permits,
96 })
97 .to_string(),
98 ));
99 }
100 };
101
102 IN_FLIGHT_REQUESTS.with_label_values(&[&endpoint]).inc();
104
105 let result = service.call(req).await;
107
108 IN_FLIGHT_REQUESTS.with_label_values(&[&endpoint]).dec();
110
111 drop(permit);
113
114 result
115 })
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use actix_web::{test, web, App, HttpResponse};
123 use std::rc::Rc;
124 use std::time::Duration;
125 use tokio::time::sleep;
126
127 #[actix_rt::test]
128 async fn test_concurrency_limiter_allows_requests_within_limit() {
129 let app = test::init_service(App::new().wrap(ConcurrencyLimiter::new(2)).route(
130 "/test",
131 web::get().to(|| async { HttpResponse::Ok().body("ok") }),
132 ))
133 .await;
134
135 let req1 = test::TestRequest::get().uri("/test").to_request();
137 let req2 = test::TestRequest::get().uri("/test").to_request();
138
139 let resp1 = test::call_service(&app, req1).await;
140 let resp2 = test::call_service(&app, req2).await;
141
142 assert!(resp1.status().is_success());
143 assert!(resp2.status().is_success());
144 }
145
146 #[actix_rt::test]
147 async fn test_concurrency_limiter_rejects_excess_requests() {
148 use tokio::sync::mpsc;
149 use tokio::task;
150
151 let (tx, mut rx) = mpsc::unbounded_channel();
152 let app = Rc::new(
153 test::init_service(App::new().wrap(ConcurrencyLimiter::new(1)).route(
154 "/slow",
155 web::get().to(move || {
156 let tx = tx.clone();
157 async move {
158 let _ = tx.send(());
160 sleep(Duration::from_millis(100)).await;
161 HttpResponse::Ok().body("ok")
162 }
163 }),
164 ))
165 .await,
166 );
167
168 let app_clone = Rc::clone(&app);
170 let req1 = test::TestRequest::get().uri("/slow").to_request();
171 let resp1_handle =
172 task::spawn_local(async move { test::call_service(&*app_clone, req1).await });
173
174 rx.recv().await.unwrap();
176
177 let req2 = test::TestRequest::get().uri("/slow").to_request();
179 let resp2_result = test::try_call_service(&*app, req2).await;
180
181 assert!(resp2_result.is_err());
183 let err = resp2_result.unwrap_err();
184 assert_eq!(err.as_response_error().status_code().as_u16(), 429);
185
186 let resp1 = resp1_handle.await.unwrap();
188 assert!(resp1.status().is_success());
189 }
190
191 #[actix_rt::test]
192 async fn test_concurrency_limiter_releases_permits_after_completion() {
193 let app = test::init_service(App::new().wrap(ConcurrencyLimiter::new(1)).route(
194 "/test",
195 web::get().to(|| async { HttpResponse::Ok().body("ok") }),
196 ))
197 .await;
198
199 let req1 = test::TestRequest::get().uri("/test").to_request();
201 let resp1 = test::call_service(&app, req1).await;
202 assert!(resp1.status().is_success());
203
204 let req2 = test::TestRequest::get().uri("/test").to_request();
207 let resp2 = test::call_service(&app, req2).await;
208 assert!(resp2.status().is_success());
209 }
210
211 #[actix_rt::test]
212 async fn test_concurrency_limiter_error_response_format() {
213 use tokio::sync::mpsc;
214 use tokio::task;
215
216 let (tx, mut rx) = mpsc::unbounded_channel();
217 let app = Rc::new(
218 test::init_service(App::new().wrap(ConcurrencyLimiter::new(1)).route(
219 "/slow",
220 web::get().to(move || {
221 let tx = tx.clone();
222 async move {
223 let _ = tx.send(());
224 sleep(Duration::from_millis(100)).await;
225 HttpResponse::Ok().body("ok")
226 }
227 }),
228 ))
229 .await,
230 );
231
232 let app_clone = Rc::clone(&app);
234 let req1 = test::TestRequest::get().uri("/slow").to_request();
235 let _resp1_handle =
236 task::spawn_local(async move { test::call_service(&*app_clone, req1).await });
237
238 rx.recv().await.unwrap();
240
241 let req2 = test::TestRequest::get().uri("/slow").to_request();
243 let resp2_result = test::try_call_service(&*app, req2).await;
244
245 assert!(resp2_result.is_err());
246 let err = resp2_result.unwrap_err();
247 assert_eq!(err.as_response_error().status_code().as_u16(), 429);
248
249 let error_str = format!("{err}");
251 assert!(error_str.contains("Too many concurrent requests"));
252 assert!(error_str.contains("max_concurrent"));
253
254 let json: serde_json::Value = serde_json::from_str(&error_str).unwrap();
256 assert_eq!(json["max_concurrent"], 1);
257 }
258
259 #[actix_rt::test]
260 async fn test_concurrency_limiter_handles_multiple_endpoints() {
261 use tokio::sync::mpsc;
262 use tokio::task;
263
264 let (tx, mut rx) = mpsc::unbounded_channel();
265 let app = Rc::new(
266 test::init_service(
267 App::new()
268 .wrap(ConcurrencyLimiter::new(1))
269 .route(
270 "/endpoint1",
271 web::get().to(move || {
272 let tx = tx.clone();
273 async move {
274 let _ = tx.send(());
275 sleep(Duration::from_millis(100)).await;
276 HttpResponse::Ok().body("ok1")
277 }
278 }),
279 )
280 .route(
281 "/endpoint2",
282 web::get().to(|| async { HttpResponse::Ok().body("ok2") }),
283 ),
284 )
285 .await,
286 );
287
288 let app_clone = Rc::clone(&app);
290 let req1 = test::TestRequest::get().uri("/endpoint1").to_request();
291 let resp1_handle =
292 task::spawn_local(async move { test::call_service(&*app_clone, req1).await });
293
294 rx.recv().await.unwrap();
296
297 let req2 = test::TestRequest::get().uri("/endpoint2").to_request();
299 let resp2_result = test::try_call_service(&*app, req2).await;
300 assert!(
301 resp2_result.is_err(),
302 "Global limit should apply across endpoints"
303 );
304 let err = resp2_result.unwrap_err();
305 assert_eq!(err.as_response_error().status_code().as_u16(), 429);
306
307 let resp1 = resp1_handle.await.unwrap();
309 assert!(resp1.status().is_success());
310
311 let req3 = test::TestRequest::get().uri("/endpoint2").to_request();
313 let resp3 = test::call_service(&*app, req3).await;
314 assert!(resp3.status().is_success());
315 }
316
317 #[actix_rt::test]
318 async fn test_concurrency_limiter_with_zero_limit() {
319 let app = test::init_service(App::new().wrap(ConcurrencyLimiter::new(0)).route(
320 "/test",
321 web::get().to(|| async { HttpResponse::Ok().body("ok") }),
322 ))
323 .await;
324
325 let req = test::TestRequest::get().uri("/test").to_request();
327 let resp_result = test::try_call_service(&app, req).await;
328
329 assert!(resp_result.is_err());
330 let err = resp_result.unwrap_err();
331 assert_eq!(err.as_response_error().status_code().as_u16(), 429);
332 }
333
334 #[actix_rt::test]
335 async fn test_concurrency_limiter_metrics_tracking() {
336 use tokio::sync::mpsc;
337 use tokio::sync::Barrier;
338 use tokio::task;
339
340 IN_FLIGHT_REQUESTS
342 .with_label_values(&["/metrics-test"])
343 .set(0.0);
344
345 let (tx, mut rx) = mpsc::unbounded_channel();
346 let barrier = Rc::new(Barrier::new(2));
347 let barrier_clone = Rc::clone(&barrier);
348
349 let app = Rc::new(
350 test::init_service(App::new().wrap(ConcurrencyLimiter::new(2)).route(
351 "/metrics-test",
352 web::get().to(move || {
353 let tx = tx.clone();
354 let barrier = barrier_clone.clone();
355 async move {
356 let _ = tx.send(());
357 barrier.wait().await;
358 HttpResponse::Ok().body("ok")
359 }
360 }),
361 ))
362 .await,
363 );
364
365 let initial_value = IN_FLIGHT_REQUESTS
366 .with_label_values(&["/metrics-test"])
367 .get();
368 assert_eq!(initial_value, 0.0, "Should start at 0");
369
370 let app_clone = Rc::clone(&app);
372 let req1 = test::TestRequest::get().uri("/metrics-test").to_request();
373 let handle = task::spawn_local(async move { test::call_service(&*app_clone, req1).await });
374
375 rx.recv().await.unwrap();
377
378 let in_flight = IN_FLIGHT_REQUESTS
380 .with_label_values(&["/metrics-test"])
381 .get();
382 assert_eq!(in_flight, 1.0, "Should be 1 while request is in flight");
383
384 barrier.wait().await;
386
387 let resp = handle.await.unwrap();
389 assert!(resp.status().is_success());
390
391 let final_value = IN_FLIGHT_REQUESTS
393 .with_label_values(&["/metrics-test"])
394 .get();
395 assert_eq!(final_value, 0.0, "Should be back to 0 after completion");
396 }
397
398 #[actix_rt::test]
399 async fn test_permits_released_on_handler_error() {
400 let app = Rc::new(
401 test::init_service(App::new().wrap(ConcurrencyLimiter::new(1)).route(
402 "/error",
403 web::get().to(|| async { HttpResponse::InternalServerError().body("error") }),
404 ))
405 .await,
406 );
407
408 let req1 = test::TestRequest::get().uri("/error").to_request();
410 let resp1 = test::call_service(&*app, req1).await;
411 assert_eq!(resp1.status(), 500);
412
413 let req2 = test::TestRequest::get().uri("/error").to_request();
415 let resp2 = test::call_service(&*app, req2).await;
416 assert_eq!(resp2.status(), 500);
417 }
418
419 #[actix_rt::test]
420 async fn test_concurrent_requests_up_to_limit() {
421 use tokio::sync::mpsc;
422 use tokio::task;
423
424 let (start_tx, mut start_rx) = mpsc::unbounded_channel();
425
426 let app = Rc::new(
427 test::init_service(App::new().wrap(ConcurrencyLimiter::new(3)).route(
428 "/test",
429 web::get().to(move || {
430 let tx = start_tx.clone();
431 async move {
432 let _ = tx.send(());
433 HttpResponse::Ok().body("ok")
434 }
435 }),
436 ))
437 .await,
438 );
439
440 let app1 = Rc::clone(&app);
442 let app2 = Rc::clone(&app);
443 let app3 = Rc::clone(&app);
444
445 let h1 = task::spawn_local(async move {
446 test::call_service(&*app1, test::TestRequest::get().uri("/test").to_request()).await
447 });
448 let h2 = task::spawn_local(async move {
449 test::call_service(&*app2, test::TestRequest::get().uri("/test").to_request()).await
450 });
451
452 let h3 = task::spawn_local(async move {
453 test::call_service(&*app3, test::TestRequest::get().uri("/test").to_request()).await
454 });
455
456 let mut count = 0;
458 while count < 3 {
459 if start_rx.recv().await.is_some() {
460 count += 1;
461 }
462 }
463
464 let r1 = h1.await.unwrap();
465 let r2 = h2.await.unwrap();
466 let r3 = h3.await.unwrap();
467
468 assert!(r1.status().is_success());
469 assert!(r2.status().is_success());
470 assert!(r3.status().is_success());
471 }
472
473 #[actix_rt::test]
474 async fn test_metric_tracking_multiple_endpoints() {
475 use tokio::sync::mpsc;
476
477 IN_FLIGHT_REQUESTS
479 .with_label_values(&["/endpoint-a"])
480 .set(0.0);
481 IN_FLIGHT_REQUESTS
482 .with_label_values(&["/endpoint-b"])
483 .set(0.0);
484
485 let (tx_a, mut rx_a) = mpsc::unbounded_channel();
486 let (tx_b, mut rx_b) = mpsc::unbounded_channel();
487
488 let app = Rc::new(
489 test::init_service(
490 App::new()
491 .wrap(ConcurrencyLimiter::new(2))
492 .route(
493 "/endpoint-a",
494 web::get().to(move || {
495 let tx = tx_a.clone();
496 async move {
497 let _ = tx.send(());
498 sleep(Duration::from_millis(100)).await;
499 HttpResponse::Ok().body("a")
500 }
501 }),
502 )
503 .route(
504 "/endpoint-b",
505 web::get().to(move || {
506 let tx = tx_b.clone();
507 async move {
508 let _ = tx.send(());
509 sleep(Duration::from_millis(100)).await;
510 HttpResponse::Ok().body("b")
511 }
512 }),
513 ),
514 )
515 .await,
516 );
517
518 let app_clone_a = Rc::clone(&app);
520 let _handle_a = tokio::task::spawn_local(async move {
521 test::call_service(
522 &*app_clone_a,
523 test::TestRequest::get().uri("/endpoint-a").to_request(),
524 )
525 .await
526 });
527 rx_a.recv().await.unwrap();
528
529 let app_clone_b = Rc::clone(&app);
531 let _handle_b = tokio::task::spawn_local(async move {
532 test::call_service(
533 &*app_clone_b,
534 test::TestRequest::get().uri("/endpoint-b").to_request(),
535 )
536 .await
537 });
538 rx_b.recv().await.unwrap();
539
540 let metric_a = IN_FLIGHT_REQUESTS.with_label_values(&["/endpoint-a"]).get();
542 let metric_b = IN_FLIGHT_REQUESTS.with_label_values(&["/endpoint-b"]).get();
543
544 assert_eq!(metric_a, 1.0, "Endpoint-a should have 1 in-flight request");
545 assert_eq!(metric_b, 1.0, "Endpoint-b should have 1 in-flight request");
546 }
547
548 #[actix_rt::test]
549 async fn test_429_response_includes_correct_limit() {
550 use tokio::sync::mpsc;
551 use tokio::task;
552
553 let (tx, mut rx) = mpsc::unbounded_channel();
554 let max_limit = 3;
555
556 let app = Rc::new(
557 test::init_service(App::new().wrap(ConcurrencyLimiter::new(max_limit)).route(
558 "/limited",
559 web::get().to(move || {
560 let tx = tx.clone();
561 async move {
562 let _ = tx.send(());
563 sleep(Duration::from_secs(10)).await;
564 HttpResponse::Ok().body("ok")
565 }
566 }),
567 ))
568 .await,
569 );
570
571 for _ in 0..max_limit {
573 let app_clone = Rc::clone(&app);
574 task::spawn_local(async move {
575 test::call_service(
576 &*app_clone,
577 test::TestRequest::get().uri("/limited").to_request(),
578 )
579 .await
580 });
581 rx.recv().await.unwrap();
582 }
583
584 let req = test::TestRequest::get().uri("/limited").to_request();
586 let result = test::try_call_service(&*app, req).await;
587
588 assert!(result.is_err());
589 let err = result.unwrap_err();
590 let error_str = format!("{err}");
591
592 assert!(error_str.contains(&max_limit.to_string()));
594 let json: serde_json::Value = serde_json::from_str(&error_str).unwrap();
595 assert_eq!(json["max_concurrent"].as_u64(), Some(max_limit as u64));
596 }
597
598 #[actix_rt::test]
599 async fn test_permits_released_after_panic_handling() {
600 let app = test::init_service(App::new().wrap(ConcurrencyLimiter::new(1)).route(
603 "/ok",
604 web::get().to(|| async { HttpResponse::Ok().body("ok") }),
605 ))
606 .await;
607
608 let req1 = test::TestRequest::get().uri("/ok").to_request();
610 let resp1 = test::call_service(&app, req1).await;
611 assert!(resp1.status().is_success());
612
613 let metric1 = IN_FLIGHT_REQUESTS.with_label_values(&["/ok"]).get();
615 assert_eq!(
616 metric1, 0.0,
617 "Metrics should be decremented after completion"
618 );
619
620 let req2 = test::TestRequest::get().uri("/ok").to_request();
622 let resp2 = test::call_service(&app, req2).await;
623 assert!(resp2.status().is_success());
624
625 let metric2 = IN_FLIGHT_REQUESTS.with_label_values(&["/ok"]).get();
626 assert_eq!(metric2, 0.0, "Metrics should remain clean");
627 }
628}