diff --git a/README.md b/README.md index 9d8a3e0..36861db 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,10 @@ pip install flask-ml Refer simple_server.py, more_server_examples.py, and simple_cli.py +``` +python3 -m simple_server +``` + #### Client Refer client_example.py @@ -33,3 +37,6 @@ To re-generate the model classes, run ``` make generate-models ``` + +See this PR for a full-walkthrough of how to contribute to Flask-ML. +https://github.com/UMass-Rescue/Flask-ML/pull/45 diff --git a/src/flask_ml/flask_ml_cli/MLCli.py b/src/flask_ml/flask_ml_cli/MLCli.py index 99b3f2c..aafb51d 100644 --- a/src/flask_ml/flask_ml_cli/MLCli.py +++ b/src/flask_ml/flask_ml_cli/MLCli.py @@ -14,6 +14,7 @@ BatchDirectoryInput, BatchFileInput, BatchTextInput, + BoolParameterDescriptor, DirectoryInput, FileInput, FloatParameterDescriptor, @@ -64,6 +65,8 @@ def get_parameter_argument_validator_func(parameter_schema: ParameterSchema): return get_int_range_check_func_arg_parser(parameter_schema.value.range) case IntParameterDescriptor(): return int + case BoolParameterDescriptor(): + return bool case _: # pragma: no cover assert_never(parameter_schema.value) diff --git a/src/flask_ml/flask_ml_server/models.py b/src/flask_ml/flask_ml_server/models.py index e2e4d27..2dac937 100644 --- a/src/flask_ml/flask_ml_server/models.py +++ b/src/flask_ml/flask_ml_server/models.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: openapi.yaml -# timestamp: 2025-01-15T12:55:32+00:00 +# timestamp: 2025-01-15T13:07:24+00:00 from __future__ import annotations @@ -110,6 +110,7 @@ class ParameterType(Enum): TEXT = 'text' RANGED_INT = 'ranged_int' INT = 'int' + BOOLEAN = 'boolean' class FloatParameterDescriptor(BaseModel): @@ -170,6 +171,14 @@ class FloatRangeDescriptor(BaseModel): max: float +class BoolParameterDescriptor(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + parameter_type: Annotated[Optional[ParameterType], Field(alias='parameterType')] = ParameterType.BOOLEAN + default: bool + + class FileType(Enum): IMG = 'img' CSV = 'csv' @@ -344,6 +353,7 @@ class ParameterSchema(BaseModel): TextParameterDescriptor, RangedIntParameterDescriptor, IntParameterDescriptor, + BoolParameterDescriptor, ] diff --git a/src/flask_ml/flask_ml_server/openapi.yaml b/src/flask_ml/flask_ml_server/openapi.yaml index 294e5bf..6c5b4f2 100644 --- a/src/flask_ml/flask_ml_server/openapi.yaml +++ b/src/flask_ml/flask_ml_server/openapi.yaml @@ -231,10 +231,11 @@ components: - $ref: '#/components/schemas/TextParameterDescriptor' - $ref: '#/components/schemas/RangedIntParameterDescriptor' - $ref: '#/components/schemas/IntParameterDescriptor' + - $ref: '#/components/schemas/BoolParameterDescriptor' ParameterType: type: string - enum: ["ranged_float", "float", "enum", "text", "ranged_int", "int"] + enum: ["ranged_float", "float", "enum", "text", "ranged_int", "int", "boolean"] RangedFloatParameterDescriptor: type: object @@ -348,6 +349,19 @@ components: type: number max: type: number + + BoolParameterDescriptor: + type: object + required: [parameterType, default] + discriminator: + propertyName: parameterType + properties: + parameterType: + type: string + $ref: "#/components/schemas/ParameterType" + default: "boolean" + default: + type: boolean # Request Models RequestBody: diff --git a/src/flask_ml/flask_ml_server/utils.py b/src/flask_ml/flask_ml_server/utils.py index c871979..de45017 100644 --- a/src/flask_ml/flask_ml_server/utils.py +++ b/src/flask_ml/flask_ml_server/utils.py @@ -10,6 +10,7 @@ BatchDirectoryInput, BatchFileInput, BatchTextInput, + BoolParameterDescriptor, DirectoryInput, EnumParameterDescriptor, FileInput, @@ -97,7 +98,6 @@ def schema_get_sample_payload(schema: TaskSchema) -> RequestBody: parameter_schema = schema.parameters inputs: Dict[str, Input] = {} - parameters = {} for input_schema in input_schema: input_type = input_schema.input_type match input_type: @@ -144,6 +144,8 @@ def schema_get_sample_payload(schema: TaskSchema) -> RequestBody: ) case _: # pragma: no cover assert_never(input_type) + + parameters = {} for parameter_schema in parameter_schema: match parameter_schema.value: case RangedFloatParameterDescriptor(): @@ -158,6 +160,8 @@ def schema_get_sample_payload(schema: TaskSchema) -> RequestBody: parameters[parameter_schema.key] = parameter_schema.value.range.min case IntParameterDescriptor(): parameters[parameter_schema.key] = parameter_schema.value.default + case BoolParameterDescriptor(): + parameters[parameter_schema.key] = parameter_schema.value.default case _: # pragma: no cover assert_never(parameter_schema.value) return RequestBody(inputs=inputs, parameters=parameters) @@ -303,6 +307,10 @@ def ensure_ml_func_hinting_and_task_schemas_are_valid( assert ( parameter_type_hint is int ), f"For key {key}, the parameter type is ParameterType.INT, but the TypeDict hint is {parameter_type_hint}. Change to int." + case ParameterType.BOOLEAN: + assert ( + parameter_type_hint is bool + ), f"For key {key}, the parameter type is ParameterType.BOOLEAN, but the TypeDict hint is {parameter_type_hint}. Change to bool." case _: # pragma: no cover assert_never(parameter_type) diff --git a/tests/conftest.py b/tests/conftest.py index 7c73254..ef89523 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -104,6 +104,9 @@ class FloatParameters(TypedDict): class EnumParameters(TypedDict): param1: str + + class BoolParameters(TypedDict): + param1: bool @server.route("/process_text") def server_process_text(inputs: SingleTextInput, parameters: TextParameters) -> ResponseBody: @@ -181,6 +184,15 @@ def server_process_directories_and_ranged_int_parameter_with_schema( ) -> ResponseBody: return ResponseBody(root=process_directories(inputs["dir_inputs"].directories, parameters)) + @server.route( + "/process_directories_and_boolean_parameter_with_schema", + get_task_schema(BATCHDIRECTORY_INPUT_SCHEMA, BOOL_PARAM_SCHEMA), + ) + def server_process_directories_and_boolean_parameter_with_schema( + inputs: DirectoryInputs, parameters: BoolParameters + ) -> ResponseBody: + return ResponseBody(root=process_directories(inputs["dir_inputs"].directories, parameters)) + @server.route( "/process_text_input_with_text_area_schema", get_task_schema(TEXTAREA_INPUT_SCHEMA, TEXT_PARAM_SCHEMA), diff --git a/tests/constants.py b/tests/constants.py index 2929b31..39bb646 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -61,3 +61,8 @@ range=IntRangeDescriptor(min=0, max=10), ), ) +BOOL_PARAM_SCHEMA = ParameterSchema( + key="param1", + label="Boolean Parameter", + value=BoolParameterDescriptor(parameter_type=ParameterType.BOOLEAN, default=False), +) diff --git a/tests/test_cli.py b/tests/test_cli.py index 61b1370..7e0a2c4 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -42,6 +42,7 @@ def test_arg_parser_has_all_subcommands(ml_cli: MLCli): "process_newfile_with_schema", "process_directory_and_enum_parameter_with_schema", "process_directories_and_ranged_int_parameter_with_schema", + "process_directories_and_boolean_parameter_with_schema", "process_text_input_with_text_area_schema" } diff --git a/tests/test_ml_server_and_client.py b/tests/test_ml_server_and_client.py index a315f4f..11dd912 100644 --- a/tests/test_ml_server_and_client.py +++ b/tests/test_ml_server_and_client.py @@ -110,6 +110,14 @@ def test_list_routes(app): "short_title": "", "task_schema": "/process_directories_and_ranged_int_parameter_with_schema/task_schema", }, + { + "order": 0, + "payload_schema": "/process_directories_and_boolean_parameter_with_schema/payload_schema", + "run_task": "/process_directories_and_boolean_parameter_with_schema", + "sample_payload": "/process_directories_and_boolean_parameter_with_schema/sample_payload", + "short_title": "", + "task_schema": "/process_directories_and_boolean_parameter_with_schema/task_schema", + }, { "order": 0, "payload_schema": "/process_text_input_with_text_area_schema/payload_schema", diff --git a/tests/test_utils.py b/tests/test_utils.py index 876a7be..aead89b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -213,6 +213,7 @@ def test_schema_get_sample_payload_on_valid_input(input_schema, expected_inputs) (INT_PARAM_SCHEMA, {"param1": 1}), (RANGED_FLOAT_PARAM_SCHEMA, {"param1": 0.0}), (RANGED_INT_PARAM_SCHEMA, {"param1": 0}), + (BOOL_PARAM_SCHEMA, {"param1": False}), ], ) def test_schema_get_sample_payload_on_valid_parameters(parameter_schema, expected_parameters): diff --git a/website/materials/guides/getting-started.md b/website/materials/guides/getting-started.md index a7b5dac..8fe8877 100644 --- a/website/materials/guides/getting-started.md +++ b/website/materials/guides/getting-started.md @@ -300,6 +300,8 @@ Next, we will write a similar schema for our parameters by building an `Paramete - `RangedFloatParameterDescriptor` - `default`: default value - `range`: a `FloatRangeDescriptor` containing a `min` value and `max` value +- `BoolParameterDescriptor` + - `default`: default value (True or False) Let's write a parameter schema for our function's parameters: