1#[cfg(test)]
4mod tests;
5
6use crate::thread_pool_manager::{PlottingThreadPoolManager, PlottingThreadPoolPair};
7use futures::channel::oneshot;
8use futures::channel::oneshot::Canceled;
9use futures::future::Either;
10use gdt_cpus::{AffinityMask, CpuInfo, ThreadPriority, set_thread_affinity, set_thread_priority};
11use rayon::{
12 ThreadBuilder, ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder, current_thread_index,
13};
14use std::collections::HashMap;
15use std::future::Future;
16use std::num::NonZeroUsize;
17use std::ops::Deref;
18use std::pin::{Pin, pin};
19use std::process::exit;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22use std::{fmt, io, iter, thread};
23use tokio::runtime::Handle;
24use tokio::task;
25use tracing::{debug, warn};
26
27const MAX_DEFAULT_FARMING_THREADS: usize = 32;
29
30#[derive(Debug)]
32pub struct AsyncJoinOnDrop<T> {
33 handle: Option<task::JoinHandle<T>>,
34 abort_on_drop: bool,
35}
36
37impl<T> Drop for AsyncJoinOnDrop<T> {
38 #[inline]
39 fn drop(&mut self) {
40 if let Some(handle) = self.handle.take() {
41 if self.abort_on_drop {
42 handle.abort();
43 }
44
45 if !handle.is_finished() {
46 task::block_in_place(move || {
47 let _: Result<_, _> = Handle::current().block_on(handle);
48 });
49 }
50 }
51 }
52}
53
54impl<T> AsyncJoinOnDrop<T> {
55 #[inline]
57 pub fn new(handle: task::JoinHandle<T>, abort_on_drop: bool) -> Self {
58 Self {
59 handle: Some(handle),
60 abort_on_drop,
61 }
62 }
63}
64
65impl<T> Future for AsyncJoinOnDrop<T> {
66 type Output = Result<T, task::JoinError>;
67
68 #[inline]
69 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
70 Pin::new(
71 self.handle
72 .as_mut()
73 .expect("Only dropped in Drop impl; qed"),
74 )
75 .poll(cx)
76 }
77}
78
79pub(crate) struct JoinOnDrop(Option<thread::JoinHandle<()>>);
81
82impl Drop for JoinOnDrop {
83 #[inline]
84 fn drop(&mut self) {
85 self.0
86 .take()
87 .expect("Always called exactly once; qed")
88 .join()
89 .expect("Panic if background thread panicked");
90 }
91}
92
93impl JoinOnDrop {
94 #[inline]
96 pub(crate) fn new(handle: thread::JoinHandle<()>) -> Self {
97 Self(Some(handle))
98 }
99}
100
101impl Deref for JoinOnDrop {
102 type Target = thread::JoinHandle<()>;
103
104 #[inline]
105 fn deref(&self) -> &Self::Target {
106 self.0.as_ref().expect("Only dropped in Drop impl; qed")
107 }
108}
109
110pub fn run_future_in_dedicated_thread<CreateFut, Fut, T>(
113 create_future: CreateFut,
114 thread_name: String,
115) -> io::Result<impl Future<Output = Result<T, Canceled>> + Send>
116where
117 CreateFut: (FnOnce() -> Fut) + Send + 'static,
118 Fut: Future<Output = T> + 'static,
119 T: Send + 'static,
120{
121 let (drop_tx, drop_rx) = oneshot::channel::<()>();
122 let (result_tx, result_rx) = oneshot::channel();
123 let handle = Handle::current();
124 let join_handle = thread::Builder::new().name(thread_name).spawn(move || {
125 let _tokio_handle_guard = handle.enter();
126
127 let future = pin!(create_future());
128
129 let result = match handle.block_on(futures::future::select(future, drop_rx)) {
130 Either::Left((result, _)) => result,
131 Either::Right(_) => {
132 return;
134 }
135 };
136 if let Err(_error) = result_tx.send(result) {
137 debug!(
138 thread_name = ?thread::current().name(),
139 "Future finished, but receiver was already dropped",
140 );
141 }
142 })?;
143 let join_on_drop = JoinOnDrop::new(join_handle);
145
146 Ok(async move {
147 let result = result_rx.await;
148 drop(drop_tx);
149 drop(join_on_drop);
150 result
151 })
152}
153
154#[derive(Clone)]
156pub struct CpuCoreSet {
157 affinity_mask: AffinityMask,
159 cpu_info: Option<Arc<CpuInfo>>,
160}
161
162impl fmt::Debug for CpuCoreSet {
163 #[inline]
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 f.debug_struct("CpuCoreSet")
166 .field("affinity_mask", &self.affinity_mask)
167 .finish_non_exhaustive()
168 }
169}
170
171impl CpuCoreSet {
172 pub fn cpu_cores(&self) -> &AffinityMask {
174 &self.affinity_mask
175 }
176
177 pub fn truncate(&mut self, num_cores: usize) {
182 let num_cores = num_cores.clamp(1, self.affinity_mask.count());
183
184 let Some(cpu_info) = &self.cpu_info else {
185 self.affinity_mask = self.affinity_mask.iter().take(num_cores).collect();
186 return;
187 };
188
189 let mut grouped_by_l2_cache_size_and_core_count = HashMap::<_, Vec<_>>::new();
190 for domain in &cpu_info.l2_domains {
191 let domain_mask = domain.mask.intersection(&self.affinity_mask);
192
193 if domain_mask.is_empty() {
194 continue;
195 }
196
197 grouped_by_l2_cache_size_and_core_count
198 .entry((domain.size_bytes, domain_mask.count()))
199 .or_default()
200 .extend(domain_mask.iter());
201 }
202
203 if grouped_by_l2_cache_size_and_core_count
205 .values()
206 .flatten()
207 .count()
208 != self.affinity_mask.count()
209 {
210 self.affinity_mask = self.affinity_mask.iter().take(num_cores).collect();
211 return;
212 }
213
214 self.affinity_mask = grouped_by_l2_cache_size_and_core_count
218 .into_values()
219 .flat_map(|cores| {
220 let limit = (cores.len() * num_cores).div_ceil(self.affinity_mask.count());
221 cores.into_iter().take(limit)
223 })
224 .take(num_cores)
225 .collect();
226 }
227
228 pub fn pin_current_thread(&self) {
230 if let Err(error) = set_thread_affinity(&self.affinity_mask) {
231 warn!(%error, cpu_cores = ?self.affinity_mask, "Failed to pin thread to CPU cores");
232 }
233 }
234}
235
236pub fn recommended_number_of_farming_threads() -> usize {
239 match CpuInfo::detect() {
240 Ok(cpu_info) => (0..cpu_info.numa_node_count)
241 .map(|numa_node| cpu_info.numa_node_mask(numa_node).count())
242 .find(|&count| count > 0)
243 .unwrap_or_else(num_cpus::get)
244 .min(MAX_DEFAULT_FARMING_THREADS),
245 Err(error) => {
246 warn!(%error, "Failed to get CPU info");
247
248 num_cpus::get().min(MAX_DEFAULT_FARMING_THREADS)
249 }
250 }
251}
252
253pub fn all_cpu_cores() -> Vec<CpuCoreSet> {
258 match CpuInfo::detect() {
259 Ok(cpu_info) => {
260 let cpu_info = Arc::new(cpu_info);
261 let cpu_cores = cpu_info
262 .l3_domains
263 .iter()
264 .map(|domain| domain.mask)
265 .filter(|affinity_mask| !affinity_mask.is_empty())
266 .map(|affinity_mask| CpuCoreSet {
267 affinity_mask,
268 cpu_info: Some(Arc::clone(&cpu_info)),
269 })
270 .collect::<Vec<_>>();
271
272 if cpu_cores.is_empty() {
273 vec![CpuCoreSet {
274 affinity_mask: (0..num_cpus::get()).collect(),
275 cpu_info: Some(cpu_info),
276 }]
277 } else {
278 cpu_cores
279 }
280 }
281 Err(error) => {
282 warn!(%error, "Failed to get L3 cache topology");
283
284 vec![CpuCoreSet {
285 affinity_mask: (0..num_cpus::get()).collect(),
286 cpu_info: None,
287 }]
288 }
289 }
290}
291
292pub fn parse_cpu_cores_sets(
295 s: &str,
296) -> Result<Vec<CpuCoreSet>, Box<dyn std::error::Error + Send + Sync>> {
297 let cpu_info = CpuInfo::detect().ok().map(Arc::new);
298
299 s.split(' ')
300 .map(|s| {
301 let mut cores = AffinityMask::empty();
302 for s in s.split(',') {
303 let mut parts = s.split('-');
304 let range_start = parts
305 .next()
306 .ok_or(
307 "Bad string format, must be comma separated list of CPU cores or ranges",
308 )?
309 .parse()?;
310
311 if let Some(range_end) = parts.next() {
312 let range_end = range_end.parse()?;
313
314 for core in range_start..=range_end {
317 cores.add(core);
318 }
319 } else {
320 cores.add(range_start);
321 }
322 }
323
324 Ok(CpuCoreSet {
325 affinity_mask: cores,
326 cpu_info: cpu_info.clone(),
327 })
328 })
329 .collect()
330}
331
332pub fn thread_pool_core_indices(
334 thread_pool_size: Option<NonZeroUsize>,
335 thread_pools: Option<NonZeroUsize>,
336) -> Vec<CpuCoreSet> {
337 thread_pool_core_indices_internal(all_cpu_cores(), thread_pool_size, thread_pools)
338}
339
340fn thread_pool_core_indices_internal(
341 all_cpu_cores: Vec<CpuCoreSet>,
342 thread_pool_size: Option<NonZeroUsize>,
343 thread_pools: Option<NonZeroUsize>,
344) -> Vec<CpuCoreSet> {
345 let cpu_info = &all_cpu_cores
346 .first()
347 .expect("Not empty according to function description; qed")
348 .cpu_info;
349
350 let thread_pools = thread_pools
353 .map(NonZeroUsize::get)
354 .or_else(|| thread_pool_size.map(|_| all_cpu_cores.len()));
355
356 if let Some(thread_pools) = thread_pools {
357 let all_cpu_cores = &all_cpu_cores;
358 let mut thread_pool_core_indices = Vec::<CpuCoreSet>::with_capacity(thread_pools);
359
360 let total_cpu_cores = all_cpu_cores.iter().flat_map(CpuCoreSet::cpu_cores).count();
361
362 if let Some(thread_pool_size) = thread_pool_size {
363 let mut cpu_cores_iterator = iter::repeat_with(|| {
366 all_cpu_cores
367 .iter()
368 .flat_map(|cpu_core_set| cpu_core_set.affinity_mask.iter())
369 })
370 .flatten();
371
372 for _ in 0..thread_pools {
373 let cpu_cores = cpu_cores_iterator
374 .by_ref()
375 .take(thread_pool_size.get())
376 .map(|core_index| core_index % total_cpu_cores)
379 .collect();
380
381 thread_pool_core_indices.push(CpuCoreSet {
382 affinity_mask: cpu_cores,
383 cpu_info: cpu_info.clone(),
384 });
385 }
386 } else {
387 let all_cpu_cores = all_cpu_cores
391 .iter()
392 .flat_map(|cpu_core_set| cpu_core_set.affinity_mask.iter())
393 .collect::<Vec<_>>();
394
395 thread_pool_core_indices = all_cpu_cores
396 .chunks(total_cpu_cores.div_ceil(thread_pools))
397 .map(|cpu_cores| CpuCoreSet {
398 affinity_mask: AffinityMask::from_cores(cpu_cores),
399 cpu_info: cpu_info.clone(),
400 })
401 .collect();
402 }
403 thread_pool_core_indices
404 } else {
405 all_cpu_cores
407 }
408}
409
410fn create_plotting_thread_pool_manager_thread_pool_pair(
411 thread_prefix: &'static str,
412 thread_pool_index: usize,
413 cpu_core_set: CpuCoreSet,
414 thread_priority: Option<ThreadPriority>,
415) -> Result<ThreadPool, ThreadPoolBuildError> {
416 let thread_name =
417 move |thread_index| format!("{thread_prefix}-{thread_pool_index}.{thread_index}");
418 let panic_handler = move |panic_info| {
422 if let Some(index) = current_thread_index() {
423 eprintln!("panic on thread {}: {:?}", thread_name(index), panic_info);
424 } else {
425 eprintln!("rayon panic handler called on non-rayon thread: {panic_info:?}");
427 }
428 exit(1);
429 };
430
431 ThreadPoolBuilder::new()
432 .thread_name(thread_name)
433 .num_threads(cpu_core_set.cpu_cores().count())
434 .panic_handler(panic_handler)
435 .spawn_handler({
436 let handle = Handle::current();
437
438 rayon_custom_spawn_handler(move |thread| {
439 let cpu_core_set = cpu_core_set.clone();
440 let handle = handle.clone();
441
442 move || {
443 cpu_core_set.pin_current_thread();
444 if let Some(thread_priority) = thread_priority
445 && let Err(error) = set_thread_priority(thread_priority)
446 {
447 warn!(%error, "Failed to set thread priority");
448 }
449 drop(cpu_core_set);
450
451 let _guard = handle.enter();
452
453 task::block_in_place(|| thread.run());
454 }
455 })
456 })
457 .build()
458}
459
460pub fn create_plotting_thread_pool_manager<I>(
471 mut cpu_core_sets: I,
472 thread_priority: Option<ThreadPriority>,
473) -> Result<PlottingThreadPoolManager, ThreadPoolBuildError>
474where
475 I: ExactSizeIterator<Item = (CpuCoreSet, CpuCoreSet)>,
476{
477 let total_thread_pools = cpu_core_sets.len();
478
479 PlottingThreadPoolManager::new(
480 |thread_pool_index| {
481 let (plotting_cpu_core_set, replotting_cpu_core_set) = cpu_core_sets
482 .next()
483 .expect("Number of thread pools is the same as cpu core sets; qed");
484
485 Ok(PlottingThreadPoolPair {
486 plotting: create_plotting_thread_pool_manager_thread_pool_pair(
487 "plotting",
488 thread_pool_index,
489 plotting_cpu_core_set,
490 thread_priority,
491 )?,
492 replotting: create_plotting_thread_pool_manager_thread_pool_pair(
493 "replotting",
494 thread_pool_index,
495 replotting_cpu_core_set,
496 thread_priority,
497 )?,
498 })
499 },
500 NonZeroUsize::new(total_thread_pools)
501 .expect("Thread pool is guaranteed to be non-empty; qed"),
502 )
503}
504
505pub fn rayon_custom_spawn_handler<SpawnHandlerBuilder, SpawnHandler, SpawnHandlerResult>(
511 mut spawn_handler_builder: SpawnHandlerBuilder,
512) -> impl FnMut(ThreadBuilder) -> io::Result<()>
513where
514 SpawnHandlerBuilder: (FnMut(ThreadBuilder) -> SpawnHandler) + Clone,
515 SpawnHandler: (FnOnce() -> SpawnHandlerResult) + Send + 'static,
516 SpawnHandlerResult: Send + 'static,
517{
518 move |thread: ThreadBuilder| {
519 let mut b = thread::Builder::new();
520 if let Some(name) = thread.name() {
521 b = b.name(name.to_owned());
522 }
523 if let Some(stack_size) = thread.stack_size() {
524 b = b.stack_size(stack_size);
525 }
526
527 b.spawn(spawn_handler_builder(thread))?;
528 Ok(())
529 }
530}
531
532pub fn tokio_rayon_spawn_handler() -> impl FnMut(ThreadBuilder) -> io::Result<()> {
535 let handle = Handle::current();
536
537 rayon_custom_spawn_handler(move |thread| {
538 let handle = handle.clone();
539
540 move || {
541 let _guard = handle.enter();
542
543 task::block_in_place(|| thread.run());
544 }
545 })
546}