1use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tokio::net::UnixStream;
13use tokio::sync::Semaphore;
14
15use super::config::get_config;
16use super::protocol::{PoolRequest, PoolResponse};
17use super::PluginError;
18
19pub struct PoolConnection {
21 stream: UnixStream,
22 id: usize,
24}
25
26impl PoolConnection {
27 pub async fn new(socket_path: &str, id: usize) -> Result<Self, PluginError> {
28 let max_attempts = get_config().pool_connect_retries;
29 let mut attempts = 0;
30 let mut delay_ms = 10u64;
31
32 tracing::debug!(connection_id = id, socket_path = %socket_path, "Connecting to pool server");
33
34 loop {
35 match UnixStream::connect(socket_path).await {
36 Ok(stream) => {
37 if attempts > 0 {
38 tracing::debug!(
39 connection_id = id,
40 attempts = attempts,
41 "Connected to pool server after retries"
42 );
43 }
44 return Ok(Self { stream, id });
45 }
46 Err(e) => {
47 attempts += 1;
48
49 if attempts >= max_attempts {
50 return Err(PluginError::SocketError(format!(
51 "Failed to connect to pool after {max_attempts} attempts: {e}. \
52 Consider increasing PLUGIN_POOL_CONNECT_RETRIES or PLUGIN_POOL_MAX_CONNECTIONS."
53 )));
54 }
55
56 if attempts <= 3 || attempts % 5 == 0 {
57 tracing::debug!(
58 connection_id = id,
59 attempt = attempts,
60 max_attempts = max_attempts,
61 delay_ms = delay_ms,
62 "Retrying connection to pool server"
63 );
64 }
65
66 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
67 delay_ms = std::cmp::min(delay_ms * 2, 1000);
68 }
69 }
70 }
71 }
72
73 pub async fn send_request(
74 &mut self,
75 request: &PoolRequest,
76 ) -> Result<PoolResponse, PluginError> {
77 let request_task_id = Self::extract_task_id(request);
79
80 let json = serde_json::to_string(request)
81 .map_err(|e| PluginError::PluginError(format!("Failed to serialize request: {e}")))?;
82
83 if let Err(e) = self.stream.write_all(format!("{json}\n").as_bytes()).await {
84 return Err(PluginError::SocketError(format!(
85 "Failed to send request: {e}"
86 )));
87 }
88
89 if let Err(e) = self.stream.flush().await {
90 return Err(PluginError::SocketError(format!(
91 "Failed to flush request: {e}"
92 )));
93 }
94
95 let mut reader = BufReader::new(&mut self.stream);
96 let mut line = String::new();
97
98 if let Err(e) = reader.read_line(&mut line).await {
99 return Err(PluginError::SocketError(format!(
100 "Failed to read response: {e}"
101 )));
102 }
103
104 tracing::debug!(response_len = line.len(), "Received response from pool");
105
106 let response: PoolResponse = serde_json::from_str(&line)
107 .map_err(|e| PluginError::PluginError(format!("Failed to parse response: {e}")))?;
108
109 if response.task_id != request_task_id {
111 tracing::error!(
112 request_task_id = %request_task_id,
113 response_task_id = %response.task_id,
114 connection_id = self.id,
115 "Response task_id mismatch"
116 );
117 return Err(PluginError::PluginError(
118 "Internal plugin error: response task_id mismatch".to_string(),
119 ));
120 }
121
122 Ok(response)
123 }
124
125 fn extract_task_id(request: &PoolRequest) -> String {
127 match request {
128 PoolRequest::Execute(req) => req.task_id.clone(),
129 PoolRequest::Precompile { task_id, .. } => task_id.clone(),
130 PoolRequest::Cache { task_id, .. } => task_id.clone(),
131 PoolRequest::Invalidate { task_id, .. } => task_id.clone(),
132 PoolRequest::Stats { task_id } => task_id.clone(),
133 PoolRequest::Health { task_id } => task_id.clone(),
134 PoolRequest::Shutdown { task_id } => task_id.clone(),
135 }
136 }
137
138 pub async fn send_request_with_timeout(
139 &mut self,
140 request: &PoolRequest,
141 timeout_secs: u64,
142 ) -> Result<PoolResponse, PluginError> {
143 tokio::time::timeout(
144 Duration::from_secs(timeout_secs),
145 self.send_request(request),
146 )
147 .await
148 .map_err(|_| PluginError::SocketError("Request timed out".to_string()))?
149 }
150
151 pub fn id(&self) -> usize {
153 self.id
154 }
155}
156
157pub struct ConnectionPool {
162 socket_path: String,
163 #[allow(dead_code)]
165 max_connections: usize,
166 next_id: Arc<AtomicUsize>,
168 pub semaphore: Arc<Semaphore>,
170}
171
172impl ConnectionPool {
173 pub fn new(socket_path: String, max_connections: usize) -> Self {
174 Self {
175 socket_path,
176 max_connections,
177 next_id: Arc::new(AtomicUsize::new(0)),
178 semaphore: Arc::new(Semaphore::new(max_connections)),
179 }
180 }
181
182 pub async fn acquire_with_permit(
186 &self,
187 permit: Option<tokio::sync::OwnedSemaphorePermit>,
188 ) -> Result<PooledConnection<'_>, PluginError> {
189 let permit = match permit {
190 Some(p) => p,
191 None => {
192 let available_permits = self.semaphore.available_permits();
193 if available_permits == 0 {
194 tracing::warn!(
195 max_connections = self.max_connections,
196 "All connection permits exhausted - waiting for connection"
197 );
198 }
199 self.semaphore.clone().acquire_owned().await.map_err(|_| {
200 PluginError::PluginError("Connection semaphore closed".to_string())
201 })?
202 }
203 };
204
205 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
206 tracing::debug!(connection_id = id, "Creating connection");
207
208 let conn = PoolConnection::new(&self.socket_path, id).await?;
209
210 Ok(PooledConnection {
211 conn: Some(conn),
212 pool: self,
213 _permit: permit,
214 })
215 }
216
217 pub async fn acquire(&self) -> Result<PooledConnection<'_>, PluginError> {
219 self.acquire_with_permit(None).await
220 }
221
222 pub fn release(&self, conn: PoolConnection) {
224 let conn_id = conn.id();
225 tracing::debug!(connection_id = conn_id, "Connection closed");
226 drop(conn);
227 }
228
229 pub fn next_connection_id(&self) -> usize {
232 self.next_id.fetch_add(1, Ordering::Relaxed)
233 }
234}
235
236pub struct PooledConnection<'a> {
238 conn: Option<PoolConnection>,
239 pool: &'a ConnectionPool,
240 _permit: tokio::sync::OwnedSemaphorePermit,
242}
243
244impl<'a> PooledConnection<'a> {
245 pub async fn send_request_with_timeout(
246 &mut self,
247 request: &PoolRequest,
248 timeout_secs: u64,
249 ) -> Result<PoolResponse, PluginError> {
250 if let Some(ref mut conn) = self.conn {
251 conn.send_request_with_timeout(request, timeout_secs).await
252 } else {
253 Err(PluginError::PluginError(
254 "Connection already released".to_string(),
255 ))
256 }
257 }
258}
259
260impl<'a> Drop for PooledConnection<'a> {
261 fn drop(&mut self) {
262 if let Some(conn) = self.conn.take() {
263 self.pool.release(conn);
264 }
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use crate::services::plugins::protocol::ExecuteRequest;
272
273 #[test]
278 fn test_connection_pool_creation() {
279 let pool = ConnectionPool::new("/tmp/test.sock".to_string(), 10);
280 assert_eq!(pool.semaphore.available_permits(), 10);
282 }
283
284 #[test]
285 fn test_connection_pool_creation_single_connection() {
286 let pool = ConnectionPool::new("/tmp/single.sock".to_string(), 1);
287 assert_eq!(pool.semaphore.available_permits(), 1);
288 }
289
290 #[test]
291 fn test_connection_pool_creation_large_pool() {
292 let pool = ConnectionPool::new("/tmp/large.sock".to_string(), 1000);
293 assert_eq!(pool.semaphore.available_permits(), 1000);
294 }
295
296 #[test]
297 fn test_connection_pool_stores_socket_path() {
298 let path = "/var/run/custom.sock";
299 let pool = ConnectionPool::new(path.to_string(), 5);
300 assert_eq!(pool.socket_path, path);
301 }
302
303 #[test]
304 fn test_connection_pool_stores_max_connections() {
305 let pool = ConnectionPool::new("/tmp/test.sock".to_string(), 42);
306 assert_eq!(pool.max_connections, 42);
307 }
308
309 #[tokio::test]
314 async fn test_connection_pool_semaphore_limits() {
315 let pool = ConnectionPool::new("/tmp/test.sock".to_string(), 2);
316
317 let permit1 = pool.semaphore.clone().try_acquire_owned();
318 assert!(permit1.is_ok());
319
320 let permit2 = pool.semaphore.clone().try_acquire_owned();
321 assert!(permit2.is_ok());
322
323 let permit3 = pool.semaphore.clone().try_acquire_owned();
325 assert!(permit3.is_err());
326 }
327
328 #[tokio::test]
329 async fn test_semaphore_permit_release_restores_capacity() {
330 let pool = ConnectionPool::new("/tmp/test.sock".to_string(), 2);
331
332 let permit1 = pool.semaphore.clone().try_acquire_owned().unwrap();
334 let permit2 = pool.semaphore.clone().try_acquire_owned().unwrap();
335
336 assert_eq!(pool.semaphore.available_permits(), 0);
338
339 drop(permit1);
341
342 assert_eq!(pool.semaphore.available_permits(), 1);
344
345 let permit3 = pool.semaphore.clone().try_acquire_owned();
347 assert!(permit3.is_ok());
348
349 drop(permit2);
351 drop(permit3.unwrap());
352
353 assert_eq!(pool.semaphore.available_permits(), 2);
355 }
356
357 #[tokio::test]
358 async fn test_semaphore_async_acquire() {
359 let pool = ConnectionPool::new("/tmp/test.sock".to_string(), 1);
360
361 let permit = pool.semaphore.clone().acquire_owned().await;
363 assert!(permit.is_ok());
364 let _permit = permit.unwrap();
365
366 assert_eq!(pool.semaphore.available_permits(), 0);
368 }
369
370 #[test]
375 fn test_connection_id_increment() {
376 let pool = ConnectionPool::new("/tmp/test.sock".to_string(), 10);
377 assert_eq!(pool.next_connection_id(), 0);
378 assert_eq!(pool.next_connection_id(), 1);
379 assert_eq!(pool.next_connection_id(), 2);
380 }
381
382 #[test]
383 fn test_connection_id_starts_at_zero() {
384 let pool = ConnectionPool::new("/tmp/test.sock".to_string(), 10);
385 assert_eq!(pool.next_connection_id(), 0);
386 }
387
388 #[test]
389 fn test_connection_id_monotonically_increasing() {
390 let pool = ConnectionPool::new("/tmp/test.sock".to_string(), 10);
391
392 let mut last_id = pool.next_connection_id();
393 for _ in 0..100 {
394 let current_id = pool.next_connection_id();
395 assert!(
396 current_id > last_id,
397 "IDs should be monotonically increasing"
398 );
399 last_id = current_id;
400 }
401 }
402
403 #[test]
404 fn test_connection_id_thread_safe() {
405 use std::thread;
406
407 let pool = Arc::new(ConnectionPool::new("/tmp/test.sock".to_string(), 100));
408 let mut handles = vec![];
409
410 for _ in 0..10 {
412 let pool_clone = pool.clone();
413 handles.push(thread::spawn(move || {
414 let mut ids = vec![];
415 for _ in 0..100 {
416 ids.push(pool_clone.next_connection_id());
417 }
418 ids
419 }));
420 }
421
422 let mut all_ids: Vec<usize> = handles
424 .into_iter()
425 .flat_map(|h| h.join().unwrap())
426 .collect();
427
428 all_ids.sort();
430 let unique_count = all_ids.windows(2).filter(|w| w[0] != w[1]).count() + 1;
431 assert_eq!(unique_count, all_ids.len(), "All IDs should be unique");
432 }
433
434 #[test]
439 fn test_extract_task_id_from_execute_request() {
440 let request = PoolRequest::Execute(Box::new(ExecuteRequest {
441 task_id: "execute-task-123".to_string(),
442 plugin_id: "test-plugin".to_string(),
443 compiled_code: None,
444 plugin_path: None,
445 params: serde_json::json!({}),
446 headers: None,
447 socket_path: "/tmp/test.sock".to_string(),
448 http_request_id: None,
449 timeout: Some(30000),
450 route: None,
451 config: None,
452 method: None,
453 query: None,
454 }));
455
456 let task_id = PoolConnection::extract_task_id(&request);
457 assert_eq!(task_id, "execute-task-123");
458 }
459
460 #[test]
461 fn test_extract_task_id_from_precompile_request() {
462 let request = PoolRequest::Precompile {
463 task_id: "precompile-task-456".to_string(),
464 plugin_id: "test-plugin".to_string(),
465 plugin_path: Some("/path/to/plugin.ts".to_string()),
466 source_code: None,
467 };
468
469 let task_id = PoolConnection::extract_task_id(&request);
470 assert_eq!(task_id, "precompile-task-456");
471 }
472
473 #[test]
474 fn test_extract_task_id_from_cache_request() {
475 let request = PoolRequest::Cache {
476 task_id: "cache-task-789".to_string(),
477 plugin_id: "test-plugin".to_string(),
478 compiled_code: "compiled code".to_string(),
479 };
480
481 let task_id = PoolConnection::extract_task_id(&request);
482 assert_eq!(task_id, "cache-task-789");
483 }
484
485 #[test]
486 fn test_extract_task_id_from_invalidate_request() {
487 let request = PoolRequest::Invalidate {
488 task_id: "invalidate-task-abc".to_string(),
489 plugin_id: "test-plugin".to_string(),
490 };
491
492 let task_id = PoolConnection::extract_task_id(&request);
493 assert_eq!(task_id, "invalidate-task-abc");
494 }
495
496 #[test]
497 fn test_extract_task_id_from_stats_request() {
498 let request = PoolRequest::Stats {
499 task_id: "stats-task-def".to_string(),
500 };
501
502 let task_id = PoolConnection::extract_task_id(&request);
503 assert_eq!(task_id, "stats-task-def");
504 }
505
506 #[test]
507 fn test_extract_task_id_from_health_request() {
508 let request = PoolRequest::Health {
509 task_id: "health-task-ghi".to_string(),
510 };
511
512 let task_id = PoolConnection::extract_task_id(&request);
513 assert_eq!(task_id, "health-task-ghi");
514 }
515
516 #[test]
517 fn test_extract_task_id_from_shutdown_request() {
518 let request = PoolRequest::Shutdown {
519 task_id: "shutdown-task-jkl".to_string(),
520 };
521
522 let task_id = PoolConnection::extract_task_id(&request);
523 assert_eq!(task_id, "shutdown-task-jkl");
524 }
525
526 #[test]
527 fn test_extract_task_id_preserves_special_characters() {
528 let request = PoolRequest::Stats {
529 task_id: "task-with-special_chars.and/slashes:colons".to_string(),
530 };
531
532 let task_id = PoolConnection::extract_task_id(&request);
533 assert_eq!(task_id, "task-with-special_chars.and/slashes:colons");
534 }
535
536 #[test]
537 fn test_extract_task_id_handles_empty_string() {
538 let request = PoolRequest::Health {
539 task_id: "".to_string(),
540 };
541
542 let task_id = PoolConnection::extract_task_id(&request);
543 assert_eq!(task_id, "");
544 }
545
546 #[test]
547 fn test_extract_task_id_handles_uuid_format() {
548 let uuid = "550e8400-e29b-41d4-a716-446655440000";
549 let request = PoolRequest::Stats {
550 task_id: uuid.to_string(),
551 };
552
553 let task_id = PoolConnection::extract_task_id(&request);
554 assert_eq!(task_id, uuid);
555 }
556
557 #[tokio::test]
562 async fn test_acquire_without_server_fails() {
563 let pool = ConnectionPool::new("/tmp/nonexistent_socket_12345.sock".to_string(), 10);
564
565 let result = pool.acquire().await;
566 assert!(result.is_err());
567
568 match result {
569 Err(PluginError::SocketError(msg)) => {
570 assert!(msg.contains("Failed to connect"));
571 }
572 _ => panic!("Expected SocketError"),
573 }
574 }
575
576 #[tokio::test]
577 async fn test_acquire_with_pre_acquired_permit() {
578 let pool = ConnectionPool::new("/tmp/nonexistent_socket_67890.sock".to_string(), 10);
579
580 let permit = pool.semaphore.clone().acquire_owned().await.unwrap();
582 assert_eq!(pool.semaphore.available_permits(), 9);
583
584 let result = pool.acquire_with_permit(Some(permit)).await;
586
587 assert!(result.is_err());
589 }
590
591 #[test]
596 fn test_pooled_connection_cannot_be_used_after_release() {
597 }
601
602 #[tokio::test]
607 async fn test_acquire_error_message_contains_helpful_info() {
608 let pool = ConnectionPool::new("/tmp/no_server_here_xyz.sock".to_string(), 10);
609
610 let result = pool.acquire().await;
611 assert!(result.is_err());
612
613 if let Err(PluginError::SocketError(msg)) = result {
614 assert!(
616 msg.contains("PLUGIN_POOL_CONNECT_RETRIES")
617 || msg.contains("PLUGIN_POOL_MAX_CONNECTIONS")
618 || msg.contains("Failed to connect"),
619 "Error message should contain helpful info: {msg}"
620 );
621 }
622 }
623
624 #[test]
629 fn test_multiple_pools_independent() {
630 let pool1 = ConnectionPool::new("/tmp/pool1.sock".to_string(), 5);
631 let pool2 = ConnectionPool::new("/tmp/pool2.sock".to_string(), 10);
632
633 assert_eq!(pool1.semaphore.available_permits(), 5);
635 assert_eq!(pool2.semaphore.available_permits(), 10);
636
637 assert_eq!(pool1.next_connection_id(), 0);
639 assert_eq!(pool2.next_connection_id(), 0);
640 assert_eq!(pool1.next_connection_id(), 1);
641 assert_eq!(pool2.next_connection_id(), 1);
642 }
643
644 #[tokio::test]
649 async fn test_concurrent_semaphore_acquire() {
650 let pool = Arc::new(ConnectionPool::new("/tmp/concurrent.sock".to_string(), 3));
651
652 let mut handles = vec![];
653
654 for i in 0..3 {
656 let pool_clone = pool.clone();
657 handles.push(tokio::spawn(async move {
658 let permit = pool_clone.semaphore.clone().acquire_owned().await;
659 assert!(permit.is_ok(), "Task {i} should acquire permit");
660 tokio::time::sleep(Duration::from_millis(10)).await;
662 }));
663 }
664
665 for handle in handles {
667 handle.await.unwrap();
668 }
669
670 assert_eq!(pool.semaphore.available_permits(), 3);
672 }
673
674 #[tokio::test]
675 async fn test_semaphore_fairness() {
676 use std::sync::atomic::AtomicU32;
677
678 let pool = Arc::new(ConnectionPool::new("/tmp/fairness.sock".to_string(), 1));
679 let counter = Arc::new(AtomicU32::new(0));
680
681 let permit = pool.semaphore.clone().acquire_owned().await.unwrap();
683
684 let mut handles = vec![];
685
686 for _ in 0..3 {
688 let pool_clone = pool.clone();
689 let counter_clone = counter.clone();
690 handles.push(tokio::spawn(async move {
691 let _permit = pool_clone.semaphore.clone().acquire_owned().await.unwrap();
692 counter_clone.fetch_add(1, Ordering::SeqCst);
693 }));
694 }
695
696 tokio::time::sleep(Duration::from_millis(50)).await;
698
699 assert_eq!(counter.load(Ordering::SeqCst), 0);
701
702 drop(permit);
704
705 for handle in handles {
707 handle.await.unwrap();
708 }
709
710 assert_eq!(counter.load(Ordering::SeqCst), 3);
712 }
713
714 #[test]
719 fn test_zero_max_connections_creates_closed_semaphore() {
720 let pool = ConnectionPool::new("/tmp/zero.sock".to_string(), 0);
721 assert_eq!(pool.semaphore.available_permits(), 0);
722
723 let permit = pool.semaphore.clone().try_acquire_owned();
725 assert!(permit.is_err());
726 }
727
728 #[test]
729 fn test_socket_path_with_spaces() {
730 let path = "/tmp/path with spaces/test.sock";
731 let pool = ConnectionPool::new(path.to_string(), 5);
732 assert_eq!(pool.socket_path, path);
733 }
734
735 #[test]
736 fn test_socket_path_with_unicode() {
737 let path = "/tmp/тест/套接字.sock";
738 let pool = ConnectionPool::new(path.to_string(), 5);
739 assert_eq!(pool.socket_path, path);
740 }
741
742 #[test]
743 fn test_very_long_socket_path() {
744 let path = format!("/tmp/{}/test.sock", "a".repeat(200));
745 let pool = ConnectionPool::new(path.clone(), 5);
746 assert_eq!(pool.socket_path, path);
747 }
748}