Maintain/Eliminate/Transform/
Safe.rs1use syn::{
21 Expr,
22 Stmt,
23 visit::{Visit, visit_expr},
24};
25
26pub fn IsSafe(E:&Expr, MaxSize:usize) -> bool { NodeCount(E) <= MaxSize && !ContainsUnsafe(E) }
33
34pub fn IsFreeVarSafe(Init:&Expr, StmtsBetween:&[Stmt]) -> bool {
47 let FreeVars = CollectFreeIdents(Init);
48
49 if FreeVars.is_empty() || StmtsBetween.is_empty() {
50 return true;
51 }
52
53 for Name in &FreeVars {
54 if IsMovedInStmts(Name, StmtsBetween) {
55 return false;
56 }
57 }
58
59 true
60}
61
62struct FreeIdentCollector {
67 Idents:Vec<String>,
68}
69
70impl<'ast> Visit<'ast> for FreeIdentCollector {
71 fn visit_expr_path(&mut self, Node:&'ast syn::ExprPath) {
72 if Node.qself.is_none() {
73 if let Some(Ident) = Node.path.get_ident() {
74 let Name = Ident.to_string();
75
76 if !self.Idents.contains(&Name) {
77 self.Idents.push(Name);
78 }
79 }
80 }
81 }
82}
83
84fn CollectFreeIdents(E:&Expr) -> Vec<String> {
85 let mut C = FreeIdentCollector { Idents:vec![] };
86
87 C.visit_expr(E);
88
89 C.Idents
90}
91
92fn IsMovedInStmts(Target:&str, Stmts:&[Stmt]) -> bool {
105 for Stmt in Stmts {
106 let mut Detector = MoveDetector { Target, Found:false };
107
108 Detector.visit_stmt(Stmt);
109
110 if Detector.Found {
111 return true;
112 }
113 }
114
115 false
116}
117
118struct MoveDetector<'a> {
119 Target:&'a str,
120
121 Found:bool,
122}
123
124impl<'ast> Visit<'ast> for MoveDetector<'ast> {
125 fn visit_expr_path(&mut self, Node:&'ast syn::ExprPath) {
128 if self.Found {
129 return;
130 }
131
132 if Node.qself.is_none() {
133 if let Some(I) = Node.path.get_ident() {
134 if I == self.Target {
135 self.Found = true;
136 }
137 }
138 }
139 }
140
141 fn visit_expr_reference(&mut self, _Node:&'ast syn::ExprReference) {}
144
145 fn visit_expr_method_call(&mut self, Node:&'ast syn::ExprMethodCall) {
150 if self.Found {
151 return;
152 }
153
154 for Arg in &Node.args {
156 self.visit_expr(Arg);
157 }
158 }
159}
160
161struct NodeCounter {
166 pub Count:usize,
167}
168
169impl<'ast> Visit<'ast> for NodeCounter {
170 fn visit_expr(&mut self, Node:&'ast Expr) {
171 self.Count += 1;
172
173 visit_expr(self, Node);
174 }
175}
176
177pub fn NodeCount(E:&Expr) -> usize {
178 let mut Counter = NodeCounter { Count:0 };
179
180 Counter.visit_expr(E);
181
182 Counter.Count
183}
184
185struct UnsafeDetector {
190 pub Found:bool,
191}
192
193impl<'ast> Visit<'ast> for UnsafeDetector {
194 fn visit_expr_unsafe(&mut self, _Node:&'ast syn::ExprUnsafe) { self.Found = true; }
195}
196
197pub fn ContainsUnsafe(E:&Expr) -> bool {
198 let mut Detector = UnsafeDetector { Found:false };
199
200 Detector.visit_expr(E);
201
202 Detector.Found
203}
204
205#[cfg(test)]
210mod Tests {
211
212 use super::*;
213
214 fn ParseExpr(Src:&str) -> Expr { syn::parse_str(Src).expect("parse expression") }
215
216 fn ParseStmts(Src:&str) -> Vec<Stmt> {
217 let File:syn::File = syn::parse_str(Src).expect("parse");
218
219 if let syn::Item::Fn(F) = &File.items[0] {
220 return F.block.stmts.clone();
221 }
222
223 vec![]
224 }
225
226 #[test]
227 fn LiteralIsSafe() {
228 assert!(IsSafe(&ParseExpr("42"), 100));
229 }
230
231 #[test]
232 fn UnsafeBlockIsNotSafe() {
233 assert!(!IsSafe(&ParseExpr("unsafe { *ptr }"), 100));
234 }
235
236 #[test]
237 fn QuestionMarkIsSafe() {
238 assert!(IsSafe(&ParseExpr("foo().map_err(|e| e)?"), 100));
239 }
240
241 #[test]
242 fn AwaitIsSafe() {
243 assert!(IsSafe(&ParseExpr("service.call().await"), 100));
244 }
245
246 #[test]
247 fn ClosureIsSafe() {
248 assert!(IsSafe(&ParseExpr("|| { 1 + 1 }"), 100));
249 }
250
251 #[test]
252 fn StructLiteralIsSafe() {
253 assert!(IsSafe(&ParseExpr("Opts { a: 1, b: 2 }"), 100));
254 }
255
256 #[test]
257 fn MacroCallIsSafe() {
258 assert!(IsSafe(&ParseExpr(r#"json!({ "key": val })"#), 100));
259 }
260
261 #[test]
262 fn OversizedExprFails() {
263 let Big = ParseExpr("a + b + c + d + e + f + g + h");
264
265 assert!(!IsSafe(&Big, 5));
266 }
267
268 #[test]
269 fn NodeCountLiteral() {
270 assert_eq!(NodeCount(&ParseExpr("1")), 1);
271 }
272
273 #[test]
274 fn NodeCountBinary() {
275 assert_eq!(NodeCount(&ParseExpr("A + B")), 3);
276 }
277
278 #[test]
279 fn NodeCountCall() {
280 assert_eq!(NodeCount(&ParseExpr("f(A, B)")), 4);
281 }
282
283 #[test]
286 fn FreeVarSafeWhenNoStmtsBetween() {
287 assert!(IsFreeVarSafe(&ParseExpr("var.clone()"), &[]));
289 }
290
291 #[test]
292 fn FreeVarSafeWhenVarOnlyBorrowed() {
293 let Stmts = ParseStmts("fn f() { let _r = &var; }");
295
296 assert!(IsFreeVarSafe(&ParseExpr("var.clone()"), &Stmts));
297 }
298
299 #[test]
300 fn FreeVarUnsafeWhenMovedIntoCall() {
301 let Stmts = ParseStmts("fn f() { g(var); }");
303
304 assert!(!IsFreeVarSafe(&ParseExpr("var.clone()"), &Stmts));
305 }
306
307 #[test]
308 fn FreeVarUnsafeWhenMovedIntoStructField() {
309 let Stmts = ParseStmts("fn f() { let _r = Foo { path }; }");
311
312 assert!(!IsFreeVarSafe(&ParseExpr("path.clone()"), &Stmts));
313 }
314
315 #[test]
316 fn FreeVarSafeWhenDifferentVarMoved() {
317 let Stmts = ParseStmts("fn f() { let _r = Foo { other }; }");
319
320 assert!(IsFreeVarSafe(&ParseExpr("path.clone()"), &Stmts));
321 }
322
323 #[test]
324 fn FreeVarUnsafeWhenMovedIntoRpcRequest() {
325 let Stmts = ParseStmts(
328 r#"fn f() {
329 client.get_file_info(Request::new(FileInfoRequest { request_id, path })).await;
330 }"#,
331 );
332
333 assert!(!IsFreeVarSafe(&ParseExpr("path.clone()"), &Stmts));
334 }
335}