1use std::{ops::Deref, rc::Rc, sync::Arc};
2
3use parking_lot::{Mutex, MutexGuard};
4
5use crate::{
6 id::Id,
7 runtime::Runtime,
8 signal::{SignalValue, TrackedRef, TrackedRefCell},
9};
10
11pub struct SyncReadRef<'a, T> {
12 _handle: Arc<Mutex<T>>,
13 pub(crate) guard: MutexGuard<'a, T>,
14}
15
16pub struct LocalReadRef<'a, T> {
17 _handle: Rc<TrackedRefCell<T>>,
18 pub(crate) guard: TrackedRef<'a, T>,
19}
20
21pub enum ReadRef<'a, T> {
22 Sync(SyncReadRef<'a, T>),
23 Local(LocalReadRef<'a, T>),
24}
25
26impl<'a, T> Deref for SyncReadRef<'a, T> {
27 type Target = T;
28 fn deref(&self) -> &Self::Target {
29 &self.guard
30 }
31}
32
33impl<'a, T> Deref for LocalReadRef<'a, T> {
34 type Target = T;
35 fn deref(&self) -> &Self::Target {
36 &self.guard
37 }
38}
39
40impl<'a, T> Deref for ReadRef<'a, T> {
41 type Target = T;
42 fn deref(&self) -> &Self::Target {
43 match self {
44 ReadRef::Sync(v) => &v.guard,
45 ReadRef::Local(v) => &v.guard,
46 }
47 }
48}
49
50impl<'a, T> SyncReadRef<'a, T> {
51 pub(crate) fn new(handle: Arc<Mutex<T>>) -> Self {
52 let guard = handle.lock();
53 let guard = unsafe { std::mem::transmute::<MutexGuard<'_, T>, MutexGuard<'a, T>>(guard) };
55 Self {
56 _handle: handle,
57 guard,
58 }
59 }
60}
61
62impl<'a, T> LocalReadRef<'a, T> {
63 pub(crate) fn new(handle: Rc<TrackedRefCell<T>>) -> Self {
64 let guard = handle.borrow();
65 let guard = unsafe { std::mem::transmute::<TrackedRef<'_, T>, TrackedRef<'a, T>>(guard) };
67 Self {
68 _handle: handle,
69 guard,
70 }
71 }
72}
73
74pub trait SignalGet<T: Clone> {
75 fn id(&self) -> Id;
77
78 fn get_untracked(&self) -> T
79 where
80 T: 'static,
81 {
82 self.try_get_untracked().unwrap()
83 }
84
85 fn get(&self) -> T
86 where
87 T: 'static,
88 {
89 self.try_get().unwrap()
90 }
91
92 #[cfg_attr(debug_assertions, track_caller)]
93 fn try_get(&self) -> Option<T>
94 where
95 T: 'static,
96 {
97 self.id().signal().map(|signal| {
98 if matches!(signal.value, SignalValue::Local(_)) {
99 Runtime::assert_ui_thread();
100 }
101 signal.get()
102 })
103 }
104
105 #[cfg_attr(debug_assertions, track_caller)]
106 fn try_get_untracked(&self) -> Option<T>
107 where
108 T: 'static,
109 {
110 self.id().signal().map(|signal| {
111 if matches!(signal.value, SignalValue::Local(_)) {
112 Runtime::assert_ui_thread();
113 }
114 signal.get_untracked()
115 })
116 }
117}
118
119pub trait SignalTrack<T> {
120 fn id(&self) -> Id;
121 #[cfg_attr(debug_assertions, track_caller)]
124 fn track(&self) {
125 let signal = self.id().signal().unwrap();
126 if matches!(signal.value, SignalValue::Local(_)) {
127 Runtime::assert_ui_thread();
128 }
129 signal.subscribe();
130 }
131
132 #[cfg_attr(debug_assertions, track_caller)]
135 fn try_track(&self) {
136 if let Some(signal) = self.id().signal() {
137 if matches!(signal.value, SignalValue::Local(_)) {
138 Runtime::assert_ui_thread();
139 }
140 signal.subscribe();
141 }
142 }
143}
144
145pub trait SignalWith<T> {
146 fn id(&self) -> Id;
148
149 #[cfg_attr(debug_assertions, track_caller)]
150 fn with<O>(&self, f: impl FnOnce(&T) -> O) -> O
151 where
152 T: 'static,
153 {
154 let signal = self.id().signal().unwrap();
155 if matches!(signal.value, SignalValue::Local(_)) {
156 Runtime::assert_ui_thread();
157 }
158 signal.with(f)
159 }
160
161 #[cfg_attr(debug_assertions, track_caller)]
162 fn with_untracked<O>(&self, f: impl FnOnce(&T) -> O) -> O
163 where
164 T: 'static,
165 {
166 let signal = self.id().signal().unwrap();
167 if matches!(signal.value, SignalValue::Local(_)) {
168 Runtime::assert_ui_thread();
169 }
170 signal.with_untracked(f)
171 }
172
173 #[cfg_attr(debug_assertions, track_caller)]
174 fn try_with<O>(&self, f: impl FnOnce(Option<&T>) -> O) -> O
175 where
176 T: 'static,
177 {
178 if let Some(signal) = self.id().signal() {
179 if matches!(signal.value, SignalValue::Local(_)) {
180 Runtime::assert_ui_thread();
181 }
182 signal.with(|v| f(Some(v)))
183 } else {
184 f(None)
185 }
186 }
187
188 fn try_with_untracked<O>(&self, f: impl FnOnce(Option<&T>) -> O) -> O
189 where
190 T: 'static,
191 {
192 if let Some(signal) = self.id().signal() {
193 if matches!(signal.value, SignalValue::Local(_)) {
194 Runtime::assert_ui_thread();
195 }
196 signal.with_untracked(|v| f(Some(v)))
197 } else {
198 f(None)
199 }
200 }
201}
202
203pub trait SignalRead<T> {
204 fn id(&self) -> Id;
206
207 fn read(&self) -> ReadRef<'_, T>
209 where
210 T: 'static,
211 {
212 self.try_read().unwrap()
213 }
214
215 fn read_untracked(&self) -> ReadRef<'_, T>
217 where
218 T: 'static,
219 {
220 self.try_read_untracked().unwrap()
221 }
222
223 fn try_read(&self) -> Option<ReadRef<'_, T>>
226 where
227 T: 'static;
228
229 fn try_read_untracked(&self) -> Option<ReadRef<'_, T>>
232 where
233 T: 'static;
234}