ab_farmer/plotter/
cpu.rs

1//! CPU plotter
2
3pub mod metrics;
4
5use crate::plotter::cpu::metrics::CpuPlotterMetrics;
6use crate::plotter::{Plotter, SectorPlottingProgress};
7use crate::thread_pool_manager::PlottingThreadPoolManager;
8use crate::utils::AsyncJoinOnDrop;
9use ab_core_primitives::ed25519::Ed25519PublicKey;
10use ab_core_primitives::sectors::SectorIndex;
11use ab_data_retrieval::piece_getter::PieceGetter;
12use ab_erasure_coding::ErasureCoding;
13use ab_farmer_components::FarmerProtocolInfo;
14use ab_farmer_components::plotting::{
15    CpuRecordsEncoder, DownloadSectorOptions, EncodeSectorOptions, PlottingError, download_sector,
16    encode_sector, write_sector,
17};
18use ab_proof_of_space::Table;
19use async_lock::{Mutex as AsyncMutex, Semaphore, SemaphoreGuardArc};
20use async_trait::async_trait;
21use bytes::Bytes;
22use event_listener_primitives::{Bag, HandlerId};
23use futures::channel::mpsc;
24use futures::stream::FuturesUnordered;
25use futures::{FutureExt, Sink, SinkExt, StreamExt, select, stream};
26use prometheus_client::registry::Registry;
27use std::any::type_name;
28use std::error::Error;
29use std::fmt;
30use std::future::pending;
31use std::marker::PhantomData;
32use std::num::NonZeroUsize;
33use std::pin::pin;
34use std::sync::Arc;
35use std::sync::atomic::{AtomicBool, Ordering};
36use std::task::Poll;
37use std::time::Instant;
38use tokio::task::yield_now;
39use tracing::{Instrument, warn};
40
41/// Type alias used for event handlers
42pub type HandlerFn3<A, B, C> = Arc<dyn Fn(&A, &B, &C) + Send + Sync + 'static>;
43type Handler3<A, B, C> = Bag<HandlerFn3<A, B, C>, A, B, C>;
44
45#[derive(Default, Debug)]
46struct Handlers {
47    plotting_progress: Handler3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
48}
49
50/// CPU plotter
51pub struct CpuPlotter<PG, PosTable> {
52    piece_getter: PG,
53    downloading_semaphore: Arc<Semaphore>,
54    plotting_thread_pool_manager: PlottingThreadPoolManager,
55    record_encoding_concurrency: NonZeroUsize,
56    global_mutex: Arc<AsyncMutex<()>>,
57    erasure_coding: ErasureCoding,
58    handlers: Arc<Handlers>,
59    tasks_sender: mpsc::Sender<AsyncJoinOnDrop<()>>,
60    _background_tasks: AsyncJoinOnDrop<()>,
61    abort_early: Arc<AtomicBool>,
62    metrics: Option<Arc<CpuPlotterMetrics>>,
63    _phantom: PhantomData<PosTable>,
64}
65
66impl<PG, PosTable> fmt::Debug for CpuPlotter<PG, PosTable> {
67    #[inline]
68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69        f.debug_struct("CpuPlotter").finish_non_exhaustive()
70    }
71}
72
73impl<PG, PosTable> Drop for CpuPlotter<PG, PosTable> {
74    #[inline]
75    fn drop(&mut self) {
76        self.abort_early.store(true, Ordering::Release);
77        self.tasks_sender.close_channel();
78    }
79}
80
81#[async_trait]
82impl<PG, PosTable> Plotter for CpuPlotter<PG, PosTable>
83where
84    PG: PieceGetter + Clone + Send + Sync + 'static,
85    PosTable: Table,
86{
87    async fn has_free_capacity(&self) -> Result<bool, String> {
88        Ok(self.downloading_semaphore.try_acquire().is_some())
89    }
90
91    async fn plot_sector(
92        &self,
93        public_key: Ed25519PublicKey,
94        sector_index: SectorIndex,
95        farmer_protocol_info: FarmerProtocolInfo,
96        pieces_in_sector: u16,
97        replotting: bool,
98        progress_sender: mpsc::Sender<SectorPlottingProgress>,
99    ) {
100        let start = Instant::now();
101
102        // Done outside the future below as a backpressure, ensuring that it is not possible to
103        // schedule unbounded number of plotting tasks
104        let downloading_permit = self.downloading_semaphore.acquire_arc().await;
105
106        self.plot_sector_internal(
107            start,
108            downloading_permit,
109            public_key,
110            sector_index,
111            farmer_protocol_info,
112            pieces_in_sector,
113            replotting,
114            progress_sender,
115        )
116        .await
117    }
118
119    async fn try_plot_sector(
120        &self,
121        public_key: Ed25519PublicKey,
122        sector_index: SectorIndex,
123        farmer_protocol_info: FarmerProtocolInfo,
124        pieces_in_sector: u16,
125        replotting: bool,
126        progress_sender: mpsc::Sender<SectorPlottingProgress>,
127    ) -> bool {
128        let start = Instant::now();
129
130        let Some(downloading_permit) = self.downloading_semaphore.try_acquire_arc() else {
131            return false;
132        };
133
134        self.plot_sector_internal(
135            start,
136            downloading_permit,
137            public_key,
138            sector_index,
139            farmer_protocol_info,
140            pieces_in_sector,
141            replotting,
142            progress_sender,
143        )
144        .await;
145
146        true
147    }
148}
149
150impl<PG, PosTable> CpuPlotter<PG, PosTable>
151where
152    PG: PieceGetter + Clone + Send + Sync + 'static,
153    PosTable: Table,
154{
155    /// Create a new instance
156    #[allow(clippy::too_many_arguments)]
157    pub fn new(
158        piece_getter: PG,
159        downloading_semaphore: Arc<Semaphore>,
160        plotting_thread_pool_manager: PlottingThreadPoolManager,
161        record_encoding_concurrency: NonZeroUsize,
162        global_mutex: Arc<AsyncMutex<()>>,
163        erasure_coding: ErasureCoding,
164        registry: Option<&mut Registry>,
165    ) -> Self {
166        let (tasks_sender, mut tasks_receiver) = mpsc::channel(1);
167
168        // Basically runs plotting tasks in the background and allows to abort on drop
169        let background_tasks = AsyncJoinOnDrop::new(
170            tokio::spawn(async move {
171                let background_tasks = FuturesUnordered::new();
172                let mut background_tasks = pin!(background_tasks);
173                // Just so that `FuturesUnordered` will never end
174                background_tasks.push(AsyncJoinOnDrop::new(tokio::spawn(pending::<()>()), true));
175
176                loop {
177                    select! {
178                        maybe_background_task = tasks_receiver.next().fuse() => {
179                            let Some(background_task) = maybe_background_task else {
180                                break;
181                            };
182
183                            background_tasks.push(background_task);
184                        },
185                        _ = background_tasks.select_next_some() => {
186                            // Nothing to do
187                        }
188                    }
189                }
190            }),
191            true,
192        );
193
194        let abort_early = Arc::new(AtomicBool::new(false));
195        let metrics = registry.map(|registry| {
196            Arc::new(CpuPlotterMetrics::new(
197                registry,
198                type_name::<PosTable>(),
199                plotting_thread_pool_manager.thread_pool_pairs(),
200            ))
201        });
202
203        Self {
204            piece_getter,
205            downloading_semaphore,
206            plotting_thread_pool_manager,
207            record_encoding_concurrency,
208            global_mutex,
209            erasure_coding,
210            handlers: Arc::default(),
211            tasks_sender,
212            _background_tasks: background_tasks,
213            abort_early,
214            metrics,
215            _phantom: PhantomData,
216        }
217    }
218
219    /// Subscribe to plotting progress notifications
220    pub fn on_plotting_progress(
221        &self,
222        callback: HandlerFn3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
223    ) -> HandlerId {
224        self.handlers.plotting_progress.add(callback)
225    }
226
227    #[allow(clippy::too_many_arguments)]
228    async fn plot_sector_internal<PS>(
229        &self,
230        start: Instant,
231        downloading_permit: SemaphoreGuardArc,
232        public_key: Ed25519PublicKey,
233        sector_index: SectorIndex,
234        farmer_protocol_info: FarmerProtocolInfo,
235        pieces_in_sector: u16,
236        replotting: bool,
237        mut progress_sender: PS,
238    ) where
239        PS: Sink<SectorPlottingProgress> + Unpin + Send + 'static,
240        PS::Error: Error,
241    {
242        if let Some(metrics) = &self.metrics {
243            metrics.sector_plotting.inc();
244        }
245
246        let progress_updater = ProgressUpdater {
247            public_key,
248            sector_index,
249            handlers: Arc::clone(&self.handlers),
250            metrics: self.metrics.clone(),
251        };
252
253        let plotting_fut = {
254            let piece_getter = self.piece_getter.clone();
255            let plotting_thread_pool_manager = self.plotting_thread_pool_manager.clone();
256            let record_encoding_concurrency = self.record_encoding_concurrency;
257            let global_mutex = Arc::clone(&self.global_mutex);
258            let erasure_coding = self.erasure_coding.clone();
259            let abort_early = Arc::clone(&self.abort_early);
260            let metrics = self.metrics.clone();
261
262            async move {
263                // Downloading
264                let downloaded_sector = {
265                    if !progress_updater
266                        .update_progress_and_events(
267                            &mut progress_sender,
268                            SectorPlottingProgress::Downloading,
269                        )
270                        .await
271                    {
272                        return;
273                    }
274
275                    // Take mutex briefly to make sure plotting is allowed right now
276                    global_mutex.lock().await;
277
278                    let downloading_start = Instant::now();
279                    let public_key_hash = &public_key.hash();
280
281                    let downloaded_sector_fut = download_sector(DownloadSectorOptions {
282                        public_key_hash,
283                        sector_index,
284                        piece_getter: &piece_getter,
285                        farmer_protocol_info,
286                        erasure_coding: &erasure_coding,
287                        pieces_in_sector,
288                    });
289
290                    let downloaded_sector = match downloaded_sector_fut.await {
291                        Ok(downloaded_sector) => downloaded_sector,
292                        Err(error) => {
293                            warn!(%error, "Failed to download sector");
294
295                            progress_updater
296                                .update_progress_and_events(
297                                    &mut progress_sender,
298                                    SectorPlottingProgress::Error {
299                                        error: format!("Failed to download sector: {error}"),
300                                    },
301                                )
302                                .await;
303
304                            return;
305                        }
306                    };
307
308                    if !progress_updater
309                        .update_progress_and_events(
310                            &mut progress_sender,
311                            SectorPlottingProgress::Downloaded(downloading_start.elapsed()),
312                        )
313                        .await
314                    {
315                        return;
316                    }
317
318                    downloaded_sector
319                };
320
321                // Plotting
322                let (sector, plotted_sector) = {
323                    let thread_pools = plotting_thread_pool_manager.get_thread_pools().await;
324                    if let Some(metrics) = &metrics {
325                        metrics.plotting_capacity_used.inc();
326                    }
327
328                    // Give a chance to interrupt plotting if necessary
329                    yield_now().await;
330
331                    if !progress_updater
332                        .update_progress_and_events(
333                            &mut progress_sender,
334                            SectorPlottingProgress::Encoding,
335                        )
336                        .await
337                    {
338                        if let Some(metrics) = &metrics {
339                            metrics.plotting_capacity_used.dec();
340                        }
341                        return;
342                    }
343
344                    let encoding_start = Instant::now();
345
346                    let plotting_result = tokio::task::block_in_place(move || {
347                        let thread_pool = if replotting {
348                            &thread_pools.replotting
349                        } else {
350                            &thread_pools.plotting
351                        };
352
353                        let encoded_sector = thread_pool.install(|| {
354                            // TODO: Reuse global table generator (this comment is in many files)
355                            let generator = PosTable::generator();
356                            let generators = (0..record_encoding_concurrency.get())
357                                .map(|_| generator.clone())
358                                .collect::<Vec<_>>();
359                            let mut records_encoder = CpuRecordsEncoder::<PosTable>::new(
360                                &generators,
361                                &erasure_coding,
362                                &global_mutex,
363                            );
364
365                            encode_sector(
366                                downloaded_sector,
367                                EncodeSectorOptions {
368                                    sector_index,
369                                    records_encoder: &mut records_encoder,
370                                    abort_early: &abort_early,
371                                },
372                            )
373                        })?;
374
375                        if abort_early.load(Ordering::Acquire) {
376                            return Err(PlottingError::AbortEarly);
377                        }
378
379                        drop(thread_pools);
380
381                        let mut sector = Vec::new();
382
383                        write_sector(&encoded_sector, &mut sector)?;
384
385                        Ok((sector, encoded_sector.plotted_sector))
386                    });
387
388                    if let Some(metrics) = &metrics {
389                        metrics.plotting_capacity_used.dec();
390                    }
391
392                    match plotting_result {
393                        Ok(plotting_result) => {
394                            if !progress_updater
395                                .update_progress_and_events(
396                                    &mut progress_sender,
397                                    SectorPlottingProgress::Encoded(encoding_start.elapsed()),
398                                )
399                                .await
400                            {
401                                return;
402                            }
403
404                            plotting_result
405                        }
406                        Err(PlottingError::AbortEarly) => {
407                            return;
408                        }
409                        Err(error) => {
410                            progress_updater
411                                .update_progress_and_events(
412                                    &mut progress_sender,
413                                    SectorPlottingProgress::Error {
414                                        error: format!("Failed to encode sector: {error}"),
415                                    },
416                                )
417                                .await;
418
419                            return;
420                        }
421                    }
422                };
423
424                progress_updater
425                    .update_progress_and_events(
426                        &mut progress_sender,
427                        SectorPlottingProgress::Finished {
428                            plotted_sector,
429                            time: start.elapsed(),
430                            sector: Box::pin({
431                                let mut sector = Some(Ok(Bytes::from(sector)));
432
433                                stream::poll_fn(move |_cx| {
434                                    // Just so that permit is dropped with stream itself
435                                    let _downloading_permit = &downloading_permit;
436
437                                    Poll::Ready(sector.take())
438                                })
439                            }),
440                        },
441                    )
442                    .await;
443            }
444        };
445
446        // Spawn a separate task such that `block_in_place` inside will not affect anything else
447        let plotting_task =
448            AsyncJoinOnDrop::new(tokio::spawn(plotting_fut.in_current_span()), true);
449        if let Err(error) = self.tasks_sender.clone().send(plotting_task).await {
450            warn!(%error, "Failed to send plotting task");
451
452            let progress = SectorPlottingProgress::Error {
453                error: format!("Failed to send plotting task: {error}"),
454            };
455
456            self.handlers
457                .plotting_progress
458                .call_simple(&public_key, &sector_index, &progress);
459        }
460    }
461}
462
463struct ProgressUpdater {
464    public_key: Ed25519PublicKey,
465    sector_index: SectorIndex,
466    handlers: Arc<Handlers>,
467    metrics: Option<Arc<CpuPlotterMetrics>>,
468}
469
470impl ProgressUpdater {
471    /// Returns `true` on success and `false` if progress receiver channel is gone
472    async fn update_progress_and_events<PS>(
473        &self,
474        progress_sender: &mut PS,
475        progress: SectorPlottingProgress,
476    ) -> bool
477    where
478        PS: Sink<SectorPlottingProgress> + Unpin,
479        PS::Error: Error,
480    {
481        if let Some(metrics) = &self.metrics {
482            match &progress {
483                SectorPlottingProgress::Downloading => {
484                    metrics.sector_downloading.inc();
485                }
486                SectorPlottingProgress::Downloaded(time) => {
487                    metrics.sector_downloading_time.observe(time.as_secs_f64());
488                    metrics.sector_downloaded.inc();
489                }
490                SectorPlottingProgress::Encoding => {
491                    metrics.sector_encoding.inc();
492                }
493                SectorPlottingProgress::Encoded(time) => {
494                    metrics.sector_encoding_time.observe(time.as_secs_f64());
495                    metrics.sector_encoded.inc();
496                }
497                SectorPlottingProgress::Finished { time, .. } => {
498                    metrics.sector_plotting_time.observe(time.as_secs_f64());
499                    metrics.sector_plotted.inc();
500                }
501                SectorPlottingProgress::Error { .. } => {
502                    metrics.sector_plotting_error.inc();
503                }
504            }
505        }
506        self.handlers.plotting_progress.call_simple(
507            &self.public_key,
508            &self.sector_index,
509            &progress,
510        );
511
512        if let Err(error) = progress_sender.send(progress).await {
513            warn!(%error, "Failed to send progress update");
514
515            false
516        } else {
517            true
518        }
519    }
520}