1#[cfg(test)]
4mod tests;
5
6use crate::thread_pool_manager::{PlottingThreadPoolManager, PlottingThreadPoolPair};
7use futures::channel::oneshot;
8use futures::channel::oneshot::Canceled;
9use futures::future::Either;
10use rayon::{
11 ThreadBuilder, ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder, current_thread_index,
12};
13use std::future::Future;
14use std::num::NonZeroUsize;
15use std::ops::Deref;
16use std::pin::{Pin, pin};
17use std::process::exit;
18use std::task::{Context, Poll};
19use std::{fmt, io, iter, thread};
20use thread_priority::{ThreadPriority, set_current_thread_priority};
21use tokio::runtime::Handle;
22use tokio::task;
23use tracing::{debug, warn};
24
25const MAX_DEFAULT_FARMING_THREADS: usize = 32;
27
28#[derive(Debug)]
30pub struct AsyncJoinOnDrop<T> {
31 handle: Option<task::JoinHandle<T>>,
32 abort_on_drop: bool,
33}
34
35impl<T> Drop for AsyncJoinOnDrop<T> {
36 #[inline]
37 fn drop(&mut self) {
38 if let Some(handle) = self.handle.take() {
39 if self.abort_on_drop {
40 handle.abort();
41 }
42
43 if !handle.is_finished() {
44 task::block_in_place(move || {
45 let _ = Handle::current().block_on(handle);
46 });
47 }
48 }
49 }
50}
51
52impl<T> AsyncJoinOnDrop<T> {
53 #[inline]
55 pub fn new(handle: task::JoinHandle<T>, abort_on_drop: bool) -> Self {
56 Self {
57 handle: Some(handle),
58 abort_on_drop,
59 }
60 }
61}
62
63impl<T> Future for AsyncJoinOnDrop<T> {
64 type Output = Result<T, task::JoinError>;
65
66 #[inline]
67 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
68 Pin::new(
69 self.handle
70 .as_mut()
71 .expect("Only dropped in Drop impl; qed"),
72 )
73 .poll(cx)
74 }
75}
76
77pub(crate) struct JoinOnDrop(Option<thread::JoinHandle<()>>);
79
80impl Drop for JoinOnDrop {
81 #[inline]
82 fn drop(&mut self) {
83 self.0
84 .take()
85 .expect("Always called exactly once; qed")
86 .join()
87 .expect("Panic if background thread panicked");
88 }
89}
90
91impl JoinOnDrop {
92 #[inline]
94 pub(crate) fn new(handle: thread::JoinHandle<()>) -> Self {
95 Self(Some(handle))
96 }
97}
98
99impl Deref for JoinOnDrop {
100 type Target = thread::JoinHandle<()>;
101
102 #[inline]
103 fn deref(&self) -> &Self::Target {
104 self.0.as_ref().expect("Only dropped in Drop impl; qed")
105 }
106}
107
108pub fn run_future_in_dedicated_thread<CreateFut, Fut, T>(
111 create_future: CreateFut,
112 thread_name: String,
113) -> io::Result<impl Future<Output = Result<T, Canceled>> + Send>
114where
115 CreateFut: (FnOnce() -> Fut) + Send + 'static,
116 Fut: Future<Output = T> + 'static,
117 T: Send + 'static,
118{
119 let (drop_tx, drop_rx) = oneshot::channel::<()>();
120 let (result_tx, result_rx) = oneshot::channel();
121 let handle = Handle::current();
122 let join_handle = thread::Builder::new().name(thread_name).spawn(move || {
123 let _tokio_handle_guard = handle.enter();
124
125 let future = pin!(create_future());
126
127 let result = match handle.block_on(futures::future::select(future, drop_rx)) {
128 Either::Left((result, _)) => result,
129 Either::Right(_) => {
130 return;
132 }
133 };
134 if let Err(_error) = result_tx.send(result) {
135 debug!(
136 thread_name = ?thread::current().name(),
137 "Future finished, but receiver was already dropped",
138 );
139 }
140 })?;
141 let join_on_drop = JoinOnDrop::new(join_handle);
143
144 Ok(async move {
145 let result = result_rx.await;
146 drop(drop_tx);
147 drop(join_on_drop);
148 result
149 })
150}
151
152#[derive(Clone)]
154pub struct CpuCoreSet {
155 cores: Vec<usize>,
157 #[cfg(feature = "numa")]
158 topology: Option<std::sync::Arc<hwlocality::Topology>>,
159}
160
161impl fmt::Debug for CpuCoreSet {
162 #[inline]
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 let mut s = f.debug_struct("CpuCoreSet");
165 #[cfg(not(feature = "numa"))]
166 if self.cores.array_windows::<2>().all(|&[a, b]| a + 1 == b) {
167 s.field(
168 "cores",
169 &format!(
170 "{}-{}",
171 self.cores.first().expect("List of cores is not empty; qed"),
172 self.cores.last().expect("List of cores is not empty; qed")
173 ),
174 );
175 } else {
176 s.field(
177 "cores",
178 &self
179 .cores
180 .iter()
181 .map(usize::to_string)
182 .collect::<Vec<_>>()
183 .join(","),
184 );
185 }
186 #[cfg(feature = "numa")]
187 {
188 use hwlocality::cpu::cpuset::CpuSet;
189 use hwlocality::ffi::PositiveInt;
190
191 s.field(
192 "cores",
193 &CpuSet::from_iter(
194 self.cores.iter().map(|&core| {
195 PositiveInt::try_from(core).expect("Valid CPU core index; qed")
196 }),
197 ),
198 );
199 }
200 s.finish_non_exhaustive()
201 }
202}
203
204impl CpuCoreSet {
205 pub fn cpu_cores(&self) -> &[usize] {
207 &self.cores
208 }
209
210 pub fn truncate(&mut self, num_cores: usize) {
217 let num_cores = num_cores.clamp(1, self.cores.len());
218
219 #[cfg(feature = "numa")]
220 if let Some(topology) = &self.topology {
221 use hwlocality::object::attributes::ObjectAttributes;
222 use hwlocality::object::types::ObjectType;
223
224 let mut grouped_by_l2_cache_size_and_core_count =
225 std::collections::HashMap::<(usize, usize), Vec<usize>>::new();
226 topology
227 .objects_with_type(ObjectType::L2Cache)
228 .for_each(|object| {
229 let l2_cache_size =
230 if let Some(ObjectAttributes::Cache(cache)) = object.attributes() {
231 cache
232 .size()
233 .map(|size| size.get() as usize)
234 .unwrap_or_default()
235 } else {
236 0
237 };
238 if let Some(cpuset) = object.complete_cpuset() {
239 let cpuset = cpuset
240 .into_iter()
241 .map(usize::from)
242 .filter(|core| self.cores.contains(core))
243 .collect::<Vec<_>>();
244 let cpuset_len = cpuset.len();
245
246 if !cpuset.is_empty() {
247 grouped_by_l2_cache_size_and_core_count
248 .entry((l2_cache_size, cpuset_len))
249 .or_default()
250 .extend(cpuset);
251 }
252 }
253 });
254
255 if grouped_by_l2_cache_size_and_core_count
257 .values()
258 .flatten()
259 .count()
260 == self.cores.len()
261 {
262 self.cores = grouped_by_l2_cache_size_and_core_count
266 .into_values()
267 .flat_map(|cores| {
268 let limit = cores.len() * num_cores / self.cores.len();
269 cores.into_iter().take(limit.max(1))
271 })
272 .collect();
273
274 self.cores.sort();
275
276 return;
277 }
278 }
279 self.cores.truncate(num_cores);
280 }
281
282 pub fn pin_current_thread(&self) {
284 #[cfg(feature = "numa")]
285 if let Some(topology) = &self.topology {
286 use hwlocality::cpu::binding::CpuBindingFlags;
287 use hwlocality::cpu::cpuset::CpuSet;
288 use hwlocality::current_thread_id;
289 use hwlocality::ffi::PositiveInt;
290
291 let cpu_cores = CpuSet::from_iter(
293 self.cores
294 .iter()
295 .map(|&core| PositiveInt::try_from(core).expect("Valid CPU core index; qed")),
296 );
297
298 if let Err(error) =
299 topology.bind_thread_cpu(current_thread_id(), &cpu_cores, CpuBindingFlags::empty())
300 {
301 warn!(%error, ?cpu_cores, "Failed to pin thread to CPU cores")
302 }
303 }
304 }
305}
306
307pub fn recommended_number_of_farming_threads() -> usize {
310 #[cfg(feature = "numa")]
311 match hwlocality::Topology::new().map(std::sync::Arc::new) {
312 Ok(topology) => {
313 return topology
314 .objects_at_depth(hwlocality::object::depth::Depth::NUMANode)
316 .filter_map(|node| node.cpuset())
318 .map(|cpuset| cpuset.iter_set().count())
320 .find(|&count| count > 0)
321 .unwrap_or_else(num_cpus::get)
322 .min(MAX_DEFAULT_FARMING_THREADS);
323 }
324 Err(error) => {
325 warn!(%error, "Failed to get NUMA topology");
326 }
327 }
328 num_cpus::get().min(MAX_DEFAULT_FARMING_THREADS)
329}
330
331pub fn all_cpu_cores() -> Vec<CpuCoreSet> {
336 #[cfg(feature = "numa")]
337 match hwlocality::Topology::new().map(std::sync::Arc::new) {
338 Ok(topology) => {
339 let cpu_cores = topology
340 .objects_with_type(hwlocality::object::types::ObjectType::L3Cache)
342 .filter_map(|node| node.cpuset())
344 .map(|cpuset| cpuset.iter_set().map(usize::from).collect::<Vec<_>>())
346 .filter(|cores| !cores.is_empty())
347 .map(|cores| CpuCoreSet {
348 cores,
349 topology: Some(std::sync::Arc::clone(&topology)),
350 })
351 .collect::<Vec<_>>();
352
353 if !cpu_cores.is_empty() {
354 return cpu_cores;
355 }
356 }
357 Err(error) => {
358 warn!(%error, "Failed to get L3 cache topology");
359 }
360 }
361 vec![CpuCoreSet {
362 cores: (0..num_cpus::get()).collect(),
363 #[cfg(feature = "numa")]
364 topology: None,
365 }]
366}
367
368pub fn parse_cpu_cores_sets(
371 s: &str,
372) -> Result<Vec<CpuCoreSet>, Box<dyn std::error::Error + Send + Sync>> {
373 #[cfg(feature = "numa")]
374 let topology = hwlocality::Topology::new().map(std::sync::Arc::new).ok();
375
376 s.split(' ')
377 .map(|s| {
378 let mut cores = Vec::new();
379 for s in s.split(',') {
380 let mut parts = s.split('-');
381 let range_start = parts
382 .next()
383 .ok_or(
384 "Bad string format, must be comma separated list of CPU cores or ranges",
385 )?
386 .parse()?;
387
388 if let Some(range_end) = parts.next() {
389 let range_end = range_end.parse()?;
390
391 cores.extend(range_start..=range_end);
392 } else {
393 cores.push(range_start);
394 }
395 }
396
397 Ok(CpuCoreSet {
398 cores,
399 #[cfg(feature = "numa")]
400 topology: topology.clone(),
401 })
402 })
403 .collect()
404}
405
406pub fn thread_pool_core_indices(
408 thread_pool_size: Option<NonZeroUsize>,
409 thread_pools: Option<NonZeroUsize>,
410) -> Vec<CpuCoreSet> {
411 thread_pool_core_indices_internal(all_cpu_cores(), thread_pool_size, thread_pools)
412}
413
414fn thread_pool_core_indices_internal(
415 all_cpu_cores: Vec<CpuCoreSet>,
416 thread_pool_size: Option<NonZeroUsize>,
417 thread_pools: Option<NonZeroUsize>,
418) -> Vec<CpuCoreSet> {
419 #[cfg(feature = "numa")]
420 let topology = &all_cpu_cores
421 .first()
422 .expect("Not empty according to function description; qed")
423 .topology;
424
425 let thread_pools = thread_pools
428 .map(|thread_pools| thread_pools.get())
429 .or_else(|| thread_pool_size.map(|_| all_cpu_cores.len()));
430
431 if let Some(thread_pools) = thread_pools {
432 let mut thread_pool_core_indices = Vec::<CpuCoreSet>::with_capacity(thread_pools);
433
434 let total_cpu_cores = all_cpu_cores.iter().flat_map(|set| set.cpu_cores()).count();
435
436 if let Some(thread_pool_size) = thread_pool_size {
437 let mut cpu_cores_iterator = iter::repeat(
440 all_cpu_cores
441 .iter()
442 .flat_map(|cpu_core_set| cpu_core_set.cores.iter())
443 .copied(),
444 )
445 .flatten();
446
447 for _ in 0..thread_pools {
448 let cpu_cores = cpu_cores_iterator
449 .by_ref()
450 .take(thread_pool_size.get())
451 .map(|core_index| core_index % total_cpu_cores)
454 .collect();
455
456 thread_pool_core_indices.push(CpuCoreSet {
457 cores: cpu_cores,
458 #[cfg(feature = "numa")]
459 topology: topology.clone(),
460 });
461 }
462 } else {
463 let all_cpu_cores = all_cpu_cores
467 .iter()
468 .flat_map(|cpu_core_set| cpu_core_set.cores.iter())
469 .copied()
470 .collect::<Vec<_>>();
471
472 thread_pool_core_indices = all_cpu_cores
473 .chunks(total_cpu_cores.div_ceil(thread_pools))
474 .map(|cpu_cores| CpuCoreSet {
475 cores: cpu_cores.to_vec(),
476 #[cfg(feature = "numa")]
477 topology: topology.clone(),
478 })
479 .collect();
480 }
481 thread_pool_core_indices
482 } else {
483 all_cpu_cores
485 }
486}
487
488fn create_plotting_thread_pool_manager_thread_pool_pair(
489 thread_prefix: &'static str,
490 thread_pool_index: usize,
491 cpu_core_set: CpuCoreSet,
492 thread_priority: Option<ThreadPriority>,
493) -> Result<ThreadPool, ThreadPoolBuildError> {
494 let thread_name =
495 move |thread_index| format!("{thread_prefix}-{thread_pool_index}.{thread_index}");
496 let panic_handler = move |panic_info| {
500 if let Some(index) = current_thread_index() {
501 eprintln!("panic on thread {}: {:?}", thread_name(index), panic_info);
502 } else {
503 eprintln!("rayon panic handler called on non-rayon thread: {panic_info:?}");
505 }
506 exit(1);
507 };
508
509 ThreadPoolBuilder::new()
510 .thread_name(thread_name)
511 .num_threads(cpu_core_set.cpu_cores().len())
512 .panic_handler(panic_handler)
513 .spawn_handler({
514 let handle = Handle::current();
515
516 rayon_custom_spawn_handler(move |thread| {
517 let cpu_core_set = cpu_core_set.clone();
518 let handle = handle.clone();
519
520 move || {
521 cpu_core_set.pin_current_thread();
522 if let Some(thread_priority) = thread_priority
523 && let Err(error) = set_current_thread_priority(thread_priority)
524 {
525 warn!(%error, "Failed to set thread priority");
526 }
527 drop(cpu_core_set);
528
529 let _guard = handle.enter();
530
531 task::block_in_place(|| thread.run())
532 }
533 })
534 })
535 .build()
536}
537
538pub fn create_plotting_thread_pool_manager<I>(
549 mut cpu_core_sets: I,
550 thread_priority: Option<ThreadPriority>,
551) -> Result<PlottingThreadPoolManager, ThreadPoolBuildError>
552where
553 I: ExactSizeIterator<Item = (CpuCoreSet, CpuCoreSet)>,
554{
555 let total_thread_pools = cpu_core_sets.len();
556
557 PlottingThreadPoolManager::new(
558 |thread_pool_index| {
559 let (plotting_cpu_core_set, replotting_cpu_core_set) = cpu_core_sets
560 .next()
561 .expect("Number of thread pools is the same as cpu core sets; qed");
562
563 Ok(PlottingThreadPoolPair {
564 plotting: create_plotting_thread_pool_manager_thread_pool_pair(
565 "plotting",
566 thread_pool_index,
567 plotting_cpu_core_set,
568 thread_priority,
569 )?,
570 replotting: create_plotting_thread_pool_manager_thread_pool_pair(
571 "replotting",
572 thread_pool_index,
573 replotting_cpu_core_set,
574 thread_priority,
575 )?,
576 })
577 },
578 NonZeroUsize::new(total_thread_pools)
579 .expect("Thread pool is guaranteed to be non-empty; qed"),
580 )
581}
582
583pub fn rayon_custom_spawn_handler<SpawnHandlerBuilder, SpawnHandler, SpawnHandlerResult>(
589 mut spawn_handler_builder: SpawnHandlerBuilder,
590) -> impl FnMut(ThreadBuilder) -> io::Result<()>
591where
592 SpawnHandlerBuilder: (FnMut(ThreadBuilder) -> SpawnHandler) + Clone,
593 SpawnHandler: (FnOnce() -> SpawnHandlerResult) + Send + 'static,
594 SpawnHandlerResult: Send + 'static,
595{
596 move |thread: ThreadBuilder| {
597 let mut b = thread::Builder::new();
598 if let Some(name) = thread.name() {
599 b = b.name(name.to_owned());
600 }
601 if let Some(stack_size) = thread.stack_size() {
602 b = b.stack_size(stack_size);
603 }
604
605 b.spawn(spawn_handler_builder(thread))?;
606 Ok(())
607 }
608}
609
610pub fn tokio_rayon_spawn_handler() -> impl FnMut(ThreadBuilder) -> io::Result<()> {
613 let handle = Handle::current();
614
615 rayon_custom_spawn_handler(move |thread| {
616 let handle = handle.clone();
617
618 move || {
619 let _guard = handle.enter();
620
621 task::block_in_place(|| thread.run())
622 }
623 })
624}