1use egg::{Id, Language, RecExpr};
2use std::{fmt::{self, Display, Formatter}};
3use thiserror::Error;
4
5pub type Sketch<L> = RecExpr<SketchNode<L>>;
9
10#[derive(Debug, Hash, PartialEq, Eq, Clone, PartialOrd, Ord)]
13pub enum SketchNode<L> {
14 Any,
18 Node(L),
22 Contains(Id),
26 OnlyContains(Id),
30 Or(Vec<Id>),
34}
35
36#[derive(Debug, Hash, PartialEq, Eq, Clone)]
37pub enum SketchDiscriminant<L: Language> {
38 Any,
39 Node(L::Discriminant),
40 Contains,
41 OnlyContains,
42 Or,
43}
44
45impl<L: Language> Language for SketchNode<L> {
46 type Discriminant = SketchDiscriminant<L>;
47
48 #[inline(always)]
49 fn discriminant(&self) -> Self::Discriminant {
50 match self {
51 SketchNode::Any => SketchDiscriminant::Any,
52 SketchNode::Node(n) => SketchDiscriminant::Node(n.discriminant()),
53 SketchNode::Contains(_) => SketchDiscriminant::Contains,
54 SketchNode::OnlyContains(_) => SketchDiscriminant::OnlyContains,
55 SketchNode::Or(_) => SketchDiscriminant::Or
56 }
57 }
58
59 fn matches(&self, _other: &Self) -> bool {
60 panic!("Should never call this")
61 }
62
63 fn children(&self) -> &[Id] {
64 match self {
65 Self::Any => &[],
66 Self::Node(n) => n.children(),
67 Self::Contains(s) => std::slice::from_ref(s),
68 Self::OnlyContains(s) => std::slice::from_ref(s),
69 Self::Or(ss) => ss.as_slice(),
70 }
71 }
72
73 fn children_mut(&mut self) -> &mut [Id] {
74 match self {
75 Self::Any => &mut [],
76 Self::Node(n) => n.children_mut(),
77 Self::Contains(s) => std::slice::from_mut(s),
78 Self::OnlyContains(s) => std::slice::from_mut(s),
79 Self::Or(ss) => ss.as_mut_slice(),
80 }
81 }
82}
83
84impl<L: Language + Display> Display for SketchNode<L> {
85 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
86 match self {
87 Self::Any => write!(f, "?"),
88 Self::Node(node) => Display::fmt(node, f),
89 Self::Contains(_) => write!(f, "contains"),
90 Self::OnlyContains(_) => write!(f, "onlyContains"),
91 Self::Or(_) => write!(f, "or"),
92 }
93 }
94}
95
96#[derive(Debug, Error)]
97pub enum SketchParseError<E> {
98 #[error("wrong number of children: {0:?}")]
99 BadChildren(egg::FromOpError),
100
101 #[error(transparent)]
102 BadOp(E),
103}
104
105impl<L: egg::FromOp> egg::FromOp for SketchNode<L> {
106 type Error = SketchParseError<L::Error>;
107
108 fn from_op(op: &str, children: Vec<Id>) -> Result<Self, Self::Error> {
109 match op {
110 "?" => {
111 if children.len() == 0 {
112 Ok(Self::Any)
113 } else {
114 Err(SketchParseError::BadChildren(egg::FromOpError::new(
115 op, children,
116 )))
117 }
118 }
119 "contains" => {
120 if children.len() == 1 {
121 Ok(Self::Contains(children[0]))
122 } else {
123 Err(SketchParseError::BadChildren(egg::FromOpError::new(
124 op, children,
125 )))
126 }
127 }
128 "onlyContains" => {
129 if children.len() == 1 {
130 Ok(Self::OnlyContains(children[0]))
131 } else {
132 Err(SketchParseError::BadChildren(egg::FromOpError::new(
133 op, children,
134 )))
135 }
136 }
137 "or" => Ok(Self::Or(children)),
138 _ => L::from_op(op, children)
139 .map(Self::Node)
140 .map_err(SketchParseError::BadOp),
141 }
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use crate::*;
149
150 #[test]
151 fn parse_and_print() {
152 let string = "(contains (f ?))";
153 let sketch = string.parse::<Sketch<SymbolLang>>().unwrap();
154
155 let mut sketch_ref = RecExpr::default();
156 let any = sketch_ref.add(SketchNode::Any);
157 let f = sketch_ref.add(SketchNode::Node(SymbolLang::new("f", vec![any])));
158 let _ = sketch_ref.add(SketchNode::Contains(f));
159
160 assert_eq!(sketch, sketch_ref);
161 assert_eq!(sketch.to_string(), string);
162 }
163}