egg_sketches/
recursive_extract.rs

1use crate::*;
2use hashcons::ExprHashCons;
3use sketch::SketchNode;
4use extract::ExtractAnalysis;
5
6/// Returns the best program satisfying `s` according to `cost_f` that is represented in the `id` e-class of `egraph`, if it exists.
7pub fn eclass_extract_sketch<L, A, CF>(
8  sketch: &Sketch<L>,
9  mut cost_f: CF,
10  egraph: &EGraph<L, A>,
11  id: Id,
12) -> Option<(CF::Cost, RecExpr<L>)>
13where
14  L: Language,
15  A: Analysis<L>,
16  CF: CostFunction<L>,
17  CF::Cost: 'static + Ord,
18{
19  assert!(egraph.clean);
20  let mut memo = HashMap::<(Id, Id), Option<(CF::Cost, Id)>>::default();
21  let sketch_root = Id::from(sketch.as_ref().len() - 1);
22  let mut exprs = ExprHashCons::new();
23
24  let mut extracted = HashMap::default();
25  let analysis = ExtractAnalysis {
26      exprs: &mut exprs,
27      cost_f: &mut cost_f,
28  };
29  analysis::one_shot_analysis(&egraph, analysis, &mut extracted);
30
31  let best_option = extract_rec(
32    id,
33    sketch,
34    sketch_root,
35    &mut cost_f,
36    egraph,
37    &mut exprs,
38    &extracted,
39    &mut memo,
40  );
41
42  best_option.map(|(best_cost, best_id)| (best_cost, exprs.extract(best_id)))
43}
44
45fn extract_rec<L, A, CF>(
46  egraph_id: Id,
47  sketch: &Sketch<L>,
48  sketch_id: Id,
49  cost_f: &mut CF,
50  egraph: &EGraph<L, A>,
51  exprs: &mut ExprHashCons<L>,
52  extracted: &HashMap<Id, (CF::Cost, Id)>,
53  memo: &mut HashMap<(Id, Id), Option<(CF::Cost, Id)>>,
54) -> Option<(CF::Cost, Id)>
55where
56  L: Language,
57  A: Analysis<L>,
58  CF: CostFunction<L>,
59  CF::Cost: 'static + Ord,
60{
61  assert_eq!(egraph.find(egraph_id), egraph_id);
62  match memo.get(&(egraph_id, sketch_id)) {
63    Some(value) => return value.clone(),
64    None => (),
65  };
66
67  let result = match &sketch[sketch_id] {
68    SketchNode::Any =>
69      extracted.get(&egraph_id).cloned(),
70    SketchNode::Node(sketch_node) => {
71      let eclass = &egraph[egraph_id];
72
73      // TODO: factorize code in extract_common
74      let mut candidates = Vec::new();
75      let mnode = &sketch_node.clone().map_children(|_| Id::from(0));
76      let _ = eclass.for_each_matching_node::<()>(mnode, |matched| {
77        let mut matches = Vec::new();
78        for (sid, id) in sketch_node.children().iter().zip(matched.children()) {
79          if let Some(m) = extract_rec(*id, sketch, *sid, cost_f, egraph, exprs, extracted, memo) {
80            matches.push(m);
81          } else {
82            break;
83          }
84        }
85
86        assert!(matched.all(|c| c == egraph.find(c)));
87        if matches.len() == matched.len() {
88          let mut node_to_child_indices = sketch_node.clone();
89          for (child_index, id) in node_to_child_indices.children_mut().into_iter().enumerate() {
90              *id = Id::from(child_index);
91          }
92          let to_match: HashMap<_, _> =
93              node_to_child_indices.children().iter().zip(matches.iter()).collect();
94          candidates.push((
95              cost_f.cost(&node_to_child_indices, |c| to_match[&c].0.clone()),
96              exprs.add(node_to_child_indices.clone().map_children(|c| to_match[&c].1)),
97          ));
98        }
99
100        Ok(())
101      });
102
103      candidates.into_iter().min_by(|x, y| x.0.cmp(&y.0))
104    }
105    SketchNode::Contains(inner_sketch_id) => {
106      // Avoid cycles.
107      // If we have visited the contains once, we do not need to
108      // visit it again as the cost in our setup only goes up.
109      memo.insert((egraph_id, sketch_id), None);
110
111      let eclass = &egraph[egraph_id];
112      let mut candidates = Vec::new();
113      // Base case: when this eclass satisfies inner sketch
114      candidates.extend(extract_rec(egraph_id, sketch, *inner_sketch_id, cost_f, egraph, exprs, extracted, memo));
115
116      // Recursive case: when children satisfy sketch
117      for enode in &eclass.nodes {
118        extract_common::push_extract_contains_candidates(&mut candidates,
119          exprs, cost_f, extracted, egraph, enode, |c, exprs, cost_f, extracted, egraph| {
120            extract_rec(c, sketch, sketch_id, cost_f, egraph, exprs, extracted, memo)
121          }
122        )
123      }
124
125      candidates.into_iter().min_by(|x, y| x.0.cmp(&y.0))
126    }
127    SketchNode::OnlyContains(inner_sketch_id) => {
128      // FIXME: duplicate code
129      // Avoid cycles.
130      // If we have visited the contains once, we do not need to
131      // visit it again as the cost in our setup only goes up.
132      memo.insert((egraph_id, sketch_id), None);
133
134      let eclass = &egraph[egraph_id];
135      let mut candidates = Vec::new();
136      // Base case: when this eclass satisfies inner sketch
137      candidates.extend(extract_rec(egraph_id, sketch, *inner_sketch_id, cost_f, egraph, exprs, extracted, memo));
138
139      // Recursive case: when children satisfy sketch
140      for enode in &eclass.nodes {
141        candidates.extend(extract_common::extract_only_contains_candidate(
142          exprs, cost_f, egraph, enode, |c, exprs, cost_f, egraph| {
143            extract_rec(c, sketch, sketch_id, cost_f, egraph, exprs, extracted, memo)
144          }
145        ))
146      }
147
148      candidates.into_iter().min_by(|x, y| x.0.cmp(&y.0))
149    }
150    SketchNode::Or(inner_sketch_ids) => {
151      inner_sketch_ids
152        .iter()
153        .flat_map(|sid| {
154            extract_rec(egraph_id, sketch, *sid, cost_f, egraph, exprs, extracted, memo)
155        })
156        .min_by(|x, y| x.0.cmp(&y.0))
157    }
158  };
159
160  /* DEBUG
161  if let SketchNode::Contains(_) = &sketch[sketch_id] {
162    if let Some((cost, _)) = &result {
163      println!("result for {:?}, {:?}: {:?}", sketch_id, egraph_id, cost);
164    }
165  } */
166
167  memo.insert((egraph_id, sketch_id), result.clone());
168  result
169}