1use super::config::get_config;
126use crate::jobs::JobProducerTrait;
127use crate::models::{
128 NetworkRepoModel, NotificationRepoModel, RelayerRepoModel, SignerRepoModel, ThinDataAppState,
129 TransactionRepoModel,
130};
131use crate::repositories::{
132 ApiKeyRepositoryTrait, NetworkRepository, PluginRepositoryTrait, RelayerRepository, Repository,
133 TransactionCounterTrait, TransactionRepository,
134};
135use crate::services::plugins::relayer_api::{RelayerApi, Request};
136use scc::HashMap as SccHashMap;
137use serde::{Deserialize, Serialize};
138use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
139use std::sync::Arc;
140use std::time::{Duration, Instant};
141use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
142use tokio::net::{UnixListener, UnixStream};
143use tokio::sync::{mpsc, watch, Semaphore};
144use tracing::{debug, info, warn};
145
146use super::PluginError;
147
148#[derive(Debug, Serialize, Deserialize, Clone)]
150#[serde(tag = "type", rename_all = "snake_case")]
151pub enum PluginMessage {
152 Register { execution_id: String },
154 ApiRequest {
156 request_id: String,
157 relayer_id: String,
158 method: crate::services::plugins::relayer_api::PluginMethod,
159 payload: serde_json::Value,
160 },
161 ApiResponse {
163 request_id: String,
164 result: Option<serde_json::Value>,
165 error: Option<String>,
166 },
167 Trace { trace: serde_json::Value },
169 Shutdown,
171}
172
173struct ExecutionContext {
175 traces_tx: Option<mpsc::Sender<Vec<serde_json::Value>>>,
178 created_at: Instant,
180 #[allow(dead_code)] bound_execution_id: String,
184}
185
186pub struct ExecutionGuard {
188 execution_id: String,
189 executions: Arc<SccHashMap<String, ExecutionContext>>,
190 rx: Option<mpsc::Receiver<Vec<serde_json::Value>>>,
191 active_count: Arc<AtomicUsize>,
193 registered: bool,
196}
197
198impl ExecutionGuard {
199 pub fn into_receiver(mut self) -> Option<mpsc::Receiver<Vec<serde_json::Value>>> {
202 self.rx.take()
203 }
204}
205
206impl Drop for ExecutionGuard {
207 fn drop(&mut self) {
208 if self.registered && self.executions.remove(&self.execution_id).is_some() {
219 self.active_count.fetch_sub(1, Ordering::AcqRel);
220 }
221 }
222}
223
224pub struct SharedSocketService {
226 socket_path: String,
228 executions: Arc<SccHashMap<String, ExecutionContext>>,
231 active_count: Arc<AtomicUsize>,
233 started: AtomicBool,
235 shutdown_tx: watch::Sender<bool>,
237 connection_semaphore: Arc<Semaphore>,
239 idle_timeout: Duration,
241 read_timeout: Duration,
243}
244
245impl SharedSocketService {
246 pub fn new(socket_path: &str) -> Result<Self, PluginError> {
248 let _ = std::fs::remove_file(socket_path);
250
251 let (shutdown_tx, _) = watch::channel(false);
252
253 let config = get_config();
255 let idle_timeout = Duration::from_secs(config.socket_idle_timeout_secs);
256 let read_timeout = Duration::from_secs(config.socket_read_timeout_secs);
257 let max_connections = config.socket_max_connections;
258
259 let executions: Arc<SccHashMap<String, ExecutionContext>> = Arc::new(SccHashMap::new());
260 let active_count = Arc::new(AtomicUsize::new(0));
261
262 let executions_clone = executions.clone();
264 let active_count_clone = active_count.clone();
265 let mut cleanup_shutdown_rx = shutdown_tx.subscribe();
266 tokio::spawn(async move {
267 let mut interval = tokio::time::interval(Duration::from_secs(60));
268 loop {
269 tokio::select! {
270 _ = interval.tick() => {}
271 _ = cleanup_shutdown_rx.changed() => {
272 if *cleanup_shutdown_rx.borrow() {
273 break;
274 }
275 }
276 }
277 let now = Instant::now();
278 let mut removed = 0usize;
280 executions_clone.retain(|_, ctx| {
281 let keep = now.duration_since(ctx.created_at) < Duration::from_secs(300);
282 if !keep {
283 removed += 1;
284 }
285 keep
286 });
287 if removed > 0 {
288 active_count_clone.fetch_sub(removed, Ordering::AcqRel);
289 }
290 }
291 });
292
293 Ok(Self {
294 socket_path: socket_path.to_string(),
295 executions,
296 active_count,
297 started: AtomicBool::new(false),
298 shutdown_tx,
299 connection_semaphore: Arc::new(Semaphore::new(max_connections)),
300 idle_timeout,
301 read_timeout,
302 })
303 }
304
305 pub fn socket_path(&self) -> &str {
306 &self.socket_path
307 }
308
309 pub async fn register_execution(
316 &self,
317 execution_id: String,
318 emit_traces: bool,
319 ) -> ExecutionGuard {
320 let (tx, rx) = if emit_traces {
322 let (tx, rx) = mpsc::channel(1);
323 (Some(tx), Some(rx))
324 } else {
325 (None, None)
326 };
327
328 let ctx = ExecutionContext {
329 traces_tx: tx,
330 created_at: Instant::now(),
331 bound_execution_id: execution_id.clone(),
332 };
333
334 let registered = match self.executions.insert(execution_id.clone(), ctx) {
336 Ok(_) => {
337 self.active_count.fetch_add(1, Ordering::AcqRel);
338 true
339 }
340 Err((existing_key, _)) => {
341 tracing::warn!(
342 execution_id = %existing_key,
343 "Duplicate execution_id detected during registration, guard will not decrement counter"
344 );
345 false
346 }
347 };
348
349 ExecutionGuard {
350 execution_id,
351 executions: self.executions.clone(),
352 rx,
353 registered,
354 active_count: self.active_count.clone(),
355 }
356 }
357
358 pub fn available_connection_slots(&self) -> usize {
360 self.connection_semaphore.available_permits()
361 }
362
363 pub fn active_connection_count(&self) -> usize {
365 get_config().socket_max_connections - self.connection_semaphore.available_permits()
366 }
367
368 pub async fn registered_executions_count(&self) -> usize {
370 self.active_count.load(Ordering::Relaxed)
371 }
372
373 pub async fn shutdown(&self) {
375 let _ = self.shutdown_tx.send(true);
376 info!("Shared socket service: shutdown signal sent");
377
378 let max_wait = Duration::from_secs(30);
380 let start = Instant::now();
381
382 while start.elapsed() < max_wait {
383 let available = self.connection_semaphore.available_permits();
384 if available == get_config().socket_max_connections {
385 break;
387 }
388 tokio::time::sleep(Duration::from_millis(100)).await;
389 }
390
391 let _ = std::fs::remove_file(&self.socket_path);
393 info!("Shared socket service: shutdown complete");
394 }
395
396 #[allow(clippy::type_complexity)]
400 pub async fn start<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(
401 self: Arc<Self>,
402 state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>>,
403 ) -> Result<(), PluginError>
404 where
405 J: JobProducerTrait + Send + Sync + 'static,
406 RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
407 TR: TransactionRepository
408 + Repository<TransactionRepoModel, String>
409 + Send
410 + Sync
411 + 'static,
412 NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
413 NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
414 SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
415 TCR: TransactionCounterTrait + Send + Sync + 'static,
416 PR: PluginRepositoryTrait + Send + Sync + 'static,
417 AKR: ApiKeyRepositoryTrait + Send + Sync + 'static,
418 {
419 if self.started.swap(true, Ordering::Acquire) {
421 return Ok(());
422 }
423
424 let listener = UnixListener::bind(&self.socket_path)
426 .map_err(|e| PluginError::SocketError(format!("Failed to bind listener: {e}")))?;
427 let executions = self.executions.clone();
428 let relayer_api = Arc::new(RelayerApi);
429 let socket_path = self.socket_path.clone();
430 let mut shutdown_rx = self.shutdown_tx.subscribe();
431 let connection_semaphore = self.connection_semaphore.clone();
432 let idle_timeout = self.idle_timeout;
433 let read_timeout = self.read_timeout;
434
435 debug!(
436 "Shared socket service: starting listener on {}",
437 socket_path
438 );
439
440 tokio::spawn(async move {
442 debug!("Shared socket service: listener task started");
443 loop {
444 tokio::select! {
445 _ = shutdown_rx.changed() => {
447 if *shutdown_rx.borrow() {
448 info!("Shared socket service: shutting down listener");
449 break;
450 }
451 }
452 accept_result = listener.accept() => {
454 match accept_result {
455 Ok((stream, _)) => {
456 match connection_semaphore.clone().try_acquire_owned() {
458 Ok(permit) => {
459 debug!("Shared socket service: accepted new connection");
460
461 let relayer_api_clone = relayer_api.clone();
462 let state_clone = Arc::clone(&state);
463 let executions_clone = executions.clone();
464
465 tokio::spawn(async move {
466 let _permit = permit;
468
469 let result = Self::handle_connection_with_timeout(
470 stream,
471 relayer_api_clone,
472 state_clone,
473 executions_clone,
474 idle_timeout,
475 read_timeout,
476 )
477 .await;
478
479 if let Err(e) = result {
480 debug!("Connection handler finished with error: {}", e);
481 }
482 });
483 }
484 Err(_) => {
485 warn!(
486 "Connection limit reached, rejecting new connection. \
487 Consider increasing PLUGIN_MAX_CONCURRENCY or PLUGIN_SOCKET_MAX_CONCURRENT_CONNECTIONS."
488 );
489 drop(stream);
490 }
491 }
492 }
493 Err(e) => {
494 warn!("Error accepting connection: {}", e);
495 }
496 }
497 }
498 }
499 }
500
501 let _ = std::fs::remove_file(&socket_path);
503 info!("Shared socket service: listener stopped");
504 });
505
506 Ok(())
507 }
508
509 #[allow(clippy::type_complexity)]
511 async fn handle_connection_with_timeout<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(
512 stream: UnixStream,
513 relayer_api: Arc<RelayerApi>,
514 state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>>,
515 executions: Arc<SccHashMap<String, ExecutionContext>>,
516 idle_timeout: Duration,
517 read_timeout: Duration,
518 ) -> Result<(), PluginError>
519 where
520 J: JobProducerTrait + Send + Sync + 'static,
521 RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
522 TR: TransactionRepository
523 + Repository<TransactionRepoModel, String>
524 + Send
525 + Sync
526 + 'static,
527 NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
528 NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
529 SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
530 TCR: TransactionCounterTrait + Send + Sync + 'static,
531 PR: PluginRepositoryTrait + Send + Sync + 'static,
532 AKR: ApiKeyRepositoryTrait + Send + Sync + 'static,
533 {
534 match tokio::time::timeout(
536 idle_timeout,
537 Self::handle_connection(stream, relayer_api, state, executions, read_timeout),
538 )
539 .await
540 {
541 Ok(result) => result,
542 Err(_) => {
543 debug!("Connection idle timeout reached");
544 Ok(())
545 }
546 }
547 }
548
549 #[allow(clippy::type_complexity)]
555 async fn handle_connection<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(
556 stream: UnixStream,
557 relayer_api: Arc<RelayerApi>,
558 state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>>,
559 executions: Arc<SccHashMap<String, ExecutionContext>>,
560 read_timeout: Duration,
561 ) -> Result<(), PluginError>
562 where
563 J: JobProducerTrait + Send + Sync + 'static,
564 RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
565 TR: TransactionRepository
566 + Repository<TransactionRepoModel, String>
567 + Send
568 + Sync
569 + 'static,
570 NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
571 NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
572 SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
573 TCR: TransactionCounterTrait + Send + Sync + 'static,
574 PR: PluginRepositoryTrait + Send + Sync + 'static,
575 AKR: ApiKeyRepositoryTrait + Send + Sync + 'static,
576 {
577 let (r, mut w) = stream.into_split();
578 let mut reader = BufReader::new(r).lines();
579
580 let mut traces: Option<Vec<serde_json::Value>> = None;
582 let mut traces_enabled = false;
584
585 let mut bound_execution_id: Option<String> = None;
588
589 loop {
590 let line = match tokio::time::timeout(read_timeout, reader.next_line()).await {
592 Ok(Ok(Some(line))) => line,
593 Ok(Ok(None)) => break, Ok(Err(e)) => {
595 warn!("Error reading from connection: {}", e);
596 break;
597 }
598 Err(_) => {
599 debug!("Read timeout on connection");
600 break;
601 }
602 };
603
604 debug!("Shared socket service: received message");
605
606 let json_value: serde_json::Value = match serde_json::from_str(&line) {
608 Ok(v) => v,
609 Err(e) => {
610 warn!("Failed to parse JSON: {}", e);
611 continue;
612 }
613 };
614
615 let has_type_field = json_value.get("type").is_some();
616
617 if has_type_field {
618 let message: PluginMessage = match serde_json::from_value(json_value) {
620 Ok(msg) => msg,
621 Err(e) => {
622 warn!("Failed to parse PluginMessage: {}", e);
623 continue;
624 }
625 };
626
627 match message {
629 PluginMessage::Register { execution_id } => {
630 if bound_execution_id.is_some() {
632 warn!("Attempted to re-register connection (security violation)");
633 break;
634 }
635
636 if let Some(has_traces) =
639 executions.read(&execution_id, |_, ctx| ctx.traces_tx.is_some())
640 {
641 traces_enabled = has_traces;
642 } else {
643 warn!("Unknown execution_id: {}", execution_id);
644 break;
645 }
646
647 debug!(
648 execution_id = %execution_id,
649 traces_enabled = traces_enabled,
650 "Connection registered"
651 );
652 bound_execution_id = Some(execution_id);
653 }
654
655 PluginMessage::ApiRequest {
656 request_id,
657 relayer_id,
658 method,
659 payload,
660 } => {
661 let exec_id = match &bound_execution_id {
663 Some(id) => id,
664 None => {
665 warn!("ApiRequest before Register (security violation)");
666 break;
667 }
668 };
669
670 let request = Request {
672 request_id: request_id.clone(),
673 relayer_id,
674 method,
675 payload,
676 http_request_id: Some(exec_id.clone()),
677 };
678
679 let response = relayer_api.handle_request(request, &state).await;
681
682 let api_response = PluginMessage::ApiResponse {
684 request_id: response.request_id,
685 result: response.result,
686 error: response.error,
687 };
688
689 let response_str = serde_json::to_string(&api_response)
690 .map_err(|e| PluginError::PluginError(e.to_string()))?
691 + "\n";
692
693 if let Err(e) = w.write_all(response_str.as_bytes()).await {
694 warn!("Failed to write API response: {}", e);
695 break;
696 }
697
698 if let Err(e) = w.flush().await {
699 warn!("Failed to flush API response: {}", e);
700 break;
701 }
702 }
703
704 PluginMessage::Trace { trace } => {
705 if traces_enabled {
707 if traces.is_none() {
708 traces = Some(Vec::new());
709 }
710 if let Some(ref mut t) = traces {
711 t.push(trace);
712 }
713 }
714 }
716
717 PluginMessage::Shutdown => {
718 debug!("Plugin requested shutdown");
719 break;
720 }
721
722 PluginMessage::ApiResponse { .. } => {
723 warn!("Received ApiResponse from plugin (invalid direction)");
724 continue;
725 }
726 }
727 } else {
728 if let Ok(request) = serde_json::from_value::<Request>(json_value.clone()) {
730 if bound_execution_id.is_none() {
734 let candidate_id = request
735 .http_request_id
736 .clone()
737 .or_else(|| Some(request.request_id.clone()));
738
739 if let Some(ref id) = candidate_id {
742 if let Some(has_traces) =
743 executions.read(id, |_, ctx| ctx.traces_tx.is_some())
744 {
745 traces_enabled = has_traces;
746 bound_execution_id = candidate_id;
747 } else {
748 debug!("Legacy request with unknown execution_id: {}", id);
749 }
750 }
751 }
752
753 let response = relayer_api.handle_request(request, &state).await;
755 let response_str = serde_json::to_string(&response)
756 .map_err(|e| PluginError::PluginError(e.to_string()))?
757 + "\n";
758
759 if let Err(e) = w.write_all(response_str.as_bytes()).await {
760 warn!("Failed to write response: {}", e);
761 break;
762 }
763
764 if let Err(e) = w.flush().await {
765 warn!("Failed to flush response: {}", e);
766 break;
767 }
768 } else {
769 warn!("Failed to parse message as either PluginMessage or legacy Request");
770 }
771 }
772 }
773
774 if traces_enabled {
776 if let Some(exec_id) = bound_execution_id {
777 let traces_tx = executions
779 .read(&exec_id, |_, ctx| ctx.traces_tx.clone())
780 .flatten();
781
782 if let Some(tx) = traces_tx {
783 let collected_traces = traces.unwrap_or_default();
784 let trace_count = collected_traces.len();
785 match tokio::time::timeout(
788 Duration::from_millis(100),
789 tx.send(collected_traces),
790 )
791 .await
792 {
793 Ok(Ok(())) => {}
794 Ok(Err(_)) => {
795 if trace_count > 0 {
796 warn!(
797 "Trace channel closed for execution_id: {} ({} traces lost)",
798 exec_id, trace_count
799 );
800 }
801 }
802 Err(_) => warn!("Timeout sending traces for execution_id: {}", exec_id),
803 }
804 }
805 }
806 }
807 debug!("Shared socket service: connection closed");
810 Ok(())
811 }
812}
813
814impl Drop for SharedSocketService {
815 fn drop(&mut self) {
816 let _ = self.shutdown_tx.send(true);
818 }
821}
822
823static SHARED_SOCKET: std::sync::OnceLock<Result<Arc<SharedSocketService>, String>> =
825 std::sync::OnceLock::new();
826
827pub fn get_shared_socket_service() -> Result<Arc<SharedSocketService>, PluginError> {
830 let socket_path = "/tmp/relayer-plugin-shared.sock";
831
832 let result = SHARED_SOCKET.get_or_init(|| {
833 let _ = std::fs::remove_file(socket_path);
835
836 match SharedSocketService::new(socket_path) {
837 Ok(service) => Ok(Arc::new(service)),
838 Err(e) => Err(e.to_string()),
839 }
840 });
841
842 match result {
843 Ok(service) => Ok(service.clone()),
844 Err(e) => Err(PluginError::SocketError(format!(
845 "Failed to create shared socket service: {e}"
846 ))),
847 }
848}
849
850#[allow(clippy::type_complexity)]
852pub async fn ensure_shared_socket_started<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(
853 state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>>,
854) -> Result<(), PluginError>
855where
856 J: JobProducerTrait + Send + Sync + 'static,
857 RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
858 TR: TransactionRepository + Repository<TransactionRepoModel, String> + Send + Sync + 'static,
859 NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
860 NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
861 SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
862 TCR: TransactionCounterTrait + Send + Sync + 'static,
863 PR: PluginRepositoryTrait + Send + Sync + 'static,
864 AKR: ApiKeyRepositoryTrait + Send + Sync + 'static,
865{
866 let service = get_shared_socket_service()?;
867 service.start(state).await
868}
869
870#[cfg(test)]
871mod tests {
872 use super::*;
873 use crate::utils::mocks::mockutils::create_mock_app_state;
874 use actix_web::web;
875 use tempfile::tempdir;
876 use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
877 use tokio::net::UnixStream;
878
879 #[tokio::test]
880 async fn test_unified_protocol_register_and_api_request() {
881 let temp_dir = tempdir().unwrap();
882 let socket_path = temp_dir.path().join("shared.sock");
883
884 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
885 let state = create_mock_app_state(None, None, None, None, None, None).await;
886
887 service
889 .clone()
890 .start(Arc::new(web::ThinData(state)))
891 .await
892 .unwrap();
893
894 let execution_id = "test-exec-123".to_string();
896 let _guard = service.register_execution(execution_id.clone(), true).await;
897
898 tokio::time::sleep(Duration::from_millis(50)).await;
900
901 let mut client = UnixStream::connect(socket_path.to_str().unwrap())
903 .await
904 .unwrap();
905
906 let register_msg = PluginMessage::Register {
908 execution_id: execution_id.clone(),
909 };
910 let msg_json = serde_json::to_string(®ister_msg).unwrap() + "\n";
911 client.write_all(msg_json.as_bytes()).await.unwrap();
912
913 let api_request = PluginMessage::ApiRequest {
915 request_id: "req-1".to_string(),
916 relayer_id: "relayer-1".to_string(),
917 method: crate::services::plugins::relayer_api::PluginMethod::GetRelayerStatus,
918 payload: serde_json::json!({}),
919 };
920 let req_json = serde_json::to_string(&api_request).unwrap() + "\n";
921 client.write_all(req_json.as_bytes()).await.unwrap();
922 client.flush().await.unwrap();
923
924 let (r, _w) = client.into_split();
926 let mut reader = BufReader::new(r);
927 let mut response_line = String::new();
928 reader.read_line(&mut response_line).await.unwrap();
929
930 let response: PluginMessage = serde_json::from_str(&response_line).unwrap();
931 match response {
932 PluginMessage::ApiResponse { request_id, .. } => {
933 assert_eq!(request_id, "req-1");
934 }
935 _ => panic!("Expected ApiResponse, got {response:?}"),
936 }
937
938 service.shutdown().await;
939 }
940
941 #[tokio::test]
942 async fn test_connection_tagging_prevents_spoofing() {
943 let temp_dir = tempdir().unwrap();
944 let socket_path = temp_dir.path().join("shared2.sock");
945
946 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
947 let state = create_mock_app_state(None, None, None, None, None, None).await;
948
949 service
950 .clone()
951 .start(Arc::new(web::ThinData(state)))
952 .await
953 .unwrap();
954
955 let execution_id = "test-exec-456".to_string();
956 let _guard = service.register_execution(execution_id.clone(), true).await;
957
958 tokio::time::sleep(Duration::from_millis(50)).await;
959
960 let mut client = UnixStream::connect(socket_path.to_str().unwrap())
961 .await
962 .unwrap();
963
964 let register_msg = PluginMessage::Register {
966 execution_id: execution_id.clone(),
967 };
968 let msg_json = serde_json::to_string(®ister_msg).unwrap() + "\n";
969 client.write_all(msg_json.as_bytes()).await.unwrap();
970
971 let spoofed_register = PluginMessage::Register {
973 execution_id: "different-exec-id".to_string(),
974 };
975 let spoofed_json = serde_json::to_string(&spoofed_register).unwrap() + "\n";
976 client.write_all(spoofed_json.as_bytes()).await.unwrap();
977 client.flush().await.unwrap();
978
979 tokio::time::sleep(Duration::from_millis(100)).await;
981
982 let (r, _w) = client.into_split();
984 let mut reader = BufReader::new(r);
985 let mut line = String::new();
986 let result = reader.read_line(&mut line).await;
987
988 assert!(result.is_err() || result.unwrap() == 0);
990
991 service.shutdown().await;
992 }
993
994 #[tokio::test]
995 async fn test_backward_compatibility_with_legacy_format() {
996 let temp_dir = tempdir().unwrap();
997 let socket_path = temp_dir.path().join("shared3.sock");
998
999 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1000 let state = create_mock_app_state(None, None, None, None, None, None).await;
1001
1002 service
1003 .clone()
1004 .start(Arc::new(web::ThinData(state)))
1005 .await
1006 .unwrap();
1007
1008 let execution_id = "test-exec-789".to_string();
1009 let _guard = service.register_execution(execution_id.clone(), true).await;
1010
1011 tokio::time::sleep(Duration::from_millis(50)).await;
1012
1013 let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1014 .await
1015 .unwrap();
1016
1017 let legacy_request = crate::services::plugins::relayer_api::Request {
1019 request_id: "legacy-1".to_string(),
1020 relayer_id: "relayer-1".to_string(),
1021 method: crate::services::plugins::relayer_api::PluginMethod::GetRelayerStatus,
1022 payload: serde_json::json!({}),
1023 http_request_id: Some(execution_id.clone()),
1024 };
1025 let legacy_json = serde_json::to_string(&legacy_request).unwrap() + "\n";
1026 client.write_all(legacy_json.as_bytes()).await.unwrap();
1027 client.flush().await.unwrap();
1028
1029 let (r, _w) = client.into_split();
1031 let mut reader = BufReader::new(r);
1032 let mut response_line = String::new();
1033 reader.read_line(&mut response_line).await.unwrap();
1034
1035 let response: crate::services::plugins::relayer_api::Response =
1036 serde_json::from_str(&response_line).unwrap();
1037
1038 assert_eq!(response.request_id, "legacy-1");
1039 assert!(response.result.is_some() || response.error.is_some());
1042
1043 service.shutdown().await;
1044 }
1045
1046 #[tokio::test]
1047 async fn test_trace_collection() {
1048 let temp_dir = tempdir().unwrap();
1049 let socket_path = temp_dir.path().join("shared4.sock");
1050
1051 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1052 let state = create_mock_app_state(None, None, None, None, None, None).await;
1053
1054 service
1055 .clone()
1056 .start(Arc::new(web::ThinData(state)))
1057 .await
1058 .unwrap();
1059
1060 let execution_id = "test-exec-trace".to_string();
1061 let guard = service.register_execution(execution_id.clone(), true).await;
1062
1063 tokio::time::sleep(Duration::from_millis(50)).await;
1064
1065 let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1066 .await
1067 .unwrap();
1068
1069 let register_msg = PluginMessage::Register {
1071 execution_id: execution_id.clone(),
1072 };
1073 client
1074 .write_all((serde_json::to_string(®ister_msg).unwrap() + "\n").as_bytes())
1075 .await
1076 .unwrap();
1077
1078 let trace1 = PluginMessage::Trace {
1080 trace: serde_json::json!({"event": "start", "timestamp": 1000}),
1081 };
1082 client
1083 .write_all((serde_json::to_string(&trace1).unwrap() + "\n").as_bytes())
1084 .await
1085 .unwrap();
1086
1087 let trace2 = PluginMessage::Trace {
1088 trace: serde_json::json!({"event": "processing", "timestamp": 2000}),
1089 };
1090 client
1091 .write_all((serde_json::to_string(&trace2).unwrap() + "\n").as_bytes())
1092 .await
1093 .unwrap();
1094
1095 let shutdown_msg = PluginMessage::Shutdown;
1097 client
1098 .write_all((serde_json::to_string(&shutdown_msg).unwrap() + "\n").as_bytes())
1099 .await
1100 .unwrap();
1101 client.flush().await.unwrap();
1102
1103 drop(client);
1104
1105 tokio::time::sleep(Duration::from_millis(100)).await;
1107
1108 let mut traces_rx = guard.into_receiver().expect("Traces should be enabled");
1110 let traces = traces_rx.recv().await.unwrap();
1111
1112 assert_eq!(traces.len(), 2);
1114 assert_eq!(traces[0]["event"], "start");
1115 assert_eq!(traces[1]["event"], "processing");
1116
1117 service.shutdown().await;
1118 }
1119
1120 #[tokio::test]
1121 async fn test_execution_guard_auto_unregister() {
1122 let temp_dir = tempdir().unwrap();
1123 let socket_path = temp_dir.path().join("shared_guard.sock");
1124
1125 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1126 let execution_id = "test-exec-guard".to_string();
1127
1128 {
1129 let _guard = service.register_execution(execution_id.clone(), true).await;
1130
1131 assert_eq!(service.registered_executions_count().await, 1);
1133 }
1134 assert_eq!(service.registered_executions_count().await, 0);
1138 }
1139
1140 #[tokio::test]
1141 async fn test_api_request_without_register_rejected() {
1142 let temp_dir = tempdir().unwrap();
1143 let socket_path = temp_dir.path().join("shared_no_register.sock");
1144
1145 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1146 let state = create_mock_app_state(None, None, None, None, None, None).await;
1147
1148 service
1149 .clone()
1150 .start(Arc::new(web::ThinData(state)))
1151 .await
1152 .unwrap();
1153
1154 tokio::time::sleep(Duration::from_millis(50)).await;
1155
1156 let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1157 .await
1158 .unwrap();
1159
1160 let api_request = PluginMessage::ApiRequest {
1162 request_id: "req-1".to_string(),
1163 relayer_id: "relayer-1".to_string(),
1164 method: crate::services::plugins::relayer_api::PluginMethod::GetRelayerStatus,
1165 payload: serde_json::json!({}),
1166 };
1167 let req_json = serde_json::to_string(&api_request).unwrap() + "\n";
1168 client.write_all(req_json.as_bytes()).await.unwrap();
1169 client.flush().await.unwrap();
1170
1171 tokio::time::sleep(Duration::from_millis(100)).await;
1173
1174 let (r, _w) = client.into_split();
1175 let mut reader = BufReader::new(r);
1176 let mut line = String::new();
1177 let result = reader.read_line(&mut line).await;
1178
1179 assert!(result.is_err() || result.unwrap() == 0);
1181
1182 service.shutdown().await;
1183 }
1184
1185 #[tokio::test]
1186 async fn test_register_with_unknown_execution_id_rejected() {
1187 let temp_dir = tempdir().unwrap();
1188 let socket_path = temp_dir.path().join("shared_unknown_exec.sock");
1189
1190 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1191 let state = create_mock_app_state(None, None, None, None, None, None).await;
1192
1193 service
1194 .clone()
1195 .start(Arc::new(web::ThinData(state)))
1196 .await
1197 .unwrap();
1198
1199 tokio::time::sleep(Duration::from_millis(50)).await;
1200
1201 let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1202 .await
1203 .unwrap();
1204
1205 let register_msg = PluginMessage::Register {
1207 execution_id: "unknown-exec-id".to_string(),
1208 };
1209 let msg_json = serde_json::to_string(®ister_msg).unwrap() + "\n";
1210 client.write_all(msg_json.as_bytes()).await.unwrap();
1211 client.flush().await.unwrap();
1212
1213 tokio::time::sleep(Duration::from_millis(100)).await;
1215
1216 let (r, _w) = client.into_split();
1217 let mut reader = BufReader::new(r);
1218 let mut line = String::new();
1219 let result = reader.read_line(&mut line).await;
1220
1221 assert!(result.is_err() || result.unwrap() == 0);
1222
1223 service.shutdown().await;
1224 }
1225
1226 #[tokio::test]
1227 async fn test_connection_limit_enforcement() {
1228 let temp_dir = tempdir().unwrap();
1229 let socket_path = temp_dir.path().join("shared_connection_limit.sock");
1230
1231 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1232 let state = create_mock_app_state(None, None, None, None, None, None).await;
1233
1234 service
1235 .clone()
1236 .start(Arc::new(web::ThinData(state)))
1237 .await
1238 .unwrap();
1239
1240 tokio::time::sleep(Duration::from_millis(50)).await;
1241
1242 let initial_permits = service.connection_semaphore.available_permits();
1244 let max_connections = get_config().socket_max_connections;
1245 assert_eq!(initial_permits, max_connections);
1246
1247 let _client = UnixStream::connect(socket_path.to_str().unwrap())
1249 .await
1250 .unwrap();
1251
1252 tokio::time::sleep(Duration::from_millis(50)).await;
1253
1254 let after_connect = service.connection_semaphore.available_permits();
1256 assert!(after_connect < initial_permits);
1257
1258 service.shutdown().await;
1259 }
1260
1261 #[tokio::test]
1262 async fn test_idle_timeout() {
1263 let temp_dir = tempdir().unwrap();
1264 let socket_path = temp_dir.path().join("shared_idle_timeout.sock");
1265
1266 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1268 let state = create_mock_app_state(None, None, None, None, None, None).await;
1269
1270 service
1271 .clone()
1272 .start(Arc::new(web::ThinData(state)))
1273 .await
1274 .unwrap();
1275
1276 let execution_id = "test-exec-idle".to_string();
1277 let _guard = service.register_execution(execution_id.clone(), true).await;
1278
1279 tokio::time::sleep(Duration::from_millis(50)).await;
1280
1281 let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1282 .await
1283 .unwrap();
1284
1285 let register_msg = PluginMessage::Register { execution_id };
1287 client
1288 .write_all((serde_json::to_string(®ister_msg).unwrap() + "\n").as_bytes())
1289 .await
1290 .unwrap();
1291 client.flush().await.unwrap();
1292
1293 tokio::time::sleep(Duration::from_millis(100)).await;
1297
1298 let shutdown_msg = PluginMessage::Shutdown;
1301 let write_result = client
1302 .write_all((serde_json::to_string(&shutdown_msg).unwrap() + "\n").as_bytes())
1303 .await;
1304
1305 assert!(write_result.is_ok(), "Connection should still be alive");
1306
1307 service.shutdown().await;
1308 }
1309
1310 #[tokio::test]
1311 async fn test_read_timeout_handling() {
1312 let temp_dir = tempdir().unwrap();
1313 let socket_path = temp_dir.path().join("shared_read_timeout.sock");
1314
1315 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1316 let state = create_mock_app_state(None, None, None, None, None, None).await;
1317
1318 service
1319 .clone()
1320 .start(Arc::new(web::ThinData(state)))
1321 .await
1322 .unwrap();
1323
1324 let execution_id = "test-exec-read-timeout".to_string();
1325 let _guard = service.register_execution(execution_id.clone(), true).await;
1326
1327 tokio::time::sleep(Duration::from_millis(50)).await;
1328
1329 let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1330 .await
1331 .unwrap();
1332
1333 let register_msg = PluginMessage::Register { execution_id };
1335 client
1336 .write_all((serde_json::to_string(®ister_msg).unwrap() + "\n").as_bytes())
1337 .await
1338 .unwrap();
1339 client.flush().await.unwrap();
1340
1341 tokio::time::sleep(Duration::from_millis(200)).await;
1346
1347 drop(client);
1349
1350 service.shutdown().await;
1351 }
1352
1353 #[tokio::test]
1354 async fn test_multiple_api_requests_same_connection() {
1355 let temp_dir = tempdir().unwrap();
1356 let socket_path = temp_dir.path().join("shared_multiple_requests.sock");
1357
1358 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1359 let state = create_mock_app_state(None, None, None, None, None, None).await;
1360
1361 service
1362 .clone()
1363 .start(Arc::new(web::ThinData(state)))
1364 .await
1365 .unwrap();
1366
1367 let execution_id = "test-exec-multi".to_string();
1368 let _guard = service.register_execution(execution_id.clone(), true).await;
1369
1370 tokio::time::sleep(Duration::from_millis(50)).await;
1371
1372 let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1373 .await
1374 .unwrap();
1375
1376 let register_msg = PluginMessage::Register {
1378 execution_id: execution_id.clone(),
1379 };
1380 client
1381 .write_all((serde_json::to_string(®ister_msg).unwrap() + "\n").as_bytes())
1382 .await
1383 .unwrap();
1384
1385 let (r, mut w) = client.into_split();
1386 let mut reader = BufReader::new(r);
1387
1388 for i in 1..=3 {
1390 let api_request = PluginMessage::ApiRequest {
1391 request_id: format!("req-{i}"),
1392 relayer_id: "relayer-1".to_string(),
1393 method: crate::services::plugins::relayer_api::PluginMethod::GetRelayerStatus,
1394 payload: serde_json::json!({}),
1395 };
1396 w.write_all((serde_json::to_string(&api_request).unwrap() + "\n").as_bytes())
1397 .await
1398 .unwrap();
1399 w.flush().await.unwrap();
1400
1401 let mut response_line = String::new();
1403 reader.read_line(&mut response_line).await.unwrap();
1404
1405 let response: PluginMessage = serde_json::from_str(&response_line).unwrap();
1406 match response {
1407 PluginMessage::ApiResponse { request_id, .. } => {
1408 assert_eq!(request_id, format!("req-{i}"));
1409 }
1410 _ => panic!("Expected ApiResponse"),
1411 }
1412 }
1413
1414 service.shutdown().await;
1415 }
1416
1417 #[tokio::test]
1418 async fn test_shutdown_signal() {
1419 let temp_dir = tempdir().unwrap();
1420 let socket_path = temp_dir.path().join("shared_shutdown_signal.sock");
1421
1422 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1423 let state = create_mock_app_state(None, None, None, None, None, None).await;
1424
1425 service
1426 .clone()
1427 .start(Arc::new(web::ThinData(state)))
1428 .await
1429 .unwrap();
1430
1431 tokio::time::sleep(Duration::from_millis(50)).await;
1432
1433 assert!(std::path::Path::new(socket_path.to_str().unwrap()).exists());
1435
1436 service.shutdown().await;
1438
1439 tokio::time::sleep(Duration::from_millis(100)).await;
1441
1442 assert!(!std::path::Path::new(socket_path.to_str().unwrap()).exists());
1444 }
1445
1446 #[tokio::test]
1447 async fn test_malformed_json_handling() {
1448 let temp_dir = tempdir().unwrap();
1449 let socket_path = temp_dir.path().join("shared_malformed.sock");
1450
1451 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1452 let state = create_mock_app_state(None, None, None, None, None, None).await;
1453
1454 service
1455 .clone()
1456 .start(Arc::new(web::ThinData(state)))
1457 .await
1458 .unwrap();
1459
1460 let execution_id = "test-exec-malformed".to_string();
1461 let _guard = service.register_execution(execution_id.clone(), true).await;
1462
1463 tokio::time::sleep(Duration::from_millis(50)).await;
1464
1465 let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1466 .await
1467 .unwrap();
1468
1469 let register_msg = PluginMessage::Register {
1471 execution_id: execution_id.clone(),
1472 };
1473 client
1474 .write_all((serde_json::to_string(®ister_msg).unwrap() + "\n").as_bytes())
1475 .await
1476 .unwrap();
1477
1478 client
1480 .write_all(b"{ this is not valid json }\n")
1481 .await
1482 .unwrap();
1483 client.flush().await.unwrap();
1484
1485 tokio::time::sleep(Duration::from_millis(100)).await;
1487
1488 let shutdown_msg = PluginMessage::Shutdown;
1490 let write_result = client
1491 .write_all((serde_json::to_string(&shutdown_msg).unwrap() + "\n").as_bytes())
1492 .await;
1493
1494 assert!(
1495 write_result.is_ok(),
1496 "Connection should still be alive after malformed JSON"
1497 );
1498
1499 service.shutdown().await;
1500 }
1501
1502 #[tokio::test]
1503 async fn test_invalid_message_direction() {
1504 let temp_dir = tempdir().unwrap();
1505 let socket_path = temp_dir.path().join("shared_invalid_direction.sock");
1506
1507 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1508 let state = create_mock_app_state(None, None, None, None, None, None).await;
1509
1510 service
1511 .clone()
1512 .start(Arc::new(web::ThinData(state)))
1513 .await
1514 .unwrap();
1515
1516 let execution_id = "test-exec-invalid-dir".to_string();
1517 let _guard = service.register_execution(execution_id.clone(), true).await;
1518
1519 tokio::time::sleep(Duration::from_millis(50)).await;
1520
1521 let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1522 .await
1523 .unwrap();
1524
1525 let register_msg = PluginMessage::Register {
1527 execution_id: execution_id.clone(),
1528 };
1529 client
1530 .write_all((serde_json::to_string(®ister_msg).unwrap() + "\n").as_bytes())
1531 .await
1532 .unwrap();
1533
1534 let invalid_msg = PluginMessage::ApiResponse {
1536 request_id: "invalid".to_string(),
1537 result: Some(serde_json::json!({})),
1538 error: None,
1539 };
1540 client
1541 .write_all((serde_json::to_string(&invalid_msg).unwrap() + "\n").as_bytes())
1542 .await
1543 .unwrap();
1544 client.flush().await.unwrap();
1545
1546 tokio::time::sleep(Duration::from_millis(100)).await;
1548
1549 let shutdown_msg = PluginMessage::Shutdown;
1551 let write_result = client
1552 .write_all((serde_json::to_string(&shutdown_msg).unwrap() + "\n").as_bytes())
1553 .await;
1554
1555 assert!(write_result.is_ok());
1556
1557 service.shutdown().await;
1558 }
1559
1560 #[tokio::test]
1561 async fn test_stale_execution_cleanup() {
1562 let temp_dir = tempdir().unwrap();
1563 let socket_path = temp_dir.path().join("shared_stale_cleanup.sock");
1564
1565 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1566
1567 let execution_id = "stale-exec".to_string();
1569 let (tx, _rx) = mpsc::channel(1);
1570 let _ = service.executions.insert(
1572 execution_id.clone(),
1573 ExecutionContext {
1574 traces_tx: Some(tx),
1575 created_at: Instant::now() - Duration::from_secs(400), bound_execution_id: execution_id.clone(),
1577 },
1578 );
1579
1580 assert!(service.executions.contains(&execution_id));
1582
1583 drop(service);
1589 }
1590
1591 #[tokio::test]
1592 async fn test_socket_path_getter() {
1593 let temp_dir = tempdir().unwrap();
1594 let socket_path = temp_dir.path().join("shared_path.sock");
1595
1596 let service = SharedSocketService::new(socket_path.to_str().unwrap()).unwrap();
1597
1598 assert_eq!(service.socket_path(), socket_path.to_str().unwrap());
1599 }
1600
1601 #[tokio::test]
1602 async fn test_trace_send_timeout() {
1603 let temp_dir = tempdir().unwrap();
1604 let socket_path = temp_dir.path().join("shared_trace_timeout.sock");
1605
1606 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1607 let state = create_mock_app_state(None, None, None, None, None, None).await;
1608
1609 service
1610 .clone()
1611 .start(Arc::new(web::ThinData(state)))
1612 .await
1613 .unwrap();
1614
1615 let execution_id = "test-exec-trace-timeout".to_string();
1616 let guard = service.register_execution(execution_id.clone(), true).await;
1617
1618 drop(guard);
1620
1621 tokio::time::sleep(Duration::from_millis(50)).await;
1622
1623 let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1624 .await
1625 .unwrap();
1626
1627 let register_msg = PluginMessage::Register {
1629 execution_id: execution_id.clone(),
1630 };
1631 client
1632 .write_all((serde_json::to_string(®ister_msg).unwrap() + "\n").as_bytes())
1633 .await
1634 .unwrap();
1635
1636 let trace = PluginMessage::Trace {
1638 trace: serde_json::json!({"event": "test"}),
1639 };
1640 client
1641 .write_all((serde_json::to_string(&trace).unwrap() + "\n").as_bytes())
1642 .await
1643 .unwrap();
1644
1645 let shutdown_msg = PluginMessage::Shutdown;
1647 client
1648 .write_all((serde_json::to_string(&shutdown_msg).unwrap() + "\n").as_bytes())
1649 .await
1650 .unwrap();
1651 client.flush().await.unwrap();
1652
1653 drop(client);
1654
1655 tokio::time::sleep(Duration::from_millis(200)).await;
1657
1658 service.shutdown().await;
1659 }
1660
1661 #[tokio::test]
1662 async fn test_get_shared_socket_service() {
1663 let service1 = get_shared_socket_service();
1665 assert!(service1.is_ok());
1666
1667 let service2 = get_shared_socket_service();
1668 assert!(service2.is_ok());
1669
1670 let svc1 = service1.unwrap();
1672 let svc2 = service2.unwrap();
1673 let path1 = svc1.socket_path();
1674 let path2 = svc2.socket_path();
1675 assert_eq!(path1, path2);
1676 }
1677
1678 #[tokio::test]
1679 async fn test_duplicate_execution_id_does_not_corrupt_counter() {
1680 let temp_dir = tempdir().unwrap();
1681 let socket_path = temp_dir.path().join("shared_duplicate_exec.sock");
1682
1683 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1684 let execution_id = "duplicate-exec-id".to_string();
1685
1686 assert_eq!(service.registered_executions_count().await, 0);
1688
1689 let guard1 = service.register_execution(execution_id.clone(), true).await;
1691 assert_eq!(service.registered_executions_count().await, 1);
1692
1693 let guard2 = service.register_execution(execution_id.clone(), true).await;
1696 assert_eq!(service.registered_executions_count().await, 1);
1698
1699 drop(guard2);
1702 assert_eq!(service.registered_executions_count().await, 1);
1703
1704 drop(guard1);
1706 assert_eq!(service.registered_executions_count().await, 0);
1707 }
1708
1709 #[tokio::test]
1710 async fn test_execution_guard_registered_field() {
1711 let temp_dir = tempdir().unwrap();
1712 let socket_path = temp_dir.path().join("shared_registered_field.sock");
1713
1714 let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1715
1716 let execution_id_1 = "unique-exec-1".to_string();
1718 let guard1 = service
1719 .register_execution(execution_id_1.clone(), true)
1720 .await;
1721 assert_eq!(service.registered_executions_count().await, 1);
1722
1723 let execution_id_2 = "unique-exec-2".to_string();
1725 let guard2 = service
1726 .register_execution(execution_id_2.clone(), false)
1727 .await;
1728 assert_eq!(service.registered_executions_count().await, 2);
1729
1730 let rx = guard1.into_receiver();
1732 assert!(rx.is_some()); let rx2 = guard2.into_receiver();
1736 assert!(rx2.is_none()); assert_eq!(service.registered_executions_count().await, 0);
1740 }
1741}