ab_farmer/plotter/
cpu.rs

1//! CPU plotter
2
3pub mod metrics;
4
5use crate::plotter::cpu::metrics::CpuPlotterMetrics;
6use crate::plotter::{Plotter, SectorPlottingProgress};
7use crate::thread_pool_manager::PlottingThreadPoolManager;
8use crate::utils::AsyncJoinOnDrop;
9use ab_core_primitives::ed25519::Ed25519PublicKey;
10use ab_core_primitives::sectors::SectorIndex;
11use ab_core_primitives::solutions::ShardCommitmentHash;
12use ab_data_retrieval::piece_getter::PieceGetter;
13use ab_erasure_coding::ErasureCoding;
14use ab_farmer_components::FarmerProtocolInfo;
15use ab_farmer_components::plotting::{
16    CpuRecordsEncoder, DownloadSectorOptions, EncodeSectorOptions, PlottingError, download_sector,
17    encode_sector, write_sector,
18};
19use ab_proof_of_space::Table;
20use async_lock::{Mutex as AsyncMutex, Semaphore, SemaphoreGuardArc};
21use async_trait::async_trait;
22use bytes::Bytes;
23use event_listener_primitives::{Bag, HandlerId};
24use futures::channel::mpsc;
25use futures::stream::FuturesUnordered;
26use futures::{FutureExt, Sink, SinkExt, StreamExt, select, stream};
27use prometheus_client::registry::Registry;
28use std::any::type_name;
29use std::error::Error;
30use std::fmt;
31use std::future::pending;
32use std::marker::PhantomData;
33use std::num::NonZeroUsize;
34use std::pin::pin;
35use std::sync::Arc;
36use std::sync::atomic::{AtomicBool, Ordering};
37use std::task::Poll;
38use std::time::Instant;
39use tokio::task::yield_now;
40use tracing::{Instrument, warn};
41
42/// Type alias used for event handlers
43pub type HandlerFn3<A, B, C> = Arc<dyn Fn(&A, &B, &C) + Send + Sync + 'static>;
44type Handler3<A, B, C> = Bag<HandlerFn3<A, B, C>, A, B, C>;
45
46#[derive(Default, Debug)]
47struct Handlers {
48    plotting_progress: Handler3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
49}
50
51/// CPU plotter
52pub struct CpuPlotter<PG, PosTable> {
53    piece_getter: PG,
54    downloading_semaphore: Arc<Semaphore>,
55    plotting_thread_pool_manager: PlottingThreadPoolManager,
56    record_encoding_concurrency: NonZeroUsize,
57    global_mutex: Arc<AsyncMutex<()>>,
58    erasure_coding: ErasureCoding,
59    handlers: Arc<Handlers>,
60    tasks_sender: mpsc::Sender<AsyncJoinOnDrop<()>>,
61    _background_tasks: AsyncJoinOnDrop<()>,
62    abort_early: Arc<AtomicBool>,
63    metrics: Option<Arc<CpuPlotterMetrics>>,
64    _phantom: PhantomData<PosTable>,
65}
66
67impl<PG, PosTable> fmt::Debug for CpuPlotter<PG, PosTable> {
68    #[inline]
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        f.debug_struct("CpuPlotter").finish_non_exhaustive()
71    }
72}
73
74impl<PG, PosTable> Drop for CpuPlotter<PG, PosTable> {
75    #[inline]
76    fn drop(&mut self) {
77        self.abort_early.store(true, Ordering::Release);
78        self.tasks_sender.close_channel();
79    }
80}
81
82#[async_trait]
83impl<PG, PosTable> Plotter for CpuPlotter<PG, PosTable>
84where
85    PG: PieceGetter + Clone + Send + Sync + 'static,
86    PosTable: Table,
87{
88    async fn has_free_capacity(&self) -> Result<bool, String> {
89        Ok(self.downloading_semaphore.try_acquire().is_some())
90    }
91
92    async fn plot_sector(
93        &self,
94        public_key: Ed25519PublicKey,
95        shard_commitments_root: ShardCommitmentHash,
96        sector_index: SectorIndex,
97        farmer_protocol_info: FarmerProtocolInfo,
98        pieces_in_sector: u16,
99        replotting: bool,
100        progress_sender: mpsc::Sender<SectorPlottingProgress>,
101    ) {
102        let start = Instant::now();
103
104        // Done outside the future below as a backpressure, ensuring that it is not possible to
105        // schedule unbounded number of plotting tasks
106        let downloading_permit = self.downloading_semaphore.acquire_arc().await;
107
108        self.plot_sector_internal(
109            start,
110            downloading_permit,
111            public_key,
112            shard_commitments_root,
113            sector_index,
114            farmer_protocol_info,
115            pieces_in_sector,
116            replotting,
117            progress_sender,
118        )
119        .await
120    }
121
122    async fn try_plot_sector(
123        &self,
124        public_key: Ed25519PublicKey,
125        shard_commitments_root: ShardCommitmentHash,
126        sector_index: SectorIndex,
127        farmer_protocol_info: FarmerProtocolInfo,
128        pieces_in_sector: u16,
129        replotting: bool,
130        progress_sender: mpsc::Sender<SectorPlottingProgress>,
131    ) -> bool {
132        let start = Instant::now();
133
134        let Some(downloading_permit) = self.downloading_semaphore.try_acquire_arc() else {
135            return false;
136        };
137
138        self.plot_sector_internal(
139            start,
140            downloading_permit,
141            public_key,
142            shard_commitments_root,
143            sector_index,
144            farmer_protocol_info,
145            pieces_in_sector,
146            replotting,
147            progress_sender,
148        )
149        .await;
150
151        true
152    }
153}
154
155impl<PG, PosTable> CpuPlotter<PG, PosTable>
156where
157    PG: PieceGetter + Clone + Send + Sync + 'static,
158    PosTable: Table,
159{
160    /// Create a new instance
161    pub fn new(
162        piece_getter: PG,
163        downloading_semaphore: Arc<Semaphore>,
164        plotting_thread_pool_manager: PlottingThreadPoolManager,
165        record_encoding_concurrency: NonZeroUsize,
166        global_mutex: Arc<AsyncMutex<()>>,
167        erasure_coding: ErasureCoding,
168        registry: Option<&mut Registry>,
169    ) -> Self {
170        let (tasks_sender, mut tasks_receiver) = mpsc::channel(1);
171
172        // Basically runs plotting tasks in the background and allows to abort on drop
173        let background_tasks = AsyncJoinOnDrop::new(
174            tokio::spawn(async move {
175                let background_tasks = FuturesUnordered::new();
176                let mut background_tasks = pin!(background_tasks);
177                // Just so that `FuturesUnordered` will never end
178                background_tasks.push(AsyncJoinOnDrop::new(tokio::spawn(pending::<()>()), true));
179
180                loop {
181                    select! {
182                        maybe_background_task = tasks_receiver.next().fuse() => {
183                            let Some(background_task) = maybe_background_task else {
184                                break;
185                            };
186
187                            background_tasks.push(background_task);
188                        },
189                        _ = background_tasks.select_next_some() => {
190                            // Nothing to do
191                        }
192                    }
193                }
194            }),
195            true,
196        );
197
198        let abort_early = Arc::new(AtomicBool::new(false));
199        let metrics = registry.map(|registry| {
200            Arc::new(CpuPlotterMetrics::new(
201                registry,
202                type_name::<PosTable>(),
203                plotting_thread_pool_manager.thread_pool_pairs(),
204            ))
205        });
206
207        Self {
208            piece_getter,
209            downloading_semaphore,
210            plotting_thread_pool_manager,
211            record_encoding_concurrency,
212            global_mutex,
213            erasure_coding,
214            handlers: Arc::default(),
215            tasks_sender,
216            _background_tasks: background_tasks,
217            abort_early,
218            metrics,
219            _phantom: PhantomData,
220        }
221    }
222
223    /// Subscribe to plotting progress notifications
224    pub fn on_plotting_progress(
225        &self,
226        callback: HandlerFn3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
227    ) -> HandlerId {
228        self.handlers.plotting_progress.add(callback)
229    }
230
231    #[expect(clippy::too_many_arguments)]
232    async fn plot_sector_internal<PS>(
233        &self,
234        start: Instant,
235        downloading_permit: SemaphoreGuardArc,
236        public_key: Ed25519PublicKey,
237        shard_commitments_root: ShardCommitmentHash,
238        sector_index: SectorIndex,
239        farmer_protocol_info: FarmerProtocolInfo,
240        pieces_in_sector: u16,
241        replotting: bool,
242        mut progress_sender: PS,
243    ) where
244        PS: Sink<SectorPlottingProgress> + Unpin + Send + 'static,
245        PS::Error: Error,
246    {
247        if let Some(metrics) = &self.metrics {
248            metrics.sector_plotting.inc();
249        }
250
251        let progress_updater = ProgressUpdater {
252            public_key,
253            sector_index,
254            handlers: Arc::clone(&self.handlers),
255            metrics: self.metrics.clone(),
256        };
257
258        let plotting_fut = {
259            let piece_getter = self.piece_getter.clone();
260            let plotting_thread_pool_manager = self.plotting_thread_pool_manager.clone();
261            let record_encoding_concurrency = self.record_encoding_concurrency;
262            let global_mutex = Arc::clone(&self.global_mutex);
263            let erasure_coding = self.erasure_coding.clone();
264            let abort_early = Arc::clone(&self.abort_early);
265            let metrics = self.metrics.clone();
266
267            async move {
268                // Downloading
269                let downloaded_sector = {
270                    if !progress_updater
271                        .update_progress_and_events(
272                            &mut progress_sender,
273                            SectorPlottingProgress::Downloading,
274                        )
275                        .await
276                    {
277                        return;
278                    }
279
280                    // Take mutex briefly to make sure plotting is allowed right now
281                    global_mutex.lock().await;
282
283                    let downloading_start = Instant::now();
284                    let public_key_hash = &public_key.hash();
285
286                    let downloaded_sector_fut = download_sector(DownloadSectorOptions {
287                        public_key_hash,
288                        shard_commitments_root: &shard_commitments_root,
289                        sector_index,
290                        piece_getter: &piece_getter,
291                        farmer_protocol_info,
292                        erasure_coding: &erasure_coding,
293                        pieces_in_sector,
294                    });
295
296                    let downloaded_sector = match downloaded_sector_fut.await {
297                        Ok(downloaded_sector) => downloaded_sector,
298                        Err(error) => {
299                            warn!(%error, "Failed to download sector");
300
301                            progress_updater
302                                .update_progress_and_events(
303                                    &mut progress_sender,
304                                    SectorPlottingProgress::Error {
305                                        error: format!("Failed to download sector: {error}"),
306                                    },
307                                )
308                                .await;
309
310                            return;
311                        }
312                    };
313
314                    if !progress_updater
315                        .update_progress_and_events(
316                            &mut progress_sender,
317                            SectorPlottingProgress::Downloaded(downloading_start.elapsed()),
318                        )
319                        .await
320                    {
321                        return;
322                    }
323
324                    downloaded_sector
325                };
326
327                // Plotting
328                let (sector, plotted_sector) = {
329                    let thread_pools = plotting_thread_pool_manager.get_thread_pools().await;
330                    if let Some(metrics) = &metrics {
331                        metrics.plotting_capacity_used.inc();
332                    }
333
334                    // Give a chance to interrupt plotting if necessary
335                    yield_now().await;
336
337                    if !progress_updater
338                        .update_progress_and_events(
339                            &mut progress_sender,
340                            SectorPlottingProgress::Encoding,
341                        )
342                        .await
343                    {
344                        if let Some(metrics) = &metrics {
345                            metrics.plotting_capacity_used.dec();
346                        }
347                        return;
348                    }
349
350                    let encoding_start = Instant::now();
351
352                    let plotting_result = tokio::task::block_in_place(move || {
353                        let thread_pool = if replotting {
354                            &thread_pools.replotting
355                        } else {
356                            &thread_pools.plotting
357                        };
358
359                        let encoded_sector = thread_pool.install(|| {
360                            // TODO: Reuse global table generator (this comment is in many files)
361                            let generator = PosTable::generator();
362                            let generators = (0..record_encoding_concurrency.get())
363                                .map(|_| generator.clone())
364                                .collect::<Vec<_>>();
365                            let mut records_encoder = CpuRecordsEncoder::<PosTable>::new(
366                                &generators,
367                                &erasure_coding,
368                                &global_mutex,
369                            );
370
371                            encode_sector(
372                                downloaded_sector,
373                                EncodeSectorOptions {
374                                    sector_index,
375                                    records_encoder: &mut records_encoder,
376                                    abort_early: &abort_early,
377                                },
378                            )
379                        })?;
380
381                        if abort_early.load(Ordering::Acquire) {
382                            return Err(PlottingError::AbortEarly);
383                        }
384
385                        drop(thread_pools);
386
387                        let mut sector = Vec::new();
388
389                        write_sector(&encoded_sector, &mut sector)?;
390
391                        Ok((sector, encoded_sector.plotted_sector))
392                    });
393
394                    if let Some(metrics) = &metrics {
395                        metrics.plotting_capacity_used.dec();
396                    }
397
398                    match plotting_result {
399                        Ok(plotting_result) => {
400                            if !progress_updater
401                                .update_progress_and_events(
402                                    &mut progress_sender,
403                                    SectorPlottingProgress::Encoded(encoding_start.elapsed()),
404                                )
405                                .await
406                            {
407                                return;
408                            }
409
410                            plotting_result
411                        }
412                        Err(PlottingError::AbortEarly) => {
413                            return;
414                        }
415                        Err(error) => {
416                            progress_updater
417                                .update_progress_and_events(
418                                    &mut progress_sender,
419                                    SectorPlottingProgress::Error {
420                                        error: format!("Failed to encode sector: {error}"),
421                                    },
422                                )
423                                .await;
424
425                            return;
426                        }
427                    }
428                };
429
430                progress_updater
431                    .update_progress_and_events(
432                        &mut progress_sender,
433                        SectorPlottingProgress::Finished {
434                            plotted_sector,
435                            time: start.elapsed(),
436                            sector: Box::pin({
437                                let mut sector = Some(Ok(Bytes::from(sector)));
438
439                                stream::poll_fn(move |_cx| {
440                                    // Just so that permit is dropped with stream itself
441                                    let _downloading_permit = &downloading_permit;
442
443                                    Poll::Ready(sector.take())
444                                })
445                            }),
446                        },
447                    )
448                    .await;
449            }
450        };
451
452        // Spawn a separate task such that `block_in_place` inside will not affect anything else
453        let plotting_task =
454            AsyncJoinOnDrop::new(tokio::spawn(plotting_fut.in_current_span()), true);
455        if let Err(error) = self.tasks_sender.clone().send(plotting_task).await {
456            warn!(%error, "Failed to send plotting task");
457
458            let progress = SectorPlottingProgress::Error {
459                error: format!("Failed to send plotting task: {error}"),
460            };
461
462            self.handlers
463                .plotting_progress
464                .call_simple(&public_key, &sector_index, &progress);
465        }
466    }
467}
468
469struct ProgressUpdater {
470    public_key: Ed25519PublicKey,
471    sector_index: SectorIndex,
472    handlers: Arc<Handlers>,
473    metrics: Option<Arc<CpuPlotterMetrics>>,
474}
475
476impl ProgressUpdater {
477    /// Returns `true` on success and `false` if progress receiver channel is gone
478    async fn update_progress_and_events<PS>(
479        &self,
480        progress_sender: &mut PS,
481        progress: SectorPlottingProgress,
482    ) -> bool
483    where
484        PS: Sink<SectorPlottingProgress> + Unpin,
485        PS::Error: Error,
486    {
487        if let Some(metrics) = &self.metrics {
488            match &progress {
489                SectorPlottingProgress::Downloading => {
490                    metrics.sector_downloading.inc();
491                }
492                SectorPlottingProgress::Downloaded(time) => {
493                    metrics.sector_downloading_time.observe(time.as_secs_f64());
494                    metrics.sector_downloaded.inc();
495                }
496                SectorPlottingProgress::Encoding => {
497                    metrics.sector_encoding.inc();
498                }
499                SectorPlottingProgress::Encoded(time) => {
500                    metrics.sector_encoding_time.observe(time.as_secs_f64());
501                    metrics.sector_encoded.inc();
502                }
503                SectorPlottingProgress::Finished { time, .. } => {
504                    metrics.sector_plotting_time.observe(time.as_secs_f64());
505                    metrics.sector_plotted.inc();
506                }
507                SectorPlottingProgress::Error { .. } => {
508                    metrics.sector_plotting_error.inc();
509                }
510            }
511        }
512        self.handlers.plotting_progress.call_simple(
513            &self.public_key,
514            &self.sector_index,
515            &progress,
516        );
517
518        if let Err(error) = progress_sender.send(progress).await {
519            warn!(%error, "Failed to send progress update");
520
521            false
522        } else {
523            true
524        }
525    }
526}