1use crate::*;
2use hashcons::ExprHashCons;
3use sketch::SketchNode;
4use extract::ExtractAnalysis;
5
6pub fn eclass_extract_sketch<L, A, CF>(
8 sketch: &Sketch<L>,
9 mut cost_f: CF,
10 egraph: &EGraph<L, A>,
11 id: Id,
12) -> Option<(CF::Cost, RecExpr<L>)>
13where
14 L: Language,
15 A: Analysis<L>,
16 CF: CostFunction<L>,
17 CF::Cost: 'static + Ord,
18{
19 assert!(egraph.clean);
20 let mut memo = HashMap::<(Id, Id), Option<(CF::Cost, Id)>>::default();
21 let sketch_root = Id::from(sketch.as_ref().len() - 1);
22 let mut exprs = ExprHashCons::new();
23
24 let mut extracted = HashMap::default();
25 let analysis = ExtractAnalysis {
26 exprs: &mut exprs,
27 cost_f: &mut cost_f,
28 };
29 analysis::one_shot_analysis(&egraph, analysis, &mut extracted);
30
31 let best_option = extract_rec(
32 id,
33 sketch,
34 sketch_root,
35 &mut cost_f,
36 egraph,
37 &mut exprs,
38 &extracted,
39 &mut memo,
40 );
41
42 best_option.map(|(best_cost, best_id)| (best_cost, exprs.extract(best_id)))
43}
44
45fn extract_rec<L, A, CF>(
46 egraph_id: Id,
47 sketch: &Sketch<L>,
48 sketch_id: Id,
49 cost_f: &mut CF,
50 egraph: &EGraph<L, A>,
51 exprs: &mut ExprHashCons<L>,
52 extracted: &HashMap<Id, (CF::Cost, Id)>,
53 memo: &mut HashMap<(Id, Id), Option<(CF::Cost, Id)>>,
54) -> Option<(CF::Cost, Id)>
55where
56 L: Language,
57 A: Analysis<L>,
58 CF: CostFunction<L>,
59 CF::Cost: 'static + Ord,
60{
61 assert_eq!(egraph.find(egraph_id), egraph_id);
62 match memo.get(&(egraph_id, sketch_id)) {
63 Some(value) => return value.clone(),
64 None => (),
65 };
66
67 let result = match &sketch[sketch_id] {
68 SketchNode::Any =>
69 extracted.get(&egraph_id).cloned(),
70 SketchNode::Node(sketch_node) => {
71 let eclass = &egraph[egraph_id];
72
73 let mut candidates = Vec::new();
75 let mnode = &sketch_node.clone().map_children(|_| Id::from(0));
76 let _ = eclass.for_each_matching_node::<()>(mnode, |matched| {
77 let mut matches = Vec::new();
78 for (sid, id) in sketch_node.children().iter().zip(matched.children()) {
79 if let Some(m) = extract_rec(*id, sketch, *sid, cost_f, egraph, exprs, extracted, memo) {
80 matches.push(m);
81 } else {
82 break;
83 }
84 }
85
86 assert!(matched.all(|c| c == egraph.find(c)));
87 if matches.len() == matched.len() {
88 let mut node_to_child_indices = sketch_node.clone();
89 for (child_index, id) in node_to_child_indices.children_mut().into_iter().enumerate() {
90 *id = Id::from(child_index);
91 }
92 let to_match: HashMap<_, _> =
93 node_to_child_indices.children().iter().zip(matches.iter()).collect();
94 candidates.push((
95 cost_f.cost(&node_to_child_indices, |c| to_match[&c].0.clone()),
96 exprs.add(node_to_child_indices.clone().map_children(|c| to_match[&c].1)),
97 ));
98 }
99
100 Ok(())
101 });
102
103 candidates.into_iter().min_by(|x, y| x.0.cmp(&y.0))
104 }
105 SketchNode::Contains(inner_sketch_id) => {
106 memo.insert((egraph_id, sketch_id), None);
110
111 let eclass = &egraph[egraph_id];
112 let mut candidates = Vec::new();
113 candidates.extend(extract_rec(egraph_id, sketch, *inner_sketch_id, cost_f, egraph, exprs, extracted, memo));
115
116 for enode in &eclass.nodes {
118 extract_common::push_extract_contains_candidates(&mut candidates,
119 exprs, cost_f, extracted, egraph, enode, |c, exprs, cost_f, extracted, egraph| {
120 extract_rec(c, sketch, sketch_id, cost_f, egraph, exprs, extracted, memo)
121 }
122 )
123 }
124
125 candidates.into_iter().min_by(|x, y| x.0.cmp(&y.0))
126 }
127 SketchNode::OnlyContains(inner_sketch_id) => {
128 memo.insert((egraph_id, sketch_id), None);
133
134 let eclass = &egraph[egraph_id];
135 let mut candidates = Vec::new();
136 candidates.extend(extract_rec(egraph_id, sketch, *inner_sketch_id, cost_f, egraph, exprs, extracted, memo));
138
139 for enode in &eclass.nodes {
141 candidates.extend(extract_common::extract_only_contains_candidate(
142 exprs, cost_f, egraph, enode, |c, exprs, cost_f, egraph| {
143 extract_rec(c, sketch, sketch_id, cost_f, egraph, exprs, extracted, memo)
144 }
145 ))
146 }
147
148 candidates.into_iter().min_by(|x, y| x.0.cmp(&y.0))
149 }
150 SketchNode::Or(inner_sketch_ids) => {
151 inner_sketch_ids
152 .iter()
153 .flat_map(|sid| {
154 extract_rec(egraph_id, sketch, *sid, cost_f, egraph, exprs, extracted, memo)
155 })
156 .min_by(|x, y| x.0.cmp(&y.0))
157 }
158 };
159
160 memo.insert((egraph_id, sketch_id), result.clone());
168 result
169}