1use crate::*;
2use analysis::{one_shot_analysis, SemiLatticeAnalysis};
3use hashcons::ExprHashCons;
4use sketch::SketchNode;
5use std::cmp::Ordering;
6
7pub fn eclass_satisfies_sketch<L: Language, A: Analysis<L>>(
9 s: &Sketch<L>,
10 egraph: &EGraph<L, A>,
11 id: Id,
12) -> bool {
13 satisfies_sketch(s, egraph).contains(&id)
14}
15
16pub fn satisfies_sketch<L: Language, A: Analysis<L>>(
18 s: &Sketch<L>,
19 egraph: &EGraph<L, A>,
20) -> HashSet<Id> {
21 assert!(egraph.clean);
22 let mut memo = HashMap::<Id, HashSet<Id>>::default();
23 let sketch_nodes = s.as_ref();
24 let sketch_root = Id::from(sketch_nodes.len() - 1);
25 satisfies_sketch_rec(sketch_nodes, sketch_root, egraph, &mut memo)
26}
27
28fn satisfies_sketch_rec<L: Language, A: Analysis<L>>(
29 s_nodes: &[SketchNode<L>],
30 s_index: Id,
31 egraph: &EGraph<L, A>,
32 memo: &mut HashMap<Id, HashSet<Id>>,
33) -> HashSet<Id> {
34 match memo.get(&s_index) {
35 Some(value) => return value.clone(),
36 None => (),
37 };
38
39 let result = match &s_nodes[usize::from(s_index)] {
40 SketchNode::Any =>
41 egraph.classes().map(|c| c.id).collect(),
42 SketchNode::Node(node) => {
43 let children_matches = node
44 .children()
45 .iter()
46 .map(|sid| satisfies_sketch_rec(s_nodes, *sid, egraph, memo))
47 .collect::<Vec<_>>();
48
49 if let Some(potential_ids) = egraph.classes_for_op(&node.discriminant()) {
50 potential_ids
51 .filter(|&id| {
52 let eclass = &egraph[id];
53
54 let mnode = &node.clone().map_children(|_| Id::from(0));
55 eclass.for_each_matching_node(mnode, |matched| {
56 let children_match = children_matches
57 .iter()
58 .zip(matched.children())
59 .all(|(matches, id)| matches.contains(id));
60 if children_match {
61 Err(())
62 } else {
63 Ok(())
64 }
65 })
66 .is_err()
67 })
68 .collect()
69 } else {
70 HashSet::default()
71 }
72 }
73 SketchNode::Contains(sid) => {
74 let contained_matched = satisfies_sketch_rec(s_nodes, *sid, egraph, memo);
75
76 let mut data = egraph
77 .classes()
78 .map(|eclass| (eclass.id, contained_matched.contains(&eclass.id)))
79 .collect::<HashMap<_, bool>>();
80
81 one_shot_analysis(egraph, SatisfiesContainsAnalysis, &mut data);
82
83 data.iter()
84 .flat_map(|(&id, &is_match)| if is_match { Some(id) } else { None })
85 .collect()
86 }
87 SketchNode::OnlyContains(sid) => {
88 let contained_matched = satisfies_sketch_rec(s_nodes, *sid, egraph, memo);
89
90 let mut data = egraph
91 .classes()
92 .map(|eclass| (eclass.id, contained_matched.contains(&eclass.id)))
93 .collect::<HashMap<_, bool>>();
94
95 one_shot_analysis(egraph, SatisfiesOnlyContainsAnalysis, &mut data);
96
97 data.iter()
98 .flat_map(|(&id, &is_match)| if is_match { Some(id) } else { None })
99 .collect()
100 }
101 SketchNode::Or(sids) => {
102 let matches = sids
103 .iter()
104 .map(|sid| satisfies_sketch_rec(s_nodes, *sid, egraph, memo));
105 matches
106 .reduce(|a, b| a.union(&b).cloned().collect())
107 .expect("empty or sketch")
108 }
109 };
110
111 memo.insert(s_index, result.clone());
112 result
113}
114
115pub struct SatisfiesContainsAnalysis;
116impl<L: Language, A: Analysis<L>> SemiLatticeAnalysis<L, A> for SatisfiesContainsAnalysis {
117 type Data = bool;
118
119 fn make<'a>(
120 &mut self,
121 _egraph: &EGraph<L, A>,
122 enode: &L,
123 analysis_of: &'a impl Fn(Id) -> &'a Self::Data,
124 ) -> Self::Data
125 where
126 Self::Data: 'a,
127 {
128 enode.any(|c| *analysis_of(c))
129 }
130
131 fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge {
132 let r = *a || b;
133 let dm = DidMerge(r != *a, r != b);
134 *a = r;
135 dm
136 }
137}
138
139pub struct SatisfiesOnlyContainsAnalysis;
140impl<L: Language, A: Analysis<L>> SemiLatticeAnalysis<L, A> for SatisfiesOnlyContainsAnalysis {
141 type Data = bool;
142
143 fn make<'a>(
144 &mut self,
145 _egraph: &EGraph<L, A>,
146 enode: &L,
147 analysis_of: &'a impl Fn(Id) -> &'a Self::Data,
148 ) -> Self::Data
149 where
150 Self::Data: 'a,
151 {
152 if enode.children().is_empty() {
153 false
154 } else {
155 enode.all(|c| *analysis_of(c))
156 }
157 }
158
159 fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge {
160 let r = *a || b;
161 let dm = DidMerge(r != *a, r != b);
162 *a = r;
163 dm
164 }
165}
166
167pub fn eclass_extract_sketch<L, A, CF>(
169 s: &Sketch<L>,
170 cost_f: CF,
171 egraph: &EGraph<L, A>,
172 id: Id,
173) -> Option<(CF::Cost, RecExpr<L>)>
174where
175 L: Language,
176 A: Analysis<L>,
177 CF: CostFunction<L>,
178 CF::Cost: 'static + Ord,
179{
180 assert_eq!(egraph.find(id), id);
181 let (exprs, eclass_to_best) = extract_sketch(s, cost_f, egraph);
182 eclass_to_best
183 .get(&id)
184 .map(|(best_cost, best_id)| (best_cost.clone(), exprs.extract(*best_id)))
185}
186
187fn extract_sketch<L, A, CF>(
188 sketch: &Sketch<L>,
189 mut cost_f: CF,
190 egraph: &EGraph<L, A>,
191) -> (ExprHashCons<L>, HashMap<Id, (CF::Cost, Id)>)
192where
193 L: Language,
194 A: Analysis<L>,
195 CF: CostFunction<L>,
196 CF::Cost: 'static + Ord,
197{
198 assert!(egraph.clean);
199 let mut memo = HashMap::<Id, HashMap<Id, (CF::Cost, Id)>>::default();
200 let sketch_root = Id::from(sketch.as_ref().len() - 1);
201 let mut exprs = ExprHashCons::new();
202
203 let mut extracted = HashMap::default();
204 let analysis = ExtractAnalysis {
205 exprs: &mut exprs,
206 cost_f: &mut cost_f,
207 };
208 one_shot_analysis(&egraph, analysis, &mut extracted);
209
210 let res = extract_sketch_rec(
211 sketch,
212 sketch_root,
213 &mut cost_f,
214 egraph,
215 &mut exprs,
216 &extracted,
217 &mut memo,
218 );
219 (exprs, res)
220}
221
222fn extract_sketch_rec<L, A, CF>(
223 sketch: &Sketch<L>,
224 sketch_id: Id,
225 cost_f: &mut CF,
226 egraph: &EGraph<L, A>,
227 exprs: &mut ExprHashCons<L>,
228 extracted: &HashMap<Id, (CF::Cost, Id)>,
229 memo: &mut HashMap<Id, HashMap<Id, (CF::Cost, Id)>>,
230) -> HashMap<Id, (CF::Cost, Id)>
231where
232 L: Language,
233 A: Analysis<L>,
234 CF: CostFunction<L>,
235 CF::Cost: 'static + Ord,
236{
237 match memo.get(&sketch_id) {
238 Some(value) => return value.clone(),
239 None => (),
240 };
241
242 let result = match &sketch[sketch_id] {
243 SketchNode::Any => extracted.clone(),
244 SketchNode::Node(sketch_node) => {
245 let children_matches = sketch_node
247 .children()
248 .iter()
249 .map(|sid| {
250 extract_sketch_rec(sketch, *sid, cost_f, egraph, exprs, extracted, memo)
251 })
252 .collect::<Vec<_>>();
253
254 if let Some(potential_ids) = egraph.classes_for_op(&sketch_node.discriminant()) {
255 potential_ids
256 .flat_map(|id| {
257 let eclass = &egraph[id];
258 let mut candidates = Vec::new();
259
260 let mnode = &sketch_node.clone().map_children(|_| Id::from(0));
261 let _ = eclass.for_each_matching_node::<()>(mnode, |matched| {
262 let mut matches = Vec::new();
265 for (cm, id) in children_matches.iter().zip(matched.children()) {
267 if let Some(m) = cm.get(id) {
268 matches.push(m);
269 } else {
270 break;
271 }
272 }
273
274 if matches.len() == matched.len() {
275 let mut node_to_child_indices = sketch_node.clone();
277 for (child_index, id) in node_to_child_indices.children_mut().into_iter().enumerate() {
278 *id = Id::from(child_index);
279 }
280
281 let to_match: HashMap<_, _> =
282 node_to_child_indices.children().iter().zip(matches.iter()).collect();
283
284 candidates.push((
285 cost_f.cost(&node_to_child_indices, |c| to_match[&c].0.clone()),
286 exprs.add(node_to_child_indices.clone().map_children(|c| to_match[&c].1)),
287 ));
288 }
289
290 Ok(())
291 });
292
293 candidates
294 .into_iter()
295 .min_by(|x, y| x.0.cmp(&y.0))
296 .map(|best| (id, best))
297 })
298 .collect()
299 } else {
300 HashMap::default()
301 }
302 }
303 SketchNode::Contains(sid) => {
304 let contained_matches =
305 extract_sketch_rec(sketch, *sid, cost_f, egraph, exprs, extracted, memo);
306
307 let mut data = egraph
308 .classes()
309 .map(|eclass| (eclass.id, contained_matches.get(&eclass.id).cloned()))
310 .collect::<HashMap<_, _>>();
311
312 let analysis = ExtractContainsAnalysis {
313 exprs,
314 cost_f,
315 extracted,
316 };
317
318 one_shot_analysis(egraph, analysis, &mut data);
319
320 data.into_iter()
321 .flat_map(|(id, maybe_best)| maybe_best.map(|b| (id, b)))
322 .collect()
323 }
324 SketchNode::OnlyContains(sid) => {
325 let contained_matches =
326 extract_sketch_rec(sketch, *sid, cost_f, egraph, exprs, extracted, memo);
327
328 let mut data = egraph
329 .classes()
330 .map(|eclass| (eclass.id, contained_matches.get(&eclass.id).cloned()))
331 .collect::<HashMap<_, _>>();
332
333 let analysis = ExtractOnlyContainsAnalysis {
334 exprs,
335 cost_f,
336 };
337
338 one_shot_analysis(egraph, analysis, &mut data);
339
340 data.into_iter()
341 .flat_map(|(id, maybe_best)| maybe_best.map(|b| (id, b)))
342 .collect()
343 }
344 SketchNode::Or(sids) => {
345 let matches = sids
346 .iter()
347 .map(|sid| {
348 extract_sketch_rec(sketch, *sid, cost_f, egraph, exprs, extracted, memo)
349 })
350 .collect::<Vec<_>>();
351 let mut matching_ids = HashSet::default();
352 for m in &matches {
353 matching_ids.extend(m.keys());
354 }
355
356 matching_ids
357 .iter()
358 .flat_map(|id| {
359 let mut candidates = Vec::new();
360 for ms in &matches {
361 candidates.extend(ms.get(id));
362 }
363 candidates
364 .into_iter()
365 .min_by(|x, y| x.0.cmp(&y.0))
366 .map(|best| (*id, best.clone()))
367 })
368 .collect()
369 }
370 };
371
372 memo.insert(sketch_id, result.clone());
380 result
381}
382
383pub struct ExtractContainsAnalysis<'a, L, CF>
384where
385 L: Language,
386 CF: CostFunction<L>,
387{
388 exprs: &'a mut ExprHashCons<L>,
389 cost_f: &'a mut CF,
390 extracted: &'a HashMap<Id, (CF::Cost, Id)>,
391}
392
393fn merge_best_option<Cost>(
394 a: &mut Option<(Cost, Id)>,
395 b: Option<(Cost, Id)>) -> DidMerge
396where
397 Cost: 'static + Ord,
398{
399 let ord = match (&a, &b) {
400 (None, None) => Ordering::Equal,
401 (Some(_), None) => Ordering::Less,
402 (None, Some(_)) => Ordering::Greater,
403 (&Some((ref ca, _)), &Some((ref cb, _))) => ca.cmp(cb),
404 };
405 match ord {
406 Ordering::Equal => DidMerge(false, false),
407 Ordering::Less => DidMerge(false, true),
408 Ordering::Greater => {
409 *a = b;
410 DidMerge(true, false)
411 }
412 }
413}
414
415impl<'a, L, A, CF> SemiLatticeAnalysis<L, A> for ExtractContainsAnalysis<'a, L, CF>
416where
417 L: Language,
418 A: Analysis<L>,
419 CF: CostFunction<L>,
420 CF::Cost: 'static + Ord,
421{
422 type Data = Option<(CF::Cost, Id)>;
423
424 fn make<'b>(
425 &mut self,
426 egraph: &EGraph<L, A>,
427 enode: &L,
428 analysis_of: &'b impl Fn(Id) -> &'b Self::Data,
429 ) -> Self::Data
430 where
431 Self::Data: 'b,
432 {
433 let mut candidates = Vec::new();
434 extract_common::push_extract_contains_candidates(&mut candidates,
435 self.exprs, self.cost_f, self.extracted,
436 egraph, enode, |c, _, _, _, _| { analysis_of(c).clone() }
437 );
438 candidates.into_iter().min_by(|x, y| x.0.cmp(&y.0))
439 }
440
441 fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge {
442 merge_best_option(a, b)
443 }
444}
445
446pub struct ExtractOnlyContainsAnalysis<'a, L, CF>
447where
448 L: Language,
449 CF: CostFunction<L>,
450{
451 exprs: &'a mut ExprHashCons<L>,
452 cost_f: &'a mut CF,
453}
454
455impl<'a, L, A, CF> SemiLatticeAnalysis<L, A> for ExtractOnlyContainsAnalysis<'a, L, CF>
456where
457 L: Language,
458 A: Analysis<L>,
459 CF: CostFunction<L>,
460 CF::Cost: 'static + Ord,
461{
462 type Data = Option<(CF::Cost, Id)>;
463
464 fn make<'b>(
465 &mut self,
466 egraph: &EGraph<L, A>,
467 enode: &L,
468 analysis_of: &'b impl Fn(Id) -> &'b Self::Data,
469 ) -> Self::Data
470 where
471 Self::Data: 'b,
472 {
473 extract_common::extract_only_contains_candidate(
474 self.exprs, self.cost_f,
475 egraph, enode,|c, _, _, _| { analysis_of(c).clone() }
476 )
477 }
478
479 fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge {
480 merge_best_option(a, b)
481 }
482}
483
484pub(crate) struct ExtractAnalysis<'a, L, CF> {
485 pub(crate) exprs: &'a mut ExprHashCons<L>,
486 pub(crate) cost_f: &'a mut CF,
487}
488
489impl<'a, L, A, CF> SemiLatticeAnalysis<L, A> for ExtractAnalysis<'a, L, CF>
490where
491 L: Language,
492 A: Analysis<L>,
493 CF: CostFunction<L>,
494 CF::Cost: 'static,
495{
496 type Data = (CF::Cost, Id);
497
498 fn make<'b>(
499 &mut self,
500 _egraph: &EGraph<L, A>,
501 enode: &L,
502 analysis_of: &'b impl Fn(Id) -> &'b Self::Data,
503 ) -> Self::Data
504 where
505 Self::Data: 'b,
506 {
507 let expr_node = enode.clone().map_children(|c| (*analysis_of)(c).1);
508 let expr = self.exprs.add(expr_node);
509 (
510 self.cost_f.cost(enode, |c| (*analysis_of)(c).0.clone()),
511 expr,
512 )
513 }
514
515 fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge {
516 if a.0 < b.0 {
517 DidMerge(false, true)
518 } else if a.0 == b.0 {
519 DidMerge(false, false)
520 } else {
521 *a = b;
522 DidMerge(true, false)
523 }
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530 use crate::*;
531
532 #[test]
533 fn bug202509() {
534 let expr_a =
535 "(>> (>> transpose transpose) (>> (>> transpose transpose) (>> transpose transpose)))"
536 .parse::<RecExpr<SymbolLang>>()
537 .unwrap();
538
539 let sketch =
540 "(>> (>> transpose transpose) (>> (>> transpose transpose) (>> transpose transpose)))"
541 .parse::<Sketch<SymbolLang>>()
542 .unwrap();
543
544 let mut egraph = EGraph::<_, ()>::default();
545 let a_root = egraph.add_expr(&expr_a);
546
547 egraph.rebuild();
548
549 let (_, best_expr) = crate::util::comparing_eclass_extract_sketch(&sketch, AstSize, AstSize, &egraph, a_root).unwrap();
550 assert_eq!(best_expr.to_string(), expr_a.to_string());
551
552 let rules = vec![
553 rewrite!("transpose-id-1"; "(>> (>> transpose transpose) ?x)" => "?x"),
554 rewrite!("transpose-id-2"; "(>> ?x (>> transpose transpose))" => "?x"),
555 ];
556
557 let runner = egg::Runner::default()
558 .with_scheduler(egg::SimpleScheduler)
559 .with_iter_limit(1)
560 .with_egraph(egraph)
561 .run(&rules);
562 let egraph = runner.egraph;
563 let (_, best_expr) =
564 crate::util::comparing_eclass_extract_sketch(&sketch, AstSize, AstSize, &egraph, egraph.find(a_root)).unwrap();
565 assert_eq!(best_expr.to_string(), expr_a.to_string());
566 }
567
568 #[test]
569 fn simple_extract() {
570 let sketch = "(contains (f ?))".parse::<Sketch<SymbolLang>>().unwrap();
571
572 let a_expr = "(g (f (v x)))".parse::<RecExpr<SymbolLang>>().unwrap();
573 let b_expr = "(h (g (f (u x))))".parse::<RecExpr<SymbolLang>>().unwrap();
574 let c_expr = "(h (g x))".parse::<RecExpr<SymbolLang>>().unwrap();
575
576 let mut egraph = EGraph::<SymbolLang, ()>::default();
577 let a = egraph.add_expr(&a_expr);
578 let b = egraph.add_expr(&b_expr);
579 let c = egraph.add_expr(&c_expr);
580
581 egraph.rebuild();
582
583 let sat1 = satisfies_sketch(&sketch, &egraph);
584 assert_eq!(sat1.len(), 5);
585 assert!(sat1.contains(&a));
586 assert!(sat1.contains(&b));
587 assert!(!sat1.contains(&c));
588
589 egraph.union(a, b);
590 egraph.rebuild();
591
592 let sat2 = satisfies_sketch(&sketch, &egraph);
593 assert_eq!(sat2.len(), 4);
594 assert!(sat2.contains(&a));
595 assert!(sat2.contains(&egraph.find(b)));
596 assert!(!sat2.contains(&c));
597
598 let (best_cost, best_expr) = crate::util::comparing_eclass_extract_sketch(&sketch, AstSize, AstSize, &egraph, a).unwrap();
599 assert_eq!(best_cost, 4);
600 assert_eq!(best_expr, a_expr);
601 }
602
603 #[test]
604 fn contains_only() {
605 let sketch = "(contains (id x))".parse::<Sketch<SymbolLang>>().unwrap();
606 let sketch_only = "(onlyContains (id x))".parse::<Sketch<SymbolLang>>().unwrap();
607
608 let expr = "(op x x x)".parse::<RecExpr<SymbolLang>>().unwrap();
609 let id_expr = "(id x)".parse::<RecExpr<SymbolLang>>().unwrap();
610 let x_expr = "x".parse::<RecExpr<SymbolLang>>().unwrap();
611
612 let mut egraph = EGraph::<SymbolLang, ()>::default();
613 let e = egraph.add_expr(&expr);
614 let id = egraph.add_expr(&id_expr);
615 let x = egraph.add_expr(&x_expr);
616 egraph.rebuild();
617
618 let sat1 = satisfies_sketch(&sketch, &egraph);
619 assert_eq!(sat1.len(), 1);
620 assert!(sat1.contains(&id));
621 let sat2 = satisfies_sketch(&sketch_only, &egraph);
622 assert_eq!(sat2.len(), 1);
623 assert!(sat2.contains(&id));
624
625 egraph.union(id, x);
626 egraph.rebuild();
627
628 let sat3 = satisfies_sketch(&sketch, &egraph);
629 assert_eq!(sat3.len(), 2);
630 assert!(sat3.contains(&egraph.find(id)));
631 assert!(sat3.contains(&egraph.find(e)));
632 let sat4 = satisfies_sketch(&sketch_only, &egraph);
633 assert_eq!(sat4.len(), 2);
634 assert!(sat4.contains(&egraph.find(id)));
635 assert!(sat4.contains(&egraph.find(e)));
636
637 let (best_cost, best_expr) = crate::util::comparing_eclass_extract_sketch(&sketch, AstSize, AstSize, &egraph, e).unwrap();
638 assert_eq!(best_cost, 5);
639 {
640 let mut expected_expr = RecExpr::default();
641 let x1 = expected_expr.add(SymbolLang::leaf("x"));
642 let x2 = expected_expr.add(SymbolLang::leaf("x"));
643 let id = expected_expr.add(SymbolLang::new("id", vec![x2]));
644 let op = expected_expr.add(SymbolLang::new("op", vec![id, x1, x1]));
645 assert_eq!(best_expr, expected_expr);
646 }
647
648 let (only_best_cost, only_best_expr) = crate::util::comparing_eclass_extract_sketch(&sketch_only, AstSize, AstSize, &egraph, e).unwrap();
649 assert_eq!(only_best_cost, 7);
650 {
651 let mut only_expected_expr = RecExpr::default();
652 let x = only_expected_expr.add(SymbolLang::leaf("x"));
653 let id = only_expected_expr.add(SymbolLang::new("id", vec![x]));
654 let op = only_expected_expr.add(SymbolLang::new("op", vec![id, id, id]));
655 assert_eq!(only_best_expr, only_expected_expr);
656 }
657 }
658}