ab_farmer/plotter/
gpu.rs

1//! GPU plotter
2
3mod gpu_encoders_manager;
4pub mod metrics;
5
6use crate::plotter::gpu::gpu_encoders_manager::GpuRecordsEncoderManager;
7use crate::plotter::gpu::metrics::GpuPlotterMetrics;
8use crate::plotter::{Plotter, SectorPlottingProgress};
9use crate::utils::AsyncJoinOnDrop;
10use ab_core_primitives::ed25519::Ed25519PublicKey;
11use ab_core_primitives::sectors::SectorIndex;
12use ab_data_retrieval::piece_getter::PieceGetter;
13use ab_erasure_coding::ErasureCoding;
14use ab_farmer_components::FarmerProtocolInfo;
15use ab_farmer_components::plotting::{
16    DownloadSectorOptions, EncodeSectorOptions, PlottingError, download_sector, encode_sector,
17    write_sector,
18};
19use ab_proof_of_space_gpu::GpuRecordsEncoder;
20use async_lock::{Mutex as AsyncMutex, Semaphore, SemaphoreGuardArc};
21use async_trait::async_trait;
22use bytes::Bytes;
23use event_listener_primitives::{Bag, HandlerId};
24use futures::channel::mpsc;
25use futures::stream::FuturesUnordered;
26use futures::{FutureExt, Sink, SinkExt, StreamExt, select, stream};
27use prometheus_client::registry::Registry;
28use std::error::Error;
29use std::fmt;
30use std::future::pending;
31use std::num::TryFromIntError;
32use std::pin::pin;
33use std::sync::Arc;
34use std::sync::atomic::{AtomicBool, Ordering};
35use std::task::Poll;
36use std::time::Instant;
37use tokio::task::yield_now;
38use tracing::{Instrument, warn};
39
40/// Type alias used for event handlers
41pub type HandlerFn3<A, B, C> = Arc<dyn Fn(&A, &B, &C) + Send + Sync + 'static>;
42type Handler3<A, B, C> = Bag<HandlerFn3<A, B, C>, A, B, C>;
43
44#[derive(Default, Debug)]
45struct Handlers {
46    plotting_progress: Handler3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
47}
48
49/// GPU plotter
50pub struct GpuPlotter<PG> {
51    piece_getter: PG,
52    downloading_semaphore: Arc<Semaphore>,
53    gpu_records_encoders_manager: GpuRecordsEncoderManager,
54    global_mutex: Arc<AsyncMutex<()>>,
55    erasure_coding: ErasureCoding,
56    handlers: Arc<Handlers>,
57    tasks_sender: mpsc::Sender<AsyncJoinOnDrop<()>>,
58    _background_tasks: AsyncJoinOnDrop<()>,
59    abort_early: Arc<AtomicBool>,
60    metrics: Option<Arc<GpuPlotterMetrics>>,
61}
62
63impl<PG> fmt::Debug for GpuPlotter<PG> {
64    #[inline]
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        f.debug_struct("GpuPlotter").finish_non_exhaustive()
67    }
68}
69
70impl<PG> Drop for GpuPlotter<PG> {
71    #[inline]
72    fn drop(&mut self) {
73        self.abort_early.store(true, Ordering::Release);
74        self.tasks_sender.close_channel();
75    }
76}
77
78#[async_trait]
79impl<PG> Plotter for GpuPlotter<PG>
80where
81    PG: PieceGetter + Clone + Send + Sync + 'static,
82{
83    async fn has_free_capacity(&self) -> Result<bool, String> {
84        Ok(self.downloading_semaphore.try_acquire().is_some())
85    }
86
87    async fn plot_sector(
88        &self,
89        public_key: Ed25519PublicKey,
90        sector_index: SectorIndex,
91        farmer_protocol_info: FarmerProtocolInfo,
92        pieces_in_sector: u16,
93        _replotting: bool,
94        progress_sender: mpsc::Sender<SectorPlottingProgress>,
95    ) {
96        let start = Instant::now();
97
98        // Done outside the future below as a backpressure, ensuring that it is not possible to
99        // schedule unbounded number of plotting tasks
100        let downloading_permit = self.downloading_semaphore.acquire_arc().await;
101
102        self.plot_sector_internal(
103            start,
104            downloading_permit,
105            public_key,
106            sector_index,
107            farmer_protocol_info,
108            pieces_in_sector,
109            progress_sender,
110        )
111        .await
112    }
113
114    async fn try_plot_sector(
115        &self,
116        public_key: Ed25519PublicKey,
117        sector_index: SectorIndex,
118        farmer_protocol_info: FarmerProtocolInfo,
119        pieces_in_sector: u16,
120        _replotting: bool,
121        progress_sender: mpsc::Sender<SectorPlottingProgress>,
122    ) -> bool {
123        let start = Instant::now();
124
125        let Some(downloading_permit) = self.downloading_semaphore.try_acquire_arc() else {
126            return false;
127        };
128
129        self.plot_sector_internal(
130            start,
131            downloading_permit,
132            public_key,
133            sector_index,
134            farmer_protocol_info,
135            pieces_in_sector,
136            progress_sender,
137        )
138        .await;
139
140        true
141    }
142}
143
144impl<PG> GpuPlotter<PG>
145where
146    PG: PieceGetter + Clone + Send + Sync + 'static,
147{
148    /// Create a new instance.
149    ///
150    /// Returns an error if empty list of encoders is provided.
151    pub fn new(
152        piece_getter: PG,
153        downloading_semaphore: Arc<Semaphore>,
154        gpu_records_encoders: Vec<GpuRecordsEncoder>,
155        global_mutex: Arc<AsyncMutex<()>>,
156        erasure_coding: ErasureCoding,
157        registry: Option<&mut Registry>,
158    ) -> Result<Self, TryFromIntError> {
159        let (tasks_sender, mut tasks_receiver) = mpsc::channel(1);
160
161        // Basically runs plotting tasks in the background and allows to abort on drop
162        let background_tasks = AsyncJoinOnDrop::new(
163            tokio::spawn(async move {
164                let background_tasks = FuturesUnordered::new();
165                let mut background_tasks = pin!(background_tasks);
166                // Just so that `FuturesUnordered` will never end
167                background_tasks.push(AsyncJoinOnDrop::new(tokio::spawn(pending::<()>()), true));
168
169                loop {
170                    select! {
171                        maybe_background_task = tasks_receiver.next().fuse() => {
172                            let Some(background_task) = maybe_background_task else {
173                                break;
174                            };
175
176                            background_tasks.push(background_task);
177                        },
178                        _ = background_tasks.select_next_some() => {
179                            // Nothing to do
180                        }
181                    }
182                }
183            }),
184            true,
185        );
186
187        let abort_early = Arc::new(AtomicBool::new(false));
188        let gpu_records_encoders_manager = GpuRecordsEncoderManager::new(gpu_records_encoders)?;
189        let metrics = registry.map(|registry| {
190            Arc::new(GpuPlotterMetrics::new(
191                registry,
192                gpu_records_encoders_manager.gpu_records_encoders(),
193            ))
194        });
195
196        Ok(Self {
197            piece_getter,
198            downloading_semaphore,
199            gpu_records_encoders_manager,
200            global_mutex,
201            erasure_coding,
202            handlers: Arc::default(),
203            tasks_sender,
204            _background_tasks: background_tasks,
205            abort_early,
206            metrics,
207        })
208    }
209
210    /// Subscribe to plotting progress notifications
211    pub fn on_plotting_progress(
212        &self,
213        callback: HandlerFn3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
214    ) -> HandlerId {
215        self.handlers.plotting_progress.add(callback)
216    }
217
218    #[allow(clippy::too_many_arguments)]
219    async fn plot_sector_internal<PS>(
220        &self,
221        start: Instant,
222        downloading_permit: SemaphoreGuardArc,
223        public_key: Ed25519PublicKey,
224        sector_index: SectorIndex,
225        farmer_protocol_info: FarmerProtocolInfo,
226        pieces_in_sector: u16,
227        mut progress_sender: PS,
228    ) where
229        PS: Sink<SectorPlottingProgress> + Unpin + Send + 'static,
230        PS::Error: Error,
231    {
232        if let Some(metrics) = &self.metrics {
233            metrics.sector_plotting.inc();
234        }
235
236        let progress_updater = ProgressUpdater {
237            public_key,
238            sector_index,
239            handlers: Arc::clone(&self.handlers),
240            metrics: self.metrics.clone(),
241        };
242
243        let plotting_fut = {
244            let piece_getter = self.piece_getter.clone();
245            let gpu_records_encoders_manager = self.gpu_records_encoders_manager.clone();
246            let global_mutex = Arc::clone(&self.global_mutex);
247            let erasure_coding = self.erasure_coding.clone();
248            let abort_early = Arc::clone(&self.abort_early);
249            let metrics = self.metrics.clone();
250
251            async move {
252                // Downloading
253                let downloaded_sector = {
254                    if !progress_updater
255                        .update_progress_and_events(
256                            &mut progress_sender,
257                            SectorPlottingProgress::Downloading,
258                        )
259                        .await
260                    {
261                        return;
262                    }
263
264                    // Take mutex briefly to make sure plotting is allowed right now
265                    global_mutex.lock().await;
266
267                    let downloading_start = Instant::now();
268                    let public_key_hash = &public_key.hash();
269
270                    let downloaded_sector_fut = download_sector(DownloadSectorOptions {
271                        public_key_hash,
272                        sector_index,
273                        piece_getter: &piece_getter,
274                        farmer_protocol_info,
275                        erasure_coding: &erasure_coding,
276                        pieces_in_sector,
277                    });
278
279                    let downloaded_sector = match downloaded_sector_fut.await {
280                        Ok(downloaded_sector) => downloaded_sector,
281                        Err(error) => {
282                            warn!(%error, "Failed to download sector");
283
284                            progress_updater
285                                .update_progress_and_events(
286                                    &mut progress_sender,
287                                    SectorPlottingProgress::Error {
288                                        error: format!("Failed to download sector: {error}"),
289                                    },
290                                )
291                                .await;
292
293                            return;
294                        }
295                    };
296
297                    if !progress_updater
298                        .update_progress_and_events(
299                            &mut progress_sender,
300                            SectorPlottingProgress::Downloaded(downloading_start.elapsed()),
301                        )
302                        .await
303                    {
304                        return;
305                    }
306
307                    downloaded_sector
308                };
309
310                // Plotting
311                let (sector, plotted_sector) = {
312                    let mut records_encoder = gpu_records_encoders_manager.get_encoder().await;
313                    if let Some(metrics) = &metrics {
314                        metrics.plotting_capacity_used.inc();
315                    }
316
317                    // Give a chance to interrupt plotting if necessary
318                    yield_now().await;
319
320                    if !progress_updater
321                        .update_progress_and_events(
322                            &mut progress_sender,
323                            SectorPlottingProgress::Encoding,
324                        )
325                        .await
326                    {
327                        if let Some(metrics) = &metrics {
328                            metrics.plotting_capacity_used.dec();
329                        }
330                        return;
331                    }
332
333                    let encoding_start = Instant::now();
334
335                    let plotting_result = tokio::task::block_in_place(move || {
336                        let encoded_sector = encode_sector(
337                            downloaded_sector,
338                            EncodeSectorOptions {
339                                sector_index,
340                                records_encoder: &mut *records_encoder,
341                                abort_early: &abort_early,
342                            },
343                        )?;
344
345                        if abort_early.load(Ordering::Acquire) {
346                            return Err(PlottingError::AbortEarly);
347                        }
348
349                        drop(records_encoder);
350
351                        let mut sector = Vec::new();
352
353                        write_sector(&encoded_sector, &mut sector)?;
354
355                        Ok((sector, encoded_sector.plotted_sector))
356                    });
357
358                    if let Some(metrics) = &metrics {
359                        metrics.plotting_capacity_used.dec();
360                    }
361
362                    match plotting_result {
363                        Ok(plotting_result) => {
364                            if !progress_updater
365                                .update_progress_and_events(
366                                    &mut progress_sender,
367                                    SectorPlottingProgress::Encoded(encoding_start.elapsed()),
368                                )
369                                .await
370                            {
371                                return;
372                            }
373
374                            plotting_result
375                        }
376                        Err(PlottingError::AbortEarly) => {
377                            return;
378                        }
379                        Err(error) => {
380                            progress_updater
381                                .update_progress_and_events(
382                                    &mut progress_sender,
383                                    SectorPlottingProgress::Error {
384                                        error: format!("Failed to encode sector: {error}"),
385                                    },
386                                )
387                                .await;
388
389                            return;
390                        }
391                    }
392                };
393
394                progress_updater
395                    .update_progress_and_events(
396                        &mut progress_sender,
397                        SectorPlottingProgress::Finished {
398                            plotted_sector,
399                            time: start.elapsed(),
400                            sector: Box::pin({
401                                let mut sector = Some(Ok(Bytes::from(sector)));
402
403                                stream::poll_fn(move |_cx| {
404                                    // Just so that permit is dropped with stream itself
405                                    let _downloading_permit = &downloading_permit;
406
407                                    Poll::Ready(sector.take())
408                                })
409                            }),
410                        },
411                    )
412                    .await;
413            }
414        };
415
416        // Spawn a separate task such that `block_in_place` inside will not affect anything else
417        let plotting_task =
418            AsyncJoinOnDrop::new(tokio::spawn(plotting_fut.in_current_span()), true);
419        if let Err(error) = self.tasks_sender.clone().send(plotting_task).await {
420            warn!(%error, "Failed to send plotting task");
421
422            let progress = SectorPlottingProgress::Error {
423                error: format!("Failed to send plotting task: {error}"),
424            };
425
426            self.handlers
427                .plotting_progress
428                .call_simple(&public_key, &sector_index, &progress);
429        }
430    }
431}
432
433struct ProgressUpdater {
434    public_key: Ed25519PublicKey,
435    sector_index: SectorIndex,
436    handlers: Arc<Handlers>,
437    metrics: Option<Arc<GpuPlotterMetrics>>,
438}
439
440impl ProgressUpdater {
441    /// Returns `true` on success and `false` if progress receiver channel is gone
442    async fn update_progress_and_events<PS>(
443        &self,
444        progress_sender: &mut PS,
445        progress: SectorPlottingProgress,
446    ) -> bool
447    where
448        PS: Sink<SectorPlottingProgress> + Unpin,
449        PS::Error: Error,
450    {
451        if let Some(metrics) = &self.metrics {
452            match &progress {
453                SectorPlottingProgress::Downloading => {
454                    metrics.sector_downloading.inc();
455                }
456                SectorPlottingProgress::Downloaded(time) => {
457                    metrics.sector_downloading_time.observe(time.as_secs_f64());
458                    metrics.sector_downloaded.inc();
459                }
460                SectorPlottingProgress::Encoding => {
461                    metrics.sector_encoding.inc();
462                }
463                SectorPlottingProgress::Encoded(time) => {
464                    metrics.sector_encoding_time.observe(time.as_secs_f64());
465                    metrics.sector_encoded.inc();
466                }
467                SectorPlottingProgress::Finished { time, .. } => {
468                    metrics.sector_plotting_time.observe(time.as_secs_f64());
469                    metrics.sector_plotted.inc();
470                }
471                SectorPlottingProgress::Error { .. } => {
472                    metrics.sector_plotting_error.inc();
473                }
474            }
475        }
476        self.handlers.plotting_progress.call_simple(
477            &self.public_key,
478            &self.sector_index,
479            &progress,
480        );
481
482        if let Err(error) = progress_sender.send(progress).await {
483            warn!(%error, "Failed to send progress update");
484
485            false
486        } else {
487            true
488        }
489    }
490}