1use std::{collections::HashMap, path::PathBuf, sync::Arc};
7
8use async_trait::async_trait;
9use base64::Engine;
10use bytes::Bytes;
11use serde::{Deserialize, Serialize};
12use tokio::sync::RwLock;
13use tracing::{debug, info, instrument};
14
15use crate::{
16 Transport::{
17 Strategy::{TransportStats, TransportStrategy, TransportType},
18 TransportConfig,
19 },
20 WASM::{
21 HostBridge::HostBridgeImpl,
22 MemoryManager::{MemoryLimits, MemoryManagerImpl},
23 Runtime::{WASMConfig, WASMRuntime},
24 WASMStats,
25 },
26};
27
28#[derive(Clone, Debug)]
30pub struct WASMTransportImpl {
31 runtime:Arc<WASMRuntime>,
33 memory_manager:Arc<RwLock<MemoryManagerImpl>>,
35 bridge:Arc<HostBridgeImpl>,
37 modules:Arc<RwLock<HashMap<String, WASMModuleInfo>>>,
39 #[allow(dead_code)]
41 config:TransportConfig,
42 connected:Arc<RwLock<bool>>,
44 stats:Arc<RwLock<TransportStats>>,
46}
47
48#[derive(Debug, Clone)]
50pub struct WASMModuleInfo {
51 pub id:String,
53 pub name:Option<String>,
55 pub path:Option<PathBuf>,
57 pub loaded_at:u64,
59 pub function_stats:HashMap<String, FunctionCallStats>,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct FunctionCallStats {
66 pub call_count:u64,
68 pub total_time_us:u64,
70 pub last_call_at:Option<u64>,
72 pub error_count:u64,
74}
75
76impl FunctionCallStats {
77 pub fn record_call(&mut self, time_us:u64) {
79 self.call_count += 1;
80 self.total_time_us += time_us;
81 self.last_call_at = Some(
82 std::time::SystemTime::now()
83 .duration_since(std::time::UNIX_EPOCH)
84 .map(|d| d.as_secs())
85 .unwrap_or(0),
86 );
87 }
88
89 pub fn record_error(&mut self) { self.error_count += 1; }
91}
92
93impl Default for FunctionCallStats {
94 fn default() -> Self { Self { call_count:0, total_time_us:0, last_call_at:None, error_count:0 } }
95}
96
97impl WASMTransportImpl {
98 pub fn new(enable_wasi:bool, memory_limit_mb:u64, max_execution_time_ms:u64) -> anyhow::Result<Self> {
100 let config = WASMConfig::new(memory_limit_mb, max_execution_time_ms, enable_wasi);
101
102 let runtime_result = tokio::runtime::Runtime::new()
105 .map_err(|e| anyhow::anyhow!("Failed to create tokio runtime: {}", e))?
106 .block_on(WASMRuntime::new(config.clone()))
107 .map_err(|e| anyhow::anyhow!("Failed to create WASM runtime: {}", e))?;
108 let runtime = Arc::new(runtime_result);
109
110 let memory_limits = MemoryLimits::new(memory_limit_mb, (memory_limit_mb as f64 * 0.75) as u64, 100);
111 let memory_manager = Arc::new(RwLock::new(MemoryManagerImpl::new(memory_limits)));
112 let bridge = Arc::new(HostBridgeImpl::new());
113
114 Ok(Self {
115 runtime,
116 memory_manager,
117 bridge,
118 modules:Arc::new(RwLock::new(HashMap::new())),
119 config:TransportConfig::default(),
120 connected:Arc::new(RwLock::new(true)), stats:Arc::new(RwLock::new(TransportStats::default())),
122 })
123 }
124
125 pub fn with_config(wasm_config:WASMConfig, transport_config:TransportConfig) -> anyhow::Result<Self> {
127 let runtime_result = tokio::runtime::Runtime::new()
128 .map_err(|e| anyhow::anyhow!("Failed to create tokio runtime: {}", e))?
129 .block_on(WASMRuntime::new(wasm_config.clone()))
130 .map_err(|e| anyhow::anyhow!("Failed to create WASM runtime: {}", e))?;
131 let runtime = Arc::new(runtime_result);
132
133 let memory_limits = MemoryLimits::new(
134 wasm_config.memory_limit_mb,
135 (wasm_config.memory_limit_mb as f64 * 0.75) as u64,
136 100,
137 );
138 let memory_manager = Arc::new(RwLock::new(MemoryManagerImpl::new(memory_limits)));
139 let bridge = Arc::new(HostBridgeImpl::new());
140
141 Ok(Self {
142 runtime,
143 memory_manager,
144 bridge,
145 modules:Arc::new(RwLock::new(HashMap::new())),
146 config:transport_config,
147 connected:Arc::new(RwLock::new(true)),
148 stats:Arc::new(RwLock::new(TransportStats::default())),
149 })
150 }
151
152 pub fn runtime(&self) -> &Arc<WASMRuntime> { &self.runtime }
154
155 pub fn memory_manager(&self) -> &Arc<RwLock<MemoryManagerImpl>> { &self.memory_manager }
157
158 pub fn bridge(&self) -> &Arc<HostBridgeImpl> { &self.bridge }
160
161 pub async fn get_modules(&self) -> HashMap<String, WASMModuleInfo> { self.modules.read().await.clone() }
163
164 pub async fn get_wasm_stats(&self) -> WASMStats {
166 let memory_manager = self.memory_manager.read().await;
167 let managers = self.modules.read().await;
168
169 WASMStats {
170 modules_loaded:managers.len(),
171 active_instances:managers.len(), total_memory_mb:memory_manager.current_usage_mb() as u64,
173 total_execution_time_ms:0, function_calls:self.stats.read().await.messages_sent,
175 }
176 }
177
178 #[instrument(skip(self, module_id, function_name, args))]
180 pub async fn call_wasm_function(
181 &self,
182 module_id:&str,
183 function_name:&str,
184 args:Vec<Bytes>,
185 ) -> anyhow::Result<Bytes> {
186 let start = std::time::Instant::now();
187
188 debug!(
189 "Calling WASM function: {}::{} with {} arguments",
190 module_id,
191 function_name,
192 args.len()
193 );
194
195 let modules = self.modules.read().await;
196 let _module = modules
197 .get(module_id)
198 .ok_or_else(|| anyhow::anyhow!("Module not found: {}", module_id))?;
199
200 let response = Bytes::new();
203
204 let mut modules_mut = self.modules.write().await;
206 if let Some(module) = modules_mut.get_mut(module_id) {
207 let stats = module.function_stats.entry(function_name.to_string()).or_default();
208 stats.record_call(start.elapsed().as_micros() as u64);
209 }
210
211 drop(modules_mut);
212
213 let mut stats = self.stats.write().await;
215 stats.record_sent(args.iter().map(|b| b.len() as u64).sum(), start.elapsed().as_micros() as u64);
216 stats.record_received(response.len() as u64);
217
218 Ok(response)
219 }
220}
221
222#[async_trait]
223impl TransportStrategy for WASMTransportImpl {
224 type Error = WASMTransportError;
225
226 #[instrument(skip(self))]
227 async fn connect(&self) -> Result<(), Self::Error> {
228 info!("WASM transport connecting");
229
230 *self.connected.write().await = true;
232
233 info!("WASM transport connected");
234
235 Ok(())
236 }
237
238 #[instrument(skip(self, request))]
239 async fn send(&self, request:&[u8]) -> Result<Vec<u8>, Self::Error> {
240 let start = std::time::Instant::now();
241
242 if !self.is_connected() {
243 return Err(WASMTransportError::NotConnected);
244 }
245
246 debug!("Sending WASM transport request ({} bytes)", request.len());
247
248 let request_str =
251 std::str::from_utf8(request).map_err(|e| WASMTransportError::InvalidRequest(e.to_string()))?;
252
253 let parts:Vec<&str> = request_str.splitn(3, ':').collect();
254 if parts.len() < 3 {
255 return Err(WASMTransportError::InvalidRequest("Invalid request format".to_string()));
256 }
257
258 let module_id = parts[0];
259 let function_name = parts[1];
260 let args_base64 = parts[2];
261
262 use base64::engine::general_purpose::STANDARD;
264 let args = vec![Bytes::from(
265 STANDARD
266 .decode(args_base64)
267 .map_err(|e| WASMTransportError::InvalidRequest(e.to_string()))?,
268 )];
269
270 let response = self
272 .call_wasm_function(module_id, function_name, args)
273 .await
274 .map_err(|e| WASMTransportError::FunctionCallFailed(e.to_string()))?;
275
276 let response_vec = response.to_vec();
278
279 let latency_us = start.elapsed().as_micros() as u64;
280
281 debug!("WASM transport request completed in {}µs", latency_us);
282
283 Ok(response_vec)
284 }
285
286 #[instrument(skip(self, data))]
287 async fn send_no_response(&self, data:&[u8]) -> Result<(), Self::Error> {
288 if !self.is_connected() {
289 return Err(WASMTransportError::NotConnected);
290 }
291
292 debug!("Sending WASM transport request without response ({} bytes)", data.len());
293
294 self.send(data).await?;
296 Ok(())
297 }
298
299 #[instrument(skip(self))]
300 async fn close(&self) -> Result<(), Self::Error> {
301 info!("Closing WASM transport");
302
303 *self.connected.write().await = false;
304
305 info!("WASM transport closed");
306
307 Ok(())
308 }
309
310 fn is_connected(&self) -> bool { self.connected.blocking_read().to_owned() }
311
312 fn transport_type(&self) -> TransportType { TransportType::WASM }
313}
314
315#[derive(Debug, thiserror::Error)]
317pub enum WASMTransportError {
318 #[error("Module not found: {0}")]
320 ModuleNotFound(String),
321
322 #[error("Function not found: {0}")]
324 FunctionNotFound(String),
325
326 #[error("Function call failed: {0}")]
328 FunctionCallFailed(String),
329
330 #[error("Memory error: {0}")]
332 MemoryError(String),
333
334 #[error("Runtime error: {0}")]
336 RuntimeError(String),
337
338 #[error("Invalid request: {0}")]
340 InvalidRequest(String),
341
342 #[error("Not connected")]
344 NotConnected,
345
346 #[error("Compilation failed: {0}")]
348 CompilationFailed(String),
349
350 #[error("Timeout")]
352 Timeout,
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use crate::Transport::Strategy::TransportStrategy;
359
360 #[test]
361 fn test_wasm_transport_creation() {
362 let result = WASMTransportImpl::new(true, 512, 30000);
363 assert!(result.is_ok());
364 let transport = result.unwrap();
365 assert!(transport.is_connected());
367 }
368
369 #[test]
370 fn test_function_call_stats() {
371 let mut stats = FunctionCallStats::default();
372 stats.record_call(100);
373 assert_eq!(stats.call_count, 1);
374 assert_eq!(stats.total_time_us, 100);
375 assert!(stats.last_call_at.is_some());
376 }
377
378 #[tokio::test]
379 async fn test_wasm_transport_not_connected_after_close() {
380 let transport = WASMTransportImpl::new(true, 512, 30000).unwrap();
381 let _:anyhow::Result<()> = transport.close().await.map_err(|e| anyhow::anyhow!(e.to_string()));
382 assert!(!transport.is_connected());
383 }
384
385 #[tokio::test]
386 async fn test_get_wasm_stats() {
387 let transport = WASMTransportImpl::new(true, 512, 30000).unwrap();
388 let stats = transport.get_wasm_stats().await;
389 assert_eq!(stats.modules_loaded, 0);
390 assert_eq!(stats.active_instances, 0);
391 }
392}