1mod gpu_encoders_manager;
4pub mod metrics;
5
6use crate::plotter::gpu::gpu_encoders_manager::GpuRecordsEncoderManager;
7use crate::plotter::gpu::metrics::GpuPlotterMetrics;
8use crate::plotter::{Plotter, SectorPlottingProgress};
9use crate::utils::AsyncJoinOnDrop;
10use ab_core_primitives::ed25519::Ed25519PublicKey;
11use ab_core_primitives::sectors::SectorIndex;
12use ab_data_retrieval::piece_getter::PieceGetter;
13use ab_erasure_coding::ErasureCoding;
14use ab_farmer_components::FarmerProtocolInfo;
15use ab_farmer_components::plotting::{
16 DownloadSectorOptions, EncodeSectorOptions, PlottingError, download_sector, encode_sector,
17 write_sector,
18};
19use ab_proof_of_space_gpu::GpuRecordsEncoder;
20use async_lock::{Mutex as AsyncMutex, Semaphore, SemaphoreGuardArc};
21use async_trait::async_trait;
22use bytes::Bytes;
23use event_listener_primitives::{Bag, HandlerId};
24use futures::channel::mpsc;
25use futures::stream::FuturesUnordered;
26use futures::{FutureExt, Sink, SinkExt, StreamExt, select, stream};
27use prometheus_client::registry::Registry;
28use std::error::Error;
29use std::fmt;
30use std::future::pending;
31use std::num::TryFromIntError;
32use std::pin::pin;
33use std::sync::Arc;
34use std::sync::atomic::{AtomicBool, Ordering};
35use std::task::Poll;
36use std::time::Instant;
37use tokio::task::yield_now;
38use tracing::{Instrument, warn};
39
40pub type HandlerFn3<A, B, C> = Arc<dyn Fn(&A, &B, &C) + Send + Sync + 'static>;
42type Handler3<A, B, C> = Bag<HandlerFn3<A, B, C>, A, B, C>;
43
44#[derive(Default, Debug)]
45struct Handlers {
46 plotting_progress: Handler3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
47}
48
49pub struct GpuPlotter<PG> {
51 piece_getter: PG,
52 downloading_semaphore: Arc<Semaphore>,
53 gpu_records_encoders_manager: GpuRecordsEncoderManager,
54 global_mutex: Arc<AsyncMutex<()>>,
55 erasure_coding: ErasureCoding,
56 handlers: Arc<Handlers>,
57 tasks_sender: mpsc::Sender<AsyncJoinOnDrop<()>>,
58 _background_tasks: AsyncJoinOnDrop<()>,
59 abort_early: Arc<AtomicBool>,
60 metrics: Option<Arc<GpuPlotterMetrics>>,
61}
62
63impl<PG> fmt::Debug for GpuPlotter<PG> {
64 #[inline]
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 f.debug_struct("GpuPlotter").finish_non_exhaustive()
67 }
68}
69
70impl<PG> Drop for GpuPlotter<PG> {
71 #[inline]
72 fn drop(&mut self) {
73 self.abort_early.store(true, Ordering::Release);
74 self.tasks_sender.close_channel();
75 }
76}
77
78#[async_trait]
79impl<PG> Plotter for GpuPlotter<PG>
80where
81 PG: PieceGetter + Clone + Send + Sync + 'static,
82{
83 async fn has_free_capacity(&self) -> Result<bool, String> {
84 Ok(self.downloading_semaphore.try_acquire().is_some())
85 }
86
87 async fn plot_sector(
88 &self,
89 public_key: Ed25519PublicKey,
90 sector_index: SectorIndex,
91 farmer_protocol_info: FarmerProtocolInfo,
92 pieces_in_sector: u16,
93 _replotting: bool,
94 progress_sender: mpsc::Sender<SectorPlottingProgress>,
95 ) {
96 let start = Instant::now();
97
98 let downloading_permit = self.downloading_semaphore.acquire_arc().await;
101
102 self.plot_sector_internal(
103 start,
104 downloading_permit,
105 public_key,
106 sector_index,
107 farmer_protocol_info,
108 pieces_in_sector,
109 progress_sender,
110 )
111 .await
112 }
113
114 async fn try_plot_sector(
115 &self,
116 public_key: Ed25519PublicKey,
117 sector_index: SectorIndex,
118 farmer_protocol_info: FarmerProtocolInfo,
119 pieces_in_sector: u16,
120 _replotting: bool,
121 progress_sender: mpsc::Sender<SectorPlottingProgress>,
122 ) -> bool {
123 let start = Instant::now();
124
125 let Some(downloading_permit) = self.downloading_semaphore.try_acquire_arc() else {
126 return false;
127 };
128
129 self.plot_sector_internal(
130 start,
131 downloading_permit,
132 public_key,
133 sector_index,
134 farmer_protocol_info,
135 pieces_in_sector,
136 progress_sender,
137 )
138 .await;
139
140 true
141 }
142}
143
144impl<PG> GpuPlotter<PG>
145where
146 PG: PieceGetter + Clone + Send + Sync + 'static,
147{
148 pub fn new(
152 piece_getter: PG,
153 downloading_semaphore: Arc<Semaphore>,
154 gpu_records_encoders: Vec<GpuRecordsEncoder>,
155 global_mutex: Arc<AsyncMutex<()>>,
156 erasure_coding: ErasureCoding,
157 registry: Option<&mut Registry>,
158 ) -> Result<Self, TryFromIntError> {
159 let (tasks_sender, mut tasks_receiver) = mpsc::channel(1);
160
161 let background_tasks = AsyncJoinOnDrop::new(
163 tokio::spawn(async move {
164 let background_tasks = FuturesUnordered::new();
165 let mut background_tasks = pin!(background_tasks);
166 background_tasks.push(AsyncJoinOnDrop::new(tokio::spawn(pending::<()>()), true));
168
169 loop {
170 select! {
171 maybe_background_task = tasks_receiver.next().fuse() => {
172 let Some(background_task) = maybe_background_task else {
173 break;
174 };
175
176 background_tasks.push(background_task);
177 },
178 _ = background_tasks.select_next_some() => {
179 }
181 }
182 }
183 }),
184 true,
185 );
186
187 let abort_early = Arc::new(AtomicBool::new(false));
188 let gpu_records_encoders_manager = GpuRecordsEncoderManager::new(gpu_records_encoders)?;
189 let metrics = registry.map(|registry| {
190 Arc::new(GpuPlotterMetrics::new(
191 registry,
192 gpu_records_encoders_manager.gpu_records_encoders(),
193 ))
194 });
195
196 Ok(Self {
197 piece_getter,
198 downloading_semaphore,
199 gpu_records_encoders_manager,
200 global_mutex,
201 erasure_coding,
202 handlers: Arc::default(),
203 tasks_sender,
204 _background_tasks: background_tasks,
205 abort_early,
206 metrics,
207 })
208 }
209
210 pub fn on_plotting_progress(
212 &self,
213 callback: HandlerFn3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
214 ) -> HandlerId {
215 self.handlers.plotting_progress.add(callback)
216 }
217
218 #[allow(clippy::too_many_arguments)]
219 async fn plot_sector_internal<PS>(
220 &self,
221 start: Instant,
222 downloading_permit: SemaphoreGuardArc,
223 public_key: Ed25519PublicKey,
224 sector_index: SectorIndex,
225 farmer_protocol_info: FarmerProtocolInfo,
226 pieces_in_sector: u16,
227 mut progress_sender: PS,
228 ) where
229 PS: Sink<SectorPlottingProgress> + Unpin + Send + 'static,
230 PS::Error: Error,
231 {
232 if let Some(metrics) = &self.metrics {
233 metrics.sector_plotting.inc();
234 }
235
236 let progress_updater = ProgressUpdater {
237 public_key,
238 sector_index,
239 handlers: Arc::clone(&self.handlers),
240 metrics: self.metrics.clone(),
241 };
242
243 let plotting_fut = {
244 let piece_getter = self.piece_getter.clone();
245 let gpu_records_encoders_manager = self.gpu_records_encoders_manager.clone();
246 let global_mutex = Arc::clone(&self.global_mutex);
247 let erasure_coding = self.erasure_coding.clone();
248 let abort_early = Arc::clone(&self.abort_early);
249 let metrics = self.metrics.clone();
250
251 async move {
252 let downloaded_sector = {
254 if !progress_updater
255 .update_progress_and_events(
256 &mut progress_sender,
257 SectorPlottingProgress::Downloading,
258 )
259 .await
260 {
261 return;
262 }
263
264 global_mutex.lock().await;
266
267 let downloading_start = Instant::now();
268 let public_key_hash = &public_key.hash();
269
270 let downloaded_sector_fut = download_sector(DownloadSectorOptions {
271 public_key_hash,
272 sector_index,
273 piece_getter: &piece_getter,
274 farmer_protocol_info,
275 erasure_coding: &erasure_coding,
276 pieces_in_sector,
277 });
278
279 let downloaded_sector = match downloaded_sector_fut.await {
280 Ok(downloaded_sector) => downloaded_sector,
281 Err(error) => {
282 warn!(%error, "Failed to download sector");
283
284 progress_updater
285 .update_progress_and_events(
286 &mut progress_sender,
287 SectorPlottingProgress::Error {
288 error: format!("Failed to download sector: {error}"),
289 },
290 )
291 .await;
292
293 return;
294 }
295 };
296
297 if !progress_updater
298 .update_progress_and_events(
299 &mut progress_sender,
300 SectorPlottingProgress::Downloaded(downloading_start.elapsed()),
301 )
302 .await
303 {
304 return;
305 }
306
307 downloaded_sector
308 };
309
310 let (sector, plotted_sector) = {
312 let mut records_encoder = gpu_records_encoders_manager.get_encoder().await;
313 if let Some(metrics) = &metrics {
314 metrics.plotting_capacity_used.inc();
315 }
316
317 yield_now().await;
319
320 if !progress_updater
321 .update_progress_and_events(
322 &mut progress_sender,
323 SectorPlottingProgress::Encoding,
324 )
325 .await
326 {
327 if let Some(metrics) = &metrics {
328 metrics.plotting_capacity_used.dec();
329 }
330 return;
331 }
332
333 let encoding_start = Instant::now();
334
335 let plotting_result = tokio::task::block_in_place(move || {
336 let encoded_sector = encode_sector(
337 downloaded_sector,
338 EncodeSectorOptions {
339 sector_index,
340 records_encoder: &mut *records_encoder,
341 abort_early: &abort_early,
342 },
343 )?;
344
345 if abort_early.load(Ordering::Acquire) {
346 return Err(PlottingError::AbortEarly);
347 }
348
349 drop(records_encoder);
350
351 let mut sector = Vec::new();
352
353 write_sector(&encoded_sector, &mut sector)?;
354
355 Ok((sector, encoded_sector.plotted_sector))
356 });
357
358 if let Some(metrics) = &metrics {
359 metrics.plotting_capacity_used.dec();
360 }
361
362 match plotting_result {
363 Ok(plotting_result) => {
364 if !progress_updater
365 .update_progress_and_events(
366 &mut progress_sender,
367 SectorPlottingProgress::Encoded(encoding_start.elapsed()),
368 )
369 .await
370 {
371 return;
372 }
373
374 plotting_result
375 }
376 Err(PlottingError::AbortEarly) => {
377 return;
378 }
379 Err(error) => {
380 progress_updater
381 .update_progress_and_events(
382 &mut progress_sender,
383 SectorPlottingProgress::Error {
384 error: format!("Failed to encode sector: {error}"),
385 },
386 )
387 .await;
388
389 return;
390 }
391 }
392 };
393
394 progress_updater
395 .update_progress_and_events(
396 &mut progress_sender,
397 SectorPlottingProgress::Finished {
398 plotted_sector,
399 time: start.elapsed(),
400 sector: Box::pin({
401 let mut sector = Some(Ok(Bytes::from(sector)));
402
403 stream::poll_fn(move |_cx| {
404 let _downloading_permit = &downloading_permit;
406
407 Poll::Ready(sector.take())
408 })
409 }),
410 },
411 )
412 .await;
413 }
414 };
415
416 let plotting_task =
418 AsyncJoinOnDrop::new(tokio::spawn(plotting_fut.in_current_span()), true);
419 if let Err(error) = self.tasks_sender.clone().send(plotting_task).await {
420 warn!(%error, "Failed to send plotting task");
421
422 let progress = SectorPlottingProgress::Error {
423 error: format!("Failed to send plotting task: {error}"),
424 };
425
426 self.handlers
427 .plotting_progress
428 .call_simple(&public_key, §or_index, &progress);
429 }
430 }
431}
432
433struct ProgressUpdater {
434 public_key: Ed25519PublicKey,
435 sector_index: SectorIndex,
436 handlers: Arc<Handlers>,
437 metrics: Option<Arc<GpuPlotterMetrics>>,
438}
439
440impl ProgressUpdater {
441 async fn update_progress_and_events<PS>(
443 &self,
444 progress_sender: &mut PS,
445 progress: SectorPlottingProgress,
446 ) -> bool
447 where
448 PS: Sink<SectorPlottingProgress> + Unpin,
449 PS::Error: Error,
450 {
451 if let Some(metrics) = &self.metrics {
452 match &progress {
453 SectorPlottingProgress::Downloading => {
454 metrics.sector_downloading.inc();
455 }
456 SectorPlottingProgress::Downloaded(time) => {
457 metrics.sector_downloading_time.observe(time.as_secs_f64());
458 metrics.sector_downloaded.inc();
459 }
460 SectorPlottingProgress::Encoding => {
461 metrics.sector_encoding.inc();
462 }
463 SectorPlottingProgress::Encoded(time) => {
464 metrics.sector_encoding_time.observe(time.as_secs_f64());
465 metrics.sector_encoded.inc();
466 }
467 SectorPlottingProgress::Finished { time, .. } => {
468 metrics.sector_plotting_time.observe(time.as_secs_f64());
469 metrics.sector_plotted.inc();
470 }
471 SectorPlottingProgress::Error { .. } => {
472 metrics.sector_plotting_error.inc();
473 }
474 }
475 }
476 self.handlers.plotting_progress.call_simple(
477 &self.public_key,
478 &self.sector_index,
479 &progress,
480 );
481
482 if let Err(error) = progress_sender.send(progress).await {
483 warn!(%error, "Failed to send progress update");
484
485 false
486 } else {
487 true
488 }
489 }
490}