Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,17 +342,7 @@ where
params: JsonValue,
options: SpawnOptions,
) -> DurableResult<SpawnResult> {
// Validate that the task is registered
{
let registry = self.registry.read().await;
if !registry.contains_key(task_name) {
return Err(DurableError::TaskNotRegistered {
task_name: task_name.to_string(),
});
}
}

self.spawn_by_name_internal(&self.pool, task_name, params, options)
self.spawn_by_name_with(&self.pool, task_name, params, options)
.await
}

Expand Down Expand Up @@ -432,11 +422,16 @@ where
// Validate that the task is registered
{
let registry = self.registry.read().await;
if !registry.contains_key(task_name) {
let Some(task) = registry.get(task_name) else {
return Err(DurableError::TaskNotRegistered {
task_name: task_name.to_string(),
});
}
};
task.validate_params(params.clone())
.map_err(|e| DurableError::InvalidTaskParams {
task_name: task_name.to_string(),
message: e.to_string(),
})?;
}

self.spawn_by_name_internal(executor, task_name, params, options)
Expand Down
12 changes: 12 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,18 @@ pub enum DurableError {
task_name: String,
},

//// Task params validation failed.
///
/// Returned when the task definition in the registry fails to validate the params
/// (before we attempt to spawn the task in Postgres).
#[error("invalid task parameters for '{task_name}': {message}")]
InvalidTaskParams {
/// The name of the task being spawned
task_name: String,
/// The error message from the task.
message: String,
},

/// Header key uses a reserved prefix.
///
/// User-provided headers cannot start with "durable::" as this prefix
Expand Down
8 changes: 8 additions & 0 deletions src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ where
State: Clone + Send + Sync + 'static,
{
fn name(&self) -> Cow<'static, str>;
/// Called before spawning, to check that the `params` are valid for this task.
fn validate_params(&self, params: JsonValue) -> Result<(), TaskError>;
async fn execute(
&self,
params: JsonValue,
Expand All @@ -127,6 +129,12 @@ where
T::name()
}

fn validate_params(&self, params: JsonValue) -> Result<(), TaskError> {
// For now, just deserialize
let _typed_params: T::Params = serde_json::from_value(params)?;
Ok(())
}

async fn execute(
&self,
params: JsonValue,
Expand Down
39 changes: 34 additions & 5 deletions tests/spawn_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
mod common;

use common::tasks::{EchoParams, EchoTask, FailingParams, FailingTask};
use durable::{CancellationPolicy, Durable, MIGRATOR, RetryStrategy, SpawnOptions};
use durable::{CancellationPolicy, Durable, DurableError, MIGRATOR, RetryStrategy, SpawnOptions};
use sqlx::PgPool;
use std::collections::HashMap;
use std::time::Duration;
Expand Down Expand Up @@ -270,6 +270,33 @@ async fn test_spawn_by_name(pool: PgPool) -> sqlx::Result<()> {
Ok(())
}

#[sqlx::test(migrator = "MIGRATOR")]
async fn test_spawn_by_name_invalid_params(pool: PgPool) -> sqlx::Result<()> {
let client = create_client(pool.clone(), "spawn_by_name").await;
client.create_queue(None).await.unwrap();
client.register::<EchoTask>().await.unwrap();

let params = serde_json::json!({
"message": 12345
});

let result = client
.spawn_by_name("echo", params, SpawnOptions::default())
.await
.expect_err("Spawning task by name with invalid params should fail");

let DurableError::InvalidTaskParams { task_name, message } = result else {
panic!("Unexpected error: {}", result);
};
assert_eq!(task_name, "echo");
assert_eq!(
message,
"serialization error: invalid type: integer `12345`, expected a string"
);

Ok(())
}

#[sqlx::test(migrator = "MIGRATOR")]
async fn test_spawn_by_name_with_options(pool: PgPool) -> sqlx::Result<()> {
let client = create_client(pool.clone(), "spawn_by_name_opts").await;
Expand Down Expand Up @@ -308,9 +335,10 @@ async fn test_spawn_with_empty_params(pool: PgPool) -> sqlx::Result<()> {
client.create_queue(None).await.unwrap();
client.register::<EchoTask>().await.unwrap();

// Empty object is valid JSON params for EchoTask (message will be missing but that's ok for this test)
// Empty object is not valid JSON params for EchoTask,
// but spawn_by_name_unchecked does not validate the JSON
let result = client
.spawn_by_name("echo", serde_json::json!({}), SpawnOptions::default())
.spawn_by_name_unchecked("echo", serde_json::json!({}), SpawnOptions::default())
.await
.expect("Failed to spawn task with empty params");

Expand All @@ -326,7 +354,8 @@ async fn test_spawn_with_complex_params(pool: PgPool) -> sqlx::Result<()> {
client.register::<EchoTask>().await.unwrap();

// Complex nested JSON structure - the params don't need to match the task's Params type
// because spawn_by_name accepts arbitrary JSON
// because spawn_by_name_unchecked does not validate the JSON
// (unlike `spawn_by_name`)
let params = serde_json::json!({
"nested": {
"array": [1, 2, 3],
Expand All @@ -341,7 +370,7 @@ async fn test_spawn_with_complex_params(pool: PgPool) -> sqlx::Result<()> {
});

let result = client
.spawn_by_name("echo", params, SpawnOptions::default())
.spawn_by_name_unchecked("echo", params, SpawnOptions::default())
.await
.expect("Failed to spawn task with complex params");

Expand Down