Skip to main content

ab_farmer/plotter/
cpu.rs

1//! CPU plotter
2
3pub mod metrics;
4
5use crate::plotter::cpu::metrics::CpuPlotterMetrics;
6use crate::plotter::{Plotter, SectorPlottingProgress};
7use crate::thread_pool_manager::PlottingThreadPoolManager;
8use crate::utils::AsyncJoinOnDrop;
9use ab_core_primitives::ed25519::Ed25519PublicKey;
10use ab_core_primitives::sectors::SectorIndex;
11use ab_core_primitives::solutions::ShardCommitmentHash;
12use ab_data_retrieval::piece_getter::PieceGetter;
13use ab_erasure_coding::ErasureCoding;
14use ab_farmer_components::FarmerProtocolInfo;
15use ab_farmer_components::plotting::{
16    CpuRecordsEncoder, DownloadSectorOptions, EncodeSectorOptions, PlottingError, download_sector,
17    encode_sector, write_sector,
18};
19use ab_proof_of_space::Table;
20use async_lock::{Mutex as AsyncMutex, Semaphore, SemaphoreGuardArc};
21use async_trait::async_trait;
22use bytes::Bytes;
23use event_listener_primitives::{Bag, HandlerId};
24use futures::channel::mpsc;
25use futures::stream::FuturesUnordered;
26use futures::{FutureExt, Sink, SinkExt, StreamExt, select, stream};
27use prometheus_client::registry::Registry;
28use std::any::type_name;
29use std::error::Error;
30use std::future::pending;
31use std::marker::PhantomData;
32use std::num::NonZeroUsize;
33use std::pin::pin;
34use std::sync::Arc;
35use std::sync::atomic::{AtomicBool, Ordering};
36use std::task::Poll;
37use std::time::Instant;
38use std::{fmt, iter};
39use tokio::task::yield_now;
40use tracing::{Instrument, warn};
41
42/// Type alias used for event handlers
43pub type HandlerFn3<A, B, C> = Arc<dyn Fn(&A, &B, &C) + Send + Sync + 'static>;
44type Handler3<A, B, C> = Bag<HandlerFn3<A, B, C>, A, B, C>;
45
46#[derive(Default, Debug)]
47struct Handlers {
48    plotting_progress: Handler3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
49}
50
51/// CPU plotter
52pub struct CpuPlotter<PG, PosTable> {
53    piece_getter: PG,
54    downloading_semaphore: Arc<Semaphore>,
55    plotting_thread_pool_manager: PlottingThreadPoolManager,
56    record_encoding_concurrency: NonZeroUsize,
57    global_mutex: Arc<AsyncMutex<()>>,
58    erasure_coding: ErasureCoding,
59    handlers: Arc<Handlers>,
60    tasks_sender: mpsc::Sender<AsyncJoinOnDrop<()>>,
61    _background_tasks: AsyncJoinOnDrop<()>,
62    abort_early: Arc<AtomicBool>,
63    metrics: Option<Arc<CpuPlotterMetrics>>,
64    _phantom: PhantomData<PosTable>,
65}
66
67impl<PG, PosTable> fmt::Debug for CpuPlotter<PG, PosTable> {
68    #[inline]
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        f.debug_struct("CpuPlotter").finish_non_exhaustive()
71    }
72}
73
74impl<PG, PosTable> Drop for CpuPlotter<PG, PosTable> {
75    #[inline]
76    fn drop(&mut self) {
77        self.abort_early.store(true, Ordering::Release);
78        self.tasks_sender.close_channel();
79    }
80}
81
82#[async_trait]
83impl<PG, PosTable> Plotter for CpuPlotter<PG, PosTable>
84where
85    PG: PieceGetter + Clone + Send + Sync + 'static,
86    PosTable: Table,
87{
88    async fn has_free_capacity(&self) -> Result<bool, String> {
89        Ok(self.downloading_semaphore.try_acquire().is_some())
90    }
91
92    async fn plot_sector(
93        &self,
94        public_key: Ed25519PublicKey,
95        shard_commitments_root: ShardCommitmentHash,
96        sector_index: SectorIndex,
97        farmer_protocol_info: FarmerProtocolInfo,
98        pieces_in_sector: u16,
99        replotting: bool,
100        progress_sender: mpsc::Sender<SectorPlottingProgress>,
101    ) {
102        let start = Instant::now();
103
104        // Done outside the future below as a backpressure, ensuring that it is not possible to
105        // schedule unbounded number of plotting tasks
106        let downloading_permit = self.downloading_semaphore.acquire_arc().await;
107
108        self.plot_sector_internal(
109            start,
110            downloading_permit,
111            public_key,
112            shard_commitments_root,
113            sector_index,
114            farmer_protocol_info,
115            pieces_in_sector,
116            replotting,
117            progress_sender,
118        )
119        .await;
120    }
121
122    async fn try_plot_sector(
123        &self,
124        public_key: Ed25519PublicKey,
125        shard_commitments_root: ShardCommitmentHash,
126        sector_index: SectorIndex,
127        farmer_protocol_info: FarmerProtocolInfo,
128        pieces_in_sector: u16,
129        replotting: bool,
130        progress_sender: mpsc::Sender<SectorPlottingProgress>,
131    ) -> bool {
132        let start = Instant::now();
133
134        let Some(downloading_permit) = self.downloading_semaphore.try_acquire_arc() else {
135            return false;
136        };
137
138        self.plot_sector_internal(
139            start,
140            downloading_permit,
141            public_key,
142            shard_commitments_root,
143            sector_index,
144            farmer_protocol_info,
145            pieces_in_sector,
146            replotting,
147            progress_sender,
148        )
149        .await;
150
151        true
152    }
153}
154
155impl<PG, PosTable> CpuPlotter<PG, PosTable>
156where
157    PG: PieceGetter + Clone + Send + Sync + 'static,
158    PosTable: Table,
159{
160    /// Create a new instance
161    pub fn new(
162        piece_getter: PG,
163        downloading_semaphore: Arc<Semaphore>,
164        plotting_thread_pool_manager: PlottingThreadPoolManager,
165        record_encoding_concurrency: NonZeroUsize,
166        global_mutex: Arc<AsyncMutex<()>>,
167        erasure_coding: ErasureCoding,
168        registry: Option<&mut Registry>,
169    ) -> Self {
170        let (tasks_sender, mut tasks_receiver) = mpsc::channel(1);
171
172        // Basically runs plotting tasks in the background and allows to abort on drop
173        let background_tasks = AsyncJoinOnDrop::new(
174            tokio::spawn(async move {
175                let background_tasks = FuturesUnordered::new();
176                let mut background_tasks = pin!(background_tasks);
177                // Just so that `FuturesUnordered` will never end
178                background_tasks.push(AsyncJoinOnDrop::new(tokio::spawn(pending::<()>()), true));
179
180                loop {
181                    select! {
182                        maybe_background_task = tasks_receiver.next().fuse() => {
183                            let Some(background_task) = maybe_background_task else {
184                                break;
185                            };
186
187                            background_tasks.push(background_task);
188                        },
189                        _ = background_tasks.select_next_some() => {
190                            // Nothing to do
191                        }
192                    }
193                }
194            }),
195            true,
196        );
197
198        let abort_early = Arc::new(AtomicBool::new(false));
199        let metrics = registry.map(|registry| {
200            Arc::new(CpuPlotterMetrics::new(
201                registry,
202                type_name::<PosTable>(),
203                plotting_thread_pool_manager.thread_pool_pairs(),
204            ))
205        });
206
207        Self {
208            piece_getter,
209            downloading_semaphore,
210            plotting_thread_pool_manager,
211            record_encoding_concurrency,
212            global_mutex,
213            erasure_coding,
214            handlers: Arc::default(),
215            tasks_sender,
216            _background_tasks: background_tasks,
217            abort_early,
218            metrics,
219            _phantom: PhantomData,
220        }
221    }
222
223    /// Subscribe to plotting progress notifications
224    pub fn on_plotting_progress(
225        &self,
226        callback: HandlerFn3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
227    ) -> HandlerId {
228        self.handlers.plotting_progress.add(callback)
229    }
230
231    #[expect(clippy::too_many_arguments)]
232    async fn plot_sector_internal<PS>(
233        &self,
234        start: Instant,
235        downloading_permit: SemaphoreGuardArc,
236        public_key: Ed25519PublicKey,
237        shard_commitments_root: ShardCommitmentHash,
238        sector_index: SectorIndex,
239        farmer_protocol_info: FarmerProtocolInfo,
240        pieces_in_sector: u16,
241        replotting: bool,
242        mut progress_sender: PS,
243    ) where
244        PS: Sink<SectorPlottingProgress> + Unpin + Send + 'static,
245        PS::Error: Error,
246    {
247        if let Some(metrics) = &self.metrics {
248            metrics.sector_plotting.inc();
249        }
250
251        let progress_updater = ProgressUpdater {
252            public_key,
253            sector_index,
254            handlers: Arc::clone(&self.handlers),
255            metrics: self.metrics.clone(),
256        };
257
258        let plotting_fut = {
259            let piece_getter = self.piece_getter.clone();
260            let plotting_thread_pool_manager = self.plotting_thread_pool_manager.clone();
261            let record_encoding_concurrency = self.record_encoding_concurrency;
262            let global_mutex = Arc::clone(&self.global_mutex);
263            let erasure_coding = self.erasure_coding.clone();
264            let abort_early = Arc::clone(&self.abort_early);
265            let metrics = self.metrics.clone();
266
267            async move {
268                // Downloading
269                let downloaded_sector = {
270                    if !progress_updater
271                        .update_progress_and_events(
272                            &mut progress_sender,
273                            SectorPlottingProgress::Downloading,
274                        )
275                        .await
276                    {
277                        return;
278                    }
279
280                    // Take mutex briefly to make sure plotting is allowed right now
281                    global_mutex.lock().await;
282
283                    let downloading_start = Instant::now();
284                    let public_key_hash = &public_key.hash();
285
286                    let downloaded_sector_fut = download_sector(DownloadSectorOptions {
287                        public_key_hash,
288                        shard_commitments_root: &shard_commitments_root,
289                        sector_index,
290                        piece_getter: &piece_getter,
291                        farmer_protocol_info,
292                        erasure_coding: &erasure_coding,
293                        pieces_in_sector,
294                    });
295
296                    let downloaded_sector = match downloaded_sector_fut.await {
297                        Ok(downloaded_sector) => downloaded_sector,
298                        Err(error) => {
299                            warn!(%error, "Failed to download sector");
300
301                            progress_updater
302                                .update_progress_and_events(
303                                    &mut progress_sender,
304                                    SectorPlottingProgress::Error {
305                                        error: format!("Failed to download sector: {error}"),
306                                    },
307                                )
308                                .await;
309
310                            return;
311                        }
312                    };
313
314                    if !progress_updater
315                        .update_progress_and_events(
316                            &mut progress_sender,
317                            SectorPlottingProgress::Downloaded(downloading_start.elapsed()),
318                        )
319                        .await
320                    {
321                        return;
322                    }
323
324                    downloaded_sector
325                };
326
327                // Plotting
328                let (sector, plotted_sector) = {
329                    let thread_pools = plotting_thread_pool_manager.get_thread_pools().await;
330                    if let Some(metrics) = &metrics {
331                        metrics.plotting_capacity_used.inc();
332                    }
333
334                    // Give a chance to interrupt plotting if necessary
335                    yield_now().await;
336
337                    if !progress_updater
338                        .update_progress_and_events(
339                            &mut progress_sender,
340                            SectorPlottingProgress::Encoding,
341                        )
342                        .await
343                    {
344                        if let Some(metrics) = &metrics {
345                            metrics.plotting_capacity_used.dec();
346                        }
347                        return;
348                    }
349
350                    let encoding_start = Instant::now();
351
352                    let plotting_result = tokio::task::block_in_place(move || {
353                        let thread_pool = if replotting {
354                            &thread_pools.replotting
355                        } else {
356                            &thread_pools.plotting
357                        };
358
359                        let encoded_sector = thread_pool.install(|| {
360                            // TODO: Reuse global table generator (this comment is in many files)
361                            let generator = PosTable::generator();
362                            let generators = iter::repeat_n(
363                                generator.clone(),
364                                record_encoding_concurrency.get(),
365                            )
366                            .collect::<Vec<_>>();
367                            let mut records_encoder = CpuRecordsEncoder::<PosTable>::new(
368                                &generators,
369                                &erasure_coding,
370                                &global_mutex,
371                            );
372
373                            encode_sector(
374                                downloaded_sector,
375                                EncodeSectorOptions {
376                                    sector_index,
377                                    records_encoder: &mut records_encoder,
378                                    abort_early: &abort_early,
379                                },
380                            )
381                        })?;
382
383                        if abort_early.load(Ordering::Acquire) {
384                            return Err(PlottingError::AbortEarly);
385                        }
386
387                        drop(thread_pools);
388
389                        let mut sector = Vec::new();
390
391                        write_sector(&encoded_sector, &mut sector)?;
392
393                        Ok((sector, encoded_sector.plotted_sector))
394                    });
395
396                    if let Some(metrics) = &metrics {
397                        metrics.plotting_capacity_used.dec();
398                    }
399
400                    match plotting_result {
401                        Ok(plotting_result) => {
402                            if !progress_updater
403                                .update_progress_and_events(
404                                    &mut progress_sender,
405                                    SectorPlottingProgress::Encoded(encoding_start.elapsed()),
406                                )
407                                .await
408                            {
409                                return;
410                            }
411
412                            plotting_result
413                        }
414                        Err(PlottingError::AbortEarly) => {
415                            return;
416                        }
417                        Err(error) => {
418                            progress_updater
419                                .update_progress_and_events(
420                                    &mut progress_sender,
421                                    SectorPlottingProgress::Error {
422                                        error: format!("Failed to encode sector: {error}"),
423                                    },
424                                )
425                                .await;
426
427                            return;
428                        }
429                    }
430                };
431
432                progress_updater
433                    .update_progress_and_events(
434                        &mut progress_sender,
435                        SectorPlottingProgress::Finished {
436                            plotted_sector,
437                            time: start.elapsed(),
438                            sector: Box::pin({
439                                let mut sector = Some(Ok(Bytes::from(sector)));
440
441                                stream::poll_fn(move |_cx| {
442                                    // Just so that permit is dropped with stream itself
443                                    let _downloading_permit = &downloading_permit;
444
445                                    Poll::Ready(sector.take())
446                                })
447                            }),
448                        },
449                    )
450                    .await;
451            }
452        };
453
454        // Spawn a separate task such that `block_in_place` inside will not affect anything else
455        let plotting_task =
456            AsyncJoinOnDrop::new(tokio::spawn(plotting_fut.in_current_span()), true);
457        if let Err(error) = self.tasks_sender.clone().send(plotting_task).await {
458            warn!(%error, "Failed to send plotting task");
459
460            let progress = SectorPlottingProgress::Error {
461                error: format!("Failed to send plotting task: {error}"),
462            };
463
464            self.handlers
465                .plotting_progress
466                .call_simple(&public_key, &sector_index, &progress);
467        }
468    }
469}
470
471struct ProgressUpdater {
472    public_key: Ed25519PublicKey,
473    sector_index: SectorIndex,
474    handlers: Arc<Handlers>,
475    metrics: Option<Arc<CpuPlotterMetrics>>,
476}
477
478impl ProgressUpdater {
479    /// Returns `true` on success and `false` if progress receiver channel is gone
480    async fn update_progress_and_events<PS>(
481        &self,
482        progress_sender: &mut PS,
483        progress: SectorPlottingProgress,
484    ) -> bool
485    where
486        PS: Sink<SectorPlottingProgress> + Unpin,
487        PS::Error: Error,
488    {
489        if let Some(metrics) = &self.metrics {
490            match &progress {
491                SectorPlottingProgress::Downloading => {
492                    metrics.sector_downloading.inc();
493                }
494                SectorPlottingProgress::Downloaded(time) => {
495                    metrics.sector_downloading_time.observe(time.as_secs_f64());
496                    metrics.sector_downloaded.inc();
497                }
498                SectorPlottingProgress::Encoding => {
499                    metrics.sector_encoding.inc();
500                }
501                SectorPlottingProgress::Encoded(time) => {
502                    metrics.sector_encoding_time.observe(time.as_secs_f64());
503                    metrics.sector_encoded.inc();
504                }
505                SectorPlottingProgress::Finished { time, .. } => {
506                    metrics.sector_plotting_time.observe(time.as_secs_f64());
507                    metrics.sector_plotted.inc();
508                }
509                SectorPlottingProgress::Error { .. } => {
510                    metrics.sector_plotting_error.inc();
511                }
512            }
513        }
514        self.handlers.plotting_progress.call_simple(
515            &self.public_key,
516            &self.sector_index,
517            &progress,
518        );
519
520        if let Err(error) = progress_sender.send(progress).await {
521            warn!(%error, "Failed to send progress update");
522
523            false
524        } else {
525            true
526        }
527    }
528}