1mod gpu_encoders_manager;
4pub mod metrics;
5
6use crate::plotter::gpu::gpu_encoders_manager::GpuRecordsEncoderManager;
7use crate::plotter::gpu::metrics::GpuPlotterMetrics;
8use crate::plotter::{Plotter, SectorPlottingProgress};
9use crate::utils::AsyncJoinOnDrop;
10use ab_core_primitives::ed25519::Ed25519PublicKey;
11use ab_core_primitives::sectors::SectorIndex;
12use ab_core_primitives::solutions::ShardCommitmentHash;
13use ab_data_retrieval::piece_getter::PieceGetter;
14use ab_erasure_coding::ErasureCoding;
15use ab_farmer_components::FarmerProtocolInfo;
16use ab_farmer_components::plotting::{
17 DownloadSectorOptions, EncodeSectorOptions, PlottingError, download_sector, encode_sector,
18 write_sector,
19};
20use ab_proof_of_space_gpu::GpuRecordsEncoder;
21use async_lock::{Mutex as AsyncMutex, Semaphore, SemaphoreGuardArc};
22use async_trait::async_trait;
23use bytes::Bytes;
24use event_listener_primitives::{Bag, HandlerId};
25use futures::channel::mpsc;
26use futures::stream::FuturesUnordered;
27use futures::{FutureExt, Sink, SinkExt, StreamExt, select, stream};
28use prometheus_client::registry::Registry;
29use std::error::Error;
30use std::fmt;
31use std::future::pending;
32use std::num::TryFromIntError;
33use std::pin::pin;
34use std::sync::Arc;
35use std::sync::atomic::{AtomicBool, Ordering};
36use std::task::Poll;
37use std::time::Instant;
38use tokio::task::yield_now;
39use tracing::{Instrument, warn};
40
41pub type HandlerFn3<A, B, C> = Arc<dyn Fn(&A, &B, &C) + Send + Sync + 'static>;
43type Handler3<A, B, C> = Bag<HandlerFn3<A, B, C>, A, B, C>;
44
45#[derive(Default, Debug)]
46struct Handlers {
47 plotting_progress: Handler3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
48}
49
50pub struct GpuPlotter<PG> {
52 piece_getter: PG,
53 downloading_semaphore: Arc<Semaphore>,
54 gpu_records_encoders_manager: GpuRecordsEncoderManager,
55 global_mutex: Arc<AsyncMutex<()>>,
56 erasure_coding: ErasureCoding,
57 handlers: Arc<Handlers>,
58 tasks_sender: mpsc::Sender<AsyncJoinOnDrop<()>>,
59 _background_tasks: AsyncJoinOnDrop<()>,
60 abort_early: Arc<AtomicBool>,
61 metrics: Option<Arc<GpuPlotterMetrics>>,
62}
63
64impl<PG> fmt::Debug for GpuPlotter<PG> {
65 #[inline]
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 f.debug_struct("GpuPlotter").finish_non_exhaustive()
68 }
69}
70
71impl<PG> Drop for GpuPlotter<PG> {
72 #[inline]
73 fn drop(&mut self) {
74 self.abort_early.store(true, Ordering::Release);
75 self.tasks_sender.close_channel();
76 }
77}
78
79#[async_trait]
80impl<PG> Plotter for GpuPlotter<PG>
81where
82 PG: PieceGetter + Clone + Send + Sync + 'static,
83{
84 async fn has_free_capacity(&self) -> Result<bool, String> {
85 Ok(self.downloading_semaphore.try_acquire().is_some())
86 }
87
88 async fn plot_sector(
89 &self,
90 public_key: Ed25519PublicKey,
91 shard_commitments_root: ShardCommitmentHash,
92 sector_index: SectorIndex,
93 farmer_protocol_info: FarmerProtocolInfo,
94 pieces_in_sector: u16,
95 _replotting: bool,
96 progress_sender: mpsc::Sender<SectorPlottingProgress>,
97 ) {
98 let start = Instant::now();
99
100 let downloading_permit = self.downloading_semaphore.acquire_arc().await;
103
104 self.plot_sector_internal(
105 start,
106 downloading_permit,
107 public_key,
108 shard_commitments_root,
109 sector_index,
110 farmer_protocol_info,
111 pieces_in_sector,
112 progress_sender,
113 )
114 .await
115 }
116
117 async fn try_plot_sector(
118 &self,
119 public_key: Ed25519PublicKey,
120 shard_commitments_root: ShardCommitmentHash,
121 sector_index: SectorIndex,
122 farmer_protocol_info: FarmerProtocolInfo,
123 pieces_in_sector: u16,
124 _replotting: bool,
125 progress_sender: mpsc::Sender<SectorPlottingProgress>,
126 ) -> bool {
127 let start = Instant::now();
128
129 let Some(downloading_permit) = self.downloading_semaphore.try_acquire_arc() else {
130 return false;
131 };
132
133 self.plot_sector_internal(
134 start,
135 downloading_permit,
136 public_key,
137 shard_commitments_root,
138 sector_index,
139 farmer_protocol_info,
140 pieces_in_sector,
141 progress_sender,
142 )
143 .await;
144
145 true
146 }
147}
148
149impl<PG> GpuPlotter<PG>
150where
151 PG: PieceGetter + Clone + Send + Sync + 'static,
152{
153 pub fn new(
157 piece_getter: PG,
158 downloading_semaphore: Arc<Semaphore>,
159 gpu_records_encoders: Vec<GpuRecordsEncoder>,
160 global_mutex: Arc<AsyncMutex<()>>,
161 erasure_coding: ErasureCoding,
162 registry: Option<&mut Registry>,
163 ) -> Result<Self, TryFromIntError> {
164 let (tasks_sender, mut tasks_receiver) = mpsc::channel(1);
165
166 let background_tasks = AsyncJoinOnDrop::new(
168 tokio::spawn(async move {
169 let background_tasks = FuturesUnordered::new();
170 let mut background_tasks = pin!(background_tasks);
171 background_tasks.push(AsyncJoinOnDrop::new(tokio::spawn(pending::<()>()), true));
173
174 loop {
175 select! {
176 maybe_background_task = tasks_receiver.next().fuse() => {
177 let Some(background_task) = maybe_background_task else {
178 break;
179 };
180
181 background_tasks.push(background_task);
182 },
183 _ = background_tasks.select_next_some() => {
184 }
186 }
187 }
188 }),
189 true,
190 );
191
192 let abort_early = Arc::new(AtomicBool::new(false));
193 let gpu_records_encoders_manager = GpuRecordsEncoderManager::new(gpu_records_encoders)?;
194 let metrics = registry.map(|registry| {
195 Arc::new(GpuPlotterMetrics::new(
196 registry,
197 gpu_records_encoders_manager.gpu_records_encoders(),
198 ))
199 });
200
201 Ok(Self {
202 piece_getter,
203 downloading_semaphore,
204 gpu_records_encoders_manager,
205 global_mutex,
206 erasure_coding,
207 handlers: Arc::default(),
208 tasks_sender,
209 _background_tasks: background_tasks,
210 abort_early,
211 metrics,
212 })
213 }
214
215 pub fn on_plotting_progress(
217 &self,
218 callback: HandlerFn3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
219 ) -> HandlerId {
220 self.handlers.plotting_progress.add(callback)
221 }
222
223 #[expect(clippy::too_many_arguments)]
224 async fn plot_sector_internal<PS>(
225 &self,
226 start: Instant,
227 downloading_permit: SemaphoreGuardArc,
228 public_key: Ed25519PublicKey,
229 shard_commitments_root: ShardCommitmentHash,
230 sector_index: SectorIndex,
231 farmer_protocol_info: FarmerProtocolInfo,
232 pieces_in_sector: u16,
233 mut progress_sender: PS,
234 ) where
235 PS: Sink<SectorPlottingProgress> + Unpin + Send + 'static,
236 PS::Error: Error,
237 {
238 if let Some(metrics) = &self.metrics {
239 metrics.sector_plotting.inc();
240 }
241
242 let progress_updater = ProgressUpdater {
243 public_key,
244 sector_index,
245 handlers: Arc::clone(&self.handlers),
246 metrics: self.metrics.clone(),
247 };
248
249 let plotting_fut = {
250 let piece_getter = self.piece_getter.clone();
251 let gpu_records_encoders_manager = self.gpu_records_encoders_manager.clone();
252 let global_mutex = Arc::clone(&self.global_mutex);
253 let erasure_coding = self.erasure_coding.clone();
254 let abort_early = Arc::clone(&self.abort_early);
255 let metrics = self.metrics.clone();
256
257 async move {
258 let downloaded_sector = {
260 if !progress_updater
261 .update_progress_and_events(
262 &mut progress_sender,
263 SectorPlottingProgress::Downloading,
264 )
265 .await
266 {
267 return;
268 }
269
270 global_mutex.lock().await;
272
273 let downloading_start = Instant::now();
274 let public_key_hash = &public_key.hash();
275
276 let downloaded_sector_fut = download_sector(DownloadSectorOptions {
277 public_key_hash,
278 shard_commitments_root: &shard_commitments_root,
279 sector_index,
280 piece_getter: &piece_getter,
281 farmer_protocol_info,
282 erasure_coding: &erasure_coding,
283 pieces_in_sector,
284 });
285
286 let downloaded_sector = match downloaded_sector_fut.await {
287 Ok(downloaded_sector) => downloaded_sector,
288 Err(error) => {
289 warn!(%error, "Failed to download sector");
290
291 progress_updater
292 .update_progress_and_events(
293 &mut progress_sender,
294 SectorPlottingProgress::Error {
295 error: format!("Failed to download sector: {error}"),
296 },
297 )
298 .await;
299
300 return;
301 }
302 };
303
304 if !progress_updater
305 .update_progress_and_events(
306 &mut progress_sender,
307 SectorPlottingProgress::Downloaded(downloading_start.elapsed()),
308 )
309 .await
310 {
311 return;
312 }
313
314 downloaded_sector
315 };
316
317 let (sector, plotted_sector) = {
319 let mut records_encoder = gpu_records_encoders_manager.get_encoder().await;
320 if let Some(metrics) = &metrics {
321 metrics.plotting_capacity_used.inc();
322 }
323
324 yield_now().await;
326
327 if !progress_updater
328 .update_progress_and_events(
329 &mut progress_sender,
330 SectorPlottingProgress::Encoding,
331 )
332 .await
333 {
334 if let Some(metrics) = &metrics {
335 metrics.plotting_capacity_used.dec();
336 }
337 return;
338 }
339
340 let encoding_start = Instant::now();
341
342 let plotting_result = tokio::task::block_in_place(move || {
343 let encoded_sector = encode_sector(
344 downloaded_sector,
345 EncodeSectorOptions {
346 sector_index,
347 records_encoder: &mut *records_encoder,
348 abort_early: &abort_early,
349 },
350 )?;
351
352 if abort_early.load(Ordering::Acquire) {
353 return Err(PlottingError::AbortEarly);
354 }
355
356 drop(records_encoder);
357
358 let mut sector = Vec::new();
359
360 write_sector(&encoded_sector, &mut sector)?;
361
362 Ok((sector, encoded_sector.plotted_sector))
363 });
364
365 if let Some(metrics) = &metrics {
366 metrics.plotting_capacity_used.dec();
367 }
368
369 match plotting_result {
370 Ok(plotting_result) => {
371 if !progress_updater
372 .update_progress_and_events(
373 &mut progress_sender,
374 SectorPlottingProgress::Encoded(encoding_start.elapsed()),
375 )
376 .await
377 {
378 return;
379 }
380
381 plotting_result
382 }
383 Err(PlottingError::AbortEarly) => {
384 return;
385 }
386 Err(error) => {
387 progress_updater
388 .update_progress_and_events(
389 &mut progress_sender,
390 SectorPlottingProgress::Error {
391 error: format!("Failed to encode sector: {error}"),
392 },
393 )
394 .await;
395
396 return;
397 }
398 }
399 };
400
401 progress_updater
402 .update_progress_and_events(
403 &mut progress_sender,
404 SectorPlottingProgress::Finished {
405 plotted_sector,
406 time: start.elapsed(),
407 sector: Box::pin({
408 let mut sector = Some(Ok(Bytes::from(sector)));
409
410 stream::poll_fn(move |_cx| {
411 let _downloading_permit = &downloading_permit;
413
414 Poll::Ready(sector.take())
415 })
416 }),
417 },
418 )
419 .await;
420 }
421 };
422
423 let plotting_task =
425 AsyncJoinOnDrop::new(tokio::spawn(plotting_fut.in_current_span()), true);
426 if let Err(error) = self.tasks_sender.clone().send(plotting_task).await {
427 warn!(%error, "Failed to send plotting task");
428
429 let progress = SectorPlottingProgress::Error {
430 error: format!("Failed to send plotting task: {error}"),
431 };
432
433 self.handlers
434 .plotting_progress
435 .call_simple(&public_key, §or_index, &progress);
436 }
437 }
438}
439
440struct ProgressUpdater {
441 public_key: Ed25519PublicKey,
442 sector_index: SectorIndex,
443 handlers: Arc<Handlers>,
444 metrics: Option<Arc<GpuPlotterMetrics>>,
445}
446
447impl ProgressUpdater {
448 async fn update_progress_and_events<PS>(
450 &self,
451 progress_sender: &mut PS,
452 progress: SectorPlottingProgress,
453 ) -> bool
454 where
455 PS: Sink<SectorPlottingProgress> + Unpin,
456 PS::Error: Error,
457 {
458 if let Some(metrics) = &self.metrics {
459 match &progress {
460 SectorPlottingProgress::Downloading => {
461 metrics.sector_downloading.inc();
462 }
463 SectorPlottingProgress::Downloaded(time) => {
464 metrics.sector_downloading_time.observe(time.as_secs_f64());
465 metrics.sector_downloaded.inc();
466 }
467 SectorPlottingProgress::Encoding => {
468 metrics.sector_encoding.inc();
469 }
470 SectorPlottingProgress::Encoded(time) => {
471 metrics.sector_encoding_time.observe(time.as_secs_f64());
472 metrics.sector_encoded.inc();
473 }
474 SectorPlottingProgress::Finished { time, .. } => {
475 metrics.sector_plotting_time.observe(time.as_secs_f64());
476 metrics.sector_plotted.inc();
477 }
478 SectorPlottingProgress::Error { .. } => {
479 metrics.sector_plotting_error.inc();
480 }
481 }
482 }
483 self.handlers.plotting_progress.call_simple(
484 &self.public_key,
485 &self.sector_index,
486 &progress,
487 );
488
489 if let Err(error) = progress_sender.send(progress).await {
490 warn!(%error, "Failed to send progress update");
491
492 false
493 } else {
494 true
495 }
496 }
497}