Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ pip install flask-ml

Refer simple_server.py, more_server_examples.py, and simple_cli.py

```
python3 -m simple_server
```

#### Client

Refer client_example.py
Expand All @@ -33,3 +37,6 @@ To re-generate the model classes, run
```
make generate-models
```

See this PR for a full-walkthrough of how to contribute to Flask-ML.
https://github.com/UMass-Rescue/Flask-ML/pull/45
3 changes: 3 additions & 0 deletions src/flask_ml/flask_ml_cli/MLCli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
BatchDirectoryInput,
BatchFileInput,
BatchTextInput,
BoolParameterDescriptor,
DirectoryInput,
FileInput,
FloatParameterDescriptor,
Expand Down Expand Up @@ -64,6 +65,8 @@ def get_parameter_argument_validator_func(parameter_schema: ParameterSchema):
return get_int_range_check_func_arg_parser(parameter_schema.value.range)
case IntParameterDescriptor():
return int
case BoolParameterDescriptor():
return bool
case _: # pragma: no cover
assert_never(parameter_schema.value)

Expand Down
12 changes: 11 additions & 1 deletion src/flask_ml/flask_ml_server/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: openapi.yaml
# timestamp: 2025-01-15T12:55:32+00:00
# timestamp: 2025-01-15T13:07:24+00:00

from __future__ import annotations

Expand Down Expand Up @@ -110,6 +110,7 @@ class ParameterType(Enum):
TEXT = 'text'
RANGED_INT = 'ranged_int'
INT = 'int'
BOOLEAN = 'boolean'


class FloatParameterDescriptor(BaseModel):
Expand Down Expand Up @@ -170,6 +171,14 @@ class FloatRangeDescriptor(BaseModel):
max: float


class BoolParameterDescriptor(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
parameter_type: Annotated[Optional[ParameterType], Field(alias='parameterType')] = ParameterType.BOOLEAN
default: bool


class FileType(Enum):
IMG = 'img'
CSV = 'csv'
Expand Down Expand Up @@ -344,6 +353,7 @@ class ParameterSchema(BaseModel):
TextParameterDescriptor,
RangedIntParameterDescriptor,
IntParameterDescriptor,
BoolParameterDescriptor,
]


Expand Down
16 changes: 15 additions & 1 deletion src/flask_ml/flask_ml_server/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,11 @@ components:
- $ref: '#/components/schemas/TextParameterDescriptor'
- $ref: '#/components/schemas/RangedIntParameterDescriptor'
- $ref: '#/components/schemas/IntParameterDescriptor'
- $ref: '#/components/schemas/BoolParameterDescriptor'

ParameterType:
type: string
enum: ["ranged_float", "float", "enum", "text", "ranged_int", "int"]
enum: ["ranged_float", "float", "enum", "text", "ranged_int", "int", "boolean"]

RangedFloatParameterDescriptor:
type: object
Expand Down Expand Up @@ -348,6 +349,19 @@ components:
type: number
max:
type: number

BoolParameterDescriptor:
type: object
required: [parameterType, default]
discriminator:
propertyName: parameterType
properties:
parameterType:
type: string
$ref: "#/components/schemas/ParameterType"
default: "boolean"
default:
type: boolean

# Request Models
RequestBody:
Expand Down
10 changes: 9 additions & 1 deletion src/flask_ml/flask_ml_server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BatchDirectoryInput,
BatchFileInput,
BatchTextInput,
BoolParameterDescriptor,
DirectoryInput,
EnumParameterDescriptor,
FileInput,
Expand Down Expand Up @@ -97,7 +98,6 @@ def schema_get_sample_payload(schema: TaskSchema) -> RequestBody:
parameter_schema = schema.parameters

inputs: Dict[str, Input] = {}
parameters = {}
for input_schema in input_schema:
input_type = input_schema.input_type
match input_type:
Expand Down Expand Up @@ -144,6 +144,8 @@ def schema_get_sample_payload(schema: TaskSchema) -> RequestBody:
)
case _: # pragma: no cover
assert_never(input_type)

parameters = {}
for parameter_schema in parameter_schema:
match parameter_schema.value:
case RangedFloatParameterDescriptor():
Expand All @@ -158,6 +160,8 @@ def schema_get_sample_payload(schema: TaskSchema) -> RequestBody:
parameters[parameter_schema.key] = parameter_schema.value.range.min
case IntParameterDescriptor():
parameters[parameter_schema.key] = parameter_schema.value.default
case BoolParameterDescriptor():
parameters[parameter_schema.key] = parameter_schema.value.default
case _: # pragma: no cover
assert_never(parameter_schema.value)
return RequestBody(inputs=inputs, parameters=parameters)
Expand Down Expand Up @@ -303,6 +307,10 @@ def ensure_ml_func_hinting_and_task_schemas_are_valid(
assert (
parameter_type_hint is int
), f"For key {key}, the parameter type is ParameterType.INT, but the TypeDict hint is {parameter_type_hint}. Change to int."
case ParameterType.BOOLEAN:
assert (
parameter_type_hint is bool
), f"For key {key}, the parameter type is ParameterType.BOOLEAN, but the TypeDict hint is {parameter_type_hint}. Change to bool."
case _: # pragma: no cover
assert_never(parameter_type)

Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ class FloatParameters(TypedDict):

class EnumParameters(TypedDict):
param1: str

class BoolParameters(TypedDict):
param1: bool

@server.route("/process_text")
def server_process_text(inputs: SingleTextInput, parameters: TextParameters) -> ResponseBody:
Expand Down Expand Up @@ -181,6 +184,15 @@ def server_process_directories_and_ranged_int_parameter_with_schema(
) -> ResponseBody:
return ResponseBody(root=process_directories(inputs["dir_inputs"].directories, parameters))

@server.route(
"/process_directories_and_boolean_parameter_with_schema",
get_task_schema(BATCHDIRECTORY_INPUT_SCHEMA, BOOL_PARAM_SCHEMA),
)
def server_process_directories_and_boolean_parameter_with_schema(
inputs: DirectoryInputs, parameters: BoolParameters
) -> ResponseBody:
return ResponseBody(root=process_directories(inputs["dir_inputs"].directories, parameters))

@server.route(
"/process_text_input_with_text_area_schema",
get_task_schema(TEXTAREA_INPUT_SCHEMA, TEXT_PARAM_SCHEMA),
Expand Down
5 changes: 5 additions & 0 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,8 @@
range=IntRangeDescriptor(min=0, max=10),
),
)
BOOL_PARAM_SCHEMA = ParameterSchema(
key="param1",
label="Boolean Parameter",
value=BoolParameterDescriptor(parameter_type=ParameterType.BOOLEAN, default=False),
)
1 change: 1 addition & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_arg_parser_has_all_subcommands(ml_cli: MLCli):
"process_newfile_with_schema",
"process_directory_and_enum_parameter_with_schema",
"process_directories_and_ranged_int_parameter_with_schema",
"process_directories_and_boolean_parameter_with_schema",
"process_text_input_with_text_area_schema"
}

Expand Down
8 changes: 8 additions & 0 deletions tests/test_ml_server_and_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ def test_list_routes(app):
"short_title": "",
"task_schema": "/process_directories_and_ranged_int_parameter_with_schema/task_schema",
},
{
"order": 0,
"payload_schema": "/process_directories_and_boolean_parameter_with_schema/payload_schema",
"run_task": "/process_directories_and_boolean_parameter_with_schema",
"sample_payload": "/process_directories_and_boolean_parameter_with_schema/sample_payload",
"short_title": "",
"task_schema": "/process_directories_and_boolean_parameter_with_schema/task_schema",
},
{
"order": 0,
"payload_schema": "/process_text_input_with_text_area_schema/payload_schema",
Expand Down
1 change: 1 addition & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def test_schema_get_sample_payload_on_valid_input(input_schema, expected_inputs)
(INT_PARAM_SCHEMA, {"param1": 1}),
(RANGED_FLOAT_PARAM_SCHEMA, {"param1": 0.0}),
(RANGED_INT_PARAM_SCHEMA, {"param1": 0}),
(BOOL_PARAM_SCHEMA, {"param1": False}),
],
)
def test_schema_get_sample_payload_on_valid_parameters(parameter_schema, expected_parameters):
Expand Down
2 changes: 2 additions & 0 deletions website/materials/guides/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ Next, we will write a similar schema for our parameters by building an `Paramete
- `RangedFloatParameterDescriptor`
- `default`: default value
- `range`: a `FloatRangeDescriptor` containing a `min` value and `max` value
- `BoolParameterDescriptor`
- `default`: default value (True or False)

Let's write a parameter schema for our function's parameters:

Expand Down
Loading