Skip to main content

ab_farmer/
utils.rs

1//! Various utilities used by farmer or with farmer
2
3#[cfg(test)]
4mod tests;
5
6use crate::thread_pool_manager::{PlottingThreadPoolManager, PlottingThreadPoolPair};
7use futures::channel::oneshot;
8use futures::channel::oneshot::Canceled;
9use futures::future::Either;
10use gdt_cpus::{AffinityMask, CpuInfo, ThreadPriority, set_thread_affinity, set_thread_priority};
11use rayon::{
12    ThreadBuilder, ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder, current_thread_index,
13};
14use std::collections::HashMap;
15use std::future::Future;
16use std::num::NonZeroUsize;
17use std::ops::Deref;
18use std::pin::{Pin, pin};
19use std::process::exit;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22use std::{fmt, io, iter, thread};
23use tokio::runtime::Handle;
24use tokio::task;
25use tracing::{debug, warn};
26
27/// It doesn't make a lot of sense to have a huge number of farming threads, 32 is plenty
28const MAX_DEFAULT_FARMING_THREADS: usize = 32;
29
30/// Joins async join handle on drop
31#[derive(Debug)]
32pub struct AsyncJoinOnDrop<T> {
33    handle: Option<task::JoinHandle<T>>,
34    abort_on_drop: bool,
35}
36
37impl<T> Drop for AsyncJoinOnDrop<T> {
38    #[inline]
39    fn drop(&mut self) {
40        if let Some(handle) = self.handle.take() {
41            if self.abort_on_drop {
42                handle.abort();
43            }
44
45            if !handle.is_finished() {
46                task::block_in_place(move || {
47                    let _: Result<_, _> = Handle::current().block_on(handle);
48                });
49            }
50        }
51    }
52}
53
54impl<T> AsyncJoinOnDrop<T> {
55    /// Create a new instance.
56    #[inline]
57    pub fn new(handle: task::JoinHandle<T>, abort_on_drop: bool) -> Self {
58        Self {
59            handle: Some(handle),
60            abort_on_drop,
61        }
62    }
63}
64
65impl<T> Future for AsyncJoinOnDrop<T> {
66    type Output = Result<T, task::JoinError>;
67
68    #[inline]
69    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
70        Pin::new(
71            self.handle
72                .as_mut()
73                .expect("Only dropped in Drop impl; qed"),
74        )
75        .poll(cx)
76    }
77}
78
79/// Joins synchronous join handle on drop
80pub(crate) struct JoinOnDrop(Option<thread::JoinHandle<()>>);
81
82impl Drop for JoinOnDrop {
83    #[inline]
84    fn drop(&mut self) {
85        self.0
86            .take()
87            .expect("Always called exactly once; qed")
88            .join()
89            .expect("Panic if background thread panicked");
90    }
91}
92
93impl JoinOnDrop {
94    // Create new instance
95    #[inline]
96    pub(crate) fn new(handle: thread::JoinHandle<()>) -> Self {
97        Self(Some(handle))
98    }
99}
100
101impl Deref for JoinOnDrop {
102    type Target = thread::JoinHandle<()>;
103
104    #[inline]
105    fn deref(&self) -> &Self::Target {
106        self.0.as_ref().expect("Only dropped in Drop impl; qed")
107    }
108}
109
110/// Runs future on a dedicated thread with the specified name, will block on drop until background
111/// thread with future is stopped too, ensuring nothing is left in memory
112pub fn run_future_in_dedicated_thread<CreateFut, Fut, T>(
113    create_future: CreateFut,
114    thread_name: String,
115) -> io::Result<impl Future<Output = Result<T, Canceled>> + Send>
116where
117    CreateFut: (FnOnce() -> Fut) + Send + 'static,
118    Fut: Future<Output = T> + 'static,
119    T: Send + 'static,
120{
121    let (drop_tx, drop_rx) = oneshot::channel::<()>();
122    let (result_tx, result_rx) = oneshot::channel();
123    let handle = Handle::current();
124    let join_handle = thread::Builder::new().name(thread_name).spawn(move || {
125        let _tokio_handle_guard = handle.enter();
126
127        let future = pin!(create_future());
128
129        let result = match handle.block_on(futures::future::select(future, drop_rx)) {
130            Either::Left((result, _)) => result,
131            Either::Right(_) => {
132                // Outer future was dropped, nothing left to do
133                return;
134            }
135        };
136        if let Err(_error) = result_tx.send(result) {
137            debug!(
138                thread_name = ?thread::current().name(),
139                "Future finished, but receiver was already dropped",
140            );
141        }
142    })?;
143    // Ensure thread will not be left hanging forever
144    let join_on_drop = JoinOnDrop::new(join_handle);
145
146    Ok(async move {
147        let result = result_rx.await;
148        drop(drop_tx);
149        drop(join_on_drop);
150        result
151    })
152}
153
154/// Abstraction for CPU core set
155#[derive(Clone)]
156pub struct CpuCoreSet {
157    /// CPU cores that belong to this set
158    affinity_mask: AffinityMask,
159    cpu_info: Option<Arc<CpuInfo>>,
160}
161
162impl fmt::Debug for CpuCoreSet {
163    #[inline]
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        f.debug_struct("CpuCoreSet")
166            .field("affinity_mask", &self.affinity_mask)
167            .finish_non_exhaustive()
168    }
169}
170
171impl CpuCoreSet {
172    /// Get cpu core numbers in this set
173    pub fn cpu_cores(&self) -> &AffinityMask {
174        &self.affinity_mask
175    }
176
177    /// Will truncate a set of CPU cores to this number.
178    ///
179    /// Truncation will take into account L2 cache topology to use a half of the actual physical
180    /// cores and half of each core type in the case of heterogeneous CPUs.
181    pub fn truncate(&mut self, num_cores: usize) {
182        let num_cores = num_cores.clamp(1, self.affinity_mask.count());
183
184        let Some(cpu_info) = &self.cpu_info else {
185            self.affinity_mask = self.affinity_mask.iter().take(num_cores).collect();
186            return;
187        };
188
189        let mut grouped_by_l2_cache_size_and_core_count = HashMap::<_, Vec<_>>::new();
190        for domain in &cpu_info.l2_domains {
191            let domain_mask = domain.mask.intersection(&self.affinity_mask);
192
193            if domain_mask.is_empty() {
194                continue;
195            }
196
197            grouped_by_l2_cache_size_and_core_count
198                .entry((domain.size_bytes, domain_mask.count()))
199                .or_default()
200                .extend(domain_mask.iter());
201        }
202
203        // Make sure all CPU cores in this set were found
204        if grouped_by_l2_cache_size_and_core_count
205            .values()
206            .flatten()
207            .count()
208            != self.affinity_mask.count()
209        {
210            self.affinity_mask = self.affinity_mask.iter().take(num_cores).collect();
211            return;
212        }
213
214        // Walk through groups of cores for each (L2 cache size + number of cores in set) tuple and
215        // pull number of CPU cores proportional to the fraction of the cores that should be
216        // returned according to function argument
217        self.affinity_mask = grouped_by_l2_cache_size_and_core_count
218            .into_values()
219            .flat_map(|cores| {
220                let limit = (cores.len() * num_cores).div_ceil(self.affinity_mask.count());
221                // At least 1 CPU core is needed
222                cores.into_iter().take(limit)
223            })
224            .take(num_cores)
225            .collect();
226    }
227
228    /// Pin current thread to this CPU core set
229    pub fn pin_current_thread(&self) {
230        if let Err(error) = set_thread_affinity(&self.affinity_mask) {
231            warn!(%error, cpu_cores = ?self.affinity_mask, "Failed to pin thread to CPU cores");
232        }
233    }
234}
235
236/// Recommended number of thread pool size for farming, equal to number of CPU cores in the first
237/// NUMA node
238pub fn recommended_number_of_farming_threads() -> usize {
239    match CpuInfo::detect() {
240        Ok(cpu_info) => (0..cpu_info.numa_node_count)
241            .map(|numa_node| cpu_info.numa_node_mask(numa_node).count())
242            .find(|&count| count > 0)
243            .unwrap_or_else(num_cpus::get)
244            .min(MAX_DEFAULT_FARMING_THREADS),
245        Err(error) => {
246            warn!(%error, "Failed to get CPU info");
247
248            num_cpus::get().min(MAX_DEFAULT_FARMING_THREADS)
249        }
250    }
251}
252
253/// Get all cpu cores, grouped into sets according to NUMA nodes or L3 cache groups on large CPUs.
254///
255/// Returned vector is guaranteed to have at least one element and have non-zero number of CPU cores
256/// in each set.
257pub fn all_cpu_cores() -> Vec<CpuCoreSet> {
258    match CpuInfo::detect() {
259        Ok(cpu_info) => {
260            let cpu_info = Arc::new(cpu_info);
261            let cpu_cores = cpu_info
262                .l3_domains
263                .iter()
264                .map(|domain| domain.mask)
265                .filter(|affinity_mask| !affinity_mask.is_empty())
266                .map(|affinity_mask| CpuCoreSet {
267                    affinity_mask,
268                    cpu_info: Some(Arc::clone(&cpu_info)),
269                })
270                .collect::<Vec<_>>();
271
272            if cpu_cores.is_empty() {
273                vec![CpuCoreSet {
274                    affinity_mask: (0..num_cpus::get()).collect(),
275                    cpu_info: Some(cpu_info),
276                }]
277            } else {
278                cpu_cores
279            }
280        }
281        Err(error) => {
282            warn!(%error, "Failed to get L3 cache topology");
283
284            vec![CpuCoreSet {
285                affinity_mask: (0..num_cpus::get()).collect(),
286                cpu_info: None,
287            }]
288        }
289    }
290}
291
292/// Parse space-separated set of groups of CPU cores (individual cores are coma-separated) into
293/// vector of CPU core sets that can be used for creation of plotting/replotting thread pools.
294pub fn parse_cpu_cores_sets(
295    s: &str,
296) -> Result<Vec<CpuCoreSet>, Box<dyn std::error::Error + Send + Sync>> {
297    let cpu_info = CpuInfo::detect().ok().map(Arc::new);
298
299    s.split(' ')
300        .map(|s| {
301            let mut cores = AffinityMask::empty();
302            for s in s.split(',') {
303                let mut parts = s.split('-');
304                let range_start = parts
305                    .next()
306                    .ok_or(
307                        "Bad string format, must be comma separated list of CPU cores or ranges",
308                    )?
309                    .parse()?;
310
311                if let Some(range_end) = parts.next() {
312                    let range_end = range_end.parse()?;
313
314                    // TODO: https://github.com/gdt-tools/gdt-cpus-rs/pull/11
315                    // cores.extend(range_start..=range_end);
316                    for core in range_start..=range_end {
317                        cores.add(core);
318                    }
319                } else {
320                    cores.add(range_start);
321                }
322            }
323
324            Ok(CpuCoreSet {
325                affinity_mask: cores,
326                cpu_info: cpu_info.clone(),
327            })
328        })
329        .collect()
330}
331
332/// Thread indices for each thread pool
333pub fn thread_pool_core_indices(
334    thread_pool_size: Option<NonZeroUsize>,
335    thread_pools: Option<NonZeroUsize>,
336) -> Vec<CpuCoreSet> {
337    thread_pool_core_indices_internal(all_cpu_cores(), thread_pool_size, thread_pools)
338}
339
340fn thread_pool_core_indices_internal(
341    all_cpu_cores: Vec<CpuCoreSet>,
342    thread_pool_size: Option<NonZeroUsize>,
343    thread_pools: Option<NonZeroUsize>,
344) -> Vec<CpuCoreSet> {
345    let cpu_info = &all_cpu_cores
346        .first()
347        .expect("Not empty according to function description; qed")
348        .cpu_info;
349
350    // In case number of thread pools is not specified, but user did customize thread pool size,
351    // default to auto-detected number of thread pools
352    let thread_pools = thread_pools
353        .map(NonZeroUsize::get)
354        .or_else(|| thread_pool_size.map(|_| all_cpu_cores.len()));
355
356    if let Some(thread_pools) = thread_pools {
357        let all_cpu_cores = &all_cpu_cores;
358        let mut thread_pool_core_indices = Vec::<CpuCoreSet>::with_capacity(thread_pools);
359
360        let total_cpu_cores = all_cpu_cores.iter().flat_map(CpuCoreSet::cpu_cores).count();
361
362        if let Some(thread_pool_size) = thread_pool_size {
363            // If thread pool size is fixed, loop over all CPU cores as many times as necessary and
364            // assign contiguous ranges of CPU cores to corresponding thread pools
365            let mut cpu_cores_iterator = iter::repeat_with(|| {
366                all_cpu_cores
367                    .iter()
368                    .flat_map(|cpu_core_set| cpu_core_set.affinity_mask.iter())
369            })
370            .flatten();
371
372            for _ in 0..thread_pools {
373                let cpu_cores = cpu_cores_iterator
374                    .by_ref()
375                    .take(thread_pool_size.get())
376                    // To loop over all CPU cores multiple times, modulo naively obtained CPU
377                    // cores by the total available number of CPU cores
378                    .map(|core_index| core_index % total_cpu_cores)
379                    .collect();
380
381                thread_pool_core_indices.push(CpuCoreSet {
382                    affinity_mask: cpu_cores,
383                    cpu_info: cpu_info.clone(),
384                });
385            }
386        } else {
387            // If thread pool size is not fixed, create threads pools with
388            // `total_cpu_cores/thread_pools` threads
389
390            let all_cpu_cores = all_cpu_cores
391                .iter()
392                .flat_map(|cpu_core_set| cpu_core_set.affinity_mask.iter())
393                .collect::<Vec<_>>();
394
395            thread_pool_core_indices = all_cpu_cores
396                .chunks(total_cpu_cores.div_ceil(thread_pools))
397                .map(|cpu_cores| CpuCoreSet {
398                    affinity_mask: AffinityMask::from_cores(cpu_cores),
399                    cpu_info: cpu_info.clone(),
400                })
401                .collect();
402        }
403        thread_pool_core_indices
404    } else {
405        // If everything is set to defaults, use physical layout of CPUs
406        all_cpu_cores
407    }
408}
409
410fn create_plotting_thread_pool_manager_thread_pool_pair(
411    thread_prefix: &'static str,
412    thread_pool_index: usize,
413    cpu_core_set: CpuCoreSet,
414    thread_priority: Option<ThreadPriority>,
415) -> Result<ThreadPool, ThreadPoolBuildError> {
416    let thread_name =
417        move |thread_index| format!("{thread_prefix}-{thread_pool_index}.{thread_index}");
418    // TODO: remove this panic handler when rayon logs panic_info
419    // https://github.com/rayon-rs/rayon/issues/1208
420    // (we'll lose the thread name, because it's not stored within rayon's WorkerThread)
421    let panic_handler = move |panic_info| {
422        if let Some(index) = current_thread_index() {
423            eprintln!("panic on thread {}: {:?}", thread_name(index), panic_info);
424        } else {
425            // We want to guarantee exit, rather than panicking in a panic handler.
426            eprintln!("rayon panic handler called on non-rayon thread: {panic_info:?}");
427        }
428        exit(1);
429    };
430
431    ThreadPoolBuilder::new()
432        .thread_name(thread_name)
433        .num_threads(cpu_core_set.cpu_cores().count())
434        .panic_handler(panic_handler)
435        .spawn_handler({
436            let handle = Handle::current();
437
438            rayon_custom_spawn_handler(move |thread| {
439                let cpu_core_set = cpu_core_set.clone();
440                let handle = handle.clone();
441
442                move || {
443                    cpu_core_set.pin_current_thread();
444                    if let Some(thread_priority) = thread_priority
445                        && let Err(error) = set_thread_priority(thread_priority)
446                    {
447                        warn!(%error, "Failed to set thread priority");
448                    }
449                    drop(cpu_core_set);
450
451                    let _guard = handle.enter();
452
453                    task::block_in_place(|| thread.run());
454                }
455            })
456        })
457        .build()
458}
459
460/// Create thread pools manager.
461///
462/// Creates thread pool pairs for each of CPU core set pair with number of plotting and replotting
463/// threads corresponding to number of cores in each set and pins threads to all of those CPU cores
464/// (each thread to all cors in a set, not thread per core). Each thread will also have Tokio
465/// context available.
466///
467/// The easiest way to obtain CPUs is using [`all_cpu_cores`], but [`thread_pool_core_indices`] in
468/// case support for user customizations is desired. They will then have to be composed into pairs
469/// for this function.
470pub fn create_plotting_thread_pool_manager<I>(
471    mut cpu_core_sets: I,
472    thread_priority: Option<ThreadPriority>,
473) -> Result<PlottingThreadPoolManager, ThreadPoolBuildError>
474where
475    I: ExactSizeIterator<Item = (CpuCoreSet, CpuCoreSet)>,
476{
477    let total_thread_pools = cpu_core_sets.len();
478
479    PlottingThreadPoolManager::new(
480        |thread_pool_index| {
481            let (plotting_cpu_core_set, replotting_cpu_core_set) = cpu_core_sets
482                .next()
483                .expect("Number of thread pools is the same as cpu core sets; qed");
484
485            Ok(PlottingThreadPoolPair {
486                plotting: create_plotting_thread_pool_manager_thread_pool_pair(
487                    "plotting",
488                    thread_pool_index,
489                    plotting_cpu_core_set,
490                    thread_priority,
491                )?,
492                replotting: create_plotting_thread_pool_manager_thread_pool_pair(
493                    "replotting",
494                    thread_pool_index,
495                    replotting_cpu_core_set,
496                    thread_priority,
497                )?,
498            })
499        },
500        NonZeroUsize::new(total_thread_pools)
501            .expect("Thread pool is guaranteed to be non-empty; qed"),
502    )
503}
504
505/// This function is supposed to be used with [`rayon::ThreadPoolBuilder::spawn_handler()`] to
506/// spawn handler with a custom logic defined by `spawn_hook_builder`.
507///
508/// `spawn_hook_builder` is called with thread builder to create `spawn_handler` that in turn will
509/// be spawn rayon's thread with desired environment.
510pub fn rayon_custom_spawn_handler<SpawnHandlerBuilder, SpawnHandler, SpawnHandlerResult>(
511    mut spawn_handler_builder: SpawnHandlerBuilder,
512) -> impl FnMut(ThreadBuilder) -> io::Result<()>
513where
514    SpawnHandlerBuilder: (FnMut(ThreadBuilder) -> SpawnHandler) + Clone,
515    SpawnHandler: (FnOnce() -> SpawnHandlerResult) + Send + 'static,
516    SpawnHandlerResult: Send + 'static,
517{
518    move |thread: ThreadBuilder| {
519        let mut b = thread::Builder::new();
520        if let Some(name) = thread.name() {
521            b = b.name(name.to_owned());
522        }
523        if let Some(stack_size) = thread.stack_size() {
524            b = b.stack_size(stack_size);
525        }
526
527        b.spawn(spawn_handler_builder(thread))?;
528        Ok(())
529    }
530}
531
532/// This function is supposed to be used with [`rayon::ThreadPoolBuilder::spawn_handler()`] to
533/// inherit current tokio runtime.
534pub fn tokio_rayon_spawn_handler() -> impl FnMut(ThreadBuilder) -> io::Result<()> {
535    let handle = Handle::current();
536
537    rayon_custom_spawn_handler(move |thread| {
538        let handle = handle.clone();
539
540        move || {
541            let _guard = handle.enter();
542
543            task::block_in_place(|| thread.run());
544        }
545    })
546}