floem/views/
virtual_stack.rs

1use std::{
2    collections::HashSet,
3    hash::{DefaultHasher, Hash, Hasher},
4    marker::PhantomData,
5    ops::{Range, RangeInclusive},
6    rc::Rc,
7};
8
9use floem_reactive::{
10    Effect, ReadSignal, RwSignal, Scope, SignalGet, SignalTrack, SignalUpdate, SignalWith,
11    WriteSignal,
12};
13use peniko::kurbo::{Rect, Size};
14use smallvec::SmallVec;
15use taffy::{FlexDirection, style::Dimension, tree::NodeId};
16
17use crate::{
18    context::ComputeLayoutCx,
19    id::ViewId,
20    prop_extractor,
21    style::{FlexDirectionProp, Style},
22    view::{self, IntoView, View},
23};
24
25use super::{Diff, DiffOpAdd, FxIndexSet, HashRun, apply_diff, diff};
26
27pub type VirtViewFn<T> = Box<dyn Fn(T) -> (Box<dyn View>, Scope)>;
28
29prop_extractor! {
30    pub VirtualExtractor {
31        pub direction: FlexDirectionProp,
32    }
33}
34
35enum VirtualItemSize<T> {
36    Fn(Rc<dyn Fn(&T) -> f64>),
37    Fixed(Rc<dyn Fn() -> f64>),
38    /// This will try to calculate the size of the items using the computed layout.
39    Assume(Option<f64>),
40}
41impl<T> Clone for VirtualItemSize<T> {
42    fn clone(&self) -> Self {
43        match self {
44            VirtualItemSize::Fn(rc) => VirtualItemSize::Fn(rc.clone()),
45            VirtualItemSize::Fixed(rc) => VirtualItemSize::Fixed(rc.clone()),
46            VirtualItemSize::Assume(x) => VirtualItemSize::Assume(*x),
47        }
48    }
49}
50
51/// A trait that can be implemented on a type so that the type can be used in a [`virtual_stack`] or [`virtual_list`](super::virtual_list()).
52pub trait VirtualVector<T> {
53    fn total_len(&self) -> usize;
54
55    fn is_empty(&self) -> bool {
56        self.total_len() == 0
57    }
58
59    fn slice(&mut self, range: Range<usize>) -> impl Iterator<Item = T>;
60
61    fn enumerate(self) -> Enumerate<Self, T>
62    where
63        Self: Sized,
64    {
65        Enumerate {
66            inner: self,
67            phantom: PhantomData,
68        }
69    }
70}
71
72/// A virtual stack that is like a [`dyn_stack`](super::dyn_stack()) but also lazily loads items for performance. See [`virtual_stack`].
73pub struct VirtualStack<T>
74where
75    T: 'static,
76{
77    id: ViewId,
78    first_content_id: Option<ViewId>,
79    style: VirtualExtractor,
80    pub(crate) direction: RwSignal<FlexDirection>,
81    item_size: RwSignal<VirtualItemSize<T>>,
82    children: Vec<Option<(ViewId, Scope)>>,
83    /// the index out of all of the items that is the first in the virtualized set. This is used to map an index to a [`ViewId`].
84    first_child_idx: usize,
85    selected_idx: HashSet<usize>,
86    viewport: Rect,
87    set_viewport: WriteSignal<Rect>,
88    view_fn: VirtViewFn<T>,
89    before_size: f64,
90    content_size: f64,
91    before_node: Option<NodeId>,
92}
93impl<T: std::clone::Clone> VirtualStack<T> {
94    // For types that implement all constraints
95    pub fn new<DF, I>(data_fn: DF) -> VirtualStack<T>
96    where
97        DF: Fn() -> I + 'static,
98        I: VirtualVector<T>,
99        T: Hash + Eq + IntoView + 'static,
100    {
101        Self::full(
102            data_fn,
103            |item| {
104                let mut hasher = DefaultHasher::new();
105                item.hash(&mut hasher);
106                hasher.finish()
107            },
108            |item| item.into_view(),
109        )
110    }
111
112    // For types that are hashable but need custom view
113    pub fn with_view<DF, I, V>(data_fn: DF, view_fn: impl Fn(T) -> V + 'static) -> VirtualStack<T>
114    where
115        DF: Fn() -> I + 'static,
116        I: VirtualVector<T>,
117        T: Hash + Eq + 'static,
118        V: IntoView,
119    {
120        Self::full(
121            data_fn,
122            |item| {
123                let mut hasher = DefaultHasher::new();
124                item.hash(&mut hasher);
125                hasher.finish()
126            },
127            move |item| view_fn(item).into_view(),
128        )
129    }
130
131    // For types that implement IntoView but need custom keys
132    pub fn with_key<DF, I, K>(data_fn: DF, key_fn: impl Fn(&T) -> K + 'static) -> VirtualStack<T>
133    where
134        DF: Fn() -> I + 'static,
135        I: VirtualVector<T>,
136        T: IntoView + 'static,
137        K: Hash + Eq + 'static,
138    {
139        Self::full(data_fn, key_fn, |item| item.into_view())
140    }
141
142    pub fn full<DF, I, KF, K, VF, V>(data_fn: DF, key_fn: KF, view_fn: VF) -> VirtualStack<T>
143    where
144        DF: Fn() -> I + 'static,
145        I: VirtualVector<T>,
146        KF: Fn(&T) -> K + 'static,
147        K: Eq + Hash + 'static,
148        VF: Fn(T) -> V + 'static,
149        V: IntoView + 'static,
150        T: 'static,
151    {
152        virtual_stack(data_fn, key_fn, view_fn)
153    }
154}
155
156impl<T> VirtualStack<T> {
157    pub fn item_size_fixed(self, size: impl Fn() -> f64 + 'static) -> Self {
158        self.item_size.set(VirtualItemSize::Fixed(Rc::new(size)));
159        self
160    }
161
162    pub fn item_size_fn(self, size: impl Fn(&T) -> f64 + 'static) -> Self {
163        self.item_size.set(VirtualItemSize::Fn(Rc::new(size)));
164        self
165    }
166}
167
168pub(crate) struct VirtualStackState<T> {
169    diff: Diff<T>,
170    first_idx: usize,
171    before_size: f64,
172    content_size: f64,
173}
174
175/// A View that is like a [`dyn_stack`](super::dyn_stack()) but also lazily loads the items as they appear in a [scroll view](super::scroll())
176///
177/// This virtualization/lazy loading is done for performance and allows for lists of millions of items to be used with very high performance.
178///
179/// By default, this view tries to calculate and assume the size of the items in the list by calculating the size of the first item that is loaded.
180/// If all of your items are not of a consistent size in the relevant axis (ie a consistent width when flex_row or a consistent height when in flex_col) you will need to specify the size of the items using [`item_size_fixed`](VirtualStack::item_size_fixed) or [`item_size_fn`](VirtualStack::item_size_fn).
181///
182/// ## Example
183/// ```
184/// use floem::prelude::*;
185///
186/// VirtualStack::new(move || 1..=1000000)
187///     .style(|s| {
188///         s.flex_col().class(LabelClass, |s| {
189///             // give each of the numbers some vertical padding and make them take up the full width of the stack
190///             s.padding_vert(2.5).width_full().justify_center()
191///         })
192///     })
193///     .scroll()
194///     .style(|s| s.size(200., 500.).border(1.0))
195///     .container()
196///     .style(|s| {
197///         s.size_full()
198///             .items_center()
199///             .justify_center()
200/// });
201/// ```
202pub fn virtual_stack<T, IF, I, KF, K, VF, V>(
203    each_fn: IF,
204    key_fn: KF,
205    view_fn: VF,
206) -> VirtualStack<T>
207where
208    T: 'static,
209    IF: Fn() -> I + 'static,
210    I: VirtualVector<T>,
211    KF: Fn(&T) -> K + 'static,
212    K: Eq + Hash + 'static,
213    VF: Fn(T) -> V + 'static,
214    V: IntoView + 'static,
215{
216    let id = ViewId::new();
217
218    let (viewport, set_viewport) = RwSignal::new_split(Rect::ZERO);
219
220    let item_size = RwSignal::new(VirtualItemSize::Assume(None));
221
222    let direction = RwSignal::new(FlexDirection::Row);
223    Effect::new(move |_| {
224        direction.track();
225        id.request_style();
226    });
227
228    Effect::new(move |prev| {
229        let mut items_vector = each_fn();
230        let viewport = viewport.get();
231        let min = match direction.get() {
232            FlexDirection::Column | FlexDirection::ColumnReverse => viewport.y0,
233            FlexDirection::Row | FlexDirection::RowReverse => viewport.x0,
234        };
235        let max = match direction.get() {
236            FlexDirection::Column | FlexDirection::ColumnReverse => viewport.height() + viewport.y0,
237            FlexDirection::Row | FlexDirection::RowReverse => viewport.width() + viewport.x0,
238        };
239        let mut items = Vec::new();
240
241        let mut before_size = 0.0;
242        let mut content_size = 0.0;
243        let mut start = 0;
244        item_size.with(|s| match s {
245            VirtualItemSize::Fixed(item_size) => {
246                let item_size = item_size();
247                let total_len = items_vector.total_len();
248                start = if item_size > 0.0 {
249                    (min / item_size).floor() as usize
250                } else {
251                    0
252                };
253                let end = if item_size > 0.0 {
254                    ((max / item_size).ceil() as usize).min(total_len)
255                } else {
256                    // TODO: Log an error
257                    (start + 1).min(total_len)
258                };
259                before_size = item_size * (start.min(total_len)) as f64;
260
261                for item in items_vector.slice(start..end) {
262                    items.push(item);
263                }
264
265                content_size = item_size * total_len as f64;
266            }
267            VirtualItemSize::Fn(size_fn) => {
268                let mut main_axis = 0.0;
269                let total_len = items_vector.total_len();
270                for (idx, item) in items_vector.slice(0..total_len).enumerate() {
271                    let item_size = size_fn(&item);
272                    content_size += item_size;
273                    if main_axis + item_size < min {
274                        main_axis += item_size;
275                        before_size += item_size;
276                        start = idx;
277                        continue;
278                    }
279
280                    if main_axis <= max {
281                        main_axis += item_size;
282                        items.push(item);
283                    }
284                }
285            }
286            VirtualItemSize::Assume(None) => {
287                // For the initial run with Assume(None), we need to render at least one item
288                let total_len = items_vector.total_len();
289                if total_len > 0 {
290                    // Add just the first item so we can measure it
291                    items.push(items_vector.slice(0..1).next().unwrap());
292
293                    // Set minimal sizes for the first render
294                    before_size = 0.0;
295                    content_size = total_len as f64 * 10.0; // Temporary content size to ensure rendering
296                }
297            }
298            VirtualItemSize::Assume(Some(item_size)) => {
299                // Once we have the assumed size, behave like Fixed size
300                let total_len = items_vector.total_len();
301                start = if *item_size > 0.0 {
302                    (min / item_size).floor() as usize
303                } else {
304                    0
305                };
306                let end = if *item_size > 0.0 {
307                    ((max / item_size).ceil() as usize).min(total_len)
308                } else {
309                    // TODO: Log an error
310                    (start + 1).min(total_len)
311                };
312                before_size = item_size * (start.min(total_len)) as f64;
313
314                for item in items_vector.slice(start..end) {
315                    items.push(item);
316                }
317                content_size = item_size * total_len as f64;
318            }
319        });
320
321        let hashed_items = items.iter().map(&key_fn).collect::<FxIndexSet<_>>();
322        let (prev_before_size, prev_content_size, diff) =
323            if let Some((prev_before_size, prev_content_size, HashRun(prev_hash_run))) = prev {
324                let mut diff = diff(&prev_hash_run, &hashed_items);
325                let mut items = items
326                    .into_iter()
327                    .map(|i| Some(i))
328                    .collect::<SmallVec<[Option<_>; 128]>>();
329                for added in &mut diff.added {
330                    added.view = Some(items[added.at].take().unwrap());
331                }
332                (prev_before_size, prev_content_size, diff)
333            } else {
334                let mut diff = Diff::default();
335                for (i, item) in items.into_iter().enumerate() {
336                    diff.added.push(DiffOpAdd {
337                        at: i,
338                        view: Some(item),
339                    });
340                }
341                (0.0, 0.0, diff)
342            };
343
344        if !diff.is_empty() || prev_before_size != before_size || prev_content_size != content_size
345        {
346            id.update_state(VirtualStackState {
347                diff,
348                first_idx: start,
349                before_size,
350                content_size,
351            });
352        }
353        (before_size, content_size, HashRun(hashed_items))
354    });
355
356    let view_fn = Box::new(Scope::current().enter_child(move |e| view_fn(e).into_any()));
357
358    VirtualStack {
359        id,
360        first_content_id: None,
361        style: Default::default(),
362        direction,
363        item_size,
364        children: Vec::new(),
365        selected_idx: HashSet::with_capacity(1),
366        first_child_idx: 0,
367        viewport: Rect::ZERO,
368        set_viewport,
369        view_fn,
370        before_size: 0.0,
371        content_size: 0.0,
372        before_node: None,
373    }
374}
375
376impl<T> View for VirtualStack<T> {
377    fn id(&self) -> ViewId {
378        self.id
379    }
380
381    fn debug_name(&self) -> std::borrow::Cow<'static, str> {
382        "VirtualStack".into()
383    }
384
385    fn update(&mut self, cx: &mut crate::context::UpdateCx, state: Box<dyn std::any::Any>) {
386        if state.is::<VirtualStackState<T>>() {
387            if let Ok(state) = state.downcast::<VirtualStackState<T>>() {
388                if self.before_size == state.before_size
389                    && self.content_size == state.content_size
390                    && state.diff.is_empty()
391                {
392                    return;
393                }
394                self.before_size = state.before_size;
395                self.content_size = state.content_size;
396                self.first_child_idx = state.first_idx;
397                apply_diff(
398                    self.id(),
399                    cx.window_state,
400                    state.diff,
401                    &mut self.children,
402                    &self.view_fn,
403                );
404                self.id.request_all();
405            }
406        } else if state.is::<usize>() {
407            if let Ok(idx) = state.downcast::<usize>() {
408                self.id.request_style_recursive();
409                self.scroll_to_idx(*idx);
410                self.selected_idx.clear();
411                self.selected_idx.insert(*idx);
412            }
413        }
414    }
415
416    fn style_pass(&mut self, cx: &mut crate::context::StyleCx<'_>) {
417        if self.style.read(cx) {
418            cx.window_state.request_paint(self.id);
419            self.direction.set(self.style.direction());
420        }
421        for (child_id_index, child) in self.id.children().into_iter().enumerate() {
422            if self
423                .selected_idx
424                .contains(&(child_id_index + self.first_child_idx))
425            {
426                cx.save();
427                cx.selected();
428                cx.style_view(child);
429                cx.restore();
430            } else {
431                cx.style_view(child);
432            }
433        }
434    }
435
436    fn view_style(&self) -> Option<crate::style::Style> {
437        let style = match self.direction.get_untracked() {
438            // using min width and height because we strongly assume that these are respected
439            FlexDirection::Column | FlexDirection::ColumnReverse => {
440                Style::new().min_height(self.content_size)
441            }
442            FlexDirection::Row | FlexDirection::RowReverse => {
443                Style::new().min_width(self.content_size)
444            }
445        };
446        Some(style)
447    }
448
449    fn layout(&mut self, cx: &mut crate::context::LayoutCx) -> taffy::tree::NodeId {
450        cx.layout_node(self.id(), true, |cx| {
451            let mut content_nodes = self
452                .id
453                .children()
454                .into_iter()
455                .map(|id| id.view().borrow_mut().layout(cx))
456                .collect::<Vec<_>>();
457
458            if self.before_node.is_none() {
459                self.before_node = Some(
460                    self.id
461                        .taffy()
462                        .borrow_mut()
463                        .new_leaf(taffy::style::Style::DEFAULT)
464                        .unwrap(),
465                );
466            }
467            let before_node = self.before_node.unwrap();
468            let _ = self.id.taffy().borrow_mut().set_style(
469                before_node,
470                taffy::style::Style {
471                    size: match self.direction.get_untracked() {
472                        FlexDirection::Column | FlexDirection::ColumnReverse => {
473                            taffy::prelude::Size {
474                                width: Dimension::auto(),
475                                height: Dimension::length(self.before_size as f32),
476                            }
477                        }
478                        FlexDirection::Row | FlexDirection::RowReverse => taffy::prelude::Size {
479                            width: Dimension::length(self.before_size as f32),
480                            height: Dimension::auto(),
481                        },
482                    },
483                    ..Default::default()
484                },
485            );
486            self.first_content_id = self.id.children().first().copied();
487            let mut nodes = vec![before_node];
488            nodes.append(&mut content_nodes);
489            nodes
490        })
491    }
492
493    fn compute_layout(&mut self, cx: &mut ComputeLayoutCx<'_>) -> Option<Rect> {
494        let viewport = cx.current_viewport();
495        if self.viewport != viewport {
496            self.viewport = viewport;
497            self.set_viewport.set(viewport);
498        }
499
500        let layout = view::default_compute_layout(self.id, cx);
501
502        let new_size = self.item_size.with(|s| match s {
503            VirtualItemSize::Assume(None) => {
504                if let Some(first_content) = self.first_content_id {
505                    let taffy_layout = first_content.get_layout()?;
506                    let size = taffy_layout.size;
507                    if size.width == 0. || size.height == 0. {
508                        return None;
509                    }
510                    let rect = Size::new(size.width as f64, size.height as f64).to_rect();
511                    let relevant_size = match self.direction.get_untracked() {
512                        FlexDirection::Column | FlexDirection::ColumnReverse => rect.height(),
513                        FlexDirection::Row | FlexDirection::RowReverse => rect.width(),
514                    };
515                    Some(relevant_size)
516                } else {
517                    None
518                }
519            }
520            _ => None,
521        });
522        if let Some(new_size) = new_size {
523            self.item_size.set(VirtualItemSize::Assume(Some(new_size)));
524        }
525
526        layout
527    }
528}
529
530impl<T> VirtualStack<T> {
531    /// Scrolls to bring the item at the given index into view
532    pub fn scroll_to_idx(&self, index: usize) {
533        let (offset, size) = self.calculate_offset(index);
534
535        // Create a rectangle at the calculated offset
536        let rect = match self.direction.get_untracked() {
537            FlexDirection::Column | FlexDirection::ColumnReverse => {
538                Rect::from_origin_size((0.0, offset), (0.0, size))
539            }
540            FlexDirection::Row | FlexDirection::RowReverse => {
541                Rect::from_origin_size((offset, 0.0), (size, 0.0))
542            }
543        };
544
545        self.id.scroll_to(Some(rect));
546    }
547
548    /// Calculates the offset position for an item at the given index
549    fn calculate_offset(&self, index: usize) -> (f64, f64) {
550        self.item_size.with(|size| match size {
551            // For fixed size items, we can calculate the offset directly
552            VirtualItemSize::Fixed(item_size) => {
553                let size = item_size();
554                (size * index as f64, size)
555            }
556
557            // For items with a size function, we would need to sum up sizes
558            VirtualItemSize::Fn(_size_fn) => {
559                // TODO? This method just doesn't work for variable item size.
560                // this will make it so that if arrow keys are used on a virtual list
561                // with item size fn, it won't scroll.
562                (0., 0.)
563            }
564
565            // For assumed size items, use the assumed size if available
566            VirtualItemSize::Assume(Some(size)) => (size * index as f64, *size),
567
568            // If we don't have size information yet, default to 0
569            VirtualItemSize::Assume(None) => (0.0, 0.),
570        })
571    }
572}
573
574impl<T: Clone> VirtualVector<T> for imbl::Vector<T> {
575    fn total_len(&self) -> usize {
576        self.len()
577    }
578
579    fn slice(&mut self, range: Range<usize>) -> impl Iterator<Item = T> {
580        self.slice(range).into_iter()
581    }
582}
583
584impl<T> VirtualVector<T> for Range<T>
585where
586    T: Copy + std::ops::Sub<Output = T> + std::ops::Add<Output = T> + PartialOrd + From<usize>,
587    usize: From<T>,
588    std::ops::Range<T>: Iterator<Item = T>,
589{
590    fn total_len(&self) -> usize {
591        // Convert the difference between end and start to usize
592        (self.end - self.start).into()
593    }
594
595    fn slice(&mut self, range: Range<usize>) -> impl Iterator<Item = T> {
596        let start = self.start + T::from(range.start);
597        let end = self.start + T::from(range.end);
598
599        // Create a new range for the slice
600        start..end
601    }
602}
603impl<T> VirtualVector<T> for RangeInclusive<T>
604where
605    T: Copy + std::ops::Sub<Output = T> + std::ops::Add<Output = T> + PartialOrd + From<usize>,
606    usize: From<T>,
607    std::ops::Range<T>: Iterator<Item = T>,
608{
609    fn total_len(&self) -> usize {
610        // For inclusive range, we need to add 1 to include the end value
611        let diff = *self.end() - *self.start();
612        Into::<usize>::into(diff) + 1
613    }
614
615    fn slice(&mut self, range: Range<usize>) -> impl Iterator<Item = T> {
616        let start = *self.start() + T::from(range.start);
617        let end = *self.start() + T::from(range.end);
618        // Create a new range for the slice (non-inclusive since that's what the Range parameter specifies)
619        start..end
620    }
621}
622
623impl<T> VirtualVector<T> for RwSignal<Vec<T>>
624where
625    T: Clone + 'static,
626{
627    fn total_len(&self) -> usize {
628        self.with(|v| v.len())
629    }
630
631    // false positive on the clippy
632    #[allow(clippy::unnecessary_to_owned)]
633    fn slice(&mut self, range: Range<usize>) -> impl Iterator<Item = T> {
634        self.with(|v| v[range].to_vec().into_iter())
635    }
636}
637
638impl<T> VirtualVector<T> for ReadSignal<Vec<T>>
639where
640    T: Clone + 'static,
641{
642    fn total_len(&self) -> usize {
643        self.with(|v| v.len())
644    }
645
646    // false positive on the clippy
647    #[allow(clippy::unnecessary_to_owned)]
648    fn slice(&mut self, range: Range<usize>) -> impl Iterator<Item = T> {
649        self.with(|v| v[range].to_vec().into_iter())
650    }
651}
652
653pub struct Enumerate<V: VirtualVector<T>, T> {
654    inner: V,
655    phantom: PhantomData<T>,
656}
657
658impl<V: VirtualVector<T>, T> VirtualVector<(usize, T)> for Enumerate<V, T> {
659    fn total_len(&self) -> usize {
660        self.inner.total_len()
661    }
662
663    fn slice(&mut self, range: Range<usize>) -> impl Iterator<Item = (usize, T)> {
664        let start = range.start;
665        self.inner
666            .slice(range)
667            .enumerate()
668            .map(move |(i, e)| (i + start, e))
669    }
670}