1use std::{collections::HashMap, sync::Arc};
7
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use tokio::sync::RwLock;
11use tracing::{debug, info, instrument, warn};
12use wasmtime::{Caller, Linker};
13
14use crate::WASM::HostBridge::{FunctionSignature, HostBridgeImpl as HostBridge, HostFunctionCallback};
15
16pub struct HostFunctionRegistry {
18 functions:Arc<RwLock<HashMap<String, RegisteredHostFunction>>>,
20 #[allow(dead_code)]
22 bridge:Arc<HostBridge>,
23}
24
25#[derive(Debug, Clone)]
27struct RegisteredHostFunction {
28 #[allow(dead_code)]
30 name:String,
31 #[allow(dead_code)]
33 signature:FunctionSignature,
34 callback:Option<HostFunctionCallback>,
36 #[allow(dead_code)]
38 registered_at:u64,
39 stats:FunctionStats,
41}
42
43#[derive(Debug, Clone, Default)]
45pub struct FunctionStats {
46 pub call_count:u64,
48 pub total_execution_ns:u64,
50 pub last_call_at:Option<u64>,
52 pub error_count:u64,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ExportConfig {
59 pub auto_export:bool,
61 pub enable_stats:bool,
63 pub max_functions:usize,
65 pub name_prefix:Option<String>,
67}
68
69impl Default for ExportConfig {
70 fn default() -> Self {
71 Self {
72 auto_export:true,
73 enable_stats:true,
74 max_functions:1000,
75 name_prefix:Some("host_".to_string()),
76 }
77 }
78}
79
80pub struct FunctionExportImpl {
82 registry:Arc<HostFunctionRegistry>,
83 config:ExportConfig,
84}
85
86impl FunctionExportImpl {
87 pub fn new(bridge:Arc<HostBridge>) -> Self {
89 Self {
90 registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
91 config:ExportConfig::default(),
92 }
93 }
94
95 pub fn with_config(bridge:Arc<HostBridge>, config:ExportConfig) -> Self {
97 Self {
98 registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
99 config,
100 }
101 }
102
103 #[instrument(skip(self, callback))]
105 pub async fn register_function(
106 &self,
107 name:&str,
108 signature:FunctionSignature,
109 callback:HostFunctionCallback,
110 ) -> Result<()> {
111 info!("Registering host function for export: {}", name);
112
113 let functions = self.registry.functions.read().await;
114
115 if functions.len() >= self.config.max_functions {
117 return Err(anyhow::anyhow!(
118 "Maximum number of exported functions reached: {}",
119 self.config.max_functions
120 ));
121 }
122
123 drop(functions);
124
125 let mut functions = self.registry.functions.write().await;
126
127 let registered_at = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
128
129 functions.insert(
130 name.to_string(),
131 RegisteredHostFunction {
132 name:name.to_string(),
133 signature,
134 callback:Some(callback),
135 registered_at,
136 stats:FunctionStats::default(),
137 },
138 );
139
140 debug!("Host function registered for WASM export: {}", name);
141 Ok(())
142 }
143
144 #[instrument(skip(self, callbacks))]
146 pub async fn register_functions(
147 &self,
148 signatures:Vec<FunctionSignature>,
149 callbacks:Vec<HostFunctionCallback>,
150 ) -> Result<()> {
151 if signatures.len() != callbacks.len() {
152 return Err(anyhow::anyhow!("Number of signatures must match number of callbacks"));
153 }
154
155 for (sig, callback) in signatures.into_iter().zip(callbacks) {
156 let name = sig.name.clone();
157 self.register_function(&name, sig, callback).await?;
158 }
159
160 Ok(())
161 }
162
163 #[instrument(skip(self, linker))]
165 pub async fn export_to_linker<T>(&self, linker:&mut Linker<T>) -> Result<()>
166 where
167 T: Send + 'static, {
168 info!(
169 "Exporting {} host functions to linker",
170 self.registry.functions.read().await.len()
171 );
172
173 let functions = self.registry.functions.read().await;
174
175 for (name, func) in functions.iter() {
176 self.export_single_function(linker, name, func)?;
177 }
178
179 info!("All host functions exported to linker");
180 Ok(())
181 }
182
183 fn export_single_function<T>(&self, linker:&mut Linker<T>, name:&str, func:&RegisteredHostFunction) -> Result<()>
185 where
186 T: Send + 'static, {
187 debug!("Exporting function: {}", name);
188
189 let callback = func
190 .callback
191 .ok_or_else(|| anyhow::anyhow!("No callback available for function: {}", name))?;
192
193 let func_name = if let Some(prefix) = &self.config.name_prefix {
194 format!("{}{}", prefix, name)
195 } else {
196 name.to_string()
197 };
198
199 let func_name_for_debug = func_name.clone();
200 let func_name_inner = func_name.clone();
201
202 let _wrapped_callback =
204 move |_caller:Caller<'_, T>, args:&[wasmtime::Val]| -> Result<Vec<wasmtime::Val>, wasmtime::Trap> {
205 let _start = std::time::Instant::now();
206
207 let args_bytes:Result<Vec<bytes::Bytes>, _> = args
209 .iter()
210 .map(|arg| {
211 match arg {
212 wasmtime::Val::I32(i) => {
213 serde_json::to_vec(i)
214 .map(bytes::Bytes::from)
215 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
216 },
217 wasmtime::Val::I64(i) => {
218 serde_json::to_vec(i)
219 .map(bytes::Bytes::from)
220 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
221 },
222 wasmtime::Val::F32(f) => {
223 serde_json::to_vec(f)
224 .map(bytes::Bytes::from)
225 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
226 },
227 wasmtime::Val::F64(f) => {
228 serde_json::to_vec(f)
229 .map(bytes::Bytes::from)
230 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
231 },
232 _ => Err(anyhow::anyhow!("Unsupported argument type")),
233 }
234 })
235 .collect();
236
237 let args_bytes = args_bytes.map_err(|_| {
238 warn!("Error converting arguments for function '{}'", func_name_inner);
239 wasmtime::Trap::StackOverflow
240 })?;
241
242 let result = callback(args_bytes);
244
245 match result {
246 Ok(response_bytes) => {
247 let result_val:serde_json::Value = serde_json::from_slice(&response_bytes).map_err(|_| {
249 warn!("Error deserializing response for function '{}'", func_name_inner);
250 wasmtime::Trap::StackOverflow
251 })?;
252
253 let ret_val = match result_val {
254 serde_json::Value::Number(n) => {
255 if let Some(i) = n.as_i64() {
256 wasmtime::Val::I32(i as i32)
257 } else if let Some(f) = n.as_f64() {
258 wasmtime::Val::I64(f as i64)
259 } else {
260 warn!("Invalid number format for function '{}'", func_name_inner);
261 return Err(wasmtime::Trap::StackOverflow);
262 }
263 },
264 _ => {
265 warn!("Unsupported response type for function '{}'", func_name_inner);
266 return Err(wasmtime::Trap::StackOverflow);
267 },
268 };
269
270 Ok(vec![ret_val])
271 },
272 Err(e) => {
273 debug!("Host function '{}' returned error: {}", func_name_inner, e);
275 Err(wasmtime::Trap::StackOverflow)
276 },
277 }
278 };
279
280 let _wasmparser_signature = wasmparser::FuncType::new([wasmparser::ValType::I32], [wasmparser::ValType::I32]);
282
283 let func_name_for_logging = func_name.clone();
287 linker
288 .func_wrap(
289 "_host", &func_name,
291 move |_caller:wasmtime::Caller<'_, T>, input_param:i32| -> i32 {
292 let start = std::time::Instant::now();
294
295 let args_bytes = match serde_json::to_vec(&input_param).map(bytes::Bytes::from) {
297 Ok(b) => b,
298 Err(e) => {
299 warn!("Serialization error for function '{}': {}", func_name_for_logging, e);
300 return -1i32;
301 },
302 };
303
304 let result = callback(vec![args_bytes]);
306
307 match result {
308 Ok(response_bytes) => {
309 let result_val:serde_json::Value = match serde_json::from_slice(&response_bytes) {
311 Ok(v) => v,
312 Err(_) => {
313 warn!("Error deserializing response for function '{}'", func_name_for_logging);
314 return -1i32;
315 },
316 };
317
318 let ret_val = match result_val {
320 serde_json::Value::Number(n) => {
321 if let Some(i) = n.as_i64() {
322 i as i32
323 } else if let Some(f) = n.as_f64() {
324 f as i32
325 } else {
326 warn!("Invalid number format for function '{}'", func_name_for_logging);
327 -1i32
328 }
329 },
330 serde_json::Value::Bool(b) => {
331 if b {
332 1i32
333 } else {
334 0i32
335 }
336 },
337 _ => {
338 warn!(
339 "Unsupported response type for function '{}', expected number or bool",
340 func_name_for_logging
341 );
342 -1i32
343 },
344 };
345
346 let duration = start.elapsed();
348 debug!(
349 "[FunctionExport] Host function '{}' executed successfully in {}µs",
350 func_name_for_logging,
351 duration.as_micros()
352 );
353
354 ret_val
355 },
356 Err(e) => {
357 debug!(
359 "[FunctionExport] Host function '{}' returned error: {}",
360 func_name_for_logging, e
361 );
362 -1i32
364 },
365 }
366 },
367 )
368 .map_err(|e| {
369 warn!("[FunctionExport] Failed to wrap host function '{}': {}", func_name_for_debug, e);
370 e
371 })?;
372
373 debug!(
374 "[FunctionExport] Host function '{}' registered successfully",
375 func_name_for_debug
376 );
377
378 Ok(())
379 }
380
381 #[allow(dead_code)]
383 fn wasmtime_signature_from_signature(&self, _sig:&FunctionSignature) -> Result<wasmparser::FuncType> {
384 Ok(wasmparser::FuncType::new([], []))
387 }
388
389 pub async fn get_function_names(&self) -> Vec<String> {
391 self.registry.functions.read().await.keys().cloned().collect()
392 }
393
394 pub async fn get_function_stats(&self, name:&str) -> Option<FunctionStats> {
396 self.registry.functions.read().await.get(name).map(|f| f.stats.clone())
397 }
398
399 #[instrument(skip(self))]
401 pub async fn unregister_function(&self, name:&str) -> Result<bool> {
402 let mut functions = self.registry.functions.write().await;
403 let removed = functions.remove(name).is_some();
404
405 if removed {
406 info!("Unregistered host function: {}", name);
407 } else {
408 warn!("Attempted to unregister non-existent function: {}", name);
409 }
410
411 Ok(removed)
412 }
413
414 pub async fn clear(&self) {
416 info!("Clearing all registered host functions");
417 self.registry.functions.write().await.clear();
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 #[tokio::test]
426 async fn test_function_export_creation() {
427 let bridge = Arc::new(HostBridgeImpl::new());
428 let export = FunctionExportImpl::new(bridge);
429
430 assert_eq!(export.get_function_names().await.len(), 0);
431 }
432
433 #[tokio::test]
434 async fn test_register_function() {
435 let bridge = Arc::new(HostBridgeImpl::new());
436 let export = FunctionExportImpl::new(bridge);
437
438 let signature = FunctionSignature {
439 name:"echo".to_string(),
440 param_types:vec![ParamType::I32],
441 return_type:Some(ReturnType::I32),
442 is_async:false,
443 };
444
445 let callback = |args:Vec<bytes::Bytes>| Ok(args.get(0).cloned().unwrap_or(bytes::Bytes::new()));
446
447 let result:anyhow::Result<()> = export.register_function("echo", signature, callback).await;
448 assert!(result.is_ok());
449 assert_eq!(export.get_function_names().await.len(), 1);
450 }
451
452 #[tokio::test]
453 async fn test_unregister_function() {
454 let bridge = Arc::new(HostBridgeImpl::new());
455 let export = FunctionExportImpl::new(bridge);
456
457 let signature = FunctionSignature {
458 name:"test".to_string(),
459 param_types:vec![ParamType::I32],
460 return_type:Some(ReturnType::I32),
461 is_async:false,
462 };
463
464 let callback = |_:Vec<bytes::Bytes>| Ok(bytes::Bytes::new());
465 let _:anyhow::Result<()> = export.register_function("test", signature, callback).await;
466
467 let result:bool = export.unregister_function("test").await.unwrap();
468 assert!(result);
469 assert_eq!(export.get_function_names().await.len(), 0);
470 }
471
472 #[test]
473 fn test_export_config_default() {
474 let config = ExportConfig::default();
475 assert_eq!(config.auto_export, true);
476 assert_eq!(config.max_functions, 1000);
477 }
478
479 #[test]
480 fn test_function_stats_default() {
481 let stats = FunctionStats::default();
482 assert_eq!(stats.call_count, 0);
483 assert_eq!(stats.error_count, 0);
484 }
485}