ab_farmer/
utils.rs

1//! Various utilities used by farmer or with farmer
2
3#[cfg(test)]
4mod tests;
5
6use crate::thread_pool_manager::{PlottingThreadPoolManager, PlottingThreadPoolPair};
7use futures::channel::oneshot;
8use futures::channel::oneshot::Canceled;
9use futures::future::Either;
10use rayon::{
11    ThreadBuilder, ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder, current_thread_index,
12};
13use std::future::Future;
14use std::num::NonZeroUsize;
15use std::ops::Deref;
16use std::pin::{Pin, pin};
17use std::process::exit;
18use std::task::{Context, Poll};
19use std::{fmt, io, iter, thread};
20use thread_priority::{ThreadPriority, set_current_thread_priority};
21use tokio::runtime::Handle;
22use tokio::task;
23use tracing::{debug, warn};
24
25/// It doesn't make a lot of sense to have a huge number of farming threads, 32 is plenty
26const MAX_DEFAULT_FARMING_THREADS: usize = 32;
27
28/// Joins async join handle on drop
29#[derive(Debug)]
30pub struct AsyncJoinOnDrop<T> {
31    handle: Option<task::JoinHandle<T>>,
32    abort_on_drop: bool,
33}
34
35impl<T> Drop for AsyncJoinOnDrop<T> {
36    #[inline]
37    fn drop(&mut self) {
38        if let Some(handle) = self.handle.take() {
39            if self.abort_on_drop {
40                handle.abort();
41            }
42
43            if !handle.is_finished() {
44                task::block_in_place(move || {
45                    let _ = Handle::current().block_on(handle);
46                });
47            }
48        }
49    }
50}
51
52impl<T> AsyncJoinOnDrop<T> {
53    /// Create a new instance.
54    #[inline]
55    pub fn new(handle: task::JoinHandle<T>, abort_on_drop: bool) -> Self {
56        Self {
57            handle: Some(handle),
58            abort_on_drop,
59        }
60    }
61}
62
63impl<T> Future for AsyncJoinOnDrop<T> {
64    type Output = Result<T, task::JoinError>;
65
66    #[inline]
67    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
68        Pin::new(
69            self.handle
70                .as_mut()
71                .expect("Only dropped in Drop impl; qed"),
72        )
73        .poll(cx)
74    }
75}
76
77/// Joins synchronous join handle on drop
78pub(crate) struct JoinOnDrop(Option<thread::JoinHandle<()>>);
79
80impl Drop for JoinOnDrop {
81    #[inline]
82    fn drop(&mut self) {
83        self.0
84            .take()
85            .expect("Always called exactly once; qed")
86            .join()
87            .expect("Panic if background thread panicked");
88    }
89}
90
91impl JoinOnDrop {
92    // Create new instance
93    #[inline]
94    pub(crate) fn new(handle: thread::JoinHandle<()>) -> Self {
95        Self(Some(handle))
96    }
97}
98
99impl Deref for JoinOnDrop {
100    type Target = thread::JoinHandle<()>;
101
102    #[inline]
103    fn deref(&self) -> &Self::Target {
104        self.0.as_ref().expect("Only dropped in Drop impl; qed")
105    }
106}
107
108/// Runs future on a dedicated thread with the specified name, will block on drop until background
109/// thread with future is stopped too, ensuring nothing is left in memory
110pub fn run_future_in_dedicated_thread<CreateFut, Fut, T>(
111    create_future: CreateFut,
112    thread_name: String,
113) -> io::Result<impl Future<Output = Result<T, Canceled>> + Send>
114where
115    CreateFut: (FnOnce() -> Fut) + Send + 'static,
116    Fut: Future<Output = T> + 'static,
117    T: Send + 'static,
118{
119    let (drop_tx, drop_rx) = oneshot::channel::<()>();
120    let (result_tx, result_rx) = oneshot::channel();
121    let handle = Handle::current();
122    let join_handle = thread::Builder::new().name(thread_name).spawn(move || {
123        let _tokio_handle_guard = handle.enter();
124
125        let future = pin!(create_future());
126
127        let result = match handle.block_on(futures::future::select(future, drop_rx)) {
128            Either::Left((result, _)) => result,
129            Either::Right(_) => {
130                // Outer future was dropped, nothing left to do
131                return;
132            }
133        };
134        if let Err(_error) = result_tx.send(result) {
135            debug!(
136                thread_name = ?thread::current().name(),
137                "Future finished, but receiver was already dropped",
138            );
139        }
140    })?;
141    // Ensure thread will not be left hanging forever
142    let join_on_drop = JoinOnDrop::new(join_handle);
143
144    Ok(async move {
145        let result = result_rx.await;
146        drop(drop_tx);
147        drop(join_on_drop);
148        result
149    })
150}
151
152/// Abstraction for CPU core set
153#[derive(Clone)]
154pub struct CpuCoreSet {
155    /// CPU cores that belong to this set
156    cores: Vec<usize>,
157    #[cfg(feature = "numa")]
158    topology: Option<std::sync::Arc<hwlocality::Topology>>,
159}
160
161impl fmt::Debug for CpuCoreSet {
162    #[inline]
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        let mut s = f.debug_struct("CpuCoreSet");
165        #[cfg(not(feature = "numa"))]
166        if self.cores.array_windows::<2>().all(|&[a, b]| a + 1 == b) {
167            s.field(
168                "cores",
169                &format!(
170                    "{}-{}",
171                    self.cores.first().expect("List of cores is not empty; qed"),
172                    self.cores.last().expect("List of cores is not empty; qed")
173                ),
174            );
175        } else {
176            s.field(
177                "cores",
178                &self
179                    .cores
180                    .iter()
181                    .map(usize::to_string)
182                    .collect::<Vec<_>>()
183                    .join(","),
184            );
185        }
186        #[cfg(feature = "numa")]
187        {
188            use hwlocality::cpu::cpuset::CpuSet;
189            use hwlocality::ffi::PositiveInt;
190
191            s.field(
192                "cores",
193                &CpuSet::from_iter(
194                    self.cores.iter().map(|&core| {
195                        PositiveInt::try_from(core).expect("Valid CPU core index; qed")
196                    }),
197                ),
198            );
199        }
200        s.finish_non_exhaustive()
201    }
202}
203
204impl CpuCoreSet {
205    /// Get cpu core numbers in this set
206    pub fn cpu_cores(&self) -> &[usize] {
207        &self.cores
208    }
209
210    /// Will truncate list of CPU cores to this number.
211    ///
212    /// Truncation will take into account L2 and L3 cache topology in order to use half of the
213    /// actual physical cores and half of each core type in case of heterogeneous CPUs.
214    ///
215    /// If `cores` is zero, call will do nothing since zero number of cores is not allowed.
216    pub fn truncate(&mut self, num_cores: usize) {
217        let num_cores = num_cores.clamp(1, self.cores.len());
218
219        #[cfg(feature = "numa")]
220        if let Some(topology) = &self.topology {
221            use hwlocality::object::attributes::ObjectAttributes;
222            use hwlocality::object::types::ObjectType;
223
224            let mut grouped_by_l2_cache_size_and_core_count =
225                std::collections::HashMap::<(usize, usize), Vec<usize>>::new();
226            topology
227                .objects_with_type(ObjectType::L2Cache)
228                .for_each(|object| {
229                    let l2_cache_size =
230                        if let Some(ObjectAttributes::Cache(cache)) = object.attributes() {
231                            cache
232                                .size()
233                                .map(|size| size.get() as usize)
234                                .unwrap_or_default()
235                        } else {
236                            0
237                        };
238                    if let Some(cpuset) = object.complete_cpuset() {
239                        let cpuset = cpuset
240                            .into_iter()
241                            .map(usize::from)
242                            .filter(|core| self.cores.contains(core))
243                            .collect::<Vec<_>>();
244                        let cpuset_len = cpuset.len();
245
246                        if !cpuset.is_empty() {
247                            grouped_by_l2_cache_size_and_core_count
248                                .entry((l2_cache_size, cpuset_len))
249                                .or_default()
250                                .extend(cpuset);
251                        }
252                    }
253                });
254
255            // Make sure all CPU cores in this set were found
256            if grouped_by_l2_cache_size_and_core_count
257                .values()
258                .flatten()
259                .count()
260                == self.cores.len()
261            {
262                // Walk through groups of cores for each (L2 cache size + number of cores in set)
263                // tuple and pull number of CPU cores proportional to the fraction of the cores that
264                // should be returned according to function argument
265                self.cores = grouped_by_l2_cache_size_and_core_count
266                    .into_values()
267                    .flat_map(|cores| {
268                        let limit = cores.len() * num_cores / self.cores.len();
269                        // At least 1 CPU core is needed
270                        cores.into_iter().take(limit.max(1))
271                    })
272                    .collect();
273
274                self.cores.sort();
275
276                return;
277            }
278        }
279        self.cores.truncate(num_cores);
280    }
281
282    /// Pin current thread to this NUMA node (not just one CPU core)
283    pub fn pin_current_thread(&self) {
284        #[cfg(feature = "numa")]
285        if let Some(topology) = &self.topology {
286            use hwlocality::cpu::binding::CpuBindingFlags;
287            use hwlocality::cpu::cpuset::CpuSet;
288            use hwlocality::current_thread_id;
289            use hwlocality::ffi::PositiveInt;
290
291            // load the cpuset for the given core index.
292            let cpu_cores = CpuSet::from_iter(
293                self.cores
294                    .iter()
295                    .map(|&core| PositiveInt::try_from(core).expect("Valid CPU core index; qed")),
296            );
297
298            if let Err(error) =
299                topology.bind_thread_cpu(current_thread_id(), &cpu_cores, CpuBindingFlags::empty())
300            {
301                warn!(%error, ?cpu_cores, "Failed to pin thread to CPU cores")
302            }
303        }
304    }
305}
306
307/// Recommended number of thread pool size for farming, equal to number of CPU cores in the first
308/// NUMA node
309pub fn recommended_number_of_farming_threads() -> usize {
310    #[cfg(feature = "numa")]
311    match hwlocality::Topology::new().map(std::sync::Arc::new) {
312        Ok(topology) => {
313            return topology
314                // Iterate over NUMA nodes
315                .objects_at_depth(hwlocality::object::depth::Depth::NUMANode)
316                // For each NUMA nodes get CPU set
317                .filter_map(|node| node.cpuset())
318                // Get number of CPU cores
319                .map(|cpuset| cpuset.iter_set().count())
320                .find(|&count| count > 0)
321                .unwrap_or_else(num_cpus::get)
322                .min(MAX_DEFAULT_FARMING_THREADS);
323        }
324        Err(error) => {
325            warn!(%error, "Failed to get NUMA topology");
326        }
327    }
328    num_cpus::get().min(MAX_DEFAULT_FARMING_THREADS)
329}
330
331/// Get all cpu cores, grouped into sets according to NUMA nodes or L3 cache groups on large CPUs.
332///
333/// Returned vector is guaranteed to have at least one element and have non-zero number of CPU cores
334/// in each set.
335pub fn all_cpu_cores() -> Vec<CpuCoreSet> {
336    #[cfg(feature = "numa")]
337    match hwlocality::Topology::new().map(std::sync::Arc::new) {
338        Ok(topology) => {
339            let cpu_cores = topology
340                // Iterate over groups of L3 caches
341                .objects_with_type(hwlocality::object::types::ObjectType::L3Cache)
342                // For each NUMA nodes get CPU set
343                .filter_map(|node| node.cpuset())
344                // For each CPU set extract individual cores
345                .map(|cpuset| cpuset.iter_set().map(usize::from).collect::<Vec<_>>())
346                .filter(|cores| !cores.is_empty())
347                .map(|cores| CpuCoreSet {
348                    cores,
349                    topology: Some(std::sync::Arc::clone(&topology)),
350                })
351                .collect::<Vec<_>>();
352
353            if !cpu_cores.is_empty() {
354                return cpu_cores;
355            }
356        }
357        Err(error) => {
358            warn!(%error, "Failed to get L3 cache topology");
359        }
360    }
361    vec![CpuCoreSet {
362        cores: (0..num_cpus::get()).collect(),
363        #[cfg(feature = "numa")]
364        topology: None,
365    }]
366}
367
368/// Parse space-separated set of groups of CPU cores (individual cores are coma-separated) into
369/// vector of CPU core sets that can be used for creation of plotting/replotting thread pools.
370pub fn parse_cpu_cores_sets(
371    s: &str,
372) -> Result<Vec<CpuCoreSet>, Box<dyn std::error::Error + Send + Sync>> {
373    #[cfg(feature = "numa")]
374    let topology = hwlocality::Topology::new().map(std::sync::Arc::new).ok();
375
376    s.split(' ')
377        .map(|s| {
378            let mut cores = Vec::new();
379            for s in s.split(',') {
380                let mut parts = s.split('-');
381                let range_start = parts
382                    .next()
383                    .ok_or(
384                        "Bad string format, must be comma separated list of CPU cores or ranges",
385                    )?
386                    .parse()?;
387
388                if let Some(range_end) = parts.next() {
389                    let range_end = range_end.parse()?;
390
391                    cores.extend(range_start..=range_end);
392                } else {
393                    cores.push(range_start);
394                }
395            }
396
397            Ok(CpuCoreSet {
398                cores,
399                #[cfg(feature = "numa")]
400                topology: topology.clone(),
401            })
402        })
403        .collect()
404}
405
406/// Thread indices for each thread pool
407pub fn thread_pool_core_indices(
408    thread_pool_size: Option<NonZeroUsize>,
409    thread_pools: Option<NonZeroUsize>,
410) -> Vec<CpuCoreSet> {
411    thread_pool_core_indices_internal(all_cpu_cores(), thread_pool_size, thread_pools)
412}
413
414fn thread_pool_core_indices_internal(
415    all_cpu_cores: Vec<CpuCoreSet>,
416    thread_pool_size: Option<NonZeroUsize>,
417    thread_pools: Option<NonZeroUsize>,
418) -> Vec<CpuCoreSet> {
419    #[cfg(feature = "numa")]
420    let topology = &all_cpu_cores
421        .first()
422        .expect("Not empty according to function description; qed")
423        .topology;
424
425    // In case number of thread pools is not specified, but user did customize thread pool size,
426    // default to auto-detected number of thread pools
427    let thread_pools = thread_pools
428        .map(|thread_pools| thread_pools.get())
429        .or_else(|| thread_pool_size.map(|_| all_cpu_cores.len()));
430
431    if let Some(thread_pools) = thread_pools {
432        let mut thread_pool_core_indices = Vec::<CpuCoreSet>::with_capacity(thread_pools);
433
434        let total_cpu_cores = all_cpu_cores.iter().flat_map(|set| set.cpu_cores()).count();
435
436        if let Some(thread_pool_size) = thread_pool_size {
437            // If thread pool size is fixed, loop over all CPU cores as many times as necessary and
438            // assign contiguous ranges of CPU cores to corresponding thread pools
439            let mut cpu_cores_iterator = iter::repeat(
440                all_cpu_cores
441                    .iter()
442                    .flat_map(|cpu_core_set| cpu_core_set.cores.iter())
443                    .copied(),
444            )
445            .flatten();
446
447            for _ in 0..thread_pools {
448                let cpu_cores = cpu_cores_iterator
449                    .by_ref()
450                    .take(thread_pool_size.get())
451                    // To loop over all CPU cores multiple times, modulo naively obtained CPU
452                    // cores by the total available number of CPU cores
453                    .map(|core_index| core_index % total_cpu_cores)
454                    .collect();
455
456                thread_pool_core_indices.push(CpuCoreSet {
457                    cores: cpu_cores,
458                    #[cfg(feature = "numa")]
459                    topology: topology.clone(),
460                });
461            }
462        } else {
463            // If thread pool size is not fixed, create threads pools with `total_cpu_cores/thread_pools` threads
464
465            let all_cpu_cores = all_cpu_cores
466                .iter()
467                .flat_map(|cpu_core_set| cpu_core_set.cores.iter())
468                .copied()
469                .collect::<Vec<_>>();
470
471            thread_pool_core_indices = all_cpu_cores
472                .chunks(total_cpu_cores.div_ceil(thread_pools))
473                .map(|cpu_cores| CpuCoreSet {
474                    cores: cpu_cores.to_vec(),
475                    #[cfg(feature = "numa")]
476                    topology: topology.clone(),
477                })
478                .collect();
479        }
480        thread_pool_core_indices
481    } else {
482        // If everything is set to defaults, use physical layout of CPUs
483        all_cpu_cores
484    }
485}
486
487fn create_plotting_thread_pool_manager_thread_pool_pair(
488    thread_prefix: &'static str,
489    thread_pool_index: usize,
490    cpu_core_set: CpuCoreSet,
491    thread_priority: Option<ThreadPriority>,
492) -> Result<ThreadPool, ThreadPoolBuildError> {
493    let thread_name =
494        move |thread_index| format!("{thread_prefix}-{thread_pool_index}.{thread_index}");
495    // TODO: remove this panic handler when rayon logs panic_info
496    // https://github.com/rayon-rs/rayon/issues/1208
497    // (we'll lose the thread name, because it's not stored within rayon's WorkerThread)
498    let panic_handler = move |panic_info| {
499        if let Some(index) = current_thread_index() {
500            eprintln!("panic on thread {}: {:?}", thread_name(index), panic_info);
501        } else {
502            // We want to guarantee exit, rather than panicking in a panic handler.
503            eprintln!("rayon panic handler called on non-rayon thread: {panic_info:?}");
504        }
505        exit(1);
506    };
507
508    ThreadPoolBuilder::new()
509        .thread_name(thread_name)
510        .num_threads(cpu_core_set.cpu_cores().len())
511        .panic_handler(panic_handler)
512        .spawn_handler({
513            let handle = Handle::current();
514
515            rayon_custom_spawn_handler(move |thread| {
516                let cpu_core_set = cpu_core_set.clone();
517                let handle = handle.clone();
518
519                move || {
520                    cpu_core_set.pin_current_thread();
521                    if let Some(thread_priority) = thread_priority
522                        && let Err(error) = set_current_thread_priority(thread_priority)
523                    {
524                        warn!(%error, "Failed to set thread priority");
525                    }
526                    drop(cpu_core_set);
527
528                    let _guard = handle.enter();
529
530                    task::block_in_place(|| thread.run())
531                }
532            })
533        })
534        .build()
535}
536
537/// Create thread pools manager.
538///
539/// Creates thread pool pairs for each of CPU core set pair with number of plotting and replotting
540/// threads corresponding to number of cores in each set and pins threads to all of those CPU cores
541/// (each thread to all cors in a set, not thread per core). Each thread will also have Tokio
542/// context available.
543///
544/// The easiest way to obtain CPUs is using [`all_cpu_cores`], but [`thread_pool_core_indices`] in case
545/// support for user customizations is desired. They will then have to be composed into pairs for this function.
546pub fn create_plotting_thread_pool_manager<I>(
547    mut cpu_core_sets: I,
548    thread_priority: Option<ThreadPriority>,
549) -> Result<PlottingThreadPoolManager, ThreadPoolBuildError>
550where
551    I: ExactSizeIterator<Item = (CpuCoreSet, CpuCoreSet)>,
552{
553    let total_thread_pools = cpu_core_sets.len();
554
555    PlottingThreadPoolManager::new(
556        |thread_pool_index| {
557            let (plotting_cpu_core_set, replotting_cpu_core_set) = cpu_core_sets
558                .next()
559                .expect("Number of thread pools is the same as cpu core sets; qed");
560
561            Ok(PlottingThreadPoolPair {
562                plotting: create_plotting_thread_pool_manager_thread_pool_pair(
563                    "plotting",
564                    thread_pool_index,
565                    plotting_cpu_core_set,
566                    thread_priority,
567                )?,
568                replotting: create_plotting_thread_pool_manager_thread_pool_pair(
569                    "replotting",
570                    thread_pool_index,
571                    replotting_cpu_core_set,
572                    thread_priority,
573                )?,
574            })
575        },
576        NonZeroUsize::new(total_thread_pools)
577            .expect("Thread pool is guaranteed to be non-empty; qed"),
578    )
579}
580
581/// This function is supposed to be used with [`rayon::ThreadPoolBuilder::spawn_handler()`] to
582/// spawn handler with a custom logic defined by `spawn_hook_builder`.
583///
584/// `spawn_hook_builder` is called with thread builder to create `spawn_handler` that in turn will
585/// be spawn rayon's thread with desired environment.
586pub fn rayon_custom_spawn_handler<SpawnHandlerBuilder, SpawnHandler, SpawnHandlerResult>(
587    mut spawn_handler_builder: SpawnHandlerBuilder,
588) -> impl FnMut(ThreadBuilder) -> io::Result<()>
589where
590    SpawnHandlerBuilder: (FnMut(ThreadBuilder) -> SpawnHandler) + Clone,
591    SpawnHandler: (FnOnce() -> SpawnHandlerResult) + Send + 'static,
592    SpawnHandlerResult: Send + 'static,
593{
594    move |thread: ThreadBuilder| {
595        let mut b = thread::Builder::new();
596        if let Some(name) = thread.name() {
597            b = b.name(name.to_owned());
598        }
599        if let Some(stack_size) = thread.stack_size() {
600            b = b.stack_size(stack_size);
601        }
602
603        b.spawn(spawn_handler_builder(thread))?;
604        Ok(())
605    }
606}
607
608/// This function is supposed to be used with [`rayon::ThreadPoolBuilder::spawn_handler()`] to
609/// inherit current tokio runtime.
610pub fn tokio_rayon_spawn_handler() -> impl FnMut(ThreadBuilder) -> io::Result<()> {
611    let handle = Handle::current();
612
613    rayon_custom_spawn_handler(move |thread| {
614        let handle = handle.clone();
615
616        move || {
617            let _guard = handle.enter();
618
619            task::block_in_place(|| thread.run())
620        }
621    })
622}