egg_sketches/
extract.rs

1use crate::*;
2use analysis::{one_shot_analysis, SemiLatticeAnalysis};
3use hashcons::ExprHashCons;
4use sketch::SketchNode;
5use std::cmp::Ordering;
6
7/// Is the `id` e-class of `egraph` representing at least one program satisfying `s`?
8pub fn eclass_satisfies_sketch<L: Language, A: Analysis<L>>(
9    s: &Sketch<L>,
10    egraph: &EGraph<L, A>,
11    id: Id,
12) -> bool {
13    satisfies_sketch(s, egraph).contains(&id)
14}
15
16/// Returns the set of e-classes of `egraph` that represent at least one program satisfying `s`.
17pub fn satisfies_sketch<L: Language, A: Analysis<L>>(
18    s: &Sketch<L>,
19    egraph: &EGraph<L, A>,
20) -> HashSet<Id> {
21    assert!(egraph.clean);
22    let mut memo = HashMap::<Id, HashSet<Id>>::default();
23    let sketch_nodes = s.as_ref();
24    let sketch_root = Id::from(sketch_nodes.len() - 1);
25    satisfies_sketch_rec(sketch_nodes, sketch_root, egraph, &mut memo)
26}
27
28fn satisfies_sketch_rec<L: Language, A: Analysis<L>>(
29    s_nodes: &[SketchNode<L>],
30    s_index: Id,
31    egraph: &EGraph<L, A>,
32    memo: &mut HashMap<Id, HashSet<Id>>,
33) -> HashSet<Id> {
34    match memo.get(&s_index) {
35        Some(value) => return value.clone(),
36        None => (),
37    };
38
39    let result = match &s_nodes[usize::from(s_index)] {
40        SketchNode::Any =>
41            egraph.classes().map(|c| c.id).collect(),
42        SketchNode::Node(node) => {
43            let children_matches = node
44                .children()
45                .iter()
46                .map(|sid| satisfies_sketch_rec(s_nodes, *sid, egraph, memo))
47                .collect::<Vec<_>>();
48
49            if let Some(potential_ids) = egraph.classes_for_op(&node.discriminant()) {
50                potential_ids
51                    .filter(|&id| {
52                        let eclass = &egraph[id];
53
54                        let mnode = &node.clone().map_children(|_| Id::from(0));
55                        eclass.for_each_matching_node(mnode, |matched| {
56                            let children_match = children_matches
57                                .iter()
58                                .zip(matched.children())
59                                .all(|(matches, id)| matches.contains(id));
60                            if children_match {
61                                Err(())
62                            } else {
63                                Ok(())
64                            }
65                        })
66                        .is_err()
67                    })
68                    .collect()
69            } else {
70                HashSet::default()
71            }
72        }
73        SketchNode::Contains(sid) => {
74            let contained_matched = satisfies_sketch_rec(s_nodes, *sid, egraph, memo);
75
76            let mut data = egraph
77                .classes()
78                .map(|eclass| (eclass.id, contained_matched.contains(&eclass.id)))
79                .collect::<HashMap<_, bool>>();
80
81            one_shot_analysis(egraph, SatisfiesContainsAnalysis, &mut data);
82
83            data.iter()
84                .flat_map(|(&id, &is_match)| if is_match { Some(id) } else { None })
85                .collect()
86        }
87        SketchNode::OnlyContains(sid) => {
88            let contained_matched = satisfies_sketch_rec(s_nodes, *sid, egraph, memo);
89
90            let mut data = egraph
91                .classes()
92                .map(|eclass| (eclass.id, contained_matched.contains(&eclass.id)))
93                .collect::<HashMap<_, bool>>();
94
95            one_shot_analysis(egraph, SatisfiesOnlyContainsAnalysis, &mut data);
96
97            data.iter()
98                .flat_map(|(&id, &is_match)| if is_match { Some(id) } else { None })
99                .collect()
100        }
101        SketchNode::Or(sids) => {
102            let matches = sids
103                .iter()
104                .map(|sid| satisfies_sketch_rec(s_nodes, *sid, egraph, memo));
105            matches
106                .reduce(|a, b| a.union(&b).cloned().collect())
107                .expect("empty or sketch")
108        }
109    };
110
111    memo.insert(s_index, result.clone());
112    result
113}
114
115pub struct SatisfiesContainsAnalysis;
116impl<L: Language, A: Analysis<L>> SemiLatticeAnalysis<L, A> for SatisfiesContainsAnalysis {
117    type Data = bool;
118
119    fn make<'a>(
120        &mut self,
121        _egraph: &EGraph<L, A>,
122        enode: &L,
123        analysis_of: &'a impl Fn(Id) -> &'a Self::Data,
124    ) -> Self::Data
125    where
126        Self::Data: 'a,
127    {
128        enode.any(|c| *analysis_of(c))
129    }
130
131    fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge {
132        let r = *a || b;
133        let dm = DidMerge(r != *a, r != b);
134        *a = r;
135        dm
136    }
137}
138
139pub struct SatisfiesOnlyContainsAnalysis;
140impl<L: Language, A: Analysis<L>> SemiLatticeAnalysis<L, A> for SatisfiesOnlyContainsAnalysis {
141    type Data = bool;
142
143    fn make<'a>(
144        &mut self,
145        _egraph: &EGraph<L, A>,
146        enode: &L,
147        analysis_of: &'a impl Fn(Id) -> &'a Self::Data,
148    ) -> Self::Data
149    where
150        Self::Data: 'a,
151    {
152        if enode.children().is_empty() {
153            false
154        } else {
155            enode.all(|c| *analysis_of(c))
156        }
157    }
158
159    fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge {
160        let r = *a || b;
161        let dm = DidMerge(r != *a, r != b);
162        *a = r;
163        dm
164    }
165}
166
167/// Returns the best program satisfying `s` according to `cost_f` that is represented in the `id` e-class of `egraph`, if it exists.
168pub fn eclass_extract_sketch<L, A, CF>(
169    s: &Sketch<L>,
170    cost_f: CF,
171    egraph: &EGraph<L, A>,
172    id: Id,
173) -> Option<(CF::Cost, RecExpr<L>)>
174where
175    L: Language,
176    A: Analysis<L>,
177    CF: CostFunction<L>,
178    CF::Cost: 'static + Ord,
179{
180    assert_eq!(egraph.find(id), id);
181    let (exprs, eclass_to_best) = extract_sketch(s, cost_f, egraph);
182    eclass_to_best
183        .get(&id)
184        .map(|(best_cost, best_id)| (best_cost.clone(), exprs.extract(*best_id)))
185}
186
187fn extract_sketch<L, A, CF>(
188    sketch: &Sketch<L>,
189    mut cost_f: CF,
190    egraph: &EGraph<L, A>,
191) -> (ExprHashCons<L>, HashMap<Id, (CF::Cost, Id)>)
192where
193    L: Language,
194    A: Analysis<L>,
195    CF: CostFunction<L>,
196    CF::Cost: 'static + Ord,
197{
198    assert!(egraph.clean);
199    let mut memo = HashMap::<Id, HashMap<Id, (CF::Cost, Id)>>::default();
200    let sketch_root = Id::from(sketch.as_ref().len() - 1);
201    let mut exprs = ExprHashCons::new();
202
203    let mut extracted = HashMap::default();
204    let analysis = ExtractAnalysis {
205        exprs: &mut exprs,
206        cost_f: &mut cost_f,
207    };
208    one_shot_analysis(&egraph, analysis, &mut extracted);
209
210    let res = extract_sketch_rec(
211        sketch,
212        sketch_root,
213        &mut cost_f,
214        egraph,
215        &mut exprs,
216        &extracted,
217        &mut memo,
218    );
219    (exprs, res)
220}
221
222fn extract_sketch_rec<L, A, CF>(
223    sketch: &Sketch<L>,
224    sketch_id: Id,
225    cost_f: &mut CF,
226    egraph: &EGraph<L, A>,
227    exprs: &mut ExprHashCons<L>,
228    extracted: &HashMap<Id, (CF::Cost, Id)>,
229    memo: &mut HashMap<Id, HashMap<Id, (CF::Cost, Id)>>,
230) -> HashMap<Id, (CF::Cost, Id)>
231where
232    L: Language,
233    A: Analysis<L>,
234    CF: CostFunction<L>,
235    CF::Cost: 'static + Ord,
236{
237    match memo.get(&sketch_id) {
238        Some(value) => return value.clone(),
239        None => (),
240    };
241
242    let result = match &sketch[sketch_id] {
243        SketchNode::Any => extracted.clone(),
244        SketchNode::Node(sketch_node) => {
245            // for each child, contains map from eclass-id to best
246            let children_matches = sketch_node
247                .children()
248                .iter()
249                .map(|sid| {
250                    extract_sketch_rec(sketch, *sid, cost_f, egraph, exprs, extracted, memo)
251                })
252                .collect::<Vec<_>>();
253
254            if let Some(potential_ids) = egraph.classes_for_op(&sketch_node.discriminant()) {
255                potential_ids
256                    .flat_map(|id| {
257                        let eclass = &egraph[id];
258                        let mut candidates = Vec::new();
259
260                        let mnode = &sketch_node.clone().map_children(|_| Id::from(0));
261                        let _ = eclass.for_each_matching_node::<()>(mnode, |matched| {
262                            // matched is a enode with children being e-classes
263
264                            let mut matches = Vec::new();
265                            // for each child, matches lists the best
266                            for (cm, id) in children_matches.iter().zip(matched.children()) {
267                                if let Some(m) = cm.get(id) {
268                                    matches.push(m);
269                                } else {
270                                    break;
271                                }
272                            }
273
274                            if matches.len() == matched.len() {
275                                // for each child, map to the best based on child index
276                                let mut node_to_child_indices = sketch_node.clone();
277                                for (child_index, id) in node_to_child_indices.children_mut().into_iter().enumerate() {
278                                    *id = Id::from(child_index);
279                                }
280                                
281                                let to_match: HashMap<_, _> =
282                                    node_to_child_indices.children().iter().zip(matches.iter()).collect();
283
284                                candidates.push((
285                                    cost_f.cost(&node_to_child_indices, |c| to_match[&c].0.clone()),
286                                    exprs.add(node_to_child_indices.clone().map_children(|c| to_match[&c].1)),
287                                ));
288                            }
289
290                            Ok(())
291                        });
292
293                        candidates
294                            .into_iter()
295                            .min_by(|x, y| x.0.cmp(&y.0))
296                            .map(|best| (id, best))
297                    })
298                    .collect()
299            } else {
300                HashMap::default()
301            }
302        }
303        SketchNode::Contains(sid) => {
304            let contained_matches =
305                extract_sketch_rec(sketch, *sid, cost_f, egraph, exprs, extracted, memo);
306
307            let mut data = egraph
308                .classes()
309                .map(|eclass| (eclass.id, contained_matches.get(&eclass.id).cloned()))
310                .collect::<HashMap<_, _>>();
311
312            let analysis = ExtractContainsAnalysis {
313                exprs,
314                cost_f,
315                extracted,
316            };
317
318            one_shot_analysis(egraph, analysis, &mut data);
319
320            data.into_iter()
321                .flat_map(|(id, maybe_best)| maybe_best.map(|b| (id, b)))
322                .collect()
323        }
324        SketchNode::OnlyContains(sid) => {
325            let contained_matches =
326                extract_sketch_rec(sketch, *sid, cost_f, egraph, exprs, extracted, memo);
327
328            let mut data = egraph
329                .classes()
330                .map(|eclass| (eclass.id, contained_matches.get(&eclass.id).cloned()))
331                .collect::<HashMap<_, _>>();
332
333            let analysis = ExtractOnlyContainsAnalysis {
334                exprs,
335                cost_f,
336            };
337
338            one_shot_analysis(egraph, analysis, &mut data);
339
340            data.into_iter()
341                .flat_map(|(id, maybe_best)| maybe_best.map(|b| (id, b)))
342                .collect()
343        }
344        SketchNode::Or(sids) => {
345            let matches = sids
346                .iter()
347                .map(|sid| {
348                    extract_sketch_rec(sketch, *sid, cost_f, egraph, exprs, extracted, memo)
349                })
350                .collect::<Vec<_>>();
351            let mut matching_ids = HashSet::default();
352            for m in &matches {
353                matching_ids.extend(m.keys());
354            }
355
356            matching_ids
357                .iter()
358                .flat_map(|id| {
359                    let mut candidates = Vec::new();
360                    for ms in &matches {
361                        candidates.extend(ms.get(id));
362                    }
363                    candidates
364                        .into_iter()
365                        .min_by(|x, y| x.0.cmp(&y.0))
366                        .map(|best| (*id, best.clone()))
367                })
368                .collect()
369        }
370    };
371
372    /* DEBUG
373        println!("result for sketch node {:?}", sketch_id);
374        for (id, (cost, expr_id)) in &result {
375            println!("- e-class {} result of cost {:?}: {}", id, cost, exprs.extract(*expr_id));
376        }
377    */
378
379    memo.insert(sketch_id, result.clone());
380    result
381}
382
383pub struct ExtractContainsAnalysis<'a, L, CF>
384where
385    L: Language,
386    CF: CostFunction<L>,
387{
388    exprs: &'a mut ExprHashCons<L>,
389    cost_f: &'a mut CF,
390    extracted: &'a HashMap<Id, (CF::Cost, Id)>,
391}
392
393fn merge_best_option<Cost>(
394    a: &mut Option<(Cost, Id)>,
395    b: Option<(Cost, Id)>) -> DidMerge
396where
397    Cost: 'static + Ord,
398{
399    let ord = match (&a, &b) {
400        (None, None) => Ordering::Equal,
401        (Some(_), None) => Ordering::Less,
402        (None, Some(_)) => Ordering::Greater,
403        (&Some((ref ca, _)), &Some((ref cb, _))) => ca.cmp(cb),
404    };
405    match ord {
406        Ordering::Equal => DidMerge(false, false),
407        Ordering::Less => DidMerge(false, true),
408        Ordering::Greater => {
409            *a = b;
410            DidMerge(true, false)
411        }
412    }
413}
414
415impl<'a, L, A, CF> SemiLatticeAnalysis<L, A> for ExtractContainsAnalysis<'a, L, CF>
416where
417    L: Language,
418    A: Analysis<L>,
419    CF: CostFunction<L>,
420    CF::Cost: 'static + Ord,
421{
422    type Data = Option<(CF::Cost, Id)>;
423
424    fn make<'b>(
425        &mut self,
426        egraph: &EGraph<L, A>,
427        enode: &L,
428        analysis_of: &'b impl Fn(Id) -> &'b Self::Data,
429    ) -> Self::Data
430    where
431        Self::Data: 'b,
432    {
433        let mut candidates = Vec::new();
434        extract_common::push_extract_contains_candidates(&mut candidates,
435            self.exprs, self.cost_f, self.extracted,
436            egraph, enode, |c, _, _, _, _| { analysis_of(c).clone() }
437        );
438        candidates.into_iter().min_by(|x, y| x.0.cmp(&y.0))
439    }
440
441    fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge {
442        merge_best_option(a, b)
443    }
444}
445
446pub struct ExtractOnlyContainsAnalysis<'a, L, CF>
447where
448    L: Language,
449    CF: CostFunction<L>,
450{
451    exprs: &'a mut ExprHashCons<L>,
452    cost_f: &'a mut CF,
453}
454
455impl<'a, L, A, CF> SemiLatticeAnalysis<L, A> for ExtractOnlyContainsAnalysis<'a, L, CF>
456where
457    L: Language,
458    A: Analysis<L>,
459    CF: CostFunction<L>,
460    CF::Cost: 'static + Ord,
461{
462    type Data = Option<(CF::Cost, Id)>;
463
464    fn make<'b>(
465        &mut self,
466        egraph: &EGraph<L, A>,
467        enode: &L,
468        analysis_of: &'b impl Fn(Id) -> &'b Self::Data,
469    ) -> Self::Data
470    where
471        Self::Data: 'b,
472    {
473        extract_common::extract_only_contains_candidate(
474            self.exprs, self.cost_f,
475            egraph, enode,|c, _, _, _| { analysis_of(c).clone() }
476        )
477    }
478
479    fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge {
480        merge_best_option(a, b)
481    }
482}
483
484pub(crate) struct ExtractAnalysis<'a, L, CF> {
485    pub(crate) exprs: &'a mut ExprHashCons<L>,
486    pub(crate) cost_f: &'a mut CF,
487}
488
489impl<'a, L, A, CF> SemiLatticeAnalysis<L, A> for ExtractAnalysis<'a, L, CF>
490where
491    L: Language,
492    A: Analysis<L>,
493    CF: CostFunction<L>,
494    CF::Cost: 'static,
495{
496    type Data = (CF::Cost, Id);
497
498    fn make<'b>(
499        &mut self,
500        _egraph: &EGraph<L, A>,
501        enode: &L,
502        analysis_of: &'b impl Fn(Id) -> &'b Self::Data,
503    ) -> Self::Data
504    where
505        Self::Data: 'b,
506    {
507        let expr_node = enode.clone().map_children(|c| (*analysis_of)(c).1);
508        let expr = self.exprs.add(expr_node);
509        (
510            self.cost_f.cost(enode, |c| (*analysis_of)(c).0.clone()),
511            expr,
512        )
513    }
514
515    fn merge(&mut self, a: &mut Self::Data, b: Self::Data) -> DidMerge {
516        if a.0 < b.0 {
517            DidMerge(false, true)
518        } else if a.0 == b.0 {
519            DidMerge(false, false)
520        } else {
521            *a = b;
522            DidMerge(true, false)
523        }
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530    use crate::*;
531
532    #[test]
533    fn bug202509() {
534        let expr_a =
535            "(>> (>> transpose transpose) (>> (>> transpose transpose) (>> transpose transpose)))"
536                .parse::<RecExpr<SymbolLang>>()
537                .unwrap();
538
539        let sketch =
540            "(>> (>> transpose transpose) (>> (>> transpose transpose) (>> transpose transpose)))"
541                .parse::<Sketch<SymbolLang>>()
542                .unwrap();
543
544        let mut egraph = EGraph::<_, ()>::default();
545        let a_root = egraph.add_expr(&expr_a);
546
547        egraph.rebuild();
548
549        let (_, best_expr) = crate::util::comparing_eclass_extract_sketch(&sketch, AstSize, AstSize, &egraph, a_root).unwrap();
550        assert_eq!(best_expr.to_string(), expr_a.to_string());
551
552        let rules = vec![
553            rewrite!("transpose-id-1";  "(>> (>> transpose transpose) ?x)" => "?x"),
554            rewrite!("transpose-id-2";  "(>> ?x (>> transpose transpose))" => "?x"),
555        ];
556
557        let runner = egg::Runner::default()
558            .with_scheduler(egg::SimpleScheduler)
559            .with_iter_limit(1)
560            .with_egraph(egraph)
561            .run(&rules);
562        let egraph = runner.egraph;
563        let (_, best_expr) =
564            crate::util::comparing_eclass_extract_sketch(&sketch, AstSize, AstSize, &egraph, egraph.find(a_root)).unwrap();
565        assert_eq!(best_expr.to_string(), expr_a.to_string());
566    }
567
568    #[test]
569    fn simple_extract() {
570        let sketch = "(contains (f ?))".parse::<Sketch<SymbolLang>>().unwrap();
571
572        let a_expr = "(g (f (v x)))".parse::<RecExpr<SymbolLang>>().unwrap();
573        let b_expr = "(h (g (f (u x))))".parse::<RecExpr<SymbolLang>>().unwrap();
574        let c_expr = "(h (g x))".parse::<RecExpr<SymbolLang>>().unwrap();
575
576        let mut egraph = EGraph::<SymbolLang, ()>::default();
577        let a = egraph.add_expr(&a_expr);
578        let b = egraph.add_expr(&b_expr);
579        let c = egraph.add_expr(&c_expr);
580
581        egraph.rebuild();
582
583        let sat1 = satisfies_sketch(&sketch, &egraph);
584        assert_eq!(sat1.len(), 5);
585        assert!(sat1.contains(&a));
586        assert!(sat1.contains(&b));
587        assert!(!sat1.contains(&c));
588
589        egraph.union(a, b);
590        egraph.rebuild();
591
592        let sat2 = satisfies_sketch(&sketch, &egraph);
593        assert_eq!(sat2.len(), 4);
594        assert!(sat2.contains(&a));
595        assert!(sat2.contains(&egraph.find(b)));
596        assert!(!sat2.contains(&c));
597
598        let (best_cost, best_expr) = crate::util::comparing_eclass_extract_sketch(&sketch, AstSize, AstSize, &egraph, a).unwrap();
599        assert_eq!(best_cost, 4);
600        assert_eq!(best_expr, a_expr);
601    }
602
603    #[test]
604    fn contains_only() {
605        let sketch = "(contains (id x))".parse::<Sketch<SymbolLang>>().unwrap();
606        let sketch_only = "(onlyContains (id x))".parse::<Sketch<SymbolLang>>().unwrap();
607
608        let expr = "(op x x x)".parse::<RecExpr<SymbolLang>>().unwrap();
609        let id_expr = "(id x)".parse::<RecExpr<SymbolLang>>().unwrap();
610        let x_expr = "x".parse::<RecExpr<SymbolLang>>().unwrap();
611
612        let mut egraph = EGraph::<SymbolLang, ()>::default();
613        let e = egraph.add_expr(&expr);
614        let id = egraph.add_expr(&id_expr);
615        let x = egraph.add_expr(&x_expr);
616        egraph.rebuild();
617
618        let sat1 = satisfies_sketch(&sketch, &egraph);
619        assert_eq!(sat1.len(), 1);
620        assert!(sat1.contains(&id));
621        let sat2 = satisfies_sketch(&sketch_only, &egraph);
622        assert_eq!(sat2.len(), 1);
623        assert!(sat2.contains(&id));
624
625        egraph.union(id, x);
626        egraph.rebuild();
627
628        let sat3 = satisfies_sketch(&sketch, &egraph);
629        assert_eq!(sat3.len(), 2);
630        assert!(sat3.contains(&egraph.find(id)));
631        assert!(sat3.contains(&egraph.find(e)));
632        let sat4 = satisfies_sketch(&sketch_only, &egraph);
633        assert_eq!(sat4.len(), 2);
634        assert!(sat4.contains(&egraph.find(id)));
635        assert!(sat4.contains(&egraph.find(e)));
636
637        let (best_cost, best_expr) = crate::util::comparing_eclass_extract_sketch(&sketch, AstSize, AstSize, &egraph, e).unwrap();
638        assert_eq!(best_cost, 5);
639        {
640            let mut expected_expr = RecExpr::default();
641            let x1 = expected_expr.add(SymbolLang::leaf("x"));
642            let x2 = expected_expr.add(SymbolLang::leaf("x"));
643            let id = expected_expr.add(SymbolLang::new("id", vec![x2]));
644            let op = expected_expr.add(SymbolLang::new("op", vec![id, x1, x1]));
645            assert_eq!(best_expr, expected_expr);
646        }
647
648        let (only_best_cost, only_best_expr) = crate::util::comparing_eclass_extract_sketch(&sketch_only, AstSize, AstSize, &egraph, e).unwrap();
649        assert_eq!(only_best_cost, 7);
650        {
651            let mut only_expected_expr = RecExpr::default();
652            let x = only_expected_expr.add(SymbolLang::leaf("x"));
653            let id = only_expected_expr.add(SymbolLang::new("id", vec![x]));
654            let op = only_expected_expr.add(SymbolLang::new("op", vec![id, id, id]));
655            assert_eq!(only_best_expr, only_expected_expr);
656        }
657    }
658}