1pub mod metrics;
4
5use crate::plotter::cpu::metrics::CpuPlotterMetrics;
6use crate::plotter::{Plotter, SectorPlottingProgress};
7use crate::thread_pool_manager::PlottingThreadPoolManager;
8use crate::utils::AsyncJoinOnDrop;
9use ab_core_primitives::ed25519::Ed25519PublicKey;
10use ab_core_primitives::sectors::SectorIndex;
11use ab_data_retrieval::piece_getter::PieceGetter;
12use ab_erasure_coding::ErasureCoding;
13use ab_farmer_components::FarmerProtocolInfo;
14use ab_farmer_components::plotting::{
15 CpuRecordsEncoder, DownloadSectorOptions, EncodeSectorOptions, PlottingError, download_sector,
16 encode_sector, write_sector,
17};
18use ab_proof_of_space::Table;
19use async_lock::{Mutex as AsyncMutex, Semaphore, SemaphoreGuardArc};
20use async_trait::async_trait;
21use bytes::Bytes;
22use event_listener_primitives::{Bag, HandlerId};
23use futures::channel::mpsc;
24use futures::stream::FuturesUnordered;
25use futures::{FutureExt, Sink, SinkExt, StreamExt, select, stream};
26use prometheus_client::registry::Registry;
27use std::any::type_name;
28use std::error::Error;
29use std::fmt;
30use std::future::pending;
31use std::marker::PhantomData;
32use std::num::NonZeroUsize;
33use std::pin::pin;
34use std::sync::Arc;
35use std::sync::atomic::{AtomicBool, Ordering};
36use std::task::Poll;
37use std::time::Instant;
38use tokio::task::yield_now;
39use tracing::{Instrument, warn};
40
41pub type HandlerFn3<A, B, C> = Arc<dyn Fn(&A, &B, &C) + Send + Sync + 'static>;
43type Handler3<A, B, C> = Bag<HandlerFn3<A, B, C>, A, B, C>;
44
45#[derive(Default, Debug)]
46struct Handlers {
47 plotting_progress: Handler3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
48}
49
50pub struct CpuPlotter<PG, PosTable> {
52 piece_getter: PG,
53 downloading_semaphore: Arc<Semaphore>,
54 plotting_thread_pool_manager: PlottingThreadPoolManager,
55 record_encoding_concurrency: NonZeroUsize,
56 global_mutex: Arc<AsyncMutex<()>>,
57 erasure_coding: ErasureCoding,
58 handlers: Arc<Handlers>,
59 tasks_sender: mpsc::Sender<AsyncJoinOnDrop<()>>,
60 _background_tasks: AsyncJoinOnDrop<()>,
61 abort_early: Arc<AtomicBool>,
62 metrics: Option<Arc<CpuPlotterMetrics>>,
63 _phantom: PhantomData<PosTable>,
64}
65
66impl<PG, PosTable> fmt::Debug for CpuPlotter<PG, PosTable> {
67 #[inline]
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 f.debug_struct("CpuPlotter").finish_non_exhaustive()
70 }
71}
72
73impl<PG, PosTable> Drop for CpuPlotter<PG, PosTable> {
74 #[inline]
75 fn drop(&mut self) {
76 self.abort_early.store(true, Ordering::Release);
77 self.tasks_sender.close_channel();
78 }
79}
80
81#[async_trait]
82impl<PG, PosTable> Plotter for CpuPlotter<PG, PosTable>
83where
84 PG: PieceGetter + Clone + Send + Sync + 'static,
85 PosTable: Table,
86{
87 async fn has_free_capacity(&self) -> Result<bool, String> {
88 Ok(self.downloading_semaphore.try_acquire().is_some())
89 }
90
91 async fn plot_sector(
92 &self,
93 public_key: Ed25519PublicKey,
94 sector_index: SectorIndex,
95 farmer_protocol_info: FarmerProtocolInfo,
96 pieces_in_sector: u16,
97 replotting: bool,
98 progress_sender: mpsc::Sender<SectorPlottingProgress>,
99 ) {
100 let start = Instant::now();
101
102 let downloading_permit = self.downloading_semaphore.acquire_arc().await;
105
106 self.plot_sector_internal(
107 start,
108 downloading_permit,
109 public_key,
110 sector_index,
111 farmer_protocol_info,
112 pieces_in_sector,
113 replotting,
114 progress_sender,
115 )
116 .await
117 }
118
119 async fn try_plot_sector(
120 &self,
121 public_key: Ed25519PublicKey,
122 sector_index: SectorIndex,
123 farmer_protocol_info: FarmerProtocolInfo,
124 pieces_in_sector: u16,
125 replotting: bool,
126 progress_sender: mpsc::Sender<SectorPlottingProgress>,
127 ) -> bool {
128 let start = Instant::now();
129
130 let Some(downloading_permit) = self.downloading_semaphore.try_acquire_arc() else {
131 return false;
132 };
133
134 self.plot_sector_internal(
135 start,
136 downloading_permit,
137 public_key,
138 sector_index,
139 farmer_protocol_info,
140 pieces_in_sector,
141 replotting,
142 progress_sender,
143 )
144 .await;
145
146 true
147 }
148}
149
150impl<PG, PosTable> CpuPlotter<PG, PosTable>
151where
152 PG: PieceGetter + Clone + Send + Sync + 'static,
153 PosTable: Table,
154{
155 #[allow(clippy::too_many_arguments)]
157 pub fn new(
158 piece_getter: PG,
159 downloading_semaphore: Arc<Semaphore>,
160 plotting_thread_pool_manager: PlottingThreadPoolManager,
161 record_encoding_concurrency: NonZeroUsize,
162 global_mutex: Arc<AsyncMutex<()>>,
163 erasure_coding: ErasureCoding,
164 registry: Option<&mut Registry>,
165 ) -> Self {
166 let (tasks_sender, mut tasks_receiver) = mpsc::channel(1);
167
168 let background_tasks = AsyncJoinOnDrop::new(
170 tokio::spawn(async move {
171 let background_tasks = FuturesUnordered::new();
172 let mut background_tasks = pin!(background_tasks);
173 background_tasks.push(AsyncJoinOnDrop::new(tokio::spawn(pending::<()>()), true));
175
176 loop {
177 select! {
178 maybe_background_task = tasks_receiver.next().fuse() => {
179 let Some(background_task) = maybe_background_task else {
180 break;
181 };
182
183 background_tasks.push(background_task);
184 },
185 _ = background_tasks.select_next_some() => {
186 }
188 }
189 }
190 }),
191 true,
192 );
193
194 let abort_early = Arc::new(AtomicBool::new(false));
195 let metrics = registry.map(|registry| {
196 Arc::new(CpuPlotterMetrics::new(
197 registry,
198 type_name::<PosTable>(),
199 plotting_thread_pool_manager.thread_pool_pairs(),
200 ))
201 });
202
203 Self {
204 piece_getter,
205 downloading_semaphore,
206 plotting_thread_pool_manager,
207 record_encoding_concurrency,
208 global_mutex,
209 erasure_coding,
210 handlers: Arc::default(),
211 tasks_sender,
212 _background_tasks: background_tasks,
213 abort_early,
214 metrics,
215 _phantom: PhantomData,
216 }
217 }
218
219 pub fn on_plotting_progress(
221 &self,
222 callback: HandlerFn3<Ed25519PublicKey, SectorIndex, SectorPlottingProgress>,
223 ) -> HandlerId {
224 self.handlers.plotting_progress.add(callback)
225 }
226
227 #[allow(clippy::too_many_arguments)]
228 async fn plot_sector_internal<PS>(
229 &self,
230 start: Instant,
231 downloading_permit: SemaphoreGuardArc,
232 public_key: Ed25519PublicKey,
233 sector_index: SectorIndex,
234 farmer_protocol_info: FarmerProtocolInfo,
235 pieces_in_sector: u16,
236 replotting: bool,
237 mut progress_sender: PS,
238 ) where
239 PS: Sink<SectorPlottingProgress> + Unpin + Send + 'static,
240 PS::Error: Error,
241 {
242 if let Some(metrics) = &self.metrics {
243 metrics.sector_plotting.inc();
244 }
245
246 let progress_updater = ProgressUpdater {
247 public_key,
248 sector_index,
249 handlers: Arc::clone(&self.handlers),
250 metrics: self.metrics.clone(),
251 };
252
253 let plotting_fut = {
254 let piece_getter = self.piece_getter.clone();
255 let plotting_thread_pool_manager = self.plotting_thread_pool_manager.clone();
256 let record_encoding_concurrency = self.record_encoding_concurrency;
257 let global_mutex = Arc::clone(&self.global_mutex);
258 let erasure_coding = self.erasure_coding.clone();
259 let abort_early = Arc::clone(&self.abort_early);
260 let metrics = self.metrics.clone();
261
262 async move {
263 let downloaded_sector = {
265 if !progress_updater
266 .update_progress_and_events(
267 &mut progress_sender,
268 SectorPlottingProgress::Downloading,
269 )
270 .await
271 {
272 return;
273 }
274
275 global_mutex.lock().await;
277
278 let downloading_start = Instant::now();
279 let public_key_hash = &public_key.hash();
280
281 let downloaded_sector_fut = download_sector(DownloadSectorOptions {
282 public_key_hash,
283 sector_index,
284 piece_getter: &piece_getter,
285 farmer_protocol_info,
286 erasure_coding: &erasure_coding,
287 pieces_in_sector,
288 });
289
290 let downloaded_sector = match downloaded_sector_fut.await {
291 Ok(downloaded_sector) => downloaded_sector,
292 Err(error) => {
293 warn!(%error, "Failed to download sector");
294
295 progress_updater
296 .update_progress_and_events(
297 &mut progress_sender,
298 SectorPlottingProgress::Error {
299 error: format!("Failed to download sector: {error}"),
300 },
301 )
302 .await;
303
304 return;
305 }
306 };
307
308 if !progress_updater
309 .update_progress_and_events(
310 &mut progress_sender,
311 SectorPlottingProgress::Downloaded(downloading_start.elapsed()),
312 )
313 .await
314 {
315 return;
316 }
317
318 downloaded_sector
319 };
320
321 let (sector, plotted_sector) = {
323 let thread_pools = plotting_thread_pool_manager.get_thread_pools().await;
324 if let Some(metrics) = &metrics {
325 metrics.plotting_capacity_used.inc();
326 }
327
328 yield_now().await;
330
331 if !progress_updater
332 .update_progress_and_events(
333 &mut progress_sender,
334 SectorPlottingProgress::Encoding,
335 )
336 .await
337 {
338 if let Some(metrics) = &metrics {
339 metrics.plotting_capacity_used.dec();
340 }
341 return;
342 }
343
344 let encoding_start = Instant::now();
345
346 let plotting_result = tokio::task::block_in_place(move || {
347 let thread_pool = if replotting {
348 &thread_pools.replotting
349 } else {
350 &thread_pools.plotting
351 };
352
353 let encoded_sector = thread_pool.install(|| {
354 let generator = PosTable::generator();
356 let generators = (0..record_encoding_concurrency.get())
357 .map(|_| generator.clone())
358 .collect::<Vec<_>>();
359 let mut records_encoder = CpuRecordsEncoder::<PosTable>::new(
360 &generators,
361 &erasure_coding,
362 &global_mutex,
363 );
364
365 encode_sector(
366 downloaded_sector,
367 EncodeSectorOptions {
368 sector_index,
369 records_encoder: &mut records_encoder,
370 abort_early: &abort_early,
371 },
372 )
373 })?;
374
375 if abort_early.load(Ordering::Acquire) {
376 return Err(PlottingError::AbortEarly);
377 }
378
379 drop(thread_pools);
380
381 let mut sector = Vec::new();
382
383 write_sector(&encoded_sector, &mut sector)?;
384
385 Ok((sector, encoded_sector.plotted_sector))
386 });
387
388 if let Some(metrics) = &metrics {
389 metrics.plotting_capacity_used.dec();
390 }
391
392 match plotting_result {
393 Ok(plotting_result) => {
394 if !progress_updater
395 .update_progress_and_events(
396 &mut progress_sender,
397 SectorPlottingProgress::Encoded(encoding_start.elapsed()),
398 )
399 .await
400 {
401 return;
402 }
403
404 plotting_result
405 }
406 Err(PlottingError::AbortEarly) => {
407 return;
408 }
409 Err(error) => {
410 progress_updater
411 .update_progress_and_events(
412 &mut progress_sender,
413 SectorPlottingProgress::Error {
414 error: format!("Failed to encode sector: {error}"),
415 },
416 )
417 .await;
418
419 return;
420 }
421 }
422 };
423
424 progress_updater
425 .update_progress_and_events(
426 &mut progress_sender,
427 SectorPlottingProgress::Finished {
428 plotted_sector,
429 time: start.elapsed(),
430 sector: Box::pin({
431 let mut sector = Some(Ok(Bytes::from(sector)));
432
433 stream::poll_fn(move |_cx| {
434 let _downloading_permit = &downloading_permit;
436
437 Poll::Ready(sector.take())
438 })
439 }),
440 },
441 )
442 .await;
443 }
444 };
445
446 let plotting_task =
448 AsyncJoinOnDrop::new(tokio::spawn(plotting_fut.in_current_span()), true);
449 if let Err(error) = self.tasks_sender.clone().send(plotting_task).await {
450 warn!(%error, "Failed to send plotting task");
451
452 let progress = SectorPlottingProgress::Error {
453 error: format!("Failed to send plotting task: {error}"),
454 };
455
456 self.handlers
457 .plotting_progress
458 .call_simple(&public_key, §or_index, &progress);
459 }
460 }
461}
462
463struct ProgressUpdater {
464 public_key: Ed25519PublicKey,
465 sector_index: SectorIndex,
466 handlers: Arc<Handlers>,
467 metrics: Option<Arc<CpuPlotterMetrics>>,
468}
469
470impl ProgressUpdater {
471 async fn update_progress_and_events<PS>(
473 &self,
474 progress_sender: &mut PS,
475 progress: SectorPlottingProgress,
476 ) -> bool
477 where
478 PS: Sink<SectorPlottingProgress> + Unpin,
479 PS::Error: Error,
480 {
481 if let Some(metrics) = &self.metrics {
482 match &progress {
483 SectorPlottingProgress::Downloading => {
484 metrics.sector_downloading.inc();
485 }
486 SectorPlottingProgress::Downloaded(time) => {
487 metrics.sector_downloading_time.observe(time.as_secs_f64());
488 metrics.sector_downloaded.inc();
489 }
490 SectorPlottingProgress::Encoding => {
491 metrics.sector_encoding.inc();
492 }
493 SectorPlottingProgress::Encoded(time) => {
494 metrics.sector_encoding_time.observe(time.as_secs_f64());
495 metrics.sector_encoded.inc();
496 }
497 SectorPlottingProgress::Finished { time, .. } => {
498 metrics.sector_plotting_time.observe(time.as_secs_f64());
499 metrics.sector_plotted.inc();
500 }
501 SectorPlottingProgress::Error { .. } => {
502 metrics.sector_plotting_error.inc();
503 }
504 }
505 }
506 self.handlers.plotting_progress.call_simple(
507 &self.public_key,
508 &self.sector_index,
509 &progress,
510 );
511
512 if let Err(error) = progress_sender.send(progress).await {
513 warn!(%error, "Failed to send progress update");
514
515 false
516 } else {
517 true
518 }
519 }
520}