openzeppelin_relayer/services/plugins/
shared_socket.rs

1//! Shared Socket Service
2//!
3//! This module provides a unified bidirectional Unix socket service for plugin communication.
4//! Instead of creating separate sockets for registration and API calls, all communication
5//! happens over a single shared socket, dramatically reducing overhead and complexity.
6//!
7//! ## Architecture
8//!
9//! **Single Shared Socket**: All plugins connect to `/tmp/relayer-plugin-shared.sock`
10//!
11//! **Bidirectional Communication**:
12//! - Plugins → Host: Register, ApiRequest, Trace, Shutdown
13//! - Host → Plugins: ApiResponse
14//!
15//! **Connection Tagging (Security)**: Each connection is "tagged" with an execution_id
16//! after the first Register message. All subsequent messages are validated against this
17//! tagged ID to prevent spoofing attacks (Plugin A cannot impersonate Plugin B).
18//!
19//! ## Message Protocol
20//!
21//! All messages are JSON objects with a `type` field that discriminates the message type:
22//!
23//! ### Plugin → Host Messages
24//!
25//! **Register** (first message, required):
26//! ```json
27//! {
28//!   "type": "register",
29//!   "execution_id": "abc-123"
30//! }
31//! ```
32//!
33//! **ApiRequest** (call Relayer API):
34//! ```json
35//! {
36//!   "type": "api_request",
37//!   "request_id": "req-1",
38//!   "relayer_id": "relayer-1",
39//!   "method": "sendTransaction",
40//!   "payload": { "to": "0x...", "value": "100" }
41//! }
42//! ```
43//!
44//! **Trace** (observability event):
45//! ```json
46//! {
47//!   "type": "trace",
48//!   "trace": { "event": "processing", "timestamp": 1234567890 }
49//! }
50//! ```
51//!
52//! **Shutdown** (graceful close):
53//! ```json
54//! {
55//!   "type": "shutdown"
56//! }
57//! ```
58//!
59//! ### Host → Plugin Messages
60//!
61//! **ApiResponse** (Relayer API result):
62//! ```json
63//! {
64//!   "type": "api_response",
65//!   "request_id": "req-1",
66//!   "result": { "id": "tx-123", "status": "success" },
67//!   "error": null
68//! }
69//! ```
70//!
71//! ## Security Model
72//!
73//! The connection tagging mechanism prevents execution_id spoofing:
74//!
75//! 1. Plugin connects to shared socket
76//! 2. Plugin sends Register message with execution_id
77//! 3. Host "tags" the connection (file descriptor) with that execution_id
78//! 4. All subsequent messages are validated against the tagged ID
79//! 5. Attempts to change execution_id are rejected and connection is closed
80//!
81//! This ensures Plugin A cannot send requests pretending to be Plugin B, even though
82//! they share the same socket file.
83//!
84//! ## Backward Compatibility
85//!
86//! The handle_connection method maintains backward compatibility with the legacy
87//! Request/Response format from socket.rs. If a message doesn't parse as PluginMessage,
88//! it attempts to parse as the legacy Request format and handles it accordingly.
89//!
90//! ## Performance Benefits vs Per-Execution Sockets
91//!
92//! | Metric | Shared Socket | Per-Execution Socket |
93//! |--------|---------------|----------------------|
94//! | File descriptors | 1 per plugin | 2 per plugin |
95//! | Syscalls | ~50% fewer | Baseline |
96//! | Connection setup | Reuse existing | Create new each time |
97//! | Memory overhead | O(active executions) | O(active executions × 2) |
98//! | Debugging | Single stream | Two separate streams |
99//!
100//! ## Example Usage
101//!
102//! ```rust,no_run
103//! use openzeppelin_relayer::services::plugins::shared_socket::{
104//!     get_shared_socket_service, ensure_shared_socket_started
105//! };
106//!
107//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
108//! // Get the global shared socket instance
109//! let service = get_shared_socket_service()?;
110//!
111//! // Register an execution (returns RAII guard)
112//! let guard = service.register_execution("exec-123".to_string(), true).await;
113//!
114//! // Plugin connects and sends messages over the shared socket...
115//! // (handled automatically by the background listener)
116//!
117//! // Collect traces when done (returns Some when emit_traces=true)
118//! if let Some(mut traces_rx) = guard.into_receiver() {
119//!     let traces = traces_rx.recv().await;
120//! }
121//! # Ok(())
122//! # }
123//! ```
124
125use super::config::get_config;
126use crate::jobs::JobProducerTrait;
127use crate::models::{
128    NetworkRepoModel, NotificationRepoModel, RelayerRepoModel, SignerRepoModel, ThinDataAppState,
129    TransactionRepoModel,
130};
131use crate::repositories::{
132    ApiKeyRepositoryTrait, NetworkRepository, PluginRepositoryTrait, RelayerRepository, Repository,
133    TransactionCounterTrait, TransactionRepository,
134};
135use crate::services::plugins::relayer_api::{RelayerApi, Request};
136use scc::HashMap as SccHashMap;
137use serde::{Deserialize, Serialize};
138use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
139use std::sync::Arc;
140use std::time::{Duration, Instant};
141use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
142use tokio::net::{UnixListener, UnixStream};
143use tokio::sync::{mpsc, watch, Semaphore};
144use tracing::{debug, info, warn};
145
146use super::PluginError;
147
148/// Unified message protocol for bidirectional communication
149#[derive(Debug, Serialize, Deserialize, Clone)]
150#[serde(tag = "type", rename_all = "snake_case")]
151pub enum PluginMessage {
152    /// Plugin registers its execution_id (first message from plugin)
153    Register { execution_id: String },
154    /// Plugin requests a Relayer API call
155    ApiRequest {
156        request_id: String,
157        relayer_id: String,
158        method: crate::services::plugins::relayer_api::PluginMethod,
159        payload: serde_json::Value,
160    },
161    /// Host responds to an API request
162    ApiResponse {
163        request_id: String,
164        result: Option<serde_json::Value>,
165        error: Option<String>,
166    },
167    /// Plugin sends a trace event (for observability)
168    Trace { trace: serde_json::Value },
169    /// Plugin signals completion
170    Shutdown,
171}
172
173/// Execution context for trace collection
174struct ExecutionContext {
175    /// Channel to send traces back to the execution (None when emit_traces=false)
176    /// When None, connection handler skips trace collection entirely for better performance
177    traces_tx: Option<mpsc::Sender<Vec<serde_json::Value>>>,
178    /// Creation timestamp for TTL cleanup
179    created_at: Instant,
180    /// The execution_id bound to this connection (for security)
181    /// Once set, all messages must match this ID to prevent spoofing
182    #[allow(dead_code)] // Used for security validation, not directly read
183    bound_execution_id: String,
184}
185
186/// RAII guard for execution registration that auto-unregisters on drop
187pub struct ExecutionGuard {
188    execution_id: String,
189    executions: Arc<SccHashMap<String, ExecutionContext>>,
190    rx: Option<mpsc::Receiver<Vec<serde_json::Value>>>,
191    /// Shared counter for tracking active executions (lock-free)
192    active_count: Arc<AtomicUsize>,
193    /// Whether this guard was successfully registered (insertion succeeded)
194    /// Only registered guards should decrement active_count on drop
195    registered: bool,
196}
197
198impl ExecutionGuard {
199    /// Get the trace receiver if tracing was enabled
200    /// Returns None if emit_traces=false was passed to register_execution
201    pub fn into_receiver(mut self) -> Option<mpsc::Receiver<Vec<serde_json::Value>>> {
202        self.rx.take()
203    }
204}
205
206impl Drop for ExecutionGuard {
207    fn drop(&mut self) {
208        // Auto-unregister on drop - synchronous with scc::HashMap (no spawn needed!)
209        // This eliminates the overhead of spawning a task for every request
210        //
211        // Only registered guards should remove entries and decrement counters.
212        // Non-registered guards (from duplicate execution_id) don't own the entry.
213        //
214        // For registered guards, only decrement if we actually removed the entry.
215        // This prevents double-decrement: if a long-running execution is GC'd by the
216        // stale entry cleanup task (which decrements the counter), and then the guard
217        // drops later, we must NOT decrement again.
218        if self.registered && self.executions.remove(&self.execution_id).is_some() {
219            self.active_count.fetch_sub(1, Ordering::AcqRel);
220        }
221    }
222}
223
224/// Shared socket service that handles multiple concurrent plugin executions
225pub struct SharedSocketService {
226    /// Socket path
227    socket_path: String,
228    /// Active execution contexts (execution_id -> ExecutionContext)
229    /// scc::HashMap provides lock-free reads and optimistic locking for writes
230    executions: Arc<SccHashMap<String, ExecutionContext>>,
231    /// Lock-free counter for active executions
232    active_count: Arc<AtomicUsize>,
233    /// Whether the listener has been started (instance-level flag)
234    started: AtomicBool,
235    /// Shutdown signal sender
236    shutdown_tx: watch::Sender<bool>,
237    /// Semaphore for connection limiting (prevents race conditions)
238    connection_semaphore: Arc<Semaphore>,
239    /// Connection idle timeout
240    idle_timeout: Duration,
241    /// Read timeout per line
242    read_timeout: Duration,
243}
244
245impl SharedSocketService {
246    /// Create a new shared socket service
247    pub fn new(socket_path: &str) -> Result<Self, PluginError> {
248        // Remove existing socket file if it exists (from previous runs or crashed processes)
249        let _ = std::fs::remove_file(socket_path);
250
251        let (shutdown_tx, _) = watch::channel(false);
252
253        // Use centralized config
254        let config = get_config();
255        let idle_timeout = Duration::from_secs(config.socket_idle_timeout_secs);
256        let read_timeout = Duration::from_secs(config.socket_read_timeout_secs);
257        let max_connections = config.socket_max_connections;
258
259        let executions: Arc<SccHashMap<String, ExecutionContext>> = Arc::new(SccHashMap::new());
260        let active_count = Arc::new(AtomicUsize::new(0));
261
262        // Spawn background cleanup task for stale executions (prevents memory leaks)
263        let executions_clone = executions.clone();
264        let active_count_clone = active_count.clone();
265        let mut cleanup_shutdown_rx = shutdown_tx.subscribe();
266        tokio::spawn(async move {
267            let mut interval = tokio::time::interval(Duration::from_secs(60));
268            loop {
269                tokio::select! {
270                    _ = interval.tick() => {}
271                    _ = cleanup_shutdown_rx.changed() => {
272                        if *cleanup_shutdown_rx.borrow() {
273                            break;
274                        }
275                    }
276                }
277                let now = Instant::now();
278                // scc::HashMap retain is lock-free per entry
279                let mut removed = 0usize;
280                executions_clone.retain(|_, ctx| {
281                    let keep = now.duration_since(ctx.created_at) < Duration::from_secs(300);
282                    if !keep {
283                        removed += 1;
284                    }
285                    keep
286                });
287                if removed > 0 {
288                    active_count_clone.fetch_sub(removed, Ordering::AcqRel);
289                }
290            }
291        });
292
293        Ok(Self {
294            socket_path: socket_path.to_string(),
295            executions,
296            active_count,
297            started: AtomicBool::new(false),
298            shutdown_tx,
299            connection_semaphore: Arc::new(Semaphore::new(max_connections)),
300            idle_timeout,
301            read_timeout,
302        })
303    }
304
305    pub fn socket_path(&self) -> &str {
306        &self.socket_path
307    }
308
309    /// Register an execution and return a guard that auto-unregisters on drop
310    /// This prevents memory leaks from forgotten unregister calls
311    ///
312    /// # Arguments
313    /// * `execution_id` - Unique identifier for this execution
314    /// * `emit_traces` - If false, skips channel creation and trace collection for better performance
315    pub async fn register_execution(
316        &self,
317        execution_id: String,
318        emit_traces: bool,
319    ) -> ExecutionGuard {
320        // Only create channel when traces are needed - saves allocation and channel overhead
321        let (tx, rx) = if emit_traces {
322            let (tx, rx) = mpsc::channel(1);
323            (Some(tx), Some(rx))
324        } else {
325            (None, None)
326        };
327
328        let ctx = ExecutionContext {
329            traces_tx: tx,
330            created_at: Instant::now(),
331            bound_execution_id: execution_id.clone(),
332        };
333
334        // scc::HashMap insert - returns Ok if new, Err if key existed (duplicate)
335        let registered = match self.executions.insert(execution_id.clone(), ctx) {
336            Ok(_) => {
337                self.active_count.fetch_add(1, Ordering::AcqRel);
338                true
339            }
340            Err((existing_key, _)) => {
341                tracing::warn!(
342                    execution_id = %existing_key,
343                    "Duplicate execution_id detected during registration, guard will not decrement counter"
344                );
345                false
346            }
347        };
348
349        ExecutionGuard {
350            execution_id,
351            executions: self.executions.clone(),
352            rx,
353            registered,
354            active_count: self.active_count.clone(),
355        }
356    }
357
358    /// Get current number of available connection slots
359    pub fn available_connection_slots(&self) -> usize {
360        self.connection_semaphore.available_permits()
361    }
362
363    /// Get current active connection count
364    pub fn active_connection_count(&self) -> usize {
365        get_config().socket_max_connections - self.connection_semaphore.available_permits()
366    }
367
368    /// Get current number of registered executions (lock-free via atomic counter)
369    pub async fn registered_executions_count(&self) -> usize {
370        self.active_count.load(Ordering::Relaxed)
371    }
372
373    /// Signal shutdown to the listener and wait for active connections to drain
374    pub async fn shutdown(&self) {
375        let _ = self.shutdown_tx.send(true);
376        info!("Shared socket service: shutdown signal sent");
377
378        // Wait for active connections to drain (max 30 seconds)
379        let max_wait = Duration::from_secs(30);
380        let start = Instant::now();
381
382        while start.elapsed() < max_wait {
383            let available = self.connection_semaphore.available_permits();
384            if available == get_config().socket_max_connections {
385                // All permits returned - no active connections
386                break;
387            }
388            tokio::time::sleep(Duration::from_millis(100)).await;
389        }
390
391        // Remove socket file after connections drained
392        let _ = std::fs::remove_file(&self.socket_path);
393        info!("Shared socket service: shutdown complete");
394    }
395
396    /// Start the shared socket service
397    /// This spawns a background task that listens for connections
398    /// Safe to call multiple times - will only start once per instance
399    #[allow(clippy::type_complexity)]
400    pub async fn start<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(
401        self: Arc<Self>,
402        state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>>,
403    ) -> Result<(), PluginError>
404    where
405        J: JobProducerTrait + Send + Sync + 'static,
406        RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
407        TR: TransactionRepository
408            + Repository<TransactionRepoModel, String>
409            + Send
410            + Sync
411            + 'static,
412        NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
413        NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
414        SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
415        TCR: TransactionCounterTrait + Send + Sync + 'static,
416        PR: PluginRepositoryTrait + Send + Sync + 'static,
417        AKR: ApiKeyRepositoryTrait + Send + Sync + 'static,
418    {
419        // Check if already started (instance-level flag)
420        if self.started.swap(true, Ordering::Acquire) {
421            return Ok(());
422        }
423
424        // Create the listener and move it into the task
425        let listener = UnixListener::bind(&self.socket_path)
426            .map_err(|e| PluginError::SocketError(format!("Failed to bind listener: {e}")))?;
427        let executions = self.executions.clone();
428        let relayer_api = Arc::new(RelayerApi);
429        let socket_path = self.socket_path.clone();
430        let mut shutdown_rx = self.shutdown_tx.subscribe();
431        let connection_semaphore = self.connection_semaphore.clone();
432        let idle_timeout = self.idle_timeout;
433        let read_timeout = self.read_timeout;
434
435        debug!(
436            "Shared socket service: starting listener on {}",
437            socket_path
438        );
439
440        // Spawn the listener task
441        tokio::spawn(async move {
442            debug!("Shared socket service: listener task started");
443            loop {
444                tokio::select! {
445                    // Check for shutdown signal
446                    _ = shutdown_rx.changed() => {
447                        if *shutdown_rx.borrow() {
448                            info!("Shared socket service: shutting down listener");
449                            break;
450                        }
451                    }
452                    // Accept new connections
453                    accept_result = listener.accept() => {
454                        match accept_result {
455                            Ok((stream, _)) => {
456                                // Try to acquire semaphore permit (no race condition!)
457                                match connection_semaphore.clone().try_acquire_owned() {
458                                    Ok(permit) => {
459                                        debug!("Shared socket service: accepted new connection");
460
461                                        let relayer_api_clone = relayer_api.clone();
462                                        let state_clone = Arc::clone(&state);
463                                        let executions_clone = executions.clone();
464
465                                        tokio::spawn(async move {
466                                            // Permit held until task completes (auto-released on drop)
467                                            let _permit = permit;
468
469                                            let result = Self::handle_connection_with_timeout(
470                                                stream,
471                                                relayer_api_clone,
472                                                state_clone,
473                                                executions_clone,
474                                                idle_timeout,
475                                                read_timeout,
476                                            )
477                                            .await;
478
479                                            if let Err(e) = result {
480                                                debug!("Connection handler finished with error: {}", e);
481                                            }
482                                        });
483                                    }
484                                    Err(_) => {
485                                        warn!(
486                                            "Connection limit reached, rejecting new connection. \
487                                            Consider increasing PLUGIN_MAX_CONCURRENCY or PLUGIN_SOCKET_MAX_CONCURRENT_CONNECTIONS."
488                                        );
489                                        drop(stream);
490                                    }
491                                }
492                            }
493                            Err(e) => {
494                                warn!("Error accepting connection: {}", e);
495                            }
496                        }
497                    }
498                }
499            }
500
501            // Cleanup on shutdown
502            let _ = std::fs::remove_file(&socket_path);
503            info!("Shared socket service: listener stopped");
504        });
505
506        Ok(())
507    }
508
509    /// Handle connection with overall idle timeout
510    #[allow(clippy::type_complexity)]
511    async fn handle_connection_with_timeout<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(
512        stream: UnixStream,
513        relayer_api: Arc<RelayerApi>,
514        state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>>,
515        executions: Arc<SccHashMap<String, ExecutionContext>>,
516        idle_timeout: Duration,
517        read_timeout: Duration,
518    ) -> Result<(), PluginError>
519    where
520        J: JobProducerTrait + Send + Sync + 'static,
521        RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
522        TR: TransactionRepository
523            + Repository<TransactionRepoModel, String>
524            + Send
525            + Sync
526            + 'static,
527        NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
528        NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
529        SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
530        TCR: TransactionCounterTrait + Send + Sync + 'static,
531        PR: PluginRepositoryTrait + Send + Sync + 'static,
532        AKR: ApiKeyRepositoryTrait + Send + Sync + 'static,
533    {
534        // Wrap the entire connection handling with an idle timeout
535        match tokio::time::timeout(
536            idle_timeout,
537            Self::handle_connection(stream, relayer_api, state, executions, read_timeout),
538        )
539        .await
540        {
541            Ok(result) => result,
542            Err(_) => {
543                debug!("Connection idle timeout reached");
544                Ok(())
545            }
546        }
547    }
548
549    /// Handle a connection from a plugin
550    ///
551    /// Security: The first message must be a Register message. Once registered,
552    /// the connection is "tagged" with that execution_id and cannot be changed.
553    /// This prevents Plugin A from spoofing Plugin B's execution_id.
554    #[allow(clippy::type_complexity)]
555    async fn handle_connection<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(
556        stream: UnixStream,
557        relayer_api: Arc<RelayerApi>,
558        state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>>,
559        executions: Arc<SccHashMap<String, ExecutionContext>>,
560        read_timeout: Duration,
561    ) -> Result<(), PluginError>
562    where
563        J: JobProducerTrait + Send + Sync + 'static,
564        RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
565        TR: TransactionRepository
566            + Repository<TransactionRepoModel, String>
567            + Send
568            + Sync
569            + 'static,
570        NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
571        NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
572        SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
573        TCR: TransactionCounterTrait + Send + Sync + 'static,
574        PR: PluginRepositoryTrait + Send + Sync + 'static,
575        AKR: ApiKeyRepositoryTrait + Send + Sync + 'static,
576    {
577        let (r, mut w) = stream.into_split();
578        let mut reader = BufReader::new(r).lines();
579
580        // Only allocate traces Vec when tracing is enabled (determined on Register)
581        let mut traces: Option<Vec<serde_json::Value>> = None;
582        // Track whether traces are enabled for this connection (set on Register)
583        let mut traces_enabled = false;
584
585        // Connection-bound execution_id (prevents spoofing)
586        // Once set, this cannot be changed for the lifetime of the connection
587        let mut bound_execution_id: Option<String> = None;
588
589        loop {
590            // Read line with timeout to prevent hanging connections
591            let line = match tokio::time::timeout(read_timeout, reader.next_line()).await {
592                Ok(Ok(Some(line))) => line,
593                Ok(Ok(None)) => break, // EOF
594                Ok(Err(e)) => {
595                    warn!("Error reading from connection: {}", e);
596                    break;
597                }
598                Err(_) => {
599                    debug!("Read timeout on connection");
600                    break;
601                }
602            };
603
604            debug!("Shared socket service: received message");
605
606            // Parse once, discriminate on "type" field for efficiency
607            let json_value: serde_json::Value = match serde_json::from_str(&line) {
608                Ok(v) => v,
609                Err(e) => {
610                    warn!("Failed to parse JSON: {}", e);
611                    continue;
612                }
613            };
614
615            let has_type_field = json_value.get("type").is_some();
616
617            if has_type_field {
618                // New unified protocol
619                let message: PluginMessage = match serde_json::from_value(json_value) {
620                    Ok(msg) => msg,
621                    Err(e) => {
622                        warn!("Failed to parse PluginMessage: {}", e);
623                        continue;
624                    }
625                };
626
627                // Handle message based on type
628                match message {
629                    PluginMessage::Register { execution_id } => {
630                        // First message must be Register
631                        if bound_execution_id.is_some() {
632                            warn!("Attempted to re-register connection (security violation)");
633                            break;
634                        }
635
636                        // Validate execution_id exists in registry and check if tracing is enabled
637                        // scc::HashMap read() is lock-free
638                        if let Some(has_traces) =
639                            executions.read(&execution_id, |_, ctx| ctx.traces_tx.is_some())
640                        {
641                            traces_enabled = has_traces;
642                        } else {
643                            warn!("Unknown execution_id: {}", execution_id);
644                            break;
645                        }
646
647                        debug!(
648                            execution_id = %execution_id,
649                            traces_enabled = traces_enabled,
650                            "Connection registered"
651                        );
652                        bound_execution_id = Some(execution_id);
653                    }
654
655                    PluginMessage::ApiRequest {
656                        request_id,
657                        relayer_id,
658                        method,
659                        payload,
660                    } => {
661                        // Must be registered first
662                        let exec_id = match &bound_execution_id {
663                            Some(id) => id,
664                            None => {
665                                warn!("ApiRequest before Register (security violation)");
666                                break;
667                            }
668                        };
669
670                        // Create Request for RelayerApi (method is already PluginMethod)
671                        let request = Request {
672                            request_id: request_id.clone(),
673                            relayer_id,
674                            method,
675                            payload,
676                            http_request_id: Some(exec_id.clone()),
677                        };
678
679                        // Handle the request
680                        let response = relayer_api.handle_request(request, &state).await;
681
682                        // Send ApiResponse back
683                        let api_response = PluginMessage::ApiResponse {
684                            request_id: response.request_id,
685                            result: response.result,
686                            error: response.error,
687                        };
688
689                        let response_str = serde_json::to_string(&api_response)
690                            .map_err(|e| PluginError::PluginError(e.to_string()))?
691                            + "\n";
692
693                        if let Err(e) = w.write_all(response_str.as_bytes()).await {
694                            warn!("Failed to write API response: {}", e);
695                            break;
696                        }
697
698                        if let Err(e) = w.flush().await {
699                            warn!("Failed to flush API response: {}", e);
700                            break;
701                        }
702                    }
703
704                    PluginMessage::Trace { trace } => {
705                        // Only collect traces if tracing is enabled for this execution
706                        if traces_enabled {
707                            if traces.is_none() {
708                                traces = Some(Vec::new());
709                            }
710                            if let Some(ref mut t) = traces {
711                                t.push(trace);
712                            }
713                        }
714                        // When traces_enabled=false, silently discard trace messages
715                    }
716
717                    PluginMessage::Shutdown => {
718                        debug!("Plugin requested shutdown");
719                        break;
720                    }
721
722                    PluginMessage::ApiResponse { .. } => {
723                        warn!("Received ApiResponse from plugin (invalid direction)");
724                        continue;
725                    }
726                }
727            } else {
728                // Legacy protocol (no "type" field)
729                if let Ok(request) = serde_json::from_value::<Request>(json_value.clone()) {
730                    // Legacy format - API requests are not trace events
731
732                    // Set execution_id from http_request_id or request_id if not bound
733                    if bound_execution_id.is_none() {
734                        let candidate_id = request
735                            .http_request_id
736                            .clone()
737                            .or_else(|| Some(request.request_id.clone()));
738
739                        // Validate execution_id exists (same as new protocol)
740                        // scc::HashMap read() is lock-free
741                        if let Some(ref id) = candidate_id {
742                            if let Some(has_traces) =
743                                executions.read(id, |_, ctx| ctx.traces_tx.is_some())
744                            {
745                                traces_enabled = has_traces;
746                                bound_execution_id = candidate_id;
747                            } else {
748                                debug!("Legacy request with unknown execution_id: {}", id);
749                            }
750                        }
751                    }
752
753                    // Handle legacy request
754                    let response = relayer_api.handle_request(request, &state).await;
755                    let response_str = serde_json::to_string(&response)
756                        .map_err(|e| PluginError::PluginError(e.to_string()))?
757                        + "\n";
758
759                    if let Err(e) = w.write_all(response_str.as_bytes()).await {
760                        warn!("Failed to write response: {}", e);
761                        break;
762                    }
763
764                    if let Err(e) = w.flush().await {
765                        warn!("Failed to flush response: {}", e);
766                        break;
767                    }
768                } else {
769                    warn!("Failed to parse message as either PluginMessage or legacy Request");
770                }
771            }
772        }
773
774        // Send traces back to caller if tracing was enabled
775        if traces_enabled {
776            if let Some(exec_id) = bound_execution_id {
777                // Get the sender from execution context (lock-free read)
778                let traces_tx = executions
779                    .read(&exec_id, |_, ctx| ctx.traces_tx.clone())
780                    .flatten();
781
782                if let Some(tx) = traces_tx {
783                    let collected_traces = traces.unwrap_or_default();
784                    let trace_count = collected_traces.len();
785                    // Short timeout: in-process channel send should be nearly instant
786                    // If receiver isn't ready in 100ms, drop traces rather than blocking
787                    match tokio::time::timeout(
788                        Duration::from_millis(100),
789                        tx.send(collected_traces),
790                    )
791                    .await
792                    {
793                        Ok(Ok(())) => {}
794                        Ok(Err(_)) => {
795                            if trace_count > 0 {
796                                warn!(
797                                    "Trace channel closed for execution_id: {} ({} traces lost)",
798                                    exec_id, trace_count
799                                );
800                            }
801                        }
802                        Err(_) => warn!("Timeout sending traces for execution_id: {}", exec_id),
803                    }
804                }
805            }
806        }
807        // When traces_enabled=false, no channel exists and we skip all trace-related work
808
809        debug!("Shared socket service: connection closed");
810        Ok(())
811    }
812}
813
814impl Drop for SharedSocketService {
815    fn drop(&mut self) {
816        // Signal shutdown (cleanup happens in shutdown() method)
817        let _ = self.shutdown_tx.send(true);
818        // Note: Socket file cleanup happens in shutdown() after connections drain
819        // Drop can't be async, so proper cleanup should use shutdown() method
820    }
821}
822
823/// Global shared socket service instance with proper error handling
824static SHARED_SOCKET: std::sync::OnceLock<Result<Arc<SharedSocketService>, String>> =
825    std::sync::OnceLock::new();
826
827/// Get or create the global shared socket service
828/// Returns error if initialization fails instead of panicking
829pub fn get_shared_socket_service() -> Result<Arc<SharedSocketService>, PluginError> {
830    let socket_path = "/tmp/relayer-plugin-shared.sock";
831
832    let result = SHARED_SOCKET.get_or_init(|| {
833        // Remove existing socket file if it exists (from previous runs)
834        let _ = std::fs::remove_file(socket_path);
835
836        match SharedSocketService::new(socket_path) {
837            Ok(service) => Ok(Arc::new(service)),
838            Err(e) => Err(e.to_string()),
839        }
840    });
841
842    match result {
843        Ok(service) => Ok(service.clone()),
844        Err(e) => Err(PluginError::SocketError(format!(
845            "Failed to create shared socket service: {e}"
846        ))),
847    }
848}
849
850/// Ensure the shared socket service is started
851#[allow(clippy::type_complexity)]
852pub async fn ensure_shared_socket_started<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>(
853    state: Arc<ThinDataAppState<J, RR, TR, NR, NFR, SR, TCR, PR, AKR>>,
854) -> Result<(), PluginError>
855where
856    J: JobProducerTrait + Send + Sync + 'static,
857    RR: RelayerRepository + Repository<RelayerRepoModel, String> + Send + Sync + 'static,
858    TR: TransactionRepository + Repository<TransactionRepoModel, String> + Send + Sync + 'static,
859    NR: NetworkRepository + Repository<NetworkRepoModel, String> + Send + Sync + 'static,
860    NFR: Repository<NotificationRepoModel, String> + Send + Sync + 'static,
861    SR: Repository<SignerRepoModel, String> + Send + Sync + 'static,
862    TCR: TransactionCounterTrait + Send + Sync + 'static,
863    PR: PluginRepositoryTrait + Send + Sync + 'static,
864    AKR: ApiKeyRepositoryTrait + Send + Sync + 'static,
865{
866    let service = get_shared_socket_service()?;
867    service.start(state).await
868}
869
870#[cfg(test)]
871mod tests {
872    use super::*;
873    use crate::utils::mocks::mockutils::create_mock_app_state;
874    use actix_web::web;
875    use tempfile::tempdir;
876    use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
877    use tokio::net::UnixStream;
878
879    #[tokio::test]
880    async fn test_unified_protocol_register_and_api_request() {
881        let temp_dir = tempdir().unwrap();
882        let socket_path = temp_dir.path().join("shared.sock");
883
884        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
885        let state = create_mock_app_state(None, None, None, None, None, None).await;
886
887        // Start the service
888        service
889            .clone()
890            .start(Arc::new(web::ThinData(state)))
891            .await
892            .unwrap();
893
894        // Register execution
895        let execution_id = "test-exec-123".to_string();
896        let _guard = service.register_execution(execution_id.clone(), true).await;
897
898        // Give the listener time to start
899        tokio::time::sleep(Duration::from_millis(50)).await;
900
901        // Connect as plugin
902        let mut client = UnixStream::connect(socket_path.to_str().unwrap())
903            .await
904            .unwrap();
905
906        // Send Register message
907        let register_msg = PluginMessage::Register {
908            execution_id: execution_id.clone(),
909        };
910        let msg_json = serde_json::to_string(&register_msg).unwrap() + "\n";
911        client.write_all(msg_json.as_bytes()).await.unwrap();
912
913        // Send ApiRequest
914        let api_request = PluginMessage::ApiRequest {
915            request_id: "req-1".to_string(),
916            relayer_id: "relayer-1".to_string(),
917            method: crate::services::plugins::relayer_api::PluginMethod::GetRelayerStatus,
918            payload: serde_json::json!({}),
919        };
920        let req_json = serde_json::to_string(&api_request).unwrap() + "\n";
921        client.write_all(req_json.as_bytes()).await.unwrap();
922        client.flush().await.unwrap();
923
924        // Read ApiResponse
925        let (r, _w) = client.into_split();
926        let mut reader = BufReader::new(r);
927        let mut response_line = String::new();
928        reader.read_line(&mut response_line).await.unwrap();
929
930        let response: PluginMessage = serde_json::from_str(&response_line).unwrap();
931        match response {
932            PluginMessage::ApiResponse { request_id, .. } => {
933                assert_eq!(request_id, "req-1");
934            }
935            _ => panic!("Expected ApiResponse, got {response:?}"),
936        }
937
938        service.shutdown().await;
939    }
940
941    #[tokio::test]
942    async fn test_connection_tagging_prevents_spoofing() {
943        let temp_dir = tempdir().unwrap();
944        let socket_path = temp_dir.path().join("shared2.sock");
945
946        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
947        let state = create_mock_app_state(None, None, None, None, None, None).await;
948
949        service
950            .clone()
951            .start(Arc::new(web::ThinData(state)))
952            .await
953            .unwrap();
954
955        let execution_id = "test-exec-456".to_string();
956        let _guard = service.register_execution(execution_id.clone(), true).await;
957
958        tokio::time::sleep(Duration::from_millis(50)).await;
959
960        let mut client = UnixStream::connect(socket_path.to_str().unwrap())
961            .await
962            .unwrap();
963
964        // Register with execution_id
965        let register_msg = PluginMessage::Register {
966            execution_id: execution_id.clone(),
967        };
968        let msg_json = serde_json::to_string(&register_msg).unwrap() + "\n";
969        client.write_all(msg_json.as_bytes()).await.unwrap();
970
971        // Try to re-register with different execution_id (security violation)
972        let spoofed_register = PluginMessage::Register {
973            execution_id: "different-exec-id".to_string(),
974        };
975        let spoofed_json = serde_json::to_string(&spoofed_register).unwrap() + "\n";
976        client.write_all(spoofed_json.as_bytes()).await.unwrap();
977        client.flush().await.unwrap();
978
979        // Connection should be closed by server
980        tokio::time::sleep(Duration::from_millis(100)).await;
981
982        // Try to read - should get EOF since connection was closed
983        let (r, _w) = client.into_split();
984        let mut reader = BufReader::new(r);
985        let mut line = String::new();
986        let result = reader.read_line(&mut line).await;
987
988        // Should either get an error or EOF (0 bytes)
989        assert!(result.is_err() || result.unwrap() == 0);
990
991        service.shutdown().await;
992    }
993
994    #[tokio::test]
995    async fn test_backward_compatibility_with_legacy_format() {
996        let temp_dir = tempdir().unwrap();
997        let socket_path = temp_dir.path().join("shared3.sock");
998
999        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1000        let state = create_mock_app_state(None, None, None, None, None, None).await;
1001
1002        service
1003            .clone()
1004            .start(Arc::new(web::ThinData(state)))
1005            .await
1006            .unwrap();
1007
1008        let execution_id = "test-exec-789".to_string();
1009        let _guard = service.register_execution(execution_id.clone(), true).await;
1010
1011        tokio::time::sleep(Duration::from_millis(50)).await;
1012
1013        let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1014            .await
1015            .unwrap();
1016
1017        // Send legacy Request format (without PluginMessage wrapper)
1018        let legacy_request = crate::services::plugins::relayer_api::Request {
1019            request_id: "legacy-1".to_string(),
1020            relayer_id: "relayer-1".to_string(),
1021            method: crate::services::plugins::relayer_api::PluginMethod::GetRelayerStatus,
1022            payload: serde_json::json!({}),
1023            http_request_id: Some(execution_id.clone()),
1024        };
1025        let legacy_json = serde_json::to_string(&legacy_request).unwrap() + "\n";
1026        client.write_all(legacy_json.as_bytes()).await.unwrap();
1027        client.flush().await.unwrap();
1028
1029        // Read legacy Response format
1030        let (r, _w) = client.into_split();
1031        let mut reader = BufReader::new(r);
1032        let mut response_line = String::new();
1033        reader.read_line(&mut response_line).await.unwrap();
1034
1035        let response: crate::services::plugins::relayer_api::Response =
1036            serde_json::from_str(&response_line).unwrap();
1037
1038        assert_eq!(response.request_id, "legacy-1");
1039        // Note: GetRelayerStatus might return an error if relayer doesn't exist
1040        // The important thing is we got a response in the correct format
1041        assert!(response.result.is_some() || response.error.is_some());
1042
1043        service.shutdown().await;
1044    }
1045
1046    #[tokio::test]
1047    async fn test_trace_collection() {
1048        let temp_dir = tempdir().unwrap();
1049        let socket_path = temp_dir.path().join("shared4.sock");
1050
1051        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1052        let state = create_mock_app_state(None, None, None, None, None, None).await;
1053
1054        service
1055            .clone()
1056            .start(Arc::new(web::ThinData(state)))
1057            .await
1058            .unwrap();
1059
1060        let execution_id = "test-exec-trace".to_string();
1061        let guard = service.register_execution(execution_id.clone(), true).await;
1062
1063        tokio::time::sleep(Duration::from_millis(50)).await;
1064
1065        let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1066            .await
1067            .unwrap();
1068
1069        // Register
1070        let register_msg = PluginMessage::Register {
1071            execution_id: execution_id.clone(),
1072        };
1073        client
1074            .write_all((serde_json::to_string(&register_msg).unwrap() + "\n").as_bytes())
1075            .await
1076            .unwrap();
1077
1078        // Send trace events
1079        let trace1 = PluginMessage::Trace {
1080            trace: serde_json::json!({"event": "start", "timestamp": 1000}),
1081        };
1082        client
1083            .write_all((serde_json::to_string(&trace1).unwrap() + "\n").as_bytes())
1084            .await
1085            .unwrap();
1086
1087        let trace2 = PluginMessage::Trace {
1088            trace: serde_json::json!({"event": "processing", "timestamp": 2000}),
1089        };
1090        client
1091            .write_all((serde_json::to_string(&trace2).unwrap() + "\n").as_bytes())
1092            .await
1093            .unwrap();
1094
1095        // Shutdown
1096        let shutdown_msg = PluginMessage::Shutdown;
1097        client
1098            .write_all((serde_json::to_string(&shutdown_msg).unwrap() + "\n").as_bytes())
1099            .await
1100            .unwrap();
1101        client.flush().await.unwrap();
1102
1103        drop(client);
1104
1105        // Wait for connection to close and traces to be sent
1106        tokio::time::sleep(Duration::from_millis(100)).await;
1107
1108        // Collect traces
1109        let mut traces_rx = guard.into_receiver().expect("Traces should be enabled");
1110        let traces = traces_rx.recv().await.unwrap();
1111
1112        // Should have collected 2 trace events
1113        assert_eq!(traces.len(), 2);
1114        assert_eq!(traces[0]["event"], "start");
1115        assert_eq!(traces[1]["event"], "processing");
1116
1117        service.shutdown().await;
1118    }
1119
1120    #[tokio::test]
1121    async fn test_execution_guard_auto_unregister() {
1122        let temp_dir = tempdir().unwrap();
1123        let socket_path = temp_dir.path().join("shared_guard.sock");
1124
1125        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1126        let execution_id = "test-exec-guard".to_string();
1127
1128        {
1129            let _guard = service.register_execution(execution_id.clone(), true).await;
1130
1131            // Verify execution is registered (use atomic counter)
1132            assert_eq!(service.registered_executions_count().await, 1);
1133        }
1134        // Guard dropped here - synchronous removal with scc (no sleep needed!)
1135
1136        // Verify execution was auto-unregistered immediately
1137        assert_eq!(service.registered_executions_count().await, 0);
1138    }
1139
1140    #[tokio::test]
1141    async fn test_api_request_without_register_rejected() {
1142        let temp_dir = tempdir().unwrap();
1143        let socket_path = temp_dir.path().join("shared_no_register.sock");
1144
1145        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1146        let state = create_mock_app_state(None, None, None, None, None, None).await;
1147
1148        service
1149            .clone()
1150            .start(Arc::new(web::ThinData(state)))
1151            .await
1152            .unwrap();
1153
1154        tokio::time::sleep(Duration::from_millis(50)).await;
1155
1156        let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1157            .await
1158            .unwrap();
1159
1160        // Send ApiRequest WITHOUT registering first (security violation)
1161        let api_request = PluginMessage::ApiRequest {
1162            request_id: "req-1".to_string(),
1163            relayer_id: "relayer-1".to_string(),
1164            method: crate::services::plugins::relayer_api::PluginMethod::GetRelayerStatus,
1165            payload: serde_json::json!({}),
1166        };
1167        let req_json = serde_json::to_string(&api_request).unwrap() + "\n";
1168        client.write_all(req_json.as_bytes()).await.unwrap();
1169        client.flush().await.unwrap();
1170
1171        // Connection should be closed by server
1172        tokio::time::sleep(Duration::from_millis(100)).await;
1173
1174        let (r, _w) = client.into_split();
1175        let mut reader = BufReader::new(r);
1176        let mut line = String::new();
1177        let result = reader.read_line(&mut line).await;
1178
1179        // Should get EOF (connection closed)
1180        assert!(result.is_err() || result.unwrap() == 0);
1181
1182        service.shutdown().await;
1183    }
1184
1185    #[tokio::test]
1186    async fn test_register_with_unknown_execution_id_rejected() {
1187        let temp_dir = tempdir().unwrap();
1188        let socket_path = temp_dir.path().join("shared_unknown_exec.sock");
1189
1190        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1191        let state = create_mock_app_state(None, None, None, None, None, None).await;
1192
1193        service
1194            .clone()
1195            .start(Arc::new(web::ThinData(state)))
1196            .await
1197            .unwrap();
1198
1199        tokio::time::sleep(Duration::from_millis(50)).await;
1200
1201        let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1202            .await
1203            .unwrap();
1204
1205        // Try to register with an execution_id that doesn't exist in registry
1206        let register_msg = PluginMessage::Register {
1207            execution_id: "unknown-exec-id".to_string(),
1208        };
1209        let msg_json = serde_json::to_string(&register_msg).unwrap() + "\n";
1210        client.write_all(msg_json.as_bytes()).await.unwrap();
1211        client.flush().await.unwrap();
1212
1213        // Connection should be closed
1214        tokio::time::sleep(Duration::from_millis(100)).await;
1215
1216        let (r, _w) = client.into_split();
1217        let mut reader = BufReader::new(r);
1218        let mut line = String::new();
1219        let result = reader.read_line(&mut line).await;
1220
1221        assert!(result.is_err() || result.unwrap() == 0);
1222
1223        service.shutdown().await;
1224    }
1225
1226    #[tokio::test]
1227    async fn test_connection_limit_enforcement() {
1228        let temp_dir = tempdir().unwrap();
1229        let socket_path = temp_dir.path().join("shared_connection_limit.sock");
1230
1231        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1232        let state = create_mock_app_state(None, None, None, None, None, None).await;
1233
1234        service
1235            .clone()
1236            .start(Arc::new(web::ThinData(state)))
1237            .await
1238            .unwrap();
1239
1240        tokio::time::sleep(Duration::from_millis(50)).await;
1241
1242        // Check initial connection count
1243        let initial_permits = service.connection_semaphore.available_permits();
1244        let max_connections = get_config().socket_max_connections;
1245        assert_eq!(initial_permits, max_connections);
1246
1247        // Create a connection (should reduce available permits)
1248        let _client = UnixStream::connect(socket_path.to_str().unwrap())
1249            .await
1250            .unwrap();
1251
1252        tokio::time::sleep(Duration::from_millis(50)).await;
1253
1254        // Available permits should be reduced
1255        let after_connect = service.connection_semaphore.available_permits();
1256        assert!(after_connect < initial_permits);
1257
1258        service.shutdown().await;
1259    }
1260
1261    #[tokio::test]
1262    async fn test_idle_timeout() {
1263        let temp_dir = tempdir().unwrap();
1264        let socket_path = temp_dir.path().join("shared_idle_timeout.sock");
1265
1266        // Create service with short idle timeout for testing
1267        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1268        let state = create_mock_app_state(None, None, None, None, None, None).await;
1269
1270        service
1271            .clone()
1272            .start(Arc::new(web::ThinData(state)))
1273            .await
1274            .unwrap();
1275
1276        let execution_id = "test-exec-idle".to_string();
1277        let _guard = service.register_execution(execution_id.clone(), true).await;
1278
1279        tokio::time::sleep(Duration::from_millis(50)).await;
1280
1281        let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1282            .await
1283            .unwrap();
1284
1285        // Register
1286        let register_msg = PluginMessage::Register { execution_id };
1287        client
1288            .write_all((serde_json::to_string(&register_msg).unwrap() + "\n").as_bytes())
1289            .await
1290            .unwrap();
1291        client.flush().await.unwrap();
1292
1293        // Wait longer than idle timeout (configured in service)
1294        // Note: idle_timeout is from config, but we can test that connection stays alive
1295        // within a reasonable time
1296        tokio::time::sleep(Duration::from_millis(100)).await;
1297
1298        // Connection should still be alive if we're within timeout
1299        // Send a Shutdown message to verify connection is still up
1300        let shutdown_msg = PluginMessage::Shutdown;
1301        let write_result = client
1302            .write_all((serde_json::to_string(&shutdown_msg).unwrap() + "\n").as_bytes())
1303            .await;
1304
1305        assert!(write_result.is_ok(), "Connection should still be alive");
1306
1307        service.shutdown().await;
1308    }
1309
1310    #[tokio::test]
1311    async fn test_read_timeout_handling() {
1312        let temp_dir = tempdir().unwrap();
1313        let socket_path = temp_dir.path().join("shared_read_timeout.sock");
1314
1315        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1316        let state = create_mock_app_state(None, None, None, None, None, None).await;
1317
1318        service
1319            .clone()
1320            .start(Arc::new(web::ThinData(state)))
1321            .await
1322            .unwrap();
1323
1324        let execution_id = "test-exec-read-timeout".to_string();
1325        let _guard = service.register_execution(execution_id.clone(), true).await;
1326
1327        tokio::time::sleep(Duration::from_millis(50)).await;
1328
1329        let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1330            .await
1331            .unwrap();
1332
1333        // Register
1334        let register_msg = PluginMessage::Register { execution_id };
1335        client
1336            .write_all((serde_json::to_string(&register_msg).unwrap() + "\n").as_bytes())
1337            .await
1338            .unwrap();
1339        client.flush().await.unwrap();
1340
1341        // Don't send anything else - connection should be cleaned up after read timeout
1342        // Read timeout is configured in service (from config)
1343
1344        // Wait a bit (but not as long as full timeout)
1345        tokio::time::sleep(Duration::from_millis(200)).await;
1346
1347        // Connection should still be valid for a short time
1348        drop(client);
1349
1350        service.shutdown().await;
1351    }
1352
1353    #[tokio::test]
1354    async fn test_multiple_api_requests_same_connection() {
1355        let temp_dir = tempdir().unwrap();
1356        let socket_path = temp_dir.path().join("shared_multiple_requests.sock");
1357
1358        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1359        let state = create_mock_app_state(None, None, None, None, None, None).await;
1360
1361        service
1362            .clone()
1363            .start(Arc::new(web::ThinData(state)))
1364            .await
1365            .unwrap();
1366
1367        let execution_id = "test-exec-multi".to_string();
1368        let _guard = service.register_execution(execution_id.clone(), true).await;
1369
1370        tokio::time::sleep(Duration::from_millis(50)).await;
1371
1372        let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1373            .await
1374            .unwrap();
1375
1376        // Register
1377        let register_msg = PluginMessage::Register {
1378            execution_id: execution_id.clone(),
1379        };
1380        client
1381            .write_all((serde_json::to_string(&register_msg).unwrap() + "\n").as_bytes())
1382            .await
1383            .unwrap();
1384
1385        let (r, mut w) = client.into_split();
1386        let mut reader = BufReader::new(r);
1387
1388        // Send multiple API requests
1389        for i in 1..=3 {
1390            let api_request = PluginMessage::ApiRequest {
1391                request_id: format!("req-{i}"),
1392                relayer_id: "relayer-1".to_string(),
1393                method: crate::services::plugins::relayer_api::PluginMethod::GetRelayerStatus,
1394                payload: serde_json::json!({}),
1395            };
1396            w.write_all((serde_json::to_string(&api_request).unwrap() + "\n").as_bytes())
1397                .await
1398                .unwrap();
1399            w.flush().await.unwrap();
1400
1401            // Read response
1402            let mut response_line = String::new();
1403            reader.read_line(&mut response_line).await.unwrap();
1404
1405            let response: PluginMessage = serde_json::from_str(&response_line).unwrap();
1406            match response {
1407                PluginMessage::ApiResponse { request_id, .. } => {
1408                    assert_eq!(request_id, format!("req-{i}"));
1409                }
1410                _ => panic!("Expected ApiResponse"),
1411            }
1412        }
1413
1414        service.shutdown().await;
1415    }
1416
1417    #[tokio::test]
1418    async fn test_shutdown_signal() {
1419        let temp_dir = tempdir().unwrap();
1420        let socket_path = temp_dir.path().join("shared_shutdown_signal.sock");
1421
1422        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1423        let state = create_mock_app_state(None, None, None, None, None, None).await;
1424
1425        service
1426            .clone()
1427            .start(Arc::new(web::ThinData(state)))
1428            .await
1429            .unwrap();
1430
1431        tokio::time::sleep(Duration::from_millis(50)).await;
1432
1433        // Verify socket file exists
1434        assert!(std::path::Path::new(socket_path.to_str().unwrap()).exists());
1435
1436        // Shutdown the service
1437        service.shutdown().await;
1438
1439        // Give time for cleanup
1440        tokio::time::sleep(Duration::from_millis(100)).await;
1441
1442        // Socket file should be removed
1443        assert!(!std::path::Path::new(socket_path.to_str().unwrap()).exists());
1444    }
1445
1446    #[tokio::test]
1447    async fn test_malformed_json_handling() {
1448        let temp_dir = tempdir().unwrap();
1449        let socket_path = temp_dir.path().join("shared_malformed.sock");
1450
1451        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1452        let state = create_mock_app_state(None, None, None, None, None, None).await;
1453
1454        service
1455            .clone()
1456            .start(Arc::new(web::ThinData(state)))
1457            .await
1458            .unwrap();
1459
1460        let execution_id = "test-exec-malformed".to_string();
1461        let _guard = service.register_execution(execution_id.clone(), true).await;
1462
1463        tokio::time::sleep(Duration::from_millis(50)).await;
1464
1465        let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1466            .await
1467            .unwrap();
1468
1469        // Register first
1470        let register_msg = PluginMessage::Register {
1471            execution_id: execution_id.clone(),
1472        };
1473        client
1474            .write_all((serde_json::to_string(&register_msg).unwrap() + "\n").as_bytes())
1475            .await
1476            .unwrap();
1477
1478        // Send malformed JSON
1479        client
1480            .write_all(b"{ this is not valid json }\n")
1481            .await
1482            .unwrap();
1483        client.flush().await.unwrap();
1484
1485        // Connection should remain open (malformed messages are logged and skipped)
1486        tokio::time::sleep(Duration::from_millis(100)).await;
1487
1488        // Send valid shutdown message to verify connection is still up
1489        let shutdown_msg = PluginMessage::Shutdown;
1490        let write_result = client
1491            .write_all((serde_json::to_string(&shutdown_msg).unwrap() + "\n").as_bytes())
1492            .await;
1493
1494        assert!(
1495            write_result.is_ok(),
1496            "Connection should still be alive after malformed JSON"
1497        );
1498
1499        service.shutdown().await;
1500    }
1501
1502    #[tokio::test]
1503    async fn test_invalid_message_direction() {
1504        let temp_dir = tempdir().unwrap();
1505        let socket_path = temp_dir.path().join("shared_invalid_direction.sock");
1506
1507        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1508        let state = create_mock_app_state(None, None, None, None, None, None).await;
1509
1510        service
1511            .clone()
1512            .start(Arc::new(web::ThinData(state)))
1513            .await
1514            .unwrap();
1515
1516        let execution_id = "test-exec-invalid-dir".to_string();
1517        let _guard = service.register_execution(execution_id.clone(), true).await;
1518
1519        tokio::time::sleep(Duration::from_millis(50)).await;
1520
1521        let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1522            .await
1523            .unwrap();
1524
1525        // Register
1526        let register_msg = PluginMessage::Register {
1527            execution_id: execution_id.clone(),
1528        };
1529        client
1530            .write_all((serde_json::to_string(&register_msg).unwrap() + "\n").as_bytes())
1531            .await
1532            .unwrap();
1533
1534        // Plugin tries to send ApiResponse (invalid direction - only Host sends ApiResponse)
1535        let invalid_msg = PluginMessage::ApiResponse {
1536            request_id: "invalid".to_string(),
1537            result: Some(serde_json::json!({})),
1538            error: None,
1539        };
1540        client
1541            .write_all((serde_json::to_string(&invalid_msg).unwrap() + "\n").as_bytes())
1542            .await
1543            .unwrap();
1544        client.flush().await.unwrap();
1545
1546        // Connection should remain open (invalid messages are logged and skipped)
1547        tokio::time::sleep(Duration::from_millis(100)).await;
1548
1549        // Verify connection is still alive
1550        let shutdown_msg = PluginMessage::Shutdown;
1551        let write_result = client
1552            .write_all((serde_json::to_string(&shutdown_msg).unwrap() + "\n").as_bytes())
1553            .await;
1554
1555        assert!(write_result.is_ok());
1556
1557        service.shutdown().await;
1558    }
1559
1560    #[tokio::test]
1561    async fn test_stale_execution_cleanup() {
1562        let temp_dir = tempdir().unwrap();
1563        let socket_path = temp_dir.path().join("shared_stale_cleanup.sock");
1564
1565        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1566
1567        // Register an execution manually with old timestamp
1568        let execution_id = "stale-exec".to_string();
1569        let (tx, _rx) = mpsc::channel(1);
1570        // scc::HashMap insert
1571        let _ = service.executions.insert(
1572            execution_id.clone(),
1573            ExecutionContext {
1574                traces_tx: Some(tx),
1575                created_at: Instant::now() - Duration::from_secs(400), // 6+ minutes old
1576                bound_execution_id: execution_id.clone(),
1577            },
1578        );
1579
1580        // Verify it's registered using scc's contains()
1581        assert!(service.executions.contains(&execution_id));
1582
1583        // Wait for cleanup task to run (it runs every 60 seconds, but we can't wait that long)
1584        // Instead, we verify the cleanup logic by checking the code in new()
1585        // The actual cleanup test would require mocking time or waiting 60+ seconds
1586
1587        // For this test, we just verify the logic exists and doesn't panic
1588        drop(service);
1589    }
1590
1591    #[tokio::test]
1592    async fn test_socket_path_getter() {
1593        let temp_dir = tempdir().unwrap();
1594        let socket_path = temp_dir.path().join("shared_path.sock");
1595
1596        let service = SharedSocketService::new(socket_path.to_str().unwrap()).unwrap();
1597
1598        assert_eq!(service.socket_path(), socket_path.to_str().unwrap());
1599    }
1600
1601    #[tokio::test]
1602    async fn test_trace_send_timeout() {
1603        let temp_dir = tempdir().unwrap();
1604        let socket_path = temp_dir.path().join("shared_trace_timeout.sock");
1605
1606        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1607        let state = create_mock_app_state(None, None, None, None, None, None).await;
1608
1609        service
1610            .clone()
1611            .start(Arc::new(web::ThinData(state)))
1612            .await
1613            .unwrap();
1614
1615        let execution_id = "test-exec-trace-timeout".to_string();
1616        let guard = service.register_execution(execution_id.clone(), true).await;
1617
1618        // Don't consume the receiver - this will cause the channel to fill up
1619        drop(guard);
1620
1621        tokio::time::sleep(Duration::from_millis(50)).await;
1622
1623        let mut client = UnixStream::connect(socket_path.to_str().unwrap())
1624            .await
1625            .unwrap();
1626
1627        // Register
1628        let register_msg = PluginMessage::Register {
1629            execution_id: execution_id.clone(),
1630        };
1631        client
1632            .write_all((serde_json::to_string(&register_msg).unwrap() + "\n").as_bytes())
1633            .await
1634            .unwrap();
1635
1636        // Send trace
1637        let trace = PluginMessage::Trace {
1638            trace: serde_json::json!({"event": "test"}),
1639        };
1640        client
1641            .write_all((serde_json::to_string(&trace).unwrap() + "\n").as_bytes())
1642            .await
1643            .unwrap();
1644
1645        // Shutdown
1646        let shutdown_msg = PluginMessage::Shutdown;
1647        client
1648            .write_all((serde_json::to_string(&shutdown_msg).unwrap() + "\n").as_bytes())
1649            .await
1650            .unwrap();
1651        client.flush().await.unwrap();
1652
1653        drop(client);
1654
1655        // Wait for connection to close - should handle timeout gracefully
1656        tokio::time::sleep(Duration::from_millis(200)).await;
1657
1658        service.shutdown().await;
1659    }
1660
1661    #[tokio::test]
1662    async fn test_get_shared_socket_service() {
1663        // Test the global singleton
1664        let service1 = get_shared_socket_service();
1665        assert!(service1.is_ok());
1666
1667        let service2 = get_shared_socket_service();
1668        assert!(service2.is_ok());
1669
1670        // Should return the same instance
1671        let svc1 = service1.unwrap();
1672        let svc2 = service2.unwrap();
1673        let path1 = svc1.socket_path();
1674        let path2 = svc2.socket_path();
1675        assert_eq!(path1, path2);
1676    }
1677
1678    #[tokio::test]
1679    async fn test_duplicate_execution_id_does_not_corrupt_counter() {
1680        let temp_dir = tempdir().unwrap();
1681        let socket_path = temp_dir.path().join("shared_duplicate_exec.sock");
1682
1683        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1684        let execution_id = "duplicate-exec-id".to_string();
1685
1686        // Initial count should be 0
1687        assert_eq!(service.registered_executions_count().await, 0);
1688
1689        // Register first execution
1690        let guard1 = service.register_execution(execution_id.clone(), true).await;
1691        assert_eq!(service.registered_executions_count().await, 1);
1692
1693        // Try to register with same execution_id (duplicate)
1694        // This should NOT increment the counter (insertion will fail)
1695        let guard2 = service.register_execution(execution_id.clone(), true).await;
1696        // Counter should still be 1 (not 2)
1697        assert_eq!(service.registered_executions_count().await, 1);
1698
1699        // Drop the duplicate guard first - should NOT decrement counter
1700        // (because it was never successfully registered)
1701        drop(guard2);
1702        assert_eq!(service.registered_executions_count().await, 1);
1703
1704        // Drop the original guard - should decrement counter
1705        drop(guard1);
1706        assert_eq!(service.registered_executions_count().await, 0);
1707    }
1708
1709    #[tokio::test]
1710    async fn test_execution_guard_registered_field() {
1711        let temp_dir = tempdir().unwrap();
1712        let socket_path = temp_dir.path().join("shared_registered_field.sock");
1713
1714        let service = Arc::new(SharedSocketService::new(socket_path.to_str().unwrap()).unwrap());
1715
1716        // Register a unique execution_id
1717        let execution_id_1 = "unique-exec-1".to_string();
1718        let guard1 = service
1719            .register_execution(execution_id_1.clone(), true)
1720            .await;
1721        assert_eq!(service.registered_executions_count().await, 1);
1722
1723        // Register another unique execution_id
1724        let execution_id_2 = "unique-exec-2".to_string();
1725        let guard2 = service
1726            .register_execution(execution_id_2.clone(), false)
1727            .await;
1728        assert_eq!(service.registered_executions_count().await, 2);
1729
1730        // into_receiver should work regardless of registered status
1731        let rx = guard1.into_receiver();
1732        assert!(rx.is_some()); // emit_traces=true
1733
1734        // guard2 had emit_traces=false
1735        let rx2 = guard2.into_receiver();
1736        assert!(rx2.is_none()); // emit_traces=false
1737
1738        // After guards are consumed via into_receiver, counter should be decremented
1739        assert_eq!(service.registered_executions_count().await, 0);
1740    }
1741}