openzeppelin_relayer/services/plugins/
protocol.rs

1//! Protocol types for pool server communication.
2//!
3//! Defines the JSON-line protocol messages exchanged between
4//! the Rust pool executor and Node.js pool server via Unix socket.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9use super::{LogEntry, LogLevel};
10
11/// Execute request payload (boxed to reduce enum size)
12#[derive(Serialize, Debug)]
13pub struct ExecuteRequest {
14    #[serde(rename = "taskId")]
15    pub task_id: String,
16    #[serde(rename = "pluginId")]
17    pub plugin_id: String,
18    #[serde(rename = "compiledCode", skip_serializing_if = "Option::is_none")]
19    pub compiled_code: Option<String>,
20    #[serde(rename = "pluginPath", skip_serializing_if = "Option::is_none")]
21    pub plugin_path: Option<String>,
22    pub params: serde_json::Value,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub headers: Option<HashMap<String, Vec<String>>>,
25    #[serde(rename = "socketPath")]
26    pub socket_path: String,
27    #[serde(rename = "httpRequestId", skip_serializing_if = "Option::is_none")]
28    pub http_request_id: Option<String>,
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub timeout: Option<u64>,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub route: Option<String>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub config: Option<serde_json::Value>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub method: Option<String>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub query: Option<serde_json::Value>,
39}
40
41/// Request messages sent to the pool server
42#[derive(Serialize, Debug)]
43#[serde(tag = "type", rename_all = "lowercase")]
44pub enum PoolRequest {
45    Execute(Box<ExecuteRequest>),
46    Precompile {
47        #[serde(rename = "taskId")]
48        task_id: String,
49        #[serde(rename = "pluginId")]
50        plugin_id: String,
51        #[serde(rename = "pluginPath", skip_serializing_if = "Option::is_none")]
52        plugin_path: Option<String>,
53        #[serde(rename = "sourceCode", skip_serializing_if = "Option::is_none")]
54        source_code: Option<String>,
55    },
56    Cache {
57        #[serde(rename = "taskId")]
58        task_id: String,
59        #[serde(rename = "pluginId")]
60        plugin_id: String,
61        #[serde(rename = "compiledCode")]
62        compiled_code: String,
63    },
64    Invalidate {
65        #[serde(rename = "taskId")]
66        task_id: String,
67        #[serde(rename = "pluginId")]
68        plugin_id: String,
69    },
70    Stats {
71        #[serde(rename = "taskId")]
72        task_id: String,
73    },
74    Health {
75        #[serde(rename = "taskId")]
76        task_id: String,
77    },
78    Shutdown {
79        #[serde(rename = "taskId")]
80        task_id: String,
81    },
82}
83
84/// Response from the pool server
85#[derive(Deserialize, Debug)]
86pub struct PoolResponse {
87    #[serde(rename = "taskId")]
88    pub task_id: String,
89    pub success: bool,
90    pub result: Option<serde_json::Value>,
91    pub error: Option<PoolError>,
92    pub logs: Option<Vec<PoolLogEntry>>,
93}
94
95/// Error details from the pool server
96#[derive(Deserialize, Debug)]
97pub struct PoolError {
98    pub message: String,
99    pub code: Option<String>,
100    pub status: Option<u16>,
101    pub details: Option<serde_json::Value>,
102}
103
104/// Log entry from plugin execution
105#[derive(Deserialize, Debug)]
106pub struct PoolLogEntry {
107    pub level: String,
108    pub message: String,
109}
110
111impl From<PoolLogEntry> for LogEntry {
112    fn from(entry: PoolLogEntry) -> Self {
113        let level = match entry.level.as_str() {
114            "error" => LogLevel::Error,
115            "warn" => LogLevel::Warn,
116            "info" => LogLevel::Info,
117            "debug" => LogLevel::Debug,
118            "result" => LogLevel::Result,
119            _ => LogLevel::Log,
120        };
121        LogEntry {
122            level,
123            message: entry.message,
124        }
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn test_pool_request_execute_serialization() {
134        let request = PoolRequest::Execute(Box::new(ExecuteRequest {
135            task_id: "test-123".to_string(),
136            plugin_id: "my-plugin".to_string(),
137            compiled_code: Some("console.log('hello')".to_string()),
138            plugin_path: None,
139            params: serde_json::json!({"key": "value"}),
140            headers: None,
141            socket_path: "/tmp/test.sock".to_string(),
142            http_request_id: Some("req-456".to_string()),
143            timeout: Some(30000),
144            route: None,
145            config: None,
146            method: None,
147            query: None,
148        }));
149
150        let json = serde_json::to_string(&request).unwrap();
151        assert!(json.contains("\"type\":\"execute\""));
152        assert!(json.contains("\"taskId\":\"test-123\""));
153        assert!(json.contains("\"pluginId\":\"my-plugin\""));
154    }
155
156    #[test]
157    fn test_pool_request_precompile_serialization() {
158        let request = PoolRequest::Precompile {
159            task_id: "precompile-123".to_string(),
160            plugin_id: "test-plugin".to_string(),
161            plugin_path: Some("/plugins/test.ts".to_string()),
162            source_code: None,
163        };
164
165        let json = serde_json::to_string(&request).unwrap();
166        assert!(json.contains("\"type\":\"precompile\""));
167        assert!(json.contains("\"taskId\":\"precompile-123\""));
168        assert!(json.contains("\"pluginPath\":\"/plugins/test.ts\""));
169    }
170
171    #[test]
172    fn test_pool_request_cache_serialization() {
173        let request = PoolRequest::Cache {
174            task_id: "cache-123".to_string(),
175            plugin_id: "test-plugin".to_string(),
176            compiled_code: "compiled code here".to_string(),
177        };
178
179        let json = serde_json::to_string(&request).unwrap();
180        assert!(json.contains("\"type\":\"cache\""));
181        assert!(json.contains("\"compiledCode\":\"compiled code here\""));
182    }
183
184    #[test]
185    fn test_pool_request_invalidate_serialization() {
186        let request = PoolRequest::Invalidate {
187            task_id: "inv-123".to_string(),
188            plugin_id: "test-plugin".to_string(),
189        };
190
191        let json = serde_json::to_string(&request).unwrap();
192        assert!(json.contains("\"type\":\"invalidate\""));
193    }
194
195    #[test]
196    fn test_pool_request_stats_serialization() {
197        let request = PoolRequest::Stats {
198            task_id: "stats-123".to_string(),
199        };
200
201        let json = serde_json::to_string(&request).unwrap();
202        assert!(json.contains("\"type\":\"stats\""));
203    }
204
205    #[test]
206    fn test_pool_request_health_serialization() {
207        let request = PoolRequest::Health {
208            task_id: "health-123".to_string(),
209        };
210
211        let json = serde_json::to_string(&request).unwrap();
212        assert!(json.contains("\"type\":\"health\""));
213    }
214
215    #[test]
216    fn test_pool_request_shutdown_serialization() {
217        let request = PoolRequest::Shutdown {
218            task_id: "shutdown-123".to_string(),
219        };
220
221        let json = serde_json::to_string(&request).unwrap();
222        assert!(json.contains("\"type\":\"shutdown\""));
223    }
224
225    #[test]
226    fn test_pool_response_success_deserialization() {
227        let json = r#"{
228            "taskId": "test-123",
229            "success": true,
230            "result": {"data": "hello"},
231            "logs": [{"level": "info", "message": "test log"}]
232        }"#;
233
234        let response: PoolResponse = serde_json::from_str(json).unwrap();
235        assert_eq!(response.task_id, "test-123");
236        assert!(response.success);
237        assert!(response.error.is_none());
238        assert!(response
239            .logs
240            .as_ref()
241            .map(|l| !l.is_empty())
242            .unwrap_or(false));
243    }
244
245    #[test]
246    fn test_pool_response_error_deserialization() {
247        let json = r#"{
248            "taskId": "test-123",
249            "success": false,
250            "error": {
251                "message": "Plugin failed",
252                "code": "EXEC_ERROR",
253                "status": 500
254            },
255            "logs": []
256        }"#;
257
258        let response: PoolResponse = serde_json::from_str(json).unwrap();
259        assert_eq!(response.task_id, "test-123");
260        assert!(!response.success);
261        assert!(response.error.is_some());
262        let err = response.error.unwrap();
263        assert_eq!(err.message, "Plugin failed");
264        assert_eq!(err.code, Some("EXEC_ERROR".to_string()));
265        assert_eq!(err.status, Some(500));
266    }
267
268    #[test]
269    fn test_pool_log_entry_conversion() {
270        let pool_entry = PoolLogEntry {
271            level: "error".to_string(),
272            message: "test error".to_string(),
273        };
274
275        let log_entry: LogEntry = pool_entry.into();
276        assert!(matches!(log_entry.level, LogLevel::Error));
277        assert_eq!(log_entry.message, "test error");
278    }
279
280    #[test]
281    fn test_pool_log_entry_level_conversion() {
282        let levels = vec![
283            ("log", LogLevel::Log),
284            ("info", LogLevel::Info),
285            ("warn", LogLevel::Warn),
286            ("error", LogLevel::Error),
287            ("debug", LogLevel::Debug),
288            ("unknown", LogLevel::Log),
289        ];
290
291        for (input, expected) in levels {
292            let pool_entry = PoolLogEntry {
293                level: input.to_string(),
294                message: "test".to_string(),
295            };
296            let log_entry: LogEntry = pool_entry.into();
297            assert!(
298                matches!(log_entry.level, ref e if std::mem::discriminant(e) == std::mem::discriminant(&expected)),
299                "Expected {:?} for input '{}', got {:?}",
300                expected,
301                input,
302                log_entry.level
303            );
304        }
305    }
306}