diff --git a/src/grpc/server.rs b/src/grpc/server.rs index fb522839..0fdd9cb6 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -25,11 +25,8 @@ impl ConsumerService for TaskbrokerServer { request: Request, ) -> Result, Status> { let start_time = Instant::now(); - let namespace = &request.get_ref().namespace; - let inflight = self - .store - .get_pending_activation(namespace.as_deref()) - .await; + let namespace = request.get_ref().namespace.as_deref(); + let inflight = self.store.get_pending_activation(namespace).await; match inflight { Ok(Some(inflight)) => { @@ -123,7 +120,6 @@ impl ConsumerService for TaskbrokerServer { }; let start_time = Instant::now(); - let res = match self .store .get_pending_activation(namespace.as_deref()) diff --git a/src/grpc/server_tests.rs b/src/grpc/server_tests.rs index b1c6ad1a..44e6f62c 100644 --- a/src/grpc/server_tests.rs +++ b/src/grpc/server_tests.rs @@ -92,7 +92,9 @@ async fn test_set_task_status_success() { let request = SetTaskStatusRequest { id: "id_0".to_string(), status: 5, // Complete - fetch_next_task: Some(FetchNextTask { namespace: None }), + fetch_next_task: Some(FetchNextTask { + namespace: Some("namespace".to_string()), + }), }; let response = service.set_task_status(Request::new(request)).await; assert!(response.is_ok()); diff --git a/src/store/inflight_activation.rs b/src/store/inflight_activation.rs index 0f2e9b3b..7180e9f6 100644 --- a/src/store/inflight_activation.rs +++ b/src/store/inflight_activation.rs @@ -326,11 +326,13 @@ impl InflightActivationStore { } #[instrument(skip_all)] - pub async fn get_pending_activation( + pub async fn get_pending_activations( &self, - namespace: Option<&str>, - ) -> Result, Error> { + namespaces: Option>, + limit: Option, + ) -> Result>, Error> { let now = Utc::now(); + let to_return = limit.unwrap_or(1); let mut query_builder = QueryBuilder::new( " @@ -344,7 +346,7 @@ impl InflightActivationStore { query_builder.push_bind(InflightActivationStatus::Processing); query_builder.push( " - WHERE id = ( + WHERE id IN ( SELECT id FROM inflight_taskactivations WHERE status = ", @@ -354,19 +356,54 @@ impl InflightActivationStore { query_builder.push_bind(now.timestamp()); query_builder.push(")"); - if let Some(namespace) = namespace { - query_builder.push(" AND namespace = "); - query_builder.push_bind(namespace); + let namespaces_vec: Vec; + if let Some(namespaces) = namespaces { + query_builder.push(" AND namespace IN ("); + let mut separated = query_builder.separated(", "); + namespaces_vec = namespaces.iter().map(|ns| ns.to_string()).collect(); + for namespace in namespaces_vec.iter() { + separated.push_bind(namespace); + } + separated.push_unseparated(")"); } - query_builder.push(" ORDER BY added_at LIMIT 1) RETURNING *"); + query_builder.push(" ORDER BY added_at LIMIT "); + query_builder.push_bind(to_return); + query_builder.push(") RETURNING *"); + + let rows: Vec = query_builder + .build_query_as() + .fetch_all(&self.write_pool) + .await? + .into_iter() + .map(|row: TableRow| row.into()) + .collect(); - let result: Option = query_builder - .build_query_as::() - .fetch_optional(&self.write_pool) - .await?; - let Some(row) = result else { return Ok(None) }; + if rows.is_empty() { + return Ok(None); + } - Ok(Some(row.into())) + Ok(Some(rows)) + } + + #[instrument(skip_all)] + pub async fn get_pending_activation( + &self, + namespace: Option<&str>, + ) -> Result, Error> { + if let Some(namespace) = namespace { + match self + .get_pending_activations(Some(vec![namespace]), Some(1)) + .await? + { + Some(rows) => Ok(Some(rows[0].clone())), + None => Ok(None), + } + } else { + match self.get_pending_activations(None, Some(1)).await? { + Some(rows) => Ok(Some(rows[0].clone())), + None => Ok(None), + } + } } #[instrument(skip_all)] diff --git a/src/store/inflight_activation_tests.rs b/src/store/inflight_activation_tests.rs index d6dbcc57..2d3d3d47 100644 --- a/src/store/inflight_activation_tests.rs +++ b/src/store/inflight_activation_tests.rs @@ -147,11 +147,7 @@ async fn test_get_pending_activation_with_race() { let store = store.clone(); join_set.spawn(async move { rx.recv().await.unwrap(); - store - .get_pending_activation(Some("namespace")) - .await - .unwrap() - .unwrap() + store.get_pending_activation(None).await.unwrap().unwrap() }); } @@ -218,6 +214,26 @@ async fn test_get_pending_activation_earliest() { ); } +#[tokio::test] +async fn test_get_pending_activations() { + let store = create_test_store().await; + + let mut batch = make_activations(5); + batch[1].namespace = "other_namespace".into(); + assert!(store.store(batch.clone()).await.is_ok()); + + let result = store + .get_pending_activations(Some(vec!["namespace", "other_namespace"]), Some(2)) + .await + .unwrap() + .unwrap(); + + assert_eq!(result.len(), 2); + assert_eq!(result[0].activation.id, "id_0"); + assert_eq!(result[1].activation.id, "id_1"); + assert_count_by_status(&store, InflightActivationStatus::Pending, 3).await; +} + #[tokio::test] async fn test_count_pending_activations() { let store = create_test_store().await; @@ -266,7 +282,13 @@ async fn set_activation_status() { .is_ok() ); assert_eq!(store.count_pending_activations().await.unwrap(), 0); - assert!(store.get_pending_activation(None).await.unwrap().is_none()); + assert!( + store + .get_pending_activation(Some("namespace")) + .await + .unwrap() + .is_none() + ); } #[tokio::test] diff --git a/src/test_utils.rs b/src/test_utils.rs index 1a456804..7cfcf4c3 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -49,7 +49,7 @@ pub fn make_activations(count: u32) -> Vec { status: InflightActivationStatus::Pending, partition: 0, offset: i as i64, - added_at: Utc::now(), + added_at: Utc::now() + chrono::Duration::seconds(i as i64), processing_attempts: 0, expires_at: None, processing_deadline: None,