1use std::{
2 any::Any,
3 cell::{Cell, RefCell},
4 collections::HashSet,
5 marker::PhantomData,
6 rc::Rc,
7};
8
9use crate::{
10 effect::{observer_clean_up, EffectPriority, EffectTrait},
11 id::Id,
12 read::SignalTrack,
13 runtime::{Runtime, RUNTIME},
14 scope::Scope,
15 signal::{ReadSignal, RwSignal, WriteSignal},
16 write::SignalUpdate,
17 SignalGet, SignalWith,
18};
19
20pub struct Memo<T: PartialEq + 'static> {
26 getter: ReadSignal<T>,
27 memo_id: Id,
28}
29
30impl<T: PartialEq + 'static> Copy for Memo<T> {}
31
32impl<T: PartialEq + 'static> Clone for Memo<T> {
33 fn clone(&self) -> Self {
34 *self
35 }
36}
37
38impl<T: Clone + PartialEq + 'static> SignalGet<T> for Memo<T>
39where
40 ReadSignal<T>: SignalGet<T>,
41{
42 fn id(&self) -> crate::id::Id {
43 self.getter.id
44 }
45
46 fn get_untracked(&self) -> T
47 where
48 T: 'static,
49 {
50 self.ensure_fresh();
51 self.getter.get_untracked()
52 }
53
54 fn get(&self) -> T
55 where
56 T: 'static,
57 {
58 self.ensure_fresh();
59 self.getter.get()
60 }
61
62 fn try_get(&self) -> Option<T>
63 where
64 T: 'static,
65 {
66 self.ensure_fresh();
67 self.getter.try_get()
68 }
69
70 fn try_get_untracked(&self) -> Option<T>
71 where
72 T: 'static,
73 {
74 self.ensure_fresh();
75 self.getter.try_get_untracked()
76 }
77}
78
79impl<T: PartialEq + 'static> SignalTrack<T> for Memo<T> {
80 fn id(&self) -> crate::id::Id {
81 self.getter.id
82 }
83}
84
85impl<T: PartialEq + 'static> SignalWith<T> for Memo<T>
86where
87 ReadSignal<T>: SignalWith<T>,
88{
89 fn id(&self) -> crate::id::Id {
90 self.getter.id
91 }
92
93 fn with<O>(&self, f: impl FnOnce(&T) -> O) -> O
94 where
95 T: 'static,
96 {
97 self.ensure_fresh();
98 self.getter.with(f)
99 }
100
101 fn with_untracked<O>(&self, f: impl FnOnce(&T) -> O) -> O
102 where
103 T: 'static,
104 {
105 self.ensure_fresh();
106 self.getter.with_untracked(f)
107 }
108
109 fn try_with<O>(&self, f: impl FnOnce(Option<&T>) -> O) -> O
110 where
111 T: 'static,
112 {
113 self.ensure_fresh();
114 self.getter.try_with(f)
115 }
116
117 fn try_with_untracked<O>(&self, f: impl FnOnce(Option<&T>) -> O) -> O
118 where
119 T: 'static,
120 {
121 self.ensure_fresh();
122 self.getter.try_with_untracked(f)
123 }
124}
125
126#[deprecated(
129 since = "0.2.0",
130 note = "Use Memo::new instead; this will be removed in a future release"
131)]
132#[cfg_attr(debug_assertions, track_caller)]
133pub fn create_memo<T>(f: impl Fn(Option<&T>) -> T + 'static) -> Memo<T>
134where
135 T: PartialEq + 'static,
136{
137 Memo::new(f)
138}
139
140impl<T: PartialEq + 'static> Memo<T> {
141 #[cfg_attr(debug_assertions, track_caller)]
142 pub fn new(f: impl Fn(Option<&T>) -> T + 'static) -> Self {
143 Runtime::assert_ui_thread();
144
145 let memo_id = Id::next();
146 let state = Rc::new(MemoState::new(memo_id, f));
147
148 memo_id.set_scope();
149 let effect: Rc<dyn EffectTrait> = state.clone();
150 RUNTIME.with(|runtime| runtime.register_effect(&effect));
151
152 let initial = state.compute_initial();
153 let (getter, setter) = RwSignal::new_split(initial);
154 state.set_signal(getter, setter);
155 state.mark_clean();
156
157 Memo { getter, memo_id }
158 }
159
160 fn ensure_fresh(&self) {
161 self.with_state(|state| state.ensure_fresh());
162 }
163
164 fn with_state<O>(&self, f: impl FnOnce(&MemoState<T>) -> O) -> Option<O> {
165 RUNTIME.with(|runtime| {
166 runtime
167 .get_effect(self.memo_id)
168 .and_then(|effect| effect.as_any().downcast_ref::<MemoState<T>>().map(f))
169 })
170 }
171}
172
173type ComputeFn<T> = Box<dyn Fn(Option<&T>) -> T>;
174
175struct MemoState<T: PartialEq + 'static> {
176 id: Id,
177 compute: ComputeFn<T>,
178 getter: RefCell<Option<ReadSignal<T>>>,
179 setter: RefCell<Option<WriteSignal<T>>>,
180 dirty: Cell<bool>,
181 observers: RefCell<HashSet<Id>>,
182 _phantom: PhantomData<T>,
183}
184
185impl<T: PartialEq + 'static> MemoState<T> {
186 fn new(id: Id, compute: impl Fn(Option<&T>) -> T + 'static) -> Self {
187 Self {
188 id,
189 compute: Box::new(compute),
190 getter: RefCell::new(None),
191 setter: RefCell::new(None),
192 dirty: Cell::new(true),
193 observers: RefCell::new(HashSet::new()),
194 _phantom: PhantomData,
195 }
196 }
197
198 fn compute_initial(&self) -> T {
199 let effect = RUNTIME
200 .with(|runtime| runtime.get_effect(self.id))
201 .expect("memo registered before initial compute");
202
203 let prev_effect =
204 RUNTIME.with(|runtime| runtime.current_effect.borrow_mut().replace(effect));
205 let scope = Scope(self.id, PhantomData);
206 let value = scope.enter(|| (self.compute)(None));
207
208 RUNTIME.with(|runtime| *runtime.current_effect.borrow_mut() = prev_effect);
209 value
210 }
211
212 fn set_signal(&self, getter: ReadSignal<T>, setter: WriteSignal<T>) {
213 self.getter.replace(Some(getter));
214 self.setter.replace(Some(setter));
215 }
216
217 fn mark_clean(&self) {
218 self.dirty.set(false);
219 }
220
221 fn ensure_fresh(&self) {
222 if !self.dirty.get() {
223 return;
224 }
225 self.recompute();
226 }
227
228 fn recompute(&self) {
229 let getter = self
230 .getter
231 .borrow()
232 .as_ref()
233 .copied()
234 .expect("memo getter set");
235 Runtime::assert_ui_thread();
236 let effect = RUNTIME
237 .with(|runtime| runtime.get_effect(self.id))
238 .expect("memo registered");
239
240 observer_clean_up(&effect);
241
242 let prev_effect =
243 RUNTIME.with(|runtime| runtime.current_effect.borrow_mut().replace(effect));
244 let scope = Scope(self.id, PhantomData);
245 let (changed, new_value) = scope.enter(|| {
246 getter.try_with_untracked(|prev| {
247 let new_value = (self.compute)(prev);
248 let changed = match prev {
249 Some(previous) => new_value != *previous,
250 None => true,
251 };
252 (changed, new_value)
253 })
254 });
255 RUNTIME.with(|runtime| *runtime.current_effect.borrow_mut() = prev_effect);
256
257 if changed {
258 if let Some(setter) = self.setter.borrow().as_ref() {
259 setter.set(new_value);
260 }
261 }
262
263 self.dirty.set(false);
264 }
265}
266
267impl<T: PartialEq + 'static> Drop for MemoState<T> {
268 fn drop(&mut self) {
269 if RUNTIME
270 .try_with(|runtime| runtime.remove_effect(self.id))
271 .is_ok()
272 {
273 self.id.dispose();
274 }
275 }
276}
277
278impl<T> EffectTrait for MemoState<T>
279where
280 T: PartialEq + 'static,
281{
282 fn id(&self) -> Id {
283 self.id
284 }
285
286 fn run(&self) -> bool {
287 self.dirty.set(true);
288 self.ensure_fresh();
289 true
290 }
291
292 fn add_observer(&self, id: Id) {
293 self.observers.borrow_mut().insert(id);
294 }
295
296 fn clear_observers(&self) -> HashSet<Id> {
297 std::mem::take(&mut *self.observers.borrow_mut())
298 }
299
300 fn priority(&self) -> EffectPriority {
301 EffectPriority::High
302 }
303
304 fn as_any(&self) -> &dyn Any {
305 self
306 }
307}