floem_reactive/
memo.rs

1use std::{
2    any::Any,
3    cell::{Cell, RefCell},
4    collections::HashSet,
5    marker::PhantomData,
6    rc::Rc,
7};
8
9use crate::{
10    effect::{observer_clean_up, EffectPriority, EffectTrait},
11    id::Id,
12    read::SignalTrack,
13    runtime::{Runtime, RUNTIME},
14    scope::Scope,
15    signal::{ReadSignal, RwSignal, WriteSignal},
16    write::SignalUpdate,
17    SignalGet, SignalWith,
18};
19
20/// A memoized derived value that only recomputes when one of its tracked
21/// dependencies changes, and only notifies dependents when its value changes.
22///
23/// Unlike the previous implementation, this is driven by dependency invalidation
24/// rather than an `Effect` that eagerly recomputes.
25pub struct Memo<T: PartialEq + 'static> {
26    getter: ReadSignal<T>,
27    memo_id: Id,
28}
29
30impl<T: PartialEq + 'static> Copy for Memo<T> {}
31
32impl<T: PartialEq + 'static> Clone for Memo<T> {
33    fn clone(&self) -> Self {
34        *self
35    }
36}
37
38impl<T: Clone + PartialEq + 'static> SignalGet<T> for Memo<T>
39where
40    ReadSignal<T>: SignalGet<T>,
41{
42    fn id(&self) -> crate::id::Id {
43        self.getter.id
44    }
45
46    fn get_untracked(&self) -> T
47    where
48        T: 'static,
49    {
50        self.ensure_fresh();
51        self.getter.get_untracked()
52    }
53
54    fn get(&self) -> T
55    where
56        T: 'static,
57    {
58        self.ensure_fresh();
59        self.getter.get()
60    }
61
62    fn try_get(&self) -> Option<T>
63    where
64        T: 'static,
65    {
66        self.ensure_fresh();
67        self.getter.try_get()
68    }
69
70    fn try_get_untracked(&self) -> Option<T>
71    where
72        T: 'static,
73    {
74        self.ensure_fresh();
75        self.getter.try_get_untracked()
76    }
77}
78
79impl<T: PartialEq + 'static> SignalTrack<T> for Memo<T> {
80    fn id(&self) -> crate::id::Id {
81        self.getter.id
82    }
83}
84
85impl<T: PartialEq + 'static> SignalWith<T> for Memo<T>
86where
87    ReadSignal<T>: SignalWith<T>,
88{
89    fn id(&self) -> crate::id::Id {
90        self.getter.id
91    }
92
93    fn with<O>(&self, f: impl FnOnce(&T) -> O) -> O
94    where
95        T: 'static,
96    {
97        self.ensure_fresh();
98        self.getter.with(f)
99    }
100
101    fn with_untracked<O>(&self, f: impl FnOnce(&T) -> O) -> O
102    where
103        T: 'static,
104    {
105        self.ensure_fresh();
106        self.getter.with_untracked(f)
107    }
108
109    fn try_with<O>(&self, f: impl FnOnce(Option<&T>) -> O) -> O
110    where
111        T: 'static,
112    {
113        self.ensure_fresh();
114        self.getter.try_with(f)
115    }
116
117    fn try_with_untracked<O>(&self, f: impl FnOnce(Option<&T>) -> O) -> O
118    where
119        T: 'static,
120    {
121        self.ensure_fresh();
122        self.getter.try_with_untracked(f)
123    }
124}
125
126/// Create a Memo which takes the computed value of the given function, and triggers
127/// the reactive system when the computed value is different from the last computed value.
128#[deprecated(
129    since = "0.2.0",
130    note = "Use Memo::new instead; this will be removed in a future release"
131)]
132#[cfg_attr(debug_assertions, track_caller)]
133pub fn create_memo<T>(f: impl Fn(Option<&T>) -> T + 'static) -> Memo<T>
134where
135    T: PartialEq + 'static,
136{
137    Memo::new(f)
138}
139
140impl<T: PartialEq + 'static> Memo<T> {
141    #[cfg_attr(debug_assertions, track_caller)]
142    pub fn new(f: impl Fn(Option<&T>) -> T + 'static) -> Self {
143        Runtime::assert_ui_thread();
144
145        let memo_id = Id::next();
146        let state = Rc::new(MemoState::new(memo_id, f));
147
148        memo_id.set_scope();
149        let effect: Rc<dyn EffectTrait> = state.clone();
150        RUNTIME.with(|runtime| runtime.register_effect(&effect));
151
152        let initial = state.compute_initial();
153        let (getter, setter) = RwSignal::new_split(initial);
154        state.set_signal(getter, setter);
155        state.mark_clean();
156
157        Memo { getter, memo_id }
158    }
159
160    fn ensure_fresh(&self) {
161        self.with_state(|state| state.ensure_fresh());
162    }
163
164    fn with_state<O>(&self, f: impl FnOnce(&MemoState<T>) -> O) -> Option<O> {
165        RUNTIME.with(|runtime| {
166            runtime
167                .get_effect(self.memo_id)
168                .and_then(|effect| effect.as_any().downcast_ref::<MemoState<T>>().map(f))
169        })
170    }
171}
172
173type ComputeFn<T> = Box<dyn Fn(Option<&T>) -> T>;
174
175struct MemoState<T: PartialEq + 'static> {
176    id: Id,
177    compute: ComputeFn<T>,
178    getter: RefCell<Option<ReadSignal<T>>>,
179    setter: RefCell<Option<WriteSignal<T>>>,
180    dirty: Cell<bool>,
181    observers: RefCell<HashSet<Id>>,
182    _phantom: PhantomData<T>,
183}
184
185impl<T: PartialEq + 'static> MemoState<T> {
186    fn new(id: Id, compute: impl Fn(Option<&T>) -> T + 'static) -> Self {
187        Self {
188            id,
189            compute: Box::new(compute),
190            getter: RefCell::new(None),
191            setter: RefCell::new(None),
192            dirty: Cell::new(true),
193            observers: RefCell::new(HashSet::new()),
194            _phantom: PhantomData,
195        }
196    }
197
198    fn compute_initial(&self) -> T {
199        let effect = RUNTIME
200            .with(|runtime| runtime.get_effect(self.id))
201            .expect("memo registered before initial compute");
202
203        let prev_effect =
204            RUNTIME.with(|runtime| runtime.current_effect.borrow_mut().replace(effect));
205        let scope = Scope(self.id, PhantomData);
206        let value = scope.enter(|| (self.compute)(None));
207
208        RUNTIME.with(|runtime| *runtime.current_effect.borrow_mut() = prev_effect);
209        value
210    }
211
212    fn set_signal(&self, getter: ReadSignal<T>, setter: WriteSignal<T>) {
213        self.getter.replace(Some(getter));
214        self.setter.replace(Some(setter));
215    }
216
217    fn mark_clean(&self) {
218        self.dirty.set(false);
219    }
220
221    fn ensure_fresh(&self) {
222        if !self.dirty.get() {
223            return;
224        }
225        self.recompute();
226    }
227
228    fn recompute(&self) {
229        let getter = self
230            .getter
231            .borrow()
232            .as_ref()
233            .copied()
234            .expect("memo getter set");
235        Runtime::assert_ui_thread();
236        let effect = RUNTIME
237            .with(|runtime| runtime.get_effect(self.id))
238            .expect("memo registered");
239
240        observer_clean_up(&effect);
241
242        let prev_effect =
243            RUNTIME.with(|runtime| runtime.current_effect.borrow_mut().replace(effect));
244        let scope = Scope(self.id, PhantomData);
245        let (changed, new_value) = scope.enter(|| {
246            getter.try_with_untracked(|prev| {
247                let new_value = (self.compute)(prev);
248                let changed = match prev {
249                    Some(previous) => new_value != *previous,
250                    None => true,
251                };
252                (changed, new_value)
253            })
254        });
255        RUNTIME.with(|runtime| *runtime.current_effect.borrow_mut() = prev_effect);
256
257        if changed {
258            if let Some(setter) = self.setter.borrow().as_ref() {
259                setter.set(new_value);
260            }
261        }
262
263        self.dirty.set(false);
264    }
265}
266
267impl<T: PartialEq + 'static> Drop for MemoState<T> {
268    fn drop(&mut self) {
269        if RUNTIME
270            .try_with(|runtime| runtime.remove_effect(self.id))
271            .is_ok()
272        {
273            self.id.dispose();
274        }
275    }
276}
277
278impl<T> EffectTrait for MemoState<T>
279where
280    T: PartialEq + 'static,
281{
282    fn id(&self) -> Id {
283        self.id
284    }
285
286    fn run(&self) -> bool {
287        self.dirty.set(true);
288        self.ensure_fresh();
289        true
290    }
291
292    fn add_observer(&self, id: Id) {
293        self.observers.borrow_mut().insert(id);
294    }
295
296    fn clear_observers(&self) -> HashSet<Id> {
297        std::mem::take(&mut *self.observers.borrow_mut())
298    }
299
300    fn priority(&self) -> EffectPriority {
301        EffectPriority::High
302    }
303
304    fn as_any(&self) -> &dyn Any {
305        self
306    }
307}