ab_networking/protocols/request_response/handlers/
generic_request_handler.rs1use crate::protocols::request_response::request_response_factory::{
5 IncomingRequest, OutgoingResponse, ProtocolConfig, RequestHandler,
6};
7use async_trait::async_trait;
8use futures::channel::mpsc;
9use futures::prelude::*;
10use libp2p::PeerId;
11use parity_scale_codec::{Decode, Encode};
12use std::fmt;
13use std::pin::Pin;
14use std::sync::Arc;
15use tracing::{debug, trace};
16
17const REQUESTS_BUFFER_SIZE: usize = 50;
19
20pub trait GenericRequest: Encode + Decode + Send + Sync + 'static {
22 const PROTOCOL_NAME: &'static str;
24 const LOG_TARGET: &'static str;
26 type Response: Encode + Decode + Send + Sync + 'static;
28}
29
30type RequestHandlerFn<Request> = Arc<
31 dyn (Fn(
32 PeerId,
33 Request,
34 )
35 -> Pin<Box<dyn Future<Output = Option<<Request as GenericRequest>::Response>> + Send>>)
36 + Send
37 + Sync
38 + 'static,
39>;
40
41pub struct GenericRequestHandler<Request>
43where
44 Request: GenericRequest,
45{
46 request_receiver: mpsc::Receiver<IncomingRequest>,
47 request_handler: RequestHandlerFn<Request>,
48 protocol_config: ProtocolConfig,
49}
50
51impl<Request> fmt::Debug for GenericRequestHandler<Request>
52where
53 Request: GenericRequest,
54{
55 #[inline]
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 f.debug_struct("GenericRequestHandler")
58 .field("protocol_name", &Request::PROTOCOL_NAME)
59 .finish_non_exhaustive()
60 }
61}
62
63impl<Request> GenericRequestHandler<Request>
64where
65 Request: GenericRequest,
66{
67 pub fn create<RH, Fut>(request_handler: RH) -> Box<dyn RequestHandler>
69 where
70 RH: (Fn(PeerId, Request) -> Fut) + Send + Sync + 'static,
71 Fut: Future<Output = Option<Request::Response>> + Send + 'static,
72 {
73 let (request_sender, request_receiver) = mpsc::channel(REQUESTS_BUFFER_SIZE);
74
75 let mut protocol_config = ProtocolConfig::new(Request::PROTOCOL_NAME);
76 protocol_config.inbound_queue = Some(request_sender);
77
78 Box::new(Self {
79 request_receiver,
80 request_handler: Arc::new(move |peer_id, request| {
81 Box::pin(request_handler(peer_id, request))
82 }),
83 protocol_config,
84 })
85 }
86
87 async fn handle_request(
89 &self,
90 peer: PeerId,
91 payload: Vec<u8>,
92 ) -> Result<Vec<u8>, RequestHandlerError> {
93 trace!(%peer, protocol=Request::LOG_TARGET, "Handling request...");
94 let request = Request::decode(&mut payload.as_slice())
95 .map_err(|_error| RequestHandlerError::InvalidRequestFormat)?;
96 let response = (self.request_handler)(peer, request).await;
97
98 Ok(response.ok_or(RequestHandlerError::NoResponse)?.encode())
99 }
100}
101
102#[async_trait]
103impl<Request> RequestHandler for GenericRequestHandler<Request>
104where
105 Request: GenericRequest,
106{
107 async fn run(&mut self) {
109 while let Some(request) = self.request_receiver.next().await {
110 let IncomingRequest {
111 peer,
112 payload,
113 pending_response,
114 } = request;
115
116 match self.handle_request(peer, payload).await {
117 Ok(response_data) => {
118 let response = OutgoingResponse {
119 result: Ok(response_data),
120 sent_feedback: None,
121 };
122
123 if pending_response.send(response).is_ok() {
124 trace!(target = Request::LOG_TARGET, %peer, "Handled request");
125 } else {
126 debug!(
127 target = Request::LOG_TARGET,
128 protocol = Request::PROTOCOL_NAME,
129 %peer,
130 "Failed to handle request: {}",
131 RequestHandlerError::SendResponse
132 );
133 }
134 }
135 Err(e) => {
136 debug!(
137 target = Request::LOG_TARGET,
138 protocol = Request::PROTOCOL_NAME,
139 %e,
140 "Failed to handle request.",
141 );
142
143 let response = OutgoingResponse {
144 result: Err(()),
145 sent_feedback: None,
146 };
147
148 if pending_response.send(response).is_err() {
149 debug!(
150 target = Request::LOG_TARGET,
151 protocol = Request::PROTOCOL_NAME,
152 %peer,
153 "Failed to handle request: {}", RequestHandlerError::SendResponse
154 );
155 }
156 }
157 }
158 }
159 }
160
161 fn protocol_config(&self) -> ProtocolConfig {
162 self.protocol_config.clone()
163 }
164
165 fn protocol_name(&self) -> &'static str {
166 Request::PROTOCOL_NAME
167 }
168
169 fn clone_box(&self) -> Box<dyn RequestHandler> {
170 let (request_sender, request_receiver) = mpsc::channel(REQUESTS_BUFFER_SIZE);
171
172 let mut protocol_config = ProtocolConfig::new(Request::PROTOCOL_NAME);
173 protocol_config.inbound_queue = Some(request_sender);
174
175 Box::new(Self {
176 request_receiver,
177 request_handler: Arc::clone(&self.request_handler),
178 protocol_config,
179 })
180 }
181}
182
183#[derive(Debug, thiserror::Error)]
184enum RequestHandlerError {
185 #[error("Failed to send response.")]
186 SendResponse,
187
188 #[error("Incorrect request format.")]
189 InvalidRequestFormat,
190
191 #[error("No response.")]
192 NoResponse,
193}