1use std::{
2 cell::RefCell,
3 collections::HashSet,
4 hash::{DefaultHasher, Hash, Hasher},
5 marker::PhantomData,
6 ops::{Range, RangeInclusive},
7 rc::Rc,
8};
9
10use floem_reactive::{
11 Effect, ReadSignal, RwSignal, Scope, SignalGet, SignalTrack, SignalUpdate, SignalWith,
12};
13use peniko::kurbo::Rect;
14use smallvec::SmallVec;
15use taffy::{Dimension, FlexDirection, tree::NodeId};
16use understory_virtual_list::{
17 ExtentModel, FixedExtentModel, PrefixSumExtentModel, compute_visible_strip,
18};
19
20use crate::{
21 event::listener::{EventListenerTrait, UpdatePhaseBoxTreeCommit},
22 prop_extractor,
23 style::{FlexDirectionProp, recalc::StyleReason},
24 view::{FinalizeFn, IntoView, LayoutNodeCx, View, ViewId},
25};
26
27use super::{Diff, DiffOpAdd, FxIndexSet, HashRun, apply_diff, diff};
28
29pub type VirtViewFn<T> = Box<dyn Fn(T) -> (Box<dyn View>, Scope)>;
30
31prop_extractor! {
32 pub VirtualExtractor {
33 pub direction: FlexDirectionProp,
34 }
35}
36
37enum VirtualItemSize<T> {
38 Fixed(Rc<dyn Fn() -> f64>),
39 Fn(Rc<dyn Fn(&T) -> f64>),
40 Assume(Option<f64>),
42}
43
44impl<T> Clone for VirtualItemSize<T> {
45 fn clone(&self) -> Self {
46 match self {
47 VirtualItemSize::Fixed(rc) => VirtualItemSize::Fixed(rc.clone()),
48 VirtualItemSize::Fn(rc) => VirtualItemSize::Fn(rc.clone()),
49 VirtualItemSize::Assume(x) => VirtualItemSize::Assume(*x),
50 }
51 }
52}
53
54enum CachedExtentModel {
56 Fixed(FixedExtentModel<f64>),
57 PrefixSum(PrefixSumExtentModel<f64>),
58}
59
60pub trait VirtualVector<T> {
62 fn total_len(&self) -> usize;
63
64 fn is_empty(&self) -> bool {
65 self.total_len() == 0
66 }
67
68 fn slice(&self, range: Range<usize>) -> impl Iterator<Item = T>;
69
70 fn enumerate(self) -> Enumerate<Self, T>
71 where
72 Self: Sized,
73 {
74 Enumerate {
75 inner: self,
76 phantom: PhantomData,
77 }
78 }
79}
80
81#[derive(Clone)]
83struct ContentSize {
84 size: f64,
85 direction: FlexDirection,
86}
87
88pub struct VirtualStack<T>
90where
91 T: 'static,
92{
93 id: ViewId,
94 style: VirtualExtractor,
95 pub(crate) direction: RwSignal<FlexDirection>,
96 item_size: RwSignal<VirtualItemSize<T>>,
97 children: Vec<Option<(ViewId, Scope)>>,
98 first_child_idx: usize,
100 selected_idx: HashSet<usize>,
101 view_fn: VirtViewFn<T>,
102 before_size: f64,
103 after_size: f64,
104 content_size: Rc<RefCell<ContentSize>>,
105 space_nodes: Option<(NodeId, NodeId)>,
106 scroll_offset: RwSignal<f64>,
107 viewport_size: RwSignal<f64>,
108}
109
110impl<T: Clone> VirtualStack<T> {
111 pub fn new<DF, I>(data_fn: DF) -> VirtualStack<T>
112 where
113 DF: Fn() -> I + 'static,
114 I: VirtualVector<T>,
115 T: Hash + Eq + IntoView + 'static,
116 {
117 Self::full(
118 data_fn,
119 |item| {
120 let mut hasher = DefaultHasher::new();
121 item.hash(&mut hasher);
122 hasher.finish()
123 },
124 |item| item.into_view(),
125 )
126 }
127
128 pub fn with_view<DF, I, V>(data_fn: DF, view_fn: impl Fn(T) -> V + 'static) -> VirtualStack<T>
129 where
130 DF: Fn() -> I + 'static,
131 I: VirtualVector<T>,
132 T: Hash + Eq + 'static,
133 V: IntoView,
134 {
135 Self::full(
136 data_fn,
137 |item| {
138 let mut hasher = DefaultHasher::new();
139 item.hash(&mut hasher);
140 hasher.finish()
141 },
142 move |item| view_fn(item).into_view(),
143 )
144 }
145
146 pub fn with_key<DF, I, K>(data_fn: DF, key_fn: impl Fn(&T) -> K + 'static) -> VirtualStack<T>
147 where
148 DF: Fn() -> I + 'static,
149 I: VirtualVector<T>,
150 T: IntoView + 'static,
151 K: Hash + Eq + 'static,
152 {
153 Self::full(data_fn, key_fn, |item| item.into_view())
154 }
155
156 pub fn full<DF, I, KF, K, VF, V>(data_fn: DF, key_fn: KF, view_fn: VF) -> VirtualStack<T>
157 where
158 DF: Fn() -> I + 'static,
159 I: VirtualVector<T>,
160 KF: Fn(&T) -> K + 'static,
161 K: Eq + Hash + 'static,
162 VF: Fn(T) -> V + 'static,
163 V: IntoView + 'static,
164 T: 'static,
165 {
166 virtual_stack(data_fn, key_fn, view_fn)
167 }
168}
169
170impl<T> VirtualStack<T> {
171 pub fn item_size_fixed(self, size: impl Fn() -> f64 + 'static) -> Self {
172 self.item_size.set(VirtualItemSize::Fixed(Rc::new(size)));
173 self
174 }
175
176 pub fn item_size_fn(self, size: impl Fn(&T) -> f64 + 'static) -> Self {
177 self.item_size.set(VirtualItemSize::Fn(Rc::new(size)));
178 self
179 }
180
181 pub fn first_layout(self) -> Self {
182 self.item_size.set(VirtualItemSize::Assume(None));
183 self
184 }
185
186 fn ensure_space_nodes(&mut self) -> (NodeId, NodeId) {
187 if self.space_nodes.is_none() {
188 let before = self
189 .id
190 .taffy()
191 .borrow_mut()
192 .new_leaf(taffy::Style::DEFAULT)
193 .unwrap();
194 let after = self
195 .id
196 .taffy()
197 .borrow_mut()
198 .new_leaf(taffy::Style::DEFAULT)
199 .unwrap();
200 self.space_nodes = Some((before, after));
201 }
202 let (before_node, after_node) = self.space_nodes.unwrap();
203 let direction = self.content_size.borrow().direction;
204 let _ = self.id.taffy().borrow_mut().set_style(
205 before_node,
206 taffy::style::Style {
207 size: match direction {
208 FlexDirection::Column | FlexDirection::ColumnReverse => taffy::prelude::Size {
209 width: Dimension::auto(),
210 height: Dimension::length(self.before_size as f32),
211 },
212 FlexDirection::Row | FlexDirection::RowReverse => taffy::prelude::Size {
213 width: Dimension::length(self.before_size as f32),
214 height: Dimension::auto(),
215 },
216 },
217 ..Default::default()
218 },
219 );
220 let _ = self.id.taffy().borrow_mut().set_style(
221 after_node,
222 taffy::style::Style {
223 size: match direction {
224 FlexDirection::Column | FlexDirection::ColumnReverse => taffy::prelude::Size {
225 width: Dimension::auto(),
226 height: Dimension::length(self.after_size as f32),
227 },
228 FlexDirection::Row | FlexDirection::RowReverse => taffy::prelude::Size {
229 width: Dimension::length(self.after_size as f32),
230 height: Dimension::auto(),
231 },
232 },
233 ..Default::default()
234 },
235 );
236 (before_node, after_node)
237 }
238}
239
240pub(crate) struct VirtualStackState<T> {
241 diff: Diff<T>,
242 first_idx: usize,
243 before_size: f64,
244 after_size: f64,
245 content_size: f64,
246}
247
248pub fn virtual_stack<T, IF, I, KF, K, VF, V>(
264 each_fn: IF,
265 key_fn: KF,
266 view_fn: VF,
267) -> VirtualStack<T>
268where
269 T: 'static,
270 IF: Fn() -> I + 'static,
271 I: VirtualVector<T>,
272 KF: Fn(&T) -> K + 'static,
273 K: Eq + Hash + 'static,
274 VF: Fn(T) -> V + 'static,
275 V: IntoView + 'static,
276{
277 let id = ViewId::new();
278 id.register_listener(UpdatePhaseBoxTreeCommit::listener_key());
279
280 let item_size: RwSignal<VirtualItemSize<T>> = RwSignal::new(VirtualItemSize::Assume(None));
281 let direction = RwSignal::new(FlexDirection::Column);
282 let scroll_offset = RwSignal::new(0.0_f64);
283 let viewport_size = RwSignal::new(0.0_f64);
284
285 let content_size = Rc::new(RefCell::new(ContentSize {
286 size: 0.0,
287 direction: FlexDirection::Column,
288 }));
289 let content_size_for_measure = content_size.clone();
290
291 let taffy_node = id.taffy_node();
294 id.taffy()
295 .borrow_mut()
296 .set_node_context(
297 taffy_node,
298 Some(LayoutNodeCx::Custom {
299 measure: Box::new(
300 move |known_dimensions, _available_space, _node_id, _style, _cx| {
301 let data = content_size_for_measure.borrow();
302 let main = data.size as f32;
303 match data.direction {
304 FlexDirection::Column | FlexDirection::ColumnReverse => taffy::Size {
305 width: known_dimensions.width.unwrap_or(0.0),
306 height: known_dimensions.height.unwrap_or(main),
307 },
308 FlexDirection::Row | FlexDirection::RowReverse => taffy::Size {
309 width: known_dimensions.width.unwrap_or(main),
310 height: known_dimensions.height.unwrap_or(0.0),
311 },
312 }
313 },
314 ),
315 finalize: None::<Box<FinalizeFn>>,
316 }),
317 )
318 .unwrap();
319
320 let cached_model: Rc<RefCell<Option<CachedExtentModel>>> = Rc::new(RefCell::new(None));
321
322 Effect::new(move |prev: Option<(f64, f64, HashRun<FxIndexSet<K>>)>| {
323 let items_vector = each_fn();
324 let total_len = items_vector.total_len();
325 let scroll = scroll_offset.get();
326 let viewport = viewport_size.get();
327 direction.track();
328
329 let mut items = Vec::new();
330 let mut before_size = 0.0_f64;
331 let mut after_size = 0.0_f64;
332 let mut content_sz = 0.0_f64;
333 let mut start = 0usize;
334
335 let mut cached = cached_model.borrow_mut();
336
337 item_size.with(|s| match s {
338 VirtualItemSize::Fixed(size_fn) => {
339 let extent = size_fn();
340 let model = match cached.as_mut() {
342 Some(CachedExtentModel::Fixed(m)) => {
343 m.set_len(total_len);
344 m.set_extent(extent);
345 m
346 }
347 _ => {
348 *cached = Some(CachedExtentModel::Fixed(FixedExtentModel::new(
349 total_len, extent,
350 )));
351 match cached.as_mut().unwrap() {
352 CachedExtentModel::Fixed(m) => m,
353 _ => unreachable!(),
354 }
355 }
356 };
357 let strip = compute_visible_strip(model, scroll, viewport, 0.0, 0.0);
358 start = strip.start;
359 before_size = strip.before_extent;
360 after_size = strip.after_extent;
361 content_sz = strip.content_extent;
362 for item in items_vector.slice(strip.start..strip.end) {
363 items.push(item);
364 }
365 }
366 VirtualItemSize::Fn(size_fn) => {
367 let model = match cached.as_mut() {
369 Some(CachedExtentModel::PrefixSum(m)) if m.len() == total_len => m,
370 _ => {
371 let mut m = PrefixSumExtentModel::<f64>::new();
372 let all: Vec<T> = items_vector.slice(0..total_len).collect();
373 m.rebuild(all, &|item: &T| size_fn(item));
374 *cached = Some(CachedExtentModel::PrefixSum(m));
375 match cached.as_mut().unwrap() {
376 CachedExtentModel::PrefixSum(m) => m,
377 _ => unreachable!(),
378 }
379 }
380 };
381 let strip = compute_visible_strip(model, scroll, viewport, 0.0, 0.0);
382 start = strip.start;
383 before_size = strip.before_extent;
384 after_size = strip.after_extent;
385 content_sz = strip.content_extent;
386 for item in items_vector.slice(strip.start..strip.end) {
387 items.push(item);
388 }
389 }
390 VirtualItemSize::Assume(assumed) => {
391 let extent = assumed.unwrap_or(10.0);
392 let model = match cached.as_mut() {
393 Some(CachedExtentModel::Fixed(m)) => {
394 m.set_len(total_len);
395 m.set_extent(extent);
396 m
397 }
398 _ => {
399 *cached = Some(CachedExtentModel::Fixed(FixedExtentModel::new(
400 total_len, extent,
401 )));
402 match cached.as_mut().unwrap() {
403 CachedExtentModel::Fixed(m) => m,
404 _ => unreachable!(),
405 }
406 }
407 };
408 let strip = compute_visible_strip(model, scroll, viewport, 0.0, 0.0);
409 start = strip.start;
410 before_size = strip.before_extent;
411 after_size = strip.after_extent;
412 content_sz = strip.content_extent;
413 if assumed.is_none() {
414 if total_len > 0
415 && let Some(item) = items_vector.slice(0..1).next()
416 {
417 items.push(item);
418 }
419 } else {
420 for item in items_vector.slice(strip.start..strip.end) {
421 items.push(item);
422 }
423 }
424 }
425 });
426
427 let hashed_items = items.iter().map(&key_fn).collect::<FxIndexSet<_>>();
428 let (prev_before, prev_content, diff) =
429 if let Some((prev_before, prev_content, HashRun(prev_hash))) = prev {
430 let mut diff = diff(&prev_hash, &hashed_items);
431 let mut items = items
432 .into_iter()
433 .map(Some)
434 .collect::<SmallVec<[Option<_>; 128]>>();
435 for added in &mut diff.added {
436 added.view = Some(items[added.at].take().unwrap());
437 }
438 (prev_before, prev_content, diff)
439 } else {
440 let mut diff = Diff::default();
441 for (i, item) in items.into_iter().enumerate() {
442 diff.added.push(DiffOpAdd {
443 at: i,
444 view: Some(item),
445 });
446 }
447 (0.0, 0.0, diff)
448 };
449
450 if !diff.is_empty() || prev_before != before_size || prev_content != content_sz {
451 id.update_state(VirtualStackState {
452 diff,
453 first_idx: start,
454 before_size,
455 after_size,
456 content_size: content_sz,
457 });
458 }
459 (before_size, content_sz, HashRun(hashed_items))
460 });
461
462 let view_fn = Box::new(Scope::current().enter_child(move |e| view_fn(e).into_any()));
463
464 VirtualStack {
465 id,
466 style: Default::default(),
467 direction,
468 item_size,
469 children: Vec::new(),
470 selected_idx: HashSet::with_capacity(1),
471 first_child_idx: 0,
472 view_fn,
473 before_size: 0.0,
474 after_size: 0.0,
475 content_size,
476 space_nodes: None,
477 scroll_offset,
478 viewport_size,
479 }
480}
481
482impl<T> View for VirtualStack<T> {
483 fn id(&self) -> ViewId {
484 self.id
485 }
486
487 fn debug_name(&self) -> std::borrow::Cow<'static, str> {
488 "VirtualStack".into()
489 }
490
491 fn update(&mut self, cx: &mut crate::context::UpdateCx, state: Box<dyn std::any::Any>) {
492 match state.downcast::<VirtualStackState<T>>() {
493 Ok(state) => {
494 if self.before_size == state.before_size
495 && self.content_size.borrow().size == state.content_size
496 && state.diff.is_empty()
497 {
498 return;
499 }
500 self.before_size = state.before_size;
501 self.after_size = state.after_size;
502 self.content_size.borrow_mut().size = state.content_size;
503 self.first_child_idx = state.first_idx;
504 apply_diff(
505 self.id(),
506 cx.window_state,
507 state.diff,
508 &mut self.children,
509 &self.view_fn,
510 );
511 let (before, after) = self.ensure_space_nodes();
512 let taffy = self.id.taffy();
513 let mut taffy = taffy.borrow_mut();
514 let this_node = self.id.taffy_node();
515 taffy.insert_child_at_index(this_node, 0, before).unwrap();
516 taffy.add_child(this_node, after).unwrap();
517 self.id.request_style(StyleReason::style_pass());
518 self.id.request_layout();
519 }
520 Err(state) => {
521 if let Ok(idx) = state.downcast::<usize>() {
523 self.id.request_style(StyleReason::style_pass());
524 self.scroll_to_idx(*idx);
525 self.selected_idx.clear();
526 self.selected_idx.insert(*idx);
527 }
528 }
529 }
530 }
531
532 fn style_pass(&mut self, cx: &mut crate::context::StyleCx<'_>) {
533 if self.style.read(cx) {
534 cx.window_state.request_paint(self.id);
535 let dir = self.style.direction();
536 self.direction.set(dir);
537 self.content_size.borrow_mut().direction = dir;
538 }
539 for (child_id_index, child) in self.id.children().into_iter().enumerate() {
540 if self
541 .selected_idx
542 .contains(&(child_id_index + self.first_child_idx))
543 {
544 child.parent_set_selected();
545 } else {
546 child.parent_clear_selected();
547 }
548 }
549 }
550
551 fn event(&mut self, cx: &mut crate::context::EventCx) -> crate::event::EventPropagation {
552 if UpdatePhaseBoxTreeCommit::extract(&cx.event).is_some() {
553 let translation = self.id.get_scroll_cx();
555 let dir = self.direction.get_untracked();
556 let new_scroll = match dir {
557 FlexDirection::Row | FlexDirection::RowReverse => translation.x,
558 FlexDirection::Column | FlexDirection::ColumnReverse => translation.y,
559 }
560 .max(0.0);
561
562 let parent_rect = self
563 .id
564 .parent()
565 .map(|id| id.get_content_rect_local())
566 .unwrap_or_default();
567 let new_viewport = match dir {
568 FlexDirection::Row | FlexDirection::RowReverse => parent_rect.width(),
569 FlexDirection::Column | FlexDirection::ColumnReverse => parent_rect.height(),
570 };
571
572 if new_scroll != self.scroll_offset.get_untracked() {
573 self.scroll_offset.set(new_scroll);
574 }
575 if new_viewport != self.viewport_size.get_untracked() {
576 self.viewport_size.set(new_viewport);
577 }
578
579 let is_unassumed = self
581 .item_size
582 .with_untracked(|s| matches!(s, VirtualItemSize::Assume(None)));
583 if is_unassumed && let Some(Some((first_child, _))) = self.children.first() {
584 let dir = self.direction.get_untracked();
585 let rect = first_child.get_layout_rect_local();
586 let size = match dir {
587 FlexDirection::Row | FlexDirection::RowReverse => rect.width(),
588 FlexDirection::Column | FlexDirection::ColumnReverse => rect.height(),
589 };
590 if size > 0.0 {
591 self.item_size.set(VirtualItemSize::Assume(Some(size)));
592 }
593 }
594 }
595 crate::event::EventPropagation::Continue
596 }
597}
598
599impl<T> VirtualStack<T> {
600 pub fn scroll_to_idx(&self, index: usize) {
601 let (offset, size) = self.calculate_offset(index);
602 let rect = match self.direction.get_untracked() {
603 FlexDirection::Column | FlexDirection::ColumnReverse => {
604 Rect::from_origin_size((0.0, offset), (0.0, size))
605 }
606 FlexDirection::Row | FlexDirection::RowReverse => {
607 Rect::from_origin_size((offset, 0.0), (size, 0.0))
608 }
609 };
610 self.id.scroll_to(Some(rect));
611 }
612
613 fn calculate_offset(&self, index: usize) -> (f64, f64) {
614 self.item_size.with(|size| match size {
615 VirtualItemSize::Fixed(size_fn) => {
616 let s = size_fn();
617 (s * index as f64, s)
618 }
619 VirtualItemSize::Fn(_) => (0.0, 0.0),
620 VirtualItemSize::Assume(Some(s)) => (s * index as f64, *s),
621 VirtualItemSize::Assume(None) => (0.0, 0.0),
622 })
623 }
624}
625
626impl<T: Clone> VirtualVector<T> for imbl::Vector<T> {
629 fn total_len(&self) -> usize {
630 self.len()
631 }
632
633 fn slice(&self, range: Range<usize>) -> impl Iterator<Item = T> {
634 imbl::Vector::slice(&mut self.clone(), range).into_iter()
635 }
636}
637
638impl<T> VirtualVector<T> for Range<T>
639where
640 T: Copy + std::ops::Sub<Output = T> + std::ops::Add<Output = T> + PartialOrd + From<usize>,
641 usize: From<T>,
642 Range<T>: Iterator<Item = T>,
643{
644 fn total_len(&self) -> usize {
645 (self.end - self.start).into()
646 }
647
648 fn slice(&self, range: Range<usize>) -> impl Iterator<Item = T> {
649 let start = self.start + T::from(range.start);
650 let end = self.start + T::from(range.end);
651 start..end
652 }
653}
654
655impl<T> VirtualVector<T> for RangeInclusive<T>
656where
657 T: Copy + std::ops::Sub<Output = T> + std::ops::Add<Output = T> + PartialOrd + From<usize>,
658 usize: From<T>,
659 Range<T>: Iterator<Item = T>,
660{
661 fn total_len(&self) -> usize {
662 let diff = *self.end() - *self.start();
663 Into::<usize>::into(diff) + 1
664 }
665
666 fn slice(&self, range: Range<usize>) -> impl Iterator<Item = T> {
667 let start = *self.start() + T::from(range.start);
668 let end = *self.start() + T::from(range.end);
669 start..end
670 }
671}
672
673impl<T> VirtualVector<T> for RwSignal<Vec<T>>
674where
675 T: Clone + 'static,
676{
677 fn total_len(&self) -> usize {
678 self.with(|v| v.len())
679 }
680
681 #[allow(clippy::unnecessary_to_owned)]
682 fn slice(&self, range: Range<usize>) -> impl Iterator<Item = T> {
683 self.with(|v| v[range].to_vec().into_iter())
684 }
685}
686
687impl<T> VirtualVector<T> for ReadSignal<Vec<T>>
688where
689 T: Clone + 'static,
690{
691 fn total_len(&self) -> usize {
692 self.with(|v| v.len())
693 }
694
695 #[allow(clippy::unnecessary_to_owned)]
696 fn slice(&self, range: Range<usize>) -> impl Iterator<Item = T> {
697 self.with(|v| v[range].to_vec().into_iter())
698 }
699}
700
701pub struct Enumerate<V: VirtualVector<T>, T> {
702 inner: V,
703 phantom: PhantomData<T>,
704}
705
706impl<V: VirtualVector<T>, T> VirtualVector<(usize, T)> for Enumerate<V, T> {
707 fn total_len(&self) -> usize {
708 self.inner.total_len()
709 }
710
711 fn slice(&self, range: Range<usize>) -> impl Iterator<Item = (usize, T)> {
712 let start = range.start;
713 self.inner
714 .slice(range)
715 .enumerate()
716 .map(move |(i, e)| (i + start, e))
717 }
718}