1use std::{
47 collections::HashMap,
48 sync::{
49 Arc,
50 atomic::{AtomicBool, AtomicU64, Ordering},
51 },
52};
53
54use anyhow::Result;
55use futures_util::{SinkExt, StreamExt, stream::SplitSink};
56use serde_json::Value;
57use tokio::{
58 net::{TcpListener, TcpStream},
59 sync::{Mutex, oneshot},
60};
61use tokio_tungstenite::{
62 MaybeTlsStream,
63 WebSocketStream,
64 accept_async,
65 connect_async,
66 tungstenite::{Message, Utf8Bytes},
67};
68
69#[derive(Clone)]
71pub struct SharedSecret(pub [u8; 32]);
72
73impl SharedSecret {
74 pub fn random() -> Self {
75 Self(rand::random::<[u8; 32]>())
79 }
80
81 pub fn as_hex(&self) -> String { hex::encode(self.0) }
82
83 pub fn from_hex(Hex:&str) -> Result<Self> {
84 let Bytes = hex::decode(Hex)?;
85
86 if Bytes.len() != 32 {
87 anyhow::bail!("shared secret must be 32 bytes (got {})", Bytes.len());
88 }
89
90 let mut Out = [0u8; 32];
91
92 Out.copy_from_slice(&Bytes);
93
94 Ok(Self(Out))
95 }
96}
97
98pub type HandlerFn =
101 Arc<dyn Fn(Value) -> futures_util::future::BoxFuture<'static, Result<Value, String>> + Send + Sync>;
102
103#[derive(Default)]
105pub struct HandlerRegistry {
106 Handlers:Mutex<HashMap<String, HandlerFn>>,
107}
108
109impl HandlerRegistry {
110 pub fn new() -> Arc<Self> { Arc::new(Self::default()) }
111
112 pub async fn Register(&self, Method:String, Handler:HandlerFn) {
113 self.Handlers.lock().await.insert(Method, Handler);
114 }
115
116 pub async fn Lookup(&self, Method:&str) -> Option<HandlerFn> { self.Handlers.lock().await.get(Method).cloned() }
117}
118
119pub async fn ServeLocal(Port:u16, Secret:SharedSecret, Registry:Arc<HandlerRegistry>) -> Result<()> {
126 let Address = format!("127.0.0.1:{}", Port);
127
128 let Listener = TcpListener::bind(&Address).await?;
129
130 tracing::info!(target: "Mist::WebSocket", "server listening on {}", Address);
131
132 let PortStr = format!("{}", Port);
136
137 CommonLibrary::Telemetry::CaptureEvent::Fn(
138 "land:mist:server:start",
139 Some(vec![("address", Address.as_str()), ("port", PortStr.as_str())]),
140 );
141
142 loop {
143 let (Stream, Peer) = match Listener.accept().await {
144 Ok(P) => P,
145
146 Err(Error) => {
147 tracing::warn!(target: "Mist::WebSocket", "accept error: {}", Error);
148
149 continue;
150 },
151 };
152
153 let SecretClone = Secret.clone();
154
155 let RegistryClone = Registry.clone();
156
157 tokio::spawn(async move {
158 if let Err(Error) = HandleConnection(Stream, SecretClone, RegistryClone).await {
159 tracing::warn!(target: "Mist::WebSocket", "connection from {} closed with error: {}", Peer, Error);
160 }
161 });
162 }
163}
164
165async fn HandleConnection(Stream:TcpStream, _Secret:SharedSecret, Registry:Arc<HandlerRegistry>) -> Result<()> {
166 let WebSocketStream = accept_async(Stream).await?;
173
174 let (mut Sink, mut Source) = WebSocketStream.split();
175
176 while let Some(MessageResult) = Source.next().await {
177 let Message = match MessageResult {
178 Ok(M) => M,
179
180 Err(Error) => {
181 tracing::debug!(target: "Mist::WebSocket", "frame read error: {}", Error);
182
183 break;
184 },
185 };
186
187 match Message {
188 Message::Text(Text) => {
189 let Envelope:Value = match serde_json::from_str(&Text) {
190 Ok(V) => V,
191
192 Err(Error) => {
193 tracing::debug!(target: "Mist::WebSocket", "bad text frame: {}", Error);
194
195 continue;
196 },
197 };
198
199 let Method = Envelope.get("method").and_then(|V| V.as_str()).unwrap_or("");
200
201 let Identifier = Envelope.get("id").cloned().unwrap_or(Value::Null);
202
203 let Params = Envelope.get("params").cloned().unwrap_or(Value::Array(vec![]));
204
205 if Method.is_empty() {
206 continue;
207 }
208
209 let Handler = Registry.Lookup(Method).await;
210
211 let Response = match Handler {
212 Some(H) => {
213 match H(Params).await {
214 Ok(Value) => serde_json::json!({ "id": Identifier, "result": Value }),
215
216 Err(ErrorMessage) => serde_json::json!({ "id": Identifier, "error": ErrorMessage }),
217 }
218 },
219
220 None => {
221 serde_json::json!({
222 "id": Identifier,
223 "error": format!("Unknown method: {}", Method),
224 })
225 },
226 };
227
228 if Identifier.is_null() {
229 continue;
231 }
232
233 if let Err(Error) = Sink.send(Message::Text(Utf8Bytes::from(Response.to_string()))).await {
234 tracing::debug!(target: "Mist::WebSocket", "send error: {}", Error);
235
236 break;
237 }
238 },
239
240 Message::Binary(Bytes) => {
241 tracing::trace!(target: "Mist::WebSocket", "binary frame ({} bytes) ignored - reserved for phase 2", Bytes.len());
242 },
243
244 Message::Close(_) => break,
245
246 _ => {},
247 }
248 }
249
250 Ok(())
251}
252
253type PendingMap = Arc<Mutex<HashMap<u64, oneshot::Sender<Result<Value, String>>>>>;
255
256pub struct Client {
259 Sink:Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
263
264 Pending:PendingMap,
265
266 NextIdentifier:AtomicU64,
267
268 Closed:AtomicBool,
269}
270
271impl Client {
272 pub async fn connect(Address:&str) -> Result<Arc<Self>> {
277 let (Stream, _Response) = connect_async(Address).await?;
278
279 let (Sink, mut Source) = Stream.split();
280
281 let Sink = Arc::new(Mutex::new(Sink));
282
283 let Pending:PendingMap = Arc::new(Mutex::new(HashMap::new()));
284
285 let SelfReference = Arc::new(Self {
286 Sink,
287 Pending:Pending.clone(),
288 NextIdentifier:AtomicU64::new(1),
289 Closed:AtomicBool::new(false),
290 });
291
292 let SelfForReader = SelfReference.clone();
295
296 tokio::spawn(async move {
297 while let Some(MessageResult) = Source.next().await {
298 let Frame = match MessageResult {
299 Ok(M) => M,
300 Err(_) => break,
301 };
302
303 match Frame {
304 Message::Text(Text) => {
305 if let Ok(Envelope) = serde_json::from_str::<Value>(&Text) {
306 let Identifier = Envelope.get("id").and_then(|V| V.as_u64());
307
308 if let Some(Identifier) = Identifier {
309 let Sender = SelfForReader.Pending.lock().await.remove(&Identifier);
310
311 if let Some(Sender) = Sender {
312 let Result = if let Some(ErrorValue) = Envelope.get("error") {
313 Err(ErrorValue.to_string())
314 } else {
315 Ok(Envelope.get("result").cloned().unwrap_or(Value::Null))
316 };
317
318 let _ = Sender.send(Result);
319 }
320 }
321 }
322 },
323 Message::Close(_) => break,
324 _ => {},
325 }
326 }
327
328 SelfForReader.Closed.store(true, Ordering::Relaxed);
329
330 let mut Pending = SelfForReader.Pending.lock().await;
332
333 for (_, Sender) in Pending.drain() {
334 let _ = Sender.send(Err("connection closed".into()));
335 }
336 });
337
338 Ok(SelfReference)
339 }
340
341 pub async fn invoke(&self, Method:&str, Params:Value) -> Result<Value, String> {
345 if self.Closed.load(Ordering::Relaxed) {
346 return Err("connection closed".into());
347 }
348
349 let Identifier = self.NextIdentifier.fetch_add(1, Ordering::Relaxed);
350
351 let (Tx, Rx) = oneshot::channel();
352
353 self.Pending.lock().await.insert(Identifier, Tx);
354
355 let Envelope = serde_json::json!({ "id": Identifier, "method": Method, "params": Params });
356
357 let Text = Envelope.to_string();
358
359 let SendResult = self.Sink.lock().await.send(Message::Text(Utf8Bytes::from(Text))).await;
360
361 if SendResult.is_err() {
362 self.Pending.lock().await.remove(&Identifier);
363
364 return Err("send failed".into());
365 }
366
367 Rx.await.map_err(|_| "request cancelled".to_string())?
368 }
369
370 pub async fn notify(&self, Method:&str, Params:Value) -> Result<(), String> {
372 if self.Closed.load(Ordering::Relaxed) {
373 return Err("connection closed".into());
374 }
375
376 let Envelope = serde_json::json!({ "id": Value::Null, "method": Method, "params": Params });
377
378 let Text = Envelope.to_string();
379
380 self.Sink
381 .lock()
382 .await
383 .send(Message::Text(Utf8Bytes::from(Text)))
384 .await
385 .map_err(|Error| Error.to_string())
386 }
387
388 pub fn is_closed(&self) -> bool { self.Closed.load(Ordering::Relaxed) }
389}