1#[cfg(test)]
4mod tests;
5
6use crate::thread_pool_manager::{PlottingThreadPoolManager, PlottingThreadPoolPair};
7use futures::channel::oneshot;
8use futures::channel::oneshot::Canceled;
9use futures::future::Either;
10use rayon::{
11 ThreadBuilder, ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder, current_thread_index,
12};
13use std::future::Future;
14use std::num::NonZeroUsize;
15use std::ops::Deref;
16use std::pin::{Pin, pin};
17use std::process::exit;
18use std::task::{Context, Poll};
19use std::{fmt, io, iter, thread};
20use thread_priority::{ThreadPriority, set_current_thread_priority};
21use tokio::runtime::Handle;
22use tokio::task;
23use tracing::{debug, warn};
24
25const MAX_DEFAULT_FARMING_THREADS: usize = 32;
27
28#[derive(Debug)]
30pub struct AsyncJoinOnDrop<T> {
31 handle: Option<task::JoinHandle<T>>,
32 abort_on_drop: bool,
33}
34
35impl<T> Drop for AsyncJoinOnDrop<T> {
36 #[inline]
37 fn drop(&mut self) {
38 if let Some(handle) = self.handle.take() {
39 if self.abort_on_drop {
40 handle.abort();
41 }
42
43 if !handle.is_finished() {
44 task::block_in_place(move || {
45 let _ = Handle::current().block_on(handle);
46 });
47 }
48 }
49 }
50}
51
52impl<T> AsyncJoinOnDrop<T> {
53 #[inline]
55 pub fn new(handle: task::JoinHandle<T>, abort_on_drop: bool) -> Self {
56 Self {
57 handle: Some(handle),
58 abort_on_drop,
59 }
60 }
61}
62
63impl<T> Future for AsyncJoinOnDrop<T> {
64 type Output = Result<T, task::JoinError>;
65
66 #[inline]
67 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
68 Pin::new(
69 self.handle
70 .as_mut()
71 .expect("Only dropped in Drop impl; qed"),
72 )
73 .poll(cx)
74 }
75}
76
77pub(crate) struct JoinOnDrop(Option<thread::JoinHandle<()>>);
79
80impl Drop for JoinOnDrop {
81 #[inline]
82 fn drop(&mut self) {
83 self.0
84 .take()
85 .expect("Always called exactly once; qed")
86 .join()
87 .expect("Panic if background thread panicked");
88 }
89}
90
91impl JoinOnDrop {
92 #[inline]
94 pub(crate) fn new(handle: thread::JoinHandle<()>) -> Self {
95 Self(Some(handle))
96 }
97}
98
99impl Deref for JoinOnDrop {
100 type Target = thread::JoinHandle<()>;
101
102 #[inline]
103 fn deref(&self) -> &Self::Target {
104 self.0.as_ref().expect("Only dropped in Drop impl; qed")
105 }
106}
107
108pub fn run_future_in_dedicated_thread<CreateFut, Fut, T>(
111 create_future: CreateFut,
112 thread_name: String,
113) -> io::Result<impl Future<Output = Result<T, Canceled>> + Send>
114where
115 CreateFut: (FnOnce() -> Fut) + Send + 'static,
116 Fut: Future<Output = T> + 'static,
117 T: Send + 'static,
118{
119 let (drop_tx, drop_rx) = oneshot::channel::<()>();
120 let (result_tx, result_rx) = oneshot::channel();
121 let handle = Handle::current();
122 let join_handle = thread::Builder::new().name(thread_name).spawn(move || {
123 let _tokio_handle_guard = handle.enter();
124
125 let future = pin!(create_future());
126
127 let result = match handle.block_on(futures::future::select(future, drop_rx)) {
128 Either::Left((result, _)) => result,
129 Either::Right(_) => {
130 return;
132 }
133 };
134 if let Err(_error) = result_tx.send(result) {
135 debug!(
136 thread_name = ?thread::current().name(),
137 "Future finished, but receiver was already dropped",
138 );
139 }
140 })?;
141 let join_on_drop = JoinOnDrop::new(join_handle);
143
144 Ok(async move {
145 let result = result_rx.await;
146 drop(drop_tx);
147 drop(join_on_drop);
148 result
149 })
150}
151
152#[derive(Clone)]
154pub struct CpuCoreSet {
155 cores: Vec<usize>,
157 #[cfg(feature = "numa")]
158 topology: Option<std::sync::Arc<hwlocality::Topology>>,
159}
160
161impl fmt::Debug for CpuCoreSet {
162 #[inline]
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 let mut s = f.debug_struct("CpuCoreSet");
165 #[cfg(not(feature = "numa"))]
166 if self.cores.array_windows::<2>().all(|&[a, b]| a + 1 == b) {
167 s.field(
168 "cores",
169 &format!(
170 "{}-{}",
171 self.cores.first().expect("List of cores is not empty; qed"),
172 self.cores.last().expect("List of cores is not empty; qed")
173 ),
174 );
175 } else {
176 s.field(
177 "cores",
178 &self
179 .cores
180 .iter()
181 .map(usize::to_string)
182 .collect::<Vec<_>>()
183 .join(","),
184 );
185 }
186 #[cfg(feature = "numa")]
187 {
188 use hwlocality::cpu::cpuset::CpuSet;
189 use hwlocality::ffi::PositiveInt;
190
191 s.field(
192 "cores",
193 &CpuSet::from_iter(
194 self.cores.iter().map(|&core| {
195 PositiveInt::try_from(core).expect("Valid CPU core index; qed")
196 }),
197 ),
198 );
199 }
200 s.finish_non_exhaustive()
201 }
202}
203
204impl CpuCoreSet {
205 pub fn cpu_cores(&self) -> &[usize] {
207 &self.cores
208 }
209
210 pub fn truncate(&mut self, num_cores: usize) {
217 let num_cores = num_cores.clamp(1, self.cores.len());
218
219 #[cfg(feature = "numa")]
220 if let Some(topology) = &self.topology {
221 use hwlocality::object::attributes::ObjectAttributes;
222 use hwlocality::object::types::ObjectType;
223
224 let mut grouped_by_l2_cache_size_and_core_count =
225 std::collections::HashMap::<(usize, usize), Vec<usize>>::new();
226 topology
227 .objects_with_type(ObjectType::L2Cache)
228 .for_each(|object| {
229 let l2_cache_size =
230 if let Some(ObjectAttributes::Cache(cache)) = object.attributes() {
231 cache
232 .size()
233 .map(|size| size.get() as usize)
234 .unwrap_or_default()
235 } else {
236 0
237 };
238 if let Some(cpuset) = object.complete_cpuset() {
239 let cpuset = cpuset
240 .into_iter()
241 .map(usize::from)
242 .filter(|core| self.cores.contains(core))
243 .collect::<Vec<_>>();
244 let cpuset_len = cpuset.len();
245
246 if !cpuset.is_empty() {
247 grouped_by_l2_cache_size_and_core_count
248 .entry((l2_cache_size, cpuset_len))
249 .or_default()
250 .extend(cpuset);
251 }
252 }
253 });
254
255 if grouped_by_l2_cache_size_and_core_count
257 .values()
258 .flatten()
259 .count()
260 == self.cores.len()
261 {
262 self.cores = grouped_by_l2_cache_size_and_core_count
266 .into_values()
267 .flat_map(|cores| {
268 let limit = cores.len() * num_cores / self.cores.len();
269 cores.into_iter().take(limit.max(1))
271 })
272 .collect();
273
274 self.cores.sort();
275
276 return;
277 }
278 }
279 self.cores.truncate(num_cores);
280 }
281
282 pub fn pin_current_thread(&self) {
284 #[cfg(feature = "numa")]
285 if let Some(topology) = &self.topology {
286 use hwlocality::cpu::binding::CpuBindingFlags;
287 use hwlocality::cpu::cpuset::CpuSet;
288 use hwlocality::current_thread_id;
289 use hwlocality::ffi::PositiveInt;
290
291 let cpu_cores = CpuSet::from_iter(
293 self.cores
294 .iter()
295 .map(|&core| PositiveInt::try_from(core).expect("Valid CPU core index; qed")),
296 );
297
298 if let Err(error) =
299 topology.bind_thread_cpu(current_thread_id(), &cpu_cores, CpuBindingFlags::empty())
300 {
301 warn!(%error, ?cpu_cores, "Failed to pin thread to CPU cores")
302 }
303 }
304 }
305}
306
307pub fn recommended_number_of_farming_threads() -> usize {
310 #[cfg(feature = "numa")]
311 match hwlocality::Topology::new().map(std::sync::Arc::new) {
312 Ok(topology) => {
313 return topology
314 .objects_at_depth(hwlocality::object::depth::Depth::NUMANode)
316 .filter_map(|node| node.cpuset())
318 .map(|cpuset| cpuset.iter_set().count())
320 .find(|&count| count > 0)
321 .unwrap_or_else(num_cpus::get)
322 .min(MAX_DEFAULT_FARMING_THREADS);
323 }
324 Err(error) => {
325 warn!(%error, "Failed to get NUMA topology");
326 }
327 }
328 num_cpus::get().min(MAX_DEFAULT_FARMING_THREADS)
329}
330
331pub fn all_cpu_cores() -> Vec<CpuCoreSet> {
336 #[cfg(feature = "numa")]
337 match hwlocality::Topology::new().map(std::sync::Arc::new) {
338 Ok(topology) => {
339 let cpu_cores = topology
340 .objects_with_type(hwlocality::object::types::ObjectType::L3Cache)
342 .filter_map(|node| node.cpuset())
344 .map(|cpuset| cpuset.iter_set().map(usize::from).collect::<Vec<_>>())
346 .filter(|cores| !cores.is_empty())
347 .map(|cores| CpuCoreSet {
348 cores,
349 topology: Some(std::sync::Arc::clone(&topology)),
350 })
351 .collect::<Vec<_>>();
352
353 if !cpu_cores.is_empty() {
354 return cpu_cores;
355 }
356 }
357 Err(error) => {
358 warn!(%error, "Failed to get L3 cache topology");
359 }
360 }
361 vec![CpuCoreSet {
362 cores: (0..num_cpus::get()).collect(),
363 #[cfg(feature = "numa")]
364 topology: None,
365 }]
366}
367
368pub fn parse_cpu_cores_sets(
371 s: &str,
372) -> Result<Vec<CpuCoreSet>, Box<dyn std::error::Error + Send + Sync>> {
373 #[cfg(feature = "numa")]
374 let topology = hwlocality::Topology::new().map(std::sync::Arc::new).ok();
375
376 s.split(' ')
377 .map(|s| {
378 let mut cores = Vec::new();
379 for s in s.split(',') {
380 let mut parts = s.split('-');
381 let range_start = parts
382 .next()
383 .ok_or(
384 "Bad string format, must be comma separated list of CPU cores or ranges",
385 )?
386 .parse()?;
387
388 if let Some(range_end) = parts.next() {
389 let range_end = range_end.parse()?;
390
391 cores.extend(range_start..=range_end);
392 } else {
393 cores.push(range_start);
394 }
395 }
396
397 Ok(CpuCoreSet {
398 cores,
399 #[cfg(feature = "numa")]
400 topology: topology.clone(),
401 })
402 })
403 .collect()
404}
405
406pub fn thread_pool_core_indices(
408 thread_pool_size: Option<NonZeroUsize>,
409 thread_pools: Option<NonZeroUsize>,
410) -> Vec<CpuCoreSet> {
411 thread_pool_core_indices_internal(all_cpu_cores(), thread_pool_size, thread_pools)
412}
413
414fn thread_pool_core_indices_internal(
415 all_cpu_cores: Vec<CpuCoreSet>,
416 thread_pool_size: Option<NonZeroUsize>,
417 thread_pools: Option<NonZeroUsize>,
418) -> Vec<CpuCoreSet> {
419 #[cfg(feature = "numa")]
420 let topology = &all_cpu_cores
421 .first()
422 .expect("Not empty according to function description; qed")
423 .topology;
424
425 let thread_pools = thread_pools
428 .map(|thread_pools| thread_pools.get())
429 .or_else(|| thread_pool_size.map(|_| all_cpu_cores.len()));
430
431 if let Some(thread_pools) = thread_pools {
432 let mut thread_pool_core_indices = Vec::<CpuCoreSet>::with_capacity(thread_pools);
433
434 let total_cpu_cores = all_cpu_cores.iter().flat_map(|set| set.cpu_cores()).count();
435
436 if let Some(thread_pool_size) = thread_pool_size {
437 let mut cpu_cores_iterator = iter::repeat(
440 all_cpu_cores
441 .iter()
442 .flat_map(|cpu_core_set| cpu_core_set.cores.iter())
443 .copied(),
444 )
445 .flatten();
446
447 for _ in 0..thread_pools {
448 let cpu_cores = cpu_cores_iterator
449 .by_ref()
450 .take(thread_pool_size.get())
451 .map(|core_index| core_index % total_cpu_cores)
454 .collect();
455
456 thread_pool_core_indices.push(CpuCoreSet {
457 cores: cpu_cores,
458 #[cfg(feature = "numa")]
459 topology: topology.clone(),
460 });
461 }
462 } else {
463 let all_cpu_cores = all_cpu_cores
466 .iter()
467 .flat_map(|cpu_core_set| cpu_core_set.cores.iter())
468 .copied()
469 .collect::<Vec<_>>();
470
471 thread_pool_core_indices = all_cpu_cores
472 .chunks(total_cpu_cores.div_ceil(thread_pools))
473 .map(|cpu_cores| CpuCoreSet {
474 cores: cpu_cores.to_vec(),
475 #[cfg(feature = "numa")]
476 topology: topology.clone(),
477 })
478 .collect();
479 }
480 thread_pool_core_indices
481 } else {
482 all_cpu_cores
484 }
485}
486
487fn create_plotting_thread_pool_manager_thread_pool_pair(
488 thread_prefix: &'static str,
489 thread_pool_index: usize,
490 cpu_core_set: CpuCoreSet,
491 thread_priority: Option<ThreadPriority>,
492) -> Result<ThreadPool, ThreadPoolBuildError> {
493 let thread_name =
494 move |thread_index| format!("{thread_prefix}-{thread_pool_index}.{thread_index}");
495 let panic_handler = move |panic_info| {
499 if let Some(index) = current_thread_index() {
500 eprintln!("panic on thread {}: {:?}", thread_name(index), panic_info);
501 } else {
502 eprintln!("rayon panic handler called on non-rayon thread: {panic_info:?}");
504 }
505 exit(1);
506 };
507
508 ThreadPoolBuilder::new()
509 .thread_name(thread_name)
510 .num_threads(cpu_core_set.cpu_cores().len())
511 .panic_handler(panic_handler)
512 .spawn_handler({
513 let handle = Handle::current();
514
515 rayon_custom_spawn_handler(move |thread| {
516 let cpu_core_set = cpu_core_set.clone();
517 let handle = handle.clone();
518
519 move || {
520 cpu_core_set.pin_current_thread();
521 if let Some(thread_priority) = thread_priority
522 && let Err(error) = set_current_thread_priority(thread_priority)
523 {
524 warn!(%error, "Failed to set thread priority");
525 }
526 drop(cpu_core_set);
527
528 let _guard = handle.enter();
529
530 task::block_in_place(|| thread.run())
531 }
532 })
533 })
534 .build()
535}
536
537pub fn create_plotting_thread_pool_manager<I>(
547 mut cpu_core_sets: I,
548 thread_priority: Option<ThreadPriority>,
549) -> Result<PlottingThreadPoolManager, ThreadPoolBuildError>
550where
551 I: ExactSizeIterator<Item = (CpuCoreSet, CpuCoreSet)>,
552{
553 let total_thread_pools = cpu_core_sets.len();
554
555 PlottingThreadPoolManager::new(
556 |thread_pool_index| {
557 let (plotting_cpu_core_set, replotting_cpu_core_set) = cpu_core_sets
558 .next()
559 .expect("Number of thread pools is the same as cpu core sets; qed");
560
561 Ok(PlottingThreadPoolPair {
562 plotting: create_plotting_thread_pool_manager_thread_pool_pair(
563 "plotting",
564 thread_pool_index,
565 plotting_cpu_core_set,
566 thread_priority,
567 )?,
568 replotting: create_plotting_thread_pool_manager_thread_pool_pair(
569 "replotting",
570 thread_pool_index,
571 replotting_cpu_core_set,
572 thread_priority,
573 )?,
574 })
575 },
576 NonZeroUsize::new(total_thread_pools)
577 .expect("Thread pool is guaranteed to be non-empty; qed"),
578 )
579}
580
581pub fn rayon_custom_spawn_handler<SpawnHandlerBuilder, SpawnHandler, SpawnHandlerResult>(
587 mut spawn_handler_builder: SpawnHandlerBuilder,
588) -> impl FnMut(ThreadBuilder) -> io::Result<()>
589where
590 SpawnHandlerBuilder: (FnMut(ThreadBuilder) -> SpawnHandler) + Clone,
591 SpawnHandler: (FnOnce() -> SpawnHandlerResult) + Send + 'static,
592 SpawnHandlerResult: Send + 'static,
593{
594 move |thread: ThreadBuilder| {
595 let mut b = thread::Builder::new();
596 if let Some(name) = thread.name() {
597 b = b.name(name.to_owned());
598 }
599 if let Some(stack_size) = thread.stack_size() {
600 b = b.stack_size(stack_size);
601 }
602
603 b.spawn(spawn_handler_builder(thread))?;
604 Ok(())
605 }
606}
607
608pub fn tokio_rayon_spawn_handler() -> impl FnMut(ThreadBuilder) -> io::Result<()> {
611 let handle = Handle::current();
612
613 rayon_custom_spawn_handler(move |thread| {
614 let handle = handle.clone();
615
616 move || {
617 let _guard = handle.enter();
618
619 task::block_in_place(|| thread.run())
620 }
621 })
622}