egg_sketches/
sketch.rs

1use egg::{Id, Language, RecExpr};
2use std::{fmt::{self, Display, Formatter}};
3use thiserror::Error;
4
5/// A [`Sketch`] is a program pattern that is satisfied by a family of programs.
6///
7/// It can also be seen as an incomplete or partial program as it can leave details unspecified.
8pub type Sketch<L> = RecExpr<SketchNode<L>>;
9
10/// The language of [`Sketch`]es.
11///
12#[derive(Debug, Hash, PartialEq, Eq, Clone, PartialOrd, Ord)]
13pub enum SketchNode<L> {
14    /// Any program of the underlying [`Language`].
15    ///
16    /// Corresponds to the `?` syntax.
17    Any,
18    /// Programs made from this [`Language`] node whose children satisfy the given sketches.
19    ///
20    /// Corresponds to the `(language_node s1 .. sn)` syntax.
21    Node(L),
22    /// Programs that contain *at least one* sub-program satisfying the given sketch.
23    ///
24    /// Corresponds to the `(contains s)` syntax.
25    Contains(Id),
26    /// Programs that *only* contain sub-programs satisfying the given sketch.
27    ///
28    /// Corresponds to the `(onlyContains s)` syntax.
29    OnlyContains(Id),
30    /// Programs that satisfy any of these sketches.
31    ///
32    /// Corresponds to the `(or s1 .. sn)` syntax.
33    Or(Vec<Id>),
34}
35
36#[derive(Debug, Hash, PartialEq, Eq, Clone)]
37pub enum SketchDiscriminant<L: Language> {
38    Any,
39    Node(L::Discriminant),
40    Contains,
41    OnlyContains,
42    Or,
43}
44
45impl<L: Language> Language for SketchNode<L> {
46    type Discriminant = SketchDiscriminant<L>;
47
48    #[inline(always)]
49    fn discriminant(&self) -> Self::Discriminant {
50        match self {
51            SketchNode::Any => SketchDiscriminant::Any,
52            SketchNode::Node(n) => SketchDiscriminant::Node(n.discriminant()),
53            SketchNode::Contains(_) => SketchDiscriminant::Contains,
54            SketchNode::OnlyContains(_) => SketchDiscriminant::OnlyContains,
55            SketchNode::Or(_) => SketchDiscriminant::Or
56        }
57    }
58
59    fn matches(&self, _other: &Self) -> bool {
60        panic!("Should never call this")
61    }
62
63    fn children(&self) -> &[Id] {
64        match self {
65            Self::Any => &[],
66            Self::Node(n) => n.children(),
67            Self::Contains(s) => std::slice::from_ref(s),
68            Self::OnlyContains(s) => std::slice::from_ref(s),
69            Self::Or(ss) => ss.as_slice(),
70        }
71    }
72
73    fn children_mut(&mut self) -> &mut [Id] {
74        match self {
75            Self::Any => &mut [],
76            Self::Node(n) => n.children_mut(),
77            Self::Contains(s) => std::slice::from_mut(s),
78            Self::OnlyContains(s) => std::slice::from_mut(s),
79            Self::Or(ss) => ss.as_mut_slice(),
80        }
81    }
82}
83
84impl<L: Language + Display> Display for SketchNode<L> {
85    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
86        match self {
87            Self::Any => write!(f, "?"),
88            Self::Node(node) => Display::fmt(node, f),
89            Self::Contains(_) => write!(f, "contains"),
90            Self::OnlyContains(_) => write!(f, "onlyContains"),
91            Self::Or(_) => write!(f, "or"),
92        }
93    }
94}
95
96#[derive(Debug, Error)]
97pub enum SketchParseError<E> {
98    #[error("wrong number of children: {0:?}")]
99    BadChildren(egg::FromOpError),
100
101    #[error(transparent)]
102    BadOp(E),
103}
104
105impl<L: egg::FromOp> egg::FromOp for SketchNode<L> {
106    type Error = SketchParseError<L::Error>;
107
108    fn from_op(op: &str, children: Vec<Id>) -> Result<Self, Self::Error> {
109        match op {
110            "?" => {
111                if children.len() == 0 {
112                    Ok(Self::Any)
113                } else {
114                    Err(SketchParseError::BadChildren(egg::FromOpError::new(
115                        op, children,
116                    )))
117                }
118            }
119            "contains" => {
120                if children.len() == 1 {
121                    Ok(Self::Contains(children[0]))
122                } else {
123                    Err(SketchParseError::BadChildren(egg::FromOpError::new(
124                        op, children,
125                    )))
126                }
127            }
128            "onlyContains" => {
129                if children.len() == 1 {
130                    Ok(Self::OnlyContains(children[0]))
131                } else {
132                    Err(SketchParseError::BadChildren(egg::FromOpError::new(
133                        op, children,
134                    )))
135                }
136            }
137            "or" => Ok(Self::Or(children)),
138            _ => L::from_op(op, children)
139                .map(Self::Node)
140                .map_err(SketchParseError::BadOp),
141        }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use crate::*;
149
150    #[test]
151    fn parse_and_print() {
152        let string = "(contains (f ?))";
153        let sketch = string.parse::<Sketch<SymbolLang>>().unwrap();
154
155        let mut sketch_ref = RecExpr::default();
156        let any = sketch_ref.add(SketchNode::Any);
157        let f = sketch_ref.add(SketchNode::Node(SymbolLang::new("f", vec![any])));
158        let _ = sketch_ref.add(SketchNode::Contains(f));
159
160        assert_eq!(sketch, sketch_ref);
161        assert_eq!(sketch.to_string(), string);
162    }
163}