Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Oct 30, 2025

📄 57% (0.57x) speedup for ClientApp.query in framework/py/flwr/clientapp/client_app.py

⏱️ Runtime : 21.7 microseconds 13.8 microseconds (best of 63 runs)

📝 Explanation and details

The optimization achieves a 57% speedup by inlining the _get_decorator function directly into the query method, eliminating function call overhead that was consuming significant time in the profiling results.

Key optimizations:

  1. Function call elimination: The original code called _get_decorator(self, MessageType.QUERY, action, mods) on every query() call. The optimized version inlines this logic directly, removing the function call overhead that was taking 3,851ns per hit according to the line profiler.

  2. Reduced attribute lookups: Uses local variables like registered_funcs = self._registered_funcs and act = action to cache frequently accessed attributes and parameters, reducing repeated dictionary/object lookups within the inner decorator function.

  3. Optimized mods concatenation: The original code used app._mods + (mods or []) which creates temporary list objects. The optimized version uses conditional logic to avoid unnecessary list concatenation when one of the lists is empty, reducing memory allocation overhead.

  4. Short-circuit identifier validation: Added a fast path that skips the expensive isidentifier() method call when the action is the default "default" value, which is the most common case based on typical usage patterns.

These optimizations are particularly effective for the common case of registering query functions with default actions and no mods, which represents the majority of real-world usage patterns. The line profiler shows the total time dropped from 2.84ms to 0.69ms, demonstrating the significant impact of eliminating the function call overhead in this hot path.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 26 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import inspect
from contextlib import contextmanager
from types import SimpleNamespace

# imports
import pytest
from clientapp.client_app import ClientApp


# Minimal stubs for required classes and functions
class Message:
    def __init__(self, content, reply_to=None, metadata=None):
        self.content = content
        self.reply_to = reply_to
        self.metadata = metadata or SimpleNamespace(message_type="query")

    def __eq__(self, other):
        return (
            isinstance(other, Message)
            and self.content == other.content
            and self.reply_to == other.reply_to
            and getattr(self.metadata, "message_type", None) == getattr(other.metadata, "message_type", None)
        )

class Context:
    def __init__(self, node_id="node", node_config=None):
        self.node_id = node_id
        self.node_config = node_config or {}

# Minimal stub for MessageType
class MessageType:
    QUERY = "query"

# Minimal stub for Mod
def dummy_mod(message, context, fn):
    # Just call the function, but mark the content
    result = fn(message, context)
    result.content = f"{result.content}|modded"
    return result

# Patch in the minimal make_ffn
def make_ffn(ffn, mods):
    def wrap_ffn(_ffn, _mod):
        def new_ffn(message, context):
            return _mod(message, context, _ffn)
        return new_ffn
    for mod in reversed(mods):
        ffn = wrap_ffn(ffn, mod)
    return ffn

# Patch in validate_message_type
def validate_message_type(mt):
    # Accepts "query" or "query.action"
    if not isinstance(mt, str):
        return False
    parts = mt.split(".")
    if len(parts) == 1:
        return parts[0] == "query"
    if len(parts) == 2:
        return parts[0] == "query" and parts[1].isidentifier()
    return False

# Patch in warn_deprecated_feature
def warn_deprecated_feature(msg):
    pass

# Patch in handle_legacy_message_from_msgtype
def handle_legacy_message_from_msgtype(client_fn, message, context):
    # Just call the client_fn and return a Message with same content
    client = client_fn(context)
    return Message(message.content, reply_to=message)

# Patch in Client
class Client:
    def __init__(self):
        pass

# Now, import the code under test
# (We will copy the ClientApp and its dependencies as provided, using the above stubs.)

# --- BEGIN: code under test (ClientApp and helpers) ---


DEFAULT_ACTION = "default"

def _inspect_maybe_adapt_client_fn_signature(client_fn):
    client_fn_args = inspect.signature(client_fn).parameters

    if len(client_fn_args) != 1:
        raise Exception("Invalid client_fn signature")

    first_arg = list(client_fn_args.keys())[0]
    first_arg_type = client_fn_args[first_arg].annotation

    if first_arg_type is str or first_arg == "cid":
        warn_deprecated_feature(
            "`client_fn` now expects a signature `def client_fn(context: Context)`."
            "The provided `client_fn` has signature: "
            f"{dict(client_fn_args.items())}. You can import the `Context` like this:"
            " `from flwr.common import Context`"
        )

        def adaptor_fn(context):
            cid = context.node_config.get("partition-id", context.node_id)
            return client_fn(str(cid))
        return adaptor_fn

    return client_fn

@contextmanager
def _empty_lifespan(_):
    yield
from clientapp.client_app import ClientApp
# --- END: code under test ---

# ----------------------------
# UNIT TESTS FOR query METHOD
# ----------------------------

# 1. BASIC TEST CASES



def test_register_multiple_query_actions():
    """Test registering multiple query actions and routing to the correct one."""
    app = ClientApp()
    @app.query("a")
    def query_a(message, context):
        return Message("A:" + message.content, reply_to=message)
    @app.query("b")
    def query_b(message, context):
        return Message("B:" + message.content, reply_to=message)
    msg_a = Message("one", metadata=SimpleNamespace(message_type="query.a"))
    msg_b = Message("two", metadata=SimpleNamespace(message_type="query.b"))
    ctx = Context()

def test_query_decorator_returns_original_function():
    """Test that the decorator returns the original function object."""
    app = ClientApp()
    def fn(message, context):
        return Message("ok")
    decorated = app.query()(fn)

# 2. EDGE TEST CASES



def test_call_unregistered_query_action():
    """Test calling a query action that was not registered."""
    app = ClientApp()
    @app.query("exists")
    def fn(message, context):
        return Message("ok")
    msg = Message("bar", metadata=SimpleNamespace(message_type="query.missing"))
    ctx = Context()
    with pytest.raises(ValueError) as e:
        app(msg, ctx)

def test_call_invalid_message_type():
    """Test calling with an invalid message type."""
    app = ClientApp()
    @app.query()
    def fn(message, context):
        return Message("ok")
    msg = Message("bar", metadata=SimpleNamespace(message_type="invalidtype"))
    ctx = Context()
    with pytest.raises(ValueError) as e:
        app(msg, ctx)



def test_register_query_after_client_fn_raises():
    """Test registering a query function after providing client_fn raises error."""
    def client_fn(context):
        return Client()
    app = ClientApp(client_fn=client_fn)
    with pytest.raises(ValueError) as e:
        @app.query()
        def fn(message, context):
            return Message("fail")

def test_query_with_missing_metadata_message_type():
    """Test calling with message missing message_type metadata raises error."""
    app = ClientApp()
    @app.query()
    def fn(message, context):
        return Message("ok")
    msg = Message("bar", metadata=SimpleNamespace())
    ctx = Context()
    with pytest.raises(AttributeError):
        app(msg, ctx)






#------------------------------------------------
from typing import Callable, Optional

# imports
import pytest
from clientapp.client_app import ClientApp


# Minimal stubs for required classes and functions
class Message:
    def __init__(self, content, reply_to=None, metadata=None):
        self.content = content
        self.reply_to = reply_to
        self.metadata = metadata if metadata is not None else MessageMetadata()

class MessageMetadata:
    def __init__(self, message_type="query.default"):
        self.message_type = message_type

class Context:
    def __init__(self, node_id="test_node", node_config=None):
        self.node_id = node_id
        self.node_config = node_config if node_config is not None else {}


DEFAULT_ACTION = "default"
from clientapp.client_app import ClientApp

# ---------------------------
# Unit Tests for query
# ---------------------------

# Basic Test Cases



def test_query_multiple_actions():
    # Register multiple query functions with different actions
    app = ClientApp()
    @app.query()
    def default_fn(message, context):
        return Message("default", reply_to=message)
    @app.query("foo")
    def foo_fn(message, context):
        return Message("foo", reply_to=message)
    @app.query("bar")
    def bar_fn(message, context):
        return Message("bar", reply_to=message)
    msg_default = Message("x", metadata=MessageMetadata("query.default"))
    msg_foo = Message("y", metadata=MessageMetadata("query.foo"))
    msg_bar = Message("z", metadata=MessageMetadata("query.bar"))
    ctx = Context()

# Edge Test Cases



def test_query_call_unregistered_action_raises():
    # Calling a query action that is not registered should raise
    app = ClientApp()
    @app.query("exists")
    def exists_fn(message, context):
        return Message("yes")
    msg = Message("test", metadata=MessageMetadata("query.missing"))
    ctx = Context()
    with pytest.raises(ValueError, match="No query function registered with name 'missing'"):
        app(msg, ctx)



def test_query_action_is_identifier_edge_cases():
    # Action names that are valid identifiers but unusual (e.g., "_", "a123", "A")
    app = ClientApp()
    @app.query("_")
    def underscore_fn(message, context):
        return Message("_", reply_to=message)
    @app.query("a123")
    def a123_fn(message, context):
        return Message("a123", reply_to=message)
    @app.query("A")
    def A_fn(message, context):
        return Message("A", reply_to=message)
    ctx = Context()

def test_query_action_case_sensitive():
    # Action names are case sensitive
    app = ClientApp()
    @app.query("foo")
    def foo_fn(message, context):
        return Message("foo", reply_to=message)
    @app.query("FOO")
    def FOO_fn(message, context):
        return Message("FOO", reply_to=message)
    ctx = Context()

# Large Scale Test Cases

To edit these changes git checkout codeflash/optimize-ClientApp.query-mhd36yb3 and push.

Codeflash Static Badge

The optimization achieves a 57% speedup by **inlining the `_get_decorator` function directly into the `query` method**, eliminating function call overhead that was consuming significant time in the profiling results.

**Key optimizations:**

1. **Function call elimination**: The original code called `_get_decorator(self, MessageType.QUERY, action, mods)` on every `query()` call. The optimized version inlines this logic directly, removing the function call overhead that was taking 3,851ns per hit according to the line profiler.

2. **Reduced attribute lookups**: Uses local variables like `registered_funcs = self._registered_funcs` and `act = action` to cache frequently accessed attributes and parameters, reducing repeated dictionary/object lookups within the inner decorator function.

3. **Optimized mods concatenation**: The original code used `app._mods + (mods or [])` which creates temporary list objects. The optimized version uses conditional logic to avoid unnecessary list concatenation when one of the lists is empty, reducing memory allocation overhead.

4. **Short-circuit identifier validation**: Added a fast path that skips the expensive `isidentifier()` method call when the action is the default "default" value, which is the most common case based on typical usage patterns.

These optimizations are particularly effective for the common case of registering query functions with default actions and no mods, which represents the majority of real-world usage patterns. The line profiler shows the total time dropped from 2.84ms to 0.69ms, demonstrating the significant impact of eliminating the function call overhead in this hot path.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 October 30, 2025 07:12
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Oct 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants