ab_farmer/plotter/
gpu.rs

1//! GPU plotter
2
3mod gpu_encoders_manager;
4pub mod metrics;
5
6use crate::plotter::gpu::gpu_encoders_manager::GpuRecordsEncoderManager;
7use crate::plotter::gpu::metrics::GpuPlotterMetrics;
8use crate::plotter::{Plotter, SectorPlottingProgress};
9use crate::utils::AsyncJoinOnDrop;
10use ab_core_primitives::ed25519::Ed25519PublicKey;
11use ab_core_primitives::sectors::SectorIndex;
12use ab_core_primitives::solutions::ShardCommitmentHash;
13use ab_data_retrieval::piece_getter::PieceGetter;
14use ab_erasure_coding::ErasureCoding;
15use ab_farmer_components::FarmerProtocolInfo;
16use ab_farmer_components::plotting::{
17    DownloadSectorOptions, EncodeSectorOptions, PlottingError, download_sector, encode_sector,
18    write_sector,
19};
20use ab_proof_of_space_gpu::GpuRecordsEncoder;
21use async_lock::{Mutex as AsyncMutex, Semaphore, SemaphoreGuardArc};
22use async_trait::async_trait;
23use bytes::Bytes;
24use event_listener_primitives::{Bag, HandlerId};
25use futures::channel::mpsc;
26use futures::stream::FuturesUnordered;
27use futures::{FutureExt, Sink, SinkExt, StreamExt, select, stream};
28use prometheus_client::registry::Registry;
29use std::error::Error;
30use std::fmt;
31use std::future::pending;
32use std::num::TryFromIntError;
33use std::pin::pin;
34use std::sync::Arc;
35use std::sync::atomic::{AtomicBool, Ordering};
36use std::task::Poll;
37use std::time::Instant;
38use tokio::task::yield_now;
39use tracing::{Instrument, warn};
40
41/// Type alias used for event handlers
42pub type HandlerFn3<A, B, C> = Arc<dyn Fn(&A, &B, &C) + Send + Sync + 'static>;
43type Handler3<A, B, C> = Bag<HandlerFn3<A, B, C>, A, B, C>;
44
45#[derive(Default, Debug)]
46struct Handlers {
47    plotting_progress: Handler3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
48}
49
50/// GPU plotter
51pub struct GpuPlotter<PG> {
52    piece_getter: PG,
53    downloading_semaphore: Arc<Semaphore>,
54    gpu_records_encoders_manager: GpuRecordsEncoderManager,
55    global_mutex: Arc<AsyncMutex<()>>,
56    erasure_coding: ErasureCoding,
57    handlers: Arc<Handlers>,
58    tasks_sender: mpsc::Sender<AsyncJoinOnDrop<()>>,
59    _background_tasks: AsyncJoinOnDrop<()>,
60    abort_early: Arc<AtomicBool>,
61    metrics: Option<Arc<GpuPlotterMetrics>>,
62}
63
64impl<PG> fmt::Debug for GpuPlotter<PG> {
65    #[inline]
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        f.debug_struct("GpuPlotter").finish_non_exhaustive()
68    }
69}
70
71impl<PG> Drop for GpuPlotter<PG> {
72    #[inline]
73    fn drop(&mut self) {
74        self.abort_early.store(true, Ordering::Release);
75        self.tasks_sender.close_channel();
76    }
77}
78
79#[async_trait]
80impl<PG> Plotter for GpuPlotter<PG>
81where
82    PG: PieceGetter + Clone + Send + Sync + 'static,
83{
84    async fn has_free_capacity(&self) -> Result<bool, String> {
85        Ok(self.downloading_semaphore.try_acquire().is_some())
86    }
87
88    async fn plot_sector(
89        &self,
90        public_key: Ed25519PublicKey,
91        shard_commitments_root: ShardCommitmentHash,
92        sector_index: SectorIndex,
93        farmer_protocol_info: FarmerProtocolInfo,
94        pieces_in_sector: u16,
95        _replotting: bool,
96        progress_sender: mpsc::Sender<SectorPlottingProgress>,
97    ) {
98        let start = Instant::now();
99
100        // Done outside the future below as a backpressure, ensuring that it is not possible to
101        // schedule unbounded number of plotting tasks
102        let downloading_permit = self.downloading_semaphore.acquire_arc().await;
103
104        self.plot_sector_internal(
105            start,
106            downloading_permit,
107            public_key,
108            shard_commitments_root,
109            sector_index,
110            farmer_protocol_info,
111            pieces_in_sector,
112            progress_sender,
113        )
114        .await
115    }
116
117    async fn try_plot_sector(
118        &self,
119        public_key: Ed25519PublicKey,
120        shard_commitments_root: ShardCommitmentHash,
121        sector_index: SectorIndex,
122        farmer_protocol_info: FarmerProtocolInfo,
123        pieces_in_sector: u16,
124        _replotting: bool,
125        progress_sender: mpsc::Sender<SectorPlottingProgress>,
126    ) -> bool {
127        let start = Instant::now();
128
129        let Some(downloading_permit) = self.downloading_semaphore.try_acquire_arc() else {
130            return false;
131        };
132
133        self.plot_sector_internal(
134            start,
135            downloading_permit,
136            public_key,
137            shard_commitments_root,
138            sector_index,
139            farmer_protocol_info,
140            pieces_in_sector,
141            progress_sender,
142        )
143        .await;
144
145        true
146    }
147}
148
149impl<PG> GpuPlotter<PG>
150where
151    PG: PieceGetter + Clone + Send + Sync + 'static,
152{
153    /// Create a new instance.
154    ///
155    /// Returns an error if empty list of encoders is provided.
156    pub fn new(
157        piece_getter: PG,
158        downloading_semaphore: Arc<Semaphore>,
159        gpu_records_encoders: Vec<GpuRecordsEncoder>,
160        global_mutex: Arc<AsyncMutex<()>>,
161        erasure_coding: ErasureCoding,
162        registry: Option<&mut Registry>,
163    ) -> Result<Self, TryFromIntError> {
164        let (tasks_sender, mut tasks_receiver) = mpsc::channel(1);
165
166        // Basically runs plotting tasks in the background and allows to abort on drop
167        let background_tasks = AsyncJoinOnDrop::new(
168            tokio::spawn(async move {
169                let background_tasks = FuturesUnordered::new();
170                let mut background_tasks = pin!(background_tasks);
171                // Just so that `FuturesUnordered` will never end
172                background_tasks.push(AsyncJoinOnDrop::new(tokio::spawn(pending::<()>()), true));
173
174                loop {
175                    select! {
176                        maybe_background_task = tasks_receiver.next().fuse() => {
177                            let Some(background_task) = maybe_background_task else {
178                                break;
179                            };
180
181                            background_tasks.push(background_task);
182                        },
183                        _ = background_tasks.select_next_some() => {
184                            // Nothing to do
185                        }
186                    }
187                }
188            }),
189            true,
190        );
191
192        let abort_early = Arc::new(AtomicBool::new(false));
193        let gpu_records_encoders_manager = GpuRecordsEncoderManager::new(gpu_records_encoders)?;
194        let metrics = registry.map(|registry| {
195            Arc::new(GpuPlotterMetrics::new(
196                registry,
197                gpu_records_encoders_manager.gpu_records_encoders(),
198            ))
199        });
200
201        Ok(Self {
202            piece_getter,
203            downloading_semaphore,
204            gpu_records_encoders_manager,
205            global_mutex,
206            erasure_coding,
207            handlers: Arc::default(),
208            tasks_sender,
209            _background_tasks: background_tasks,
210            abort_early,
211            metrics,
212        })
213    }
214
215    /// Subscribe to plotting progress notifications
216    pub fn on_plotting_progress(
217        &self,
218        callback: HandlerFn3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
219    ) -> HandlerId {
220        self.handlers.plotting_progress.add(callback)
221    }
222
223    #[expect(clippy::too_many_arguments)]
224    async fn plot_sector_internal<PS>(
225        &self,
226        start: Instant,
227        downloading_permit: SemaphoreGuardArc,
228        public_key: Ed25519PublicKey,
229        shard_commitments_root: ShardCommitmentHash,
230        sector_index: SectorIndex,
231        farmer_protocol_info: FarmerProtocolInfo,
232        pieces_in_sector: u16,
233        mut progress_sender: PS,
234    ) where
235        PS: Sink<SectorPlottingProgress> + Unpin + Send + 'static,
236        PS::Error: Error,
237    {
238        if let Some(metrics) = &self.metrics {
239            metrics.sector_plotting.inc();
240        }
241
242        let progress_updater = ProgressUpdater {
243            public_key,
244            sector_index,
245            handlers: Arc::clone(&self.handlers),
246            metrics: self.metrics.clone(),
247        };
248
249        let plotting_fut = {
250            let piece_getter = self.piece_getter.clone();
251            let gpu_records_encoders_manager = self.gpu_records_encoders_manager.clone();
252            let global_mutex = Arc::clone(&self.global_mutex);
253            let erasure_coding = self.erasure_coding.clone();
254            let abort_early = Arc::clone(&self.abort_early);
255            let metrics = self.metrics.clone();
256
257            async move {
258                // Downloading
259                let downloaded_sector = {
260                    if !progress_updater
261                        .update_progress_and_events(
262                            &mut progress_sender,
263                            SectorPlottingProgress::Downloading,
264                        )
265                        .await
266                    {
267                        return;
268                    }
269
270                    // Take mutex briefly to make sure plotting is allowed right now
271                    global_mutex.lock().await;
272
273                    let downloading_start = Instant::now();
274                    let public_key_hash = &public_key.hash();
275
276                    let downloaded_sector_fut = download_sector(DownloadSectorOptions {
277                        public_key_hash,
278                        shard_commitments_root: &shard_commitments_root,
279                        sector_index,
280                        piece_getter: &piece_getter,
281                        farmer_protocol_info,
282                        erasure_coding: &erasure_coding,
283                        pieces_in_sector,
284                    });
285
286                    let downloaded_sector = match downloaded_sector_fut.await {
287                        Ok(downloaded_sector) => downloaded_sector,
288                        Err(error) => {
289                            warn!(%error, "Failed to download sector");
290
291                            progress_updater
292                                .update_progress_and_events(
293                                    &mut progress_sender,
294                                    SectorPlottingProgress::Error {
295                                        error: format!("Failed to download sector: {error}"),
296                                    },
297                                )
298                                .await;
299
300                            return;
301                        }
302                    };
303
304                    if !progress_updater
305                        .update_progress_and_events(
306                            &mut progress_sender,
307                            SectorPlottingProgress::Downloaded(downloading_start.elapsed()),
308                        )
309                        .await
310                    {
311                        return;
312                    }
313
314                    downloaded_sector
315                };
316
317                // Plotting
318                let (sector, plotted_sector) = {
319                    let mut records_encoder = gpu_records_encoders_manager.get_encoder().await;
320                    if let Some(metrics) = &metrics {
321                        metrics.plotting_capacity_used.inc();
322                    }
323
324                    // Give a chance to interrupt plotting if necessary
325                    yield_now().await;
326
327                    if !progress_updater
328                        .update_progress_and_events(
329                            &mut progress_sender,
330                            SectorPlottingProgress::Encoding,
331                        )
332                        .await
333                    {
334                        if let Some(metrics) = &metrics {
335                            metrics.plotting_capacity_used.dec();
336                        }
337                        return;
338                    }
339
340                    let encoding_start = Instant::now();
341
342                    let plotting_result = tokio::task::block_in_place(move || {
343                        let encoded_sector = encode_sector(
344                            downloaded_sector,
345                            EncodeSectorOptions {
346                                sector_index,
347                                records_encoder: &mut *records_encoder,
348                                abort_early: &abort_early,
349                            },
350                        )?;
351
352                        if abort_early.load(Ordering::Acquire) {
353                            return Err(PlottingError::AbortEarly);
354                        }
355
356                        drop(records_encoder);
357
358                        let mut sector = Vec::new();
359
360                        write_sector(&encoded_sector, &mut sector)?;
361
362                        Ok((sector, encoded_sector.plotted_sector))
363                    });
364
365                    if let Some(metrics) = &metrics {
366                        metrics.plotting_capacity_used.dec();
367                    }
368
369                    match plotting_result {
370                        Ok(plotting_result) => {
371                            if !progress_updater
372                                .update_progress_and_events(
373                                    &mut progress_sender,
374                                    SectorPlottingProgress::Encoded(encoding_start.elapsed()),
375                                )
376                                .await
377                            {
378                                return;
379                            }
380
381                            plotting_result
382                        }
383                        Err(PlottingError::AbortEarly) => {
384                            return;
385                        }
386                        Err(error) => {
387                            progress_updater
388                                .update_progress_and_events(
389                                    &mut progress_sender,
390                                    SectorPlottingProgress::Error {
391                                        error: format!("Failed to encode sector: {error}"),
392                                    },
393                                )
394                                .await;
395
396                            return;
397                        }
398                    }
399                };
400
401                progress_updater
402                    .update_progress_and_events(
403                        &mut progress_sender,
404                        SectorPlottingProgress::Finished {
405                            plotted_sector,
406                            time: start.elapsed(),
407                            sector: Box::pin({
408                                let mut sector = Some(Ok(Bytes::from(sector)));
409
410                                stream::poll_fn(move |_cx| {
411                                    // Just so that permit is dropped with stream itself
412                                    let _downloading_permit = &downloading_permit;
413
414                                    Poll::Ready(sector.take())
415                                })
416                            }),
417                        },
418                    )
419                    .await;
420            }
421        };
422
423        // Spawn a separate task such that `block_in_place` inside will not affect anything else
424        let plotting_task =
425            AsyncJoinOnDrop::new(tokio::spawn(plotting_fut.in_current_span()), true);
426        if let Err(error) = self.tasks_sender.clone().send(plotting_task).await {
427            warn!(%error, "Failed to send plotting task");
428
429            let progress = SectorPlottingProgress::Error {
430                error: format!("Failed to send plotting task: {error}"),
431            };
432
433            self.handlers
434                .plotting_progress
435                .call_simple(&public_key, &sector_index, &progress);
436        }
437    }
438}
439
440struct ProgressUpdater {
441    public_key: Ed25519PublicKey,
442    sector_index: SectorIndex,
443    handlers: Arc<Handlers>,
444    metrics: Option<Arc<GpuPlotterMetrics>>,
445}
446
447impl ProgressUpdater {
448    /// Returns `true` on success and `false` if progress receiver channel is gone
449    async fn update_progress_and_events<PS>(
450        &self,
451        progress_sender: &mut PS,
452        progress: SectorPlottingProgress,
453    ) -> bool
454    where
455        PS: Sink<SectorPlottingProgress> + Unpin,
456        PS::Error: Error,
457    {
458        if let Some(metrics) = &self.metrics {
459            match &progress {
460                SectorPlottingProgress::Downloading => {
461                    metrics.sector_downloading.inc();
462                }
463                SectorPlottingProgress::Downloaded(time) => {
464                    metrics.sector_downloading_time.observe(time.as_secs_f64());
465                    metrics.sector_downloaded.inc();
466                }
467                SectorPlottingProgress::Encoding => {
468                    metrics.sector_encoding.inc();
469                }
470                SectorPlottingProgress::Encoded(time) => {
471                    metrics.sector_encoding_time.observe(time.as_secs_f64());
472                    metrics.sector_encoded.inc();
473                }
474                SectorPlottingProgress::Finished { time, .. } => {
475                    metrics.sector_plotting_time.observe(time.as_secs_f64());
476                    metrics.sector_plotted.inc();
477                }
478                SectorPlottingProgress::Error { .. } => {
479                    metrics.sector_plotting_error.inc();
480                }
481            }
482        }
483        self.handlers.plotting_progress.call_simple(
484            &self.public_key,
485            &self.sector_index,
486            &progress,
487        );
488
489        if let Err(error) = progress_sender.send(progress).await {
490            warn!(%error, "Failed to send progress update");
491
492            false
493        } else {
494            true
495        }
496    }
497}