Skip to main content

Maintain/Eliminate/Transform/
Safe.rs

1//=============================================================================//
2// File Path: Element/Maintain/Source/Eliminate/Transform/Safe.rs
3//=============================================================================//
4// Module: Safe - Safety predicates for inlinable initialisers
5//
6// An initialiser is "safe to inline" when:
7//   1. Its AST node count does not exceed MaxSize.
8//   2. It does not contain an unsafe { } block.
9//   3. Every plain-identifier free variable in the initialiser is not moved
10//      (consumed by value) in the statements between the candidate declaration
11//      and the substitution site (IsFreeVarSafe).
12//
13// Notable inclusions (expressions that ARE safe in Rust):
14//   - ? operator  (Expr::Try)    - propagates the error, context unchanged.
15//   - .await      (Expr::Await)  - async context is caller-determined.
16//   - Closures    (Expr::Closure)- safe when the closure does not capture the
17//     candidate binding (checked by Count, not here).
18//=============================================================================//
19
20use syn::{
21	Expr,
22	Stmt,
23	visit::{Visit, visit_expr},
24};
25
26// ---------------------------------------------------------------------------
27// Public API
28// ---------------------------------------------------------------------------
29
30/// Returns true when E is safe to substitute at its single use site,
31/// considering only the expression itself (size and unsafe).
32pub fn IsSafe(E:&Expr, MaxSize:usize) -> bool { NodeCount(E) <= MaxSize && !ContainsUnsafe(E) }
33
34/// Returns true when every plain-identifier free variable inside Init remains
35/// live (not moved by value) in the statements that appear between the
36/// candidate let-binding (exclusive) and the substitution site (exclusive).
37///
38/// Stmts is the slice Block.stmts[CandidateIndex + 1..SubstSite] - i.e. the
39/// statements that execute after the let but before the single use.
40///
41/// The check is conservative: it only tracks Expr::Path single-segment bare
42/// identifiers. Qualified paths (a::b), method receivers (.foo()), and
43/// reference borrows (&x) are all ignored, so false positives (keeping a
44/// binding that would have been safe) are possible but false negatives that
45/// would introduce a compile error are not.
46pub fn IsFreeVarSafe(Init:&Expr, StmtsBetween:&[Stmt]) -> bool {
47	let FreeVars = CollectFreeIdents(Init);
48
49	if FreeVars.is_empty() || StmtsBetween.is_empty() {
50		return true;
51	}
52
53	for Name in &FreeVars {
54		if IsMovedInStmts(Name, StmtsBetween) {
55			return false;
56		}
57	}
58
59	true
60}
61
62// ---------------------------------------------------------------------------
63// Free-identifier collection
64// ---------------------------------------------------------------------------
65
66struct FreeIdentCollector {
67	Idents:Vec<String>,
68}
69
70impl<'ast> Visit<'ast> for FreeIdentCollector {
71	fn visit_expr_path(&mut self, Node:&'ast syn::ExprPath) {
72		if Node.qself.is_none() {
73			if let Some(Ident) = Node.path.get_ident() {
74				let Name = Ident.to_string();
75
76				if !self.Idents.contains(&Name) {
77					self.Idents.push(Name);
78				}
79			}
80		}
81	}
82}
83
84fn CollectFreeIdents(E:&Expr) -> Vec<String> {
85	let mut C = FreeIdentCollector { Idents:vec![] };
86
87	C.visit_expr(E);
88
89	C.Idents
90}
91
92// ---------------------------------------------------------------------------
93// Move detection
94// ---------------------------------------------------------------------------
95
96/// Returns true when Target is consumed by value in any of the statements.
97/// We detect by-value consumption conservatively: if Target appears as a
98/// plain Expr::Path (bare identifier, no & / &mut / ref prefix) in a
99/// position that syntactically transfers ownership:
100///   - a function or method call argument
101///   - a struct / tuple-struct field value (shorthand or explicit)
102///   - the right-hand side of a let initialiser or assignment
103///   - an array element, tuple element, or return expression
104fn IsMovedInStmts(Target:&str, Stmts:&[Stmt]) -> bool {
105	for Stmt in Stmts {
106		let mut Detector = MoveDetector { Target, Found:false };
107
108		Detector.visit_stmt(Stmt);
109
110		if Detector.Found {
111			return true;
112		}
113	}
114
115	false
116}
117
118struct MoveDetector<'a> {
119	Target:&'a str,
120
121	Found:bool,
122}
123
124impl<'ast> Visit<'ast> for MoveDetector<'ast> {
125	// Plain identifier used as a call argument, struct field, array element,
126	// tuple element, return value, or let/assign RHS - all by-value moves.
127	fn visit_expr_path(&mut self, Node:&'ast syn::ExprPath) {
128		if self.Found {
129			return;
130		}
131
132		if Node.qself.is_none() {
133			if let Some(I) = Node.path.get_ident() {
134				if I == self.Target {
135					self.Found = true;
136				}
137			}
138		}
139	}
140
141	// Do not descend into reference expressions: &Target and &mut Target
142	// borrow rather than move, so they are safe.
143	fn visit_expr_reference(&mut self, _Node:&'ast syn::ExprReference) {}
144
145	// Do not descend into method call receivers: self.foo(target_as_self)
146	// is a borrow in almost all cases; conservatively skip.
147	// We DO still visit the arguments via the default walk, so explicit
148	// by-value arguments to method calls are still caught.
149	fn visit_expr_method_call(&mut self, Node:&'ast syn::ExprMethodCall) {
150		if self.Found {
151			return;
152		}
153
154		// Visit arguments only; skip the receiver (self.receiver).
155		for Arg in &Node.args {
156			self.visit_expr(Arg);
157		}
158	}
159}
160
161// ---------------------------------------------------------------------------
162// Node counting
163// ---------------------------------------------------------------------------
164
165struct NodeCounter {
166	pub Count:usize,
167}
168
169impl<'ast> Visit<'ast> for NodeCounter {
170	fn visit_expr(&mut self, Node:&'ast Expr) {
171		self.Count += 1;
172
173		visit_expr(self, Node);
174	}
175}
176
177pub fn NodeCount(E:&Expr) -> usize {
178	let mut Counter = NodeCounter { Count:0 };
179
180	Counter.visit_expr(E);
181
182	Counter.Count
183}
184
185// ---------------------------------------------------------------------------
186// Unsafe detection
187// ---------------------------------------------------------------------------
188
189struct UnsafeDetector {
190	pub Found:bool,
191}
192
193impl<'ast> Visit<'ast> for UnsafeDetector {
194	fn visit_expr_unsafe(&mut self, _Node:&'ast syn::ExprUnsafe) { self.Found = true; }
195}
196
197pub fn ContainsUnsafe(E:&Expr) -> bool {
198	let mut Detector = UnsafeDetector { Found:false };
199
200	Detector.visit_expr(E);
201
202	Detector.Found
203}
204
205// ---------------------------------------------------------------------------
206// Unit tests
207// ---------------------------------------------------------------------------
208
209#[cfg(test)]
210mod Tests {
211
212	use super::*;
213
214	fn ParseExpr(Src:&str) -> Expr { syn::parse_str(Src).expect("parse expression") }
215
216	fn ParseStmts(Src:&str) -> Vec<Stmt> {
217		let File:syn::File = syn::parse_str(Src).expect("parse");
218
219		if let syn::Item::Fn(F) = &File.items[0] {
220			return F.block.stmts.clone();
221		}
222
223		vec![]
224	}
225
226	#[test]
227	fn LiteralIsSafe() {
228		assert!(IsSafe(&ParseExpr("42"), 100));
229	}
230
231	#[test]
232	fn UnsafeBlockIsNotSafe() {
233		assert!(!IsSafe(&ParseExpr("unsafe { *ptr }"), 100));
234	}
235
236	#[test]
237	fn QuestionMarkIsSafe() {
238		assert!(IsSafe(&ParseExpr("foo().map_err(|e| e)?"), 100));
239	}
240
241	#[test]
242	fn AwaitIsSafe() {
243		assert!(IsSafe(&ParseExpr("service.call().await"), 100));
244	}
245
246	#[test]
247	fn ClosureIsSafe() {
248		assert!(IsSafe(&ParseExpr("|| { 1 + 1 }"), 100));
249	}
250
251	#[test]
252	fn StructLiteralIsSafe() {
253		assert!(IsSafe(&ParseExpr("Opts { a: 1, b: 2 }"), 100));
254	}
255
256	#[test]
257	fn MacroCallIsSafe() {
258		assert!(IsSafe(&ParseExpr(r#"json!({ "key": val })"#), 100));
259	}
260
261	#[test]
262	fn OversizedExprFails() {
263		let Big = ParseExpr("a + b + c + d + e + f + g + h");
264
265		assert!(!IsSafe(&Big, 5));
266	}
267
268	#[test]
269	fn NodeCountLiteral() {
270		assert_eq!(NodeCount(&ParseExpr("1")), 1);
271	}
272
273	#[test]
274	fn NodeCountBinary() {
275		assert_eq!(NodeCount(&ParseExpr("A + B")), 3);
276	}
277
278	#[test]
279	fn NodeCountCall() {
280		assert_eq!(NodeCount(&ParseExpr("f(A, B)")), 4);
281	}
282
283	// IsFreeVarSafe tests
284
285	#[test]
286	fn FreeVarSafeWhenNoStmtsBetween() {
287		// Init = var.clone(), no statements between decl and use: safe.
288		assert!(IsFreeVarSafe(&ParseExpr("var.clone()"), &[]));
289	}
290
291	#[test]
292	fn FreeVarSafeWhenVarOnlyBorrowed() {
293		// &var between decl and use: borrow, not a move.
294		let Stmts = ParseStmts("fn f() { let _r = &var; }");
295
296		assert!(IsFreeVarSafe(&ParseExpr("var.clone()"), &Stmts));
297	}
298
299	#[test]
300	fn FreeVarUnsafeWhenMovedIntoCall() {
301		// fn g(x: String) consumes x by value.
302		let Stmts = ParseStmts("fn f() { g(var); }");
303
304		assert!(!IsFreeVarSafe(&ParseExpr("var.clone()"), &Stmts));
305	}
306
307	#[test]
308	fn FreeVarUnsafeWhenMovedIntoStructField() {
309		// Struct field shorthand { path } moves path.
310		let Stmts = ParseStmts("fn f() { let _r = Foo { path }; }");
311
312		assert!(!IsFreeVarSafe(&ParseExpr("path.clone()"), &Stmts));
313	}
314
315	#[test]
316	fn FreeVarSafeWhenDifferentVarMoved() {
317		// A different variable is moved; the one we care about is untouched.
318		let Stmts = ParseStmts("fn f() { let _r = Foo { other }; }");
319
320		assert!(IsFreeVarSafe(&ParseExpr("path.clone()"), &Stmts));
321	}
322
323	#[test]
324	fn FreeVarUnsafeWhenMovedIntoRpcRequest() {
325		// Exact pattern from AirClient::get_file_info.
326		// path is moved into FileInfoRequest { request_id, path }.
327		let Stmts = ParseStmts(
328			r#"fn f() {
329                client.get_file_info(Request::new(FileInfoRequest { request_id, path })).await;
330            }"#,
331		);
332
333		assert!(!IsFreeVarSafe(&ParseExpr("path.clone()"), &Stmts));
334	}
335}