1use std::{
2 collections::HashSet,
3 hash::{DefaultHasher, Hash, Hasher},
4 marker::PhantomData,
5 ops::{Range, RangeInclusive},
6 rc::Rc,
7};
8
9use floem_reactive::{
10 Effect, ReadSignal, RwSignal, Scope, SignalGet, SignalTrack, SignalUpdate, SignalWith,
11 WriteSignal,
12};
13use peniko::kurbo::{Rect, Size};
14use smallvec::SmallVec;
15use taffy::{FlexDirection, style::Dimension, tree::NodeId};
16
17use crate::{
18 context::ComputeLayoutCx,
19 id::ViewId,
20 prop_extractor,
21 style::{FlexDirectionProp, Style},
22 view::{self, IntoView, View},
23};
24
25use super::{Diff, DiffOpAdd, FxIndexSet, HashRun, apply_diff, diff};
26
27pub type VirtViewFn<T> = Box<dyn Fn(T) -> (Box<dyn View>, Scope)>;
28
29prop_extractor! {
30 pub VirtualExtractor {
31 pub direction: FlexDirectionProp,
32 }
33}
34
35enum VirtualItemSize<T> {
36 Fn(Rc<dyn Fn(&T) -> f64>),
37 Fixed(Rc<dyn Fn() -> f64>),
38 Assume(Option<f64>),
40}
41impl<T> Clone for VirtualItemSize<T> {
42 fn clone(&self) -> Self {
43 match self {
44 VirtualItemSize::Fn(rc) => VirtualItemSize::Fn(rc.clone()),
45 VirtualItemSize::Fixed(rc) => VirtualItemSize::Fixed(rc.clone()),
46 VirtualItemSize::Assume(x) => VirtualItemSize::Assume(*x),
47 }
48 }
49}
50
51pub trait VirtualVector<T> {
53 fn total_len(&self) -> usize;
54
55 fn is_empty(&self) -> bool {
56 self.total_len() == 0
57 }
58
59 fn slice(&mut self, range: Range<usize>) -> impl Iterator<Item = T>;
60
61 fn enumerate(self) -> Enumerate<Self, T>
62 where
63 Self: Sized,
64 {
65 Enumerate {
66 inner: self,
67 phantom: PhantomData,
68 }
69 }
70}
71
72pub struct VirtualStack<T>
74where
75 T: 'static,
76{
77 id: ViewId,
78 first_content_id: Option<ViewId>,
79 style: VirtualExtractor,
80 pub(crate) direction: RwSignal<FlexDirection>,
81 item_size: RwSignal<VirtualItemSize<T>>,
82 children: Vec<Option<(ViewId, Scope)>>,
83 first_child_idx: usize,
85 selected_idx: HashSet<usize>,
86 viewport: Rect,
87 set_viewport: WriteSignal<Rect>,
88 view_fn: VirtViewFn<T>,
89 before_size: f64,
90 content_size: f64,
91 before_node: Option<NodeId>,
92}
93impl<T: std::clone::Clone> VirtualStack<T> {
94 pub fn new<DF, I>(data_fn: DF) -> VirtualStack<T>
96 where
97 DF: Fn() -> I + 'static,
98 I: VirtualVector<T>,
99 T: Hash + Eq + IntoView + 'static,
100 {
101 Self::full(
102 data_fn,
103 |item| {
104 let mut hasher = DefaultHasher::new();
105 item.hash(&mut hasher);
106 hasher.finish()
107 },
108 |item| item.into_view(),
109 )
110 }
111
112 pub fn with_view<DF, I, V>(data_fn: DF, view_fn: impl Fn(T) -> V + 'static) -> VirtualStack<T>
114 where
115 DF: Fn() -> I + 'static,
116 I: VirtualVector<T>,
117 T: Hash + Eq + 'static,
118 V: IntoView,
119 {
120 Self::full(
121 data_fn,
122 |item| {
123 let mut hasher = DefaultHasher::new();
124 item.hash(&mut hasher);
125 hasher.finish()
126 },
127 move |item| view_fn(item).into_view(),
128 )
129 }
130
131 pub fn with_key<DF, I, K>(data_fn: DF, key_fn: impl Fn(&T) -> K + 'static) -> VirtualStack<T>
133 where
134 DF: Fn() -> I + 'static,
135 I: VirtualVector<T>,
136 T: IntoView + 'static,
137 K: Hash + Eq + 'static,
138 {
139 Self::full(data_fn, key_fn, |item| item.into_view())
140 }
141
142 pub fn full<DF, I, KF, K, VF, V>(data_fn: DF, key_fn: KF, view_fn: VF) -> VirtualStack<T>
143 where
144 DF: Fn() -> I + 'static,
145 I: VirtualVector<T>,
146 KF: Fn(&T) -> K + 'static,
147 K: Eq + Hash + 'static,
148 VF: Fn(T) -> V + 'static,
149 V: IntoView + 'static,
150 T: 'static,
151 {
152 virtual_stack(data_fn, key_fn, view_fn)
153 }
154}
155
156impl<T> VirtualStack<T> {
157 pub fn item_size_fixed(self, size: impl Fn() -> f64 + 'static) -> Self {
158 self.item_size.set(VirtualItemSize::Fixed(Rc::new(size)));
159 self
160 }
161
162 pub fn item_size_fn(self, size: impl Fn(&T) -> f64 + 'static) -> Self {
163 self.item_size.set(VirtualItemSize::Fn(Rc::new(size)));
164 self
165 }
166}
167
168pub(crate) struct VirtualStackState<T> {
169 diff: Diff<T>,
170 first_idx: usize,
171 before_size: f64,
172 content_size: f64,
173}
174
175pub fn virtual_stack<T, IF, I, KF, K, VF, V>(
203 each_fn: IF,
204 key_fn: KF,
205 view_fn: VF,
206) -> VirtualStack<T>
207where
208 T: 'static,
209 IF: Fn() -> I + 'static,
210 I: VirtualVector<T>,
211 KF: Fn(&T) -> K + 'static,
212 K: Eq + Hash + 'static,
213 VF: Fn(T) -> V + 'static,
214 V: IntoView + 'static,
215{
216 let id = ViewId::new();
217
218 let (viewport, set_viewport) = RwSignal::new_split(Rect::ZERO);
219
220 let item_size = RwSignal::new(VirtualItemSize::Assume(None));
221
222 let direction = RwSignal::new(FlexDirection::Row);
223 Effect::new(move |_| {
224 direction.track();
225 id.request_style();
226 });
227
228 Effect::new(move |prev| {
229 let mut items_vector = each_fn();
230 let viewport = viewport.get();
231 let min = match direction.get() {
232 FlexDirection::Column | FlexDirection::ColumnReverse => viewport.y0,
233 FlexDirection::Row | FlexDirection::RowReverse => viewport.x0,
234 };
235 let max = match direction.get() {
236 FlexDirection::Column | FlexDirection::ColumnReverse => viewport.height() + viewport.y0,
237 FlexDirection::Row | FlexDirection::RowReverse => viewport.width() + viewport.x0,
238 };
239 let mut items = Vec::new();
240
241 let mut before_size = 0.0;
242 let mut content_size = 0.0;
243 let mut start = 0;
244 item_size.with(|s| match s {
245 VirtualItemSize::Fixed(item_size) => {
246 let item_size = item_size();
247 let total_len = items_vector.total_len();
248 start = if item_size > 0.0 {
249 (min / item_size).floor() as usize
250 } else {
251 0
252 };
253 let end = if item_size > 0.0 {
254 ((max / item_size).ceil() as usize).min(total_len)
255 } else {
256 (start + 1).min(total_len)
258 };
259 before_size = item_size * (start.min(total_len)) as f64;
260
261 for item in items_vector.slice(start..end) {
262 items.push(item);
263 }
264
265 content_size = item_size * total_len as f64;
266 }
267 VirtualItemSize::Fn(size_fn) => {
268 let mut main_axis = 0.0;
269 let total_len = items_vector.total_len();
270 for (idx, item) in items_vector.slice(0..total_len).enumerate() {
271 let item_size = size_fn(&item);
272 content_size += item_size;
273 if main_axis + item_size < min {
274 main_axis += item_size;
275 before_size += item_size;
276 start = idx;
277 continue;
278 }
279
280 if main_axis <= max {
281 main_axis += item_size;
282 items.push(item);
283 }
284 }
285 }
286 VirtualItemSize::Assume(None) => {
287 let total_len = items_vector.total_len();
289 if total_len > 0 {
290 items.push(items_vector.slice(0..1).next().unwrap());
292
293 before_size = 0.0;
295 content_size = total_len as f64 * 10.0; }
297 }
298 VirtualItemSize::Assume(Some(item_size)) => {
299 let total_len = items_vector.total_len();
301 start = if *item_size > 0.0 {
302 (min / item_size).floor() as usize
303 } else {
304 0
305 };
306 let end = if *item_size > 0.0 {
307 ((max / item_size).ceil() as usize).min(total_len)
308 } else {
309 (start + 1).min(total_len)
311 };
312 before_size = item_size * (start.min(total_len)) as f64;
313
314 for item in items_vector.slice(start..end) {
315 items.push(item);
316 }
317 content_size = item_size * total_len as f64;
318 }
319 });
320
321 let hashed_items = items.iter().map(&key_fn).collect::<FxIndexSet<_>>();
322 let (prev_before_size, prev_content_size, diff) =
323 if let Some((prev_before_size, prev_content_size, HashRun(prev_hash_run))) = prev {
324 let mut diff = diff(&prev_hash_run, &hashed_items);
325 let mut items = items
326 .into_iter()
327 .map(|i| Some(i))
328 .collect::<SmallVec<[Option<_>; 128]>>();
329 for added in &mut diff.added {
330 added.view = Some(items[added.at].take().unwrap());
331 }
332 (prev_before_size, prev_content_size, diff)
333 } else {
334 let mut diff = Diff::default();
335 for (i, item) in items.into_iter().enumerate() {
336 diff.added.push(DiffOpAdd {
337 at: i,
338 view: Some(item),
339 });
340 }
341 (0.0, 0.0, diff)
342 };
343
344 if !diff.is_empty() || prev_before_size != before_size || prev_content_size != content_size
345 {
346 id.update_state(VirtualStackState {
347 diff,
348 first_idx: start,
349 before_size,
350 content_size,
351 });
352 }
353 (before_size, content_size, HashRun(hashed_items))
354 });
355
356 let view_fn = Box::new(Scope::current().enter_child(move |e| view_fn(e).into_any()));
357
358 VirtualStack {
359 id,
360 first_content_id: None,
361 style: Default::default(),
362 direction,
363 item_size,
364 children: Vec::new(),
365 selected_idx: HashSet::with_capacity(1),
366 first_child_idx: 0,
367 viewport: Rect::ZERO,
368 set_viewport,
369 view_fn,
370 before_size: 0.0,
371 content_size: 0.0,
372 before_node: None,
373 }
374}
375
376impl<T> View for VirtualStack<T> {
377 fn id(&self) -> ViewId {
378 self.id
379 }
380
381 fn debug_name(&self) -> std::borrow::Cow<'static, str> {
382 "VirtualStack".into()
383 }
384
385 fn update(&mut self, cx: &mut crate::context::UpdateCx, state: Box<dyn std::any::Any>) {
386 if state.is::<VirtualStackState<T>>() {
387 if let Ok(state) = state.downcast::<VirtualStackState<T>>() {
388 if self.before_size == state.before_size
389 && self.content_size == state.content_size
390 && state.diff.is_empty()
391 {
392 return;
393 }
394 self.before_size = state.before_size;
395 self.content_size = state.content_size;
396 self.first_child_idx = state.first_idx;
397 apply_diff(
398 self.id(),
399 cx.window_state,
400 state.diff,
401 &mut self.children,
402 &self.view_fn,
403 );
404 self.id.request_all();
405 }
406 } else if state.is::<usize>() {
407 if let Ok(idx) = state.downcast::<usize>() {
408 self.id.request_style_recursive();
409 self.scroll_to_idx(*idx);
410 self.selected_idx.clear();
411 self.selected_idx.insert(*idx);
412 }
413 }
414 }
415
416 fn style_pass(&mut self, cx: &mut crate::context::StyleCx<'_>) {
417 if self.style.read(cx) {
418 cx.window_state.request_paint(self.id);
419 self.direction.set(self.style.direction());
420 }
421 for (child_id_index, child) in self.id.children().into_iter().enumerate() {
422 if self
423 .selected_idx
424 .contains(&(child_id_index + self.first_child_idx))
425 {
426 cx.save();
427 cx.selected();
428 cx.style_view(child);
429 cx.restore();
430 } else {
431 cx.style_view(child);
432 }
433 }
434 }
435
436 fn view_style(&self) -> Option<crate::style::Style> {
437 let style = match self.direction.get_untracked() {
438 FlexDirection::Column | FlexDirection::ColumnReverse => {
440 Style::new().min_height(self.content_size)
441 }
442 FlexDirection::Row | FlexDirection::RowReverse => {
443 Style::new().min_width(self.content_size)
444 }
445 };
446 Some(style)
447 }
448
449 fn layout(&mut self, cx: &mut crate::context::LayoutCx) -> taffy::tree::NodeId {
450 cx.layout_node(self.id(), true, |cx| {
451 let mut content_nodes = self
452 .id
453 .children()
454 .into_iter()
455 .map(|id| id.view().borrow_mut().layout(cx))
456 .collect::<Vec<_>>();
457
458 if self.before_node.is_none() {
459 self.before_node = Some(
460 self.id
461 .taffy()
462 .borrow_mut()
463 .new_leaf(taffy::style::Style::DEFAULT)
464 .unwrap(),
465 );
466 }
467 let before_node = self.before_node.unwrap();
468 let _ = self.id.taffy().borrow_mut().set_style(
469 before_node,
470 taffy::style::Style {
471 size: match self.direction.get_untracked() {
472 FlexDirection::Column | FlexDirection::ColumnReverse => {
473 taffy::prelude::Size {
474 width: Dimension::auto(),
475 height: Dimension::length(self.before_size as f32),
476 }
477 }
478 FlexDirection::Row | FlexDirection::RowReverse => taffy::prelude::Size {
479 width: Dimension::length(self.before_size as f32),
480 height: Dimension::auto(),
481 },
482 },
483 ..Default::default()
484 },
485 );
486 self.first_content_id = self.id.children().first().copied();
487 let mut nodes = vec![before_node];
488 nodes.append(&mut content_nodes);
489 nodes
490 })
491 }
492
493 fn compute_layout(&mut self, cx: &mut ComputeLayoutCx<'_>) -> Option<Rect> {
494 let viewport = cx.current_viewport();
495 if self.viewport != viewport {
496 self.viewport = viewport;
497 self.set_viewport.set(viewport);
498 }
499
500 let layout = view::default_compute_layout(self.id, cx);
501
502 let new_size = self.item_size.with(|s| match s {
503 VirtualItemSize::Assume(None) => {
504 if let Some(first_content) = self.first_content_id {
505 let taffy_layout = first_content.get_layout()?;
506 let size = taffy_layout.size;
507 if size.width == 0. || size.height == 0. {
508 return None;
509 }
510 let rect = Size::new(size.width as f64, size.height as f64).to_rect();
511 let relevant_size = match self.direction.get_untracked() {
512 FlexDirection::Column | FlexDirection::ColumnReverse => rect.height(),
513 FlexDirection::Row | FlexDirection::RowReverse => rect.width(),
514 };
515 Some(relevant_size)
516 } else {
517 None
518 }
519 }
520 _ => None,
521 });
522 if let Some(new_size) = new_size {
523 self.item_size.set(VirtualItemSize::Assume(Some(new_size)));
524 }
525
526 layout
527 }
528}
529
530impl<T> VirtualStack<T> {
531 pub fn scroll_to_idx(&self, index: usize) {
533 let (offset, size) = self.calculate_offset(index);
534
535 let rect = match self.direction.get_untracked() {
537 FlexDirection::Column | FlexDirection::ColumnReverse => {
538 Rect::from_origin_size((0.0, offset), (0.0, size))
539 }
540 FlexDirection::Row | FlexDirection::RowReverse => {
541 Rect::from_origin_size((offset, 0.0), (size, 0.0))
542 }
543 };
544
545 self.id.scroll_to(Some(rect));
546 }
547
548 fn calculate_offset(&self, index: usize) -> (f64, f64) {
550 self.item_size.with(|size| match size {
551 VirtualItemSize::Fixed(item_size) => {
553 let size = item_size();
554 (size * index as f64, size)
555 }
556
557 VirtualItemSize::Fn(_size_fn) => {
559 (0., 0.)
563 }
564
565 VirtualItemSize::Assume(Some(size)) => (size * index as f64, *size),
567
568 VirtualItemSize::Assume(None) => (0.0, 0.),
570 })
571 }
572}
573
574impl<T: Clone> VirtualVector<T> for imbl::Vector<T> {
575 fn total_len(&self) -> usize {
576 self.len()
577 }
578
579 fn slice(&mut self, range: Range<usize>) -> impl Iterator<Item = T> {
580 self.slice(range).into_iter()
581 }
582}
583
584impl<T> VirtualVector<T> for Range<T>
585where
586 T: Copy + std::ops::Sub<Output = T> + std::ops::Add<Output = T> + PartialOrd + From<usize>,
587 usize: From<T>,
588 std::ops::Range<T>: Iterator<Item = T>,
589{
590 fn total_len(&self) -> usize {
591 (self.end - self.start).into()
593 }
594
595 fn slice(&mut self, range: Range<usize>) -> impl Iterator<Item = T> {
596 let start = self.start + T::from(range.start);
597 let end = self.start + T::from(range.end);
598
599 start..end
601 }
602}
603impl<T> VirtualVector<T> for RangeInclusive<T>
604where
605 T: Copy + std::ops::Sub<Output = T> + std::ops::Add<Output = T> + PartialOrd + From<usize>,
606 usize: From<T>,
607 std::ops::Range<T>: Iterator<Item = T>,
608{
609 fn total_len(&self) -> usize {
610 let diff = *self.end() - *self.start();
612 Into::<usize>::into(diff) + 1
613 }
614
615 fn slice(&mut self, range: Range<usize>) -> impl Iterator<Item = T> {
616 let start = *self.start() + T::from(range.start);
617 let end = *self.start() + T::from(range.end);
618 start..end
620 }
621}
622
623impl<T> VirtualVector<T> for RwSignal<Vec<T>>
624where
625 T: Clone + 'static,
626{
627 fn total_len(&self) -> usize {
628 self.with(|v| v.len())
629 }
630
631 #[allow(clippy::unnecessary_to_owned)]
633 fn slice(&mut self, range: Range<usize>) -> impl Iterator<Item = T> {
634 self.with(|v| v[range].to_vec().into_iter())
635 }
636}
637
638impl<T> VirtualVector<T> for ReadSignal<Vec<T>>
639where
640 T: Clone + 'static,
641{
642 fn total_len(&self) -> usize {
643 self.with(|v| v.len())
644 }
645
646 #[allow(clippy::unnecessary_to_owned)]
648 fn slice(&mut self, range: Range<usize>) -> impl Iterator<Item = T> {
649 self.with(|v| v[range].to_vec().into_iter())
650 }
651}
652
653pub struct Enumerate<V: VirtualVector<T>, T> {
654 inner: V,
655 phantom: PhantomData<T>,
656}
657
658impl<V: VirtualVector<T>, T> VirtualVector<(usize, T)> for Enumerate<V, T> {
659 fn total_len(&self) -> usize {
660 self.inner.total_len()
661 }
662
663 fn slice(&mut self, range: Range<usize>) -> impl Iterator<Item = (usize, T)> {
664 let start = range.start;
665 self.inner
666 .slice(range)
667 .enumerate()
668 .map(move |(i, e)| (i + start, e))
669 }
670}