Skip to content

Conversation

@jpli02
Copy link
Contributor

@jpli02 jpli02 commented Jan 12, 2026

Async implementation

  1. Async Implementation (tp_worker.py)
    Improved generate_async() to use async methods consistently
    Made step_async() thread-safe by running scheduler operations in a single-threaded executor
    Added proper async control flow with await asyncio.sleep(0)

  2. New Test File (test_async_inference.py)
    Test suite for async inference with fast_dllm_v2 on gsm8k
    Configurable parameters and performance metrics
    Error handling and cleanup

  3. Fix some small bugs

Summary by CodeRabbit

Release Notes

  • New Features

    • Added asynchronous inference support to enable non-blocking generation requests across worker components
    • Included example script demonstrating async inference workflow and performance measurement
  • Bug Fixes

    • Fixed cache key handling for improved logits retrieval consistency
    • Added block index validation to prevent potential index errors
    • Improved ROCm fallback for broader hardware compatibility
  • Chores

    • Updated example configurations with revised model paths and parameters

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the Diffulex project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link

coderabbitai bot commented Jan 12, 2026

📝 Walkthrough

Walkthrough

The PR introduces asynchronous execution capabilities to multiple engine components (dp_worker, tp_worker, model_runner) using ThreadPoolExecutor and asyncio, adds an async inference example script, fixes cache key handling from numeric to string-based, adds ROCm import fallback for compatibility, updates Triton imports, and modifies example scripts with new configurations.

Changes

Cohort / File(s) Summary
Async Engine Core
diffulex/engine/dp_worker.py, diffulex/engine/tp_worker.py, diffulex/engine/model_runner.py
Added async methods (_ask_async, add_request_async, step_async, is_finished_async, generate_async) with ThreadPoolExecutor lifecycle management. dp_worker parallelize across DP replicas; tp_worker wraps sync logic; model_runner provides single-threaded async delegation. Executor initialized lazily and cleaned up in exit().
Sampler & Cache Management
diffulex/sampler/base.py, diffulex/strategy/block_diffusion/engine/kvcache_manager.py
Changed _fetch_last_logits to use string-keyed cache (seq_id as str) with fallback to current batch's last logit when block uncached, adding error handling for empty logits. Added bounds validation (0 <= prev_block_idx) in may_append to prevent negative indices.
Kernel & Compatibility Updates
diffulex_kernel/python/dllm_flash_attn.py, diffulex_legacy/layers/attention/ops/triton_decode_attn_clm.py, diffulex_legacy/layers/attention/ops/chunked_prefill_decoding_unified_kernel.py
Capitalized loop iterator T.parallel to T.Parallel. Normalized Triton imports (explicit import triton; import triton.language as tl). Added guarded ROCm import with fallback function use_rocm_custom_paged_attention() for environments lacking ROCm support.
Example Scripts
examples/test_async_inference.py, examples/test_dream_dvllm_gsm8k.py, examples/test_fastdllmv2_diffulex_gsm8k.py, examples/test_llada_dvllm_human_eval.py, examples/test_sdar_dvllm.py
Added new async inference example orchestrating concurrent generation with performance metrics. Updated dream/fastdllmv2 examples with dataset sources and parameters. Changed model path capitalization (LLaDA) and added SDAR model import. Modified dataset loading to use full test split or local paths with trust_remote_code.
Shell Script
scripts/test_dvllm_dream_gsm8k.sh
Reduced CUDA_VISIBLE_DEVICES from 0-7 to 0-1.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Worker
    participant Executor as ThreadPoolExecutor
    participant Replica
    
    User->>Worker: generate_async(prompts, params)
    activate Worker
    
    loop For each prompt
        Worker->>Worker: add_request_async(prompt)
        Worker->>Executor: _ask_async(replica, "add_request")
        activate Executor
        Executor->>Replica: add_request (blocking)
        Executor-->>Worker: ✓
        deactivate Executor
    end
    
    loop Until all finished
        Worker->>Executor: step_async()
        activate Executor
        par Parallel replicas
            Executor->>Replica: step() [replica 1]
            Executor->>Replica: step() [replica 2]
            Executor->>Replica: step() [replica N]
        and
        end
        Executor-->>Worker: outputs, tokens, metrics
        deactivate Executor
        
        Worker->>Executor: is_finished_async()
        activate Executor
        Executor->>Replica: is_finished() [all]
        Executor-->>Worker: done status
        deactivate Executor
        
        Worker->>Worker: remap local IDs → global IDs
        Worker->>Worker: aggregate results
    end
    
    Worker-->>User: results (text, token_ids, n_diff_steps)
    deactivate Worker
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

  • zhijie-group/Diffulex#18 — Modifies the same engine files (dp_worker.py, tp_worker.py, model_runner.py) with overlapping class changes for logging and config propagation.

Poem

🐰 Behold! Async threads now multiply,
No more sequential waits—requests fly!
Replicas dance in parallel grace,
Futures resolved at ThreadPool's pace,
Throughput soars high, as inference takes flight!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Async support for inference engine' directly and accurately reflects the main focus of the changeset, which adds asynchronous APIs across multiple worker and engine components.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
examples/test_dream_dvllm_gsm8k.py (1)

19-20: Fix incorrect method name: csv.dictReader should be csv.DictReader.

The csv module uses DictReader (with capital D and R). This will raise an AttributeError at runtime.

🐛 Proposed fix
-        reader = csv.dictReader(f)
+        reader = csv.DictReader(f)
examples/test_llada_dvllm_human_eval.py (1)

18-19: Fix incorrect method name: csv.dictReader should be csv.DictReader.

Same issue as in test_dream_dvllm_gsm8k.py. This will raise an AttributeError at runtime.

🐛 Proposed fix
-        reader = csv.dictReader(f)
+        reader = csv.DictReader(f)
🤖 Fix all issues with AI agents
In
@diffulex_legacy/layers/attention/ops/chunked_prefill_decoding_unified_kernel.py:
- Around line 19-24: The fallback currently catches a broad Exception and
defines use_rocm_custom_paged_attention() with no parameters, but the call site
invokes use_rocm_custom_paged_attention(query.dtype, head_size, block_size,
num_queries_per_kv, max_seq_len, sliding_window, kv_cache_dtype, alibi_slopes);
change the except to catch ImportError only and redefine the fallback function
use_rocm_custom_paged_attention(...same 8 parameters...) -> bool to accept the
same parameter list and return False (match the type hints/signature used at
call site) so CUDA-only environments won't raise a TypeError.

In @diffulex/engine/dp_worker.py:
- Around line 421-424: In generate_async, remove the unused local variables
pending and conn_to_idx (they are assigned from slices.keys() and self.conns but
never referenced); edit the function to delete the lines that create pending and
conn_to_idx so only the used variable collected and slices remain, and run
tests/lint to ensure no other references to pending or conn_to_idx exist.

In @diffulex/engine/model_runner.py:
- Around line 118-126: Replace asyncio.get_event_loop() with
asyncio.get_running_loop() in call_async, and serialize executor creation using
an asyncio.Lock: add an asyncio.Lock instance (e.g., self._executor_lock) during
the ModelRunnerBase __init__, then in call_async do a double-checked pattern: if
getattr(self, '_executor', None) is None: await self._executor_lock; check
_executor again and only then create and assign
ThreadPoolExecutor(max_workers=1) to self._executor; finally call
loop.run_in_executor(self._executor, self.call, method_name, *args). This fixes
the deprecated API usage and prevents concurrent creation of multiple executors.

In @diffulex/engine/tp_worker.py:
- Around line 84-101: The ThreadPoolExecutor created lazily in step_async()
(stored as self._step_executor) is never shut down causing leaked threads;
update the worker's exit() method to check for self._step_executor, call its
shutdown(wait=True) (or shutdown(wait=False) if non-blocking termination is
preferred), handle any exceptions, and set self._step_executor = None so
resources are released and repeated exits are safe; ensure exit() cleans up even
if executor was never created.

In @examples/test_async_inference.py:
- Around line 134-138: The print statement inside the exception handler uses an
f-string with no placeholders; update the except block (the "except Exception as
e:" handler) to remove the unnecessary f prefix from the first print (change
print(f"\n[Error during async inference]") to a normal string print) so it is
not an f-string without placeholders while keeping the other prints
(print(f"Error: {e}") and traceback.print_exc()) unchanged.
- Line 95: The variable `tokenizer` created by calling
`AutoTokenizer.from_pretrained(args.model, trust_remote_code=True,
use_fast=True)` in examples/test_async_inference.py is unused; either remove
that line entirely or use `tokenizer` for prompt validation/token counting
before sending requests (e.g., call tokenizer.encode or tokenizer(model_input)
to compute token length) and replace any hardcoded token assumptions with the
computed values; update imports if removed to avoid unused-import warnings.
🧹 Nitpick comments (7)
examples/test_sdar_dvllm.py (2)

97-109: Consider using portable default paths for example scripts.

The default paths reference machine-specific locations (/data1/ckpts/...) and a user-specific home directory (/home/ljp/...). This makes the example less portable for other users.

Consider using environment variables or relative paths:

Suggested improvement
     parser.add_argument(
         "--model",
         type=str,
-        default="/data1/ckpts/SDAR/SDAR-1.7B-Chat",
+        default=os.environ.get("SDAR_MODEL_PATH", "./SDAR-1.7B-Chat"),
         help="SDAR HF model directory (contains config.json + model.safetensors).",
     )
     parser.add_argument("--device", type=int, default=0)
     parser.add_argument(
         "--converted-dir",
         type=str,
-        default="/home/ljp/tmp/diffulex_sdar_converted",
+        default=os.environ.get("SDAR_CONVERTED_DIR", "/tmp/diffulex_sdar_converted"),
         help="Output directory for converted checkpoint keys (Diffulex-native).",
     )

139-139: Consider adding a comment explaining the side-effect import.

This import registers the SDAR model class with AutoModelForDiffusionLM, enabling from_config() to instantiate it when model_name="sdar". A brief comment would clarify the intent for future readers.

+    # Register SDAR model class with AutoModelForDiffusionLM
     import diffulex.model.sdar
examples/test_fastdllmv2_diffulex_gsm8k.py (2)

42-43: Hardcoded local paths reduce portability.

The model path and local_data_path are hardcoded to /data1/... which won't work for other developers or CI environments. Consider using environment variables or command-line arguments for configurability.

♻️ Suggested approach
+import argparse
+# ... in main
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--model", type=str, default=os.environ.get("MODEL_PATH", "/data1/ckpts/Efficient-Large-Model/Fast_dLLM_v2_7B"))
+    parser.add_argument("--data-path", type=str, default=os.environ.get("DATA_PATH", "/data1/LargeData/gsm8k"))
+    args = parser.parse_args()
-    model = "/data1/ckpts/Efficient-Large-Model/Fast_dLLM_v2_7B"
-    local_data_path = "/data1/LargeData/gsm8k"
+    model = args.model
+    local_data_path = args.data_path

62-63: Clean up commented code and consider fallback logic.

The commented-out line for hosted dataset loading should either be removed or converted to a fallback mechanism. The current approach requires local data to exist.

♻️ Suggested fallback pattern
-    # dataset = load_dataset("gsm8k", "main", split="test")["question"][:10]
-    dataset = load_dataset(local_data_path, "main", split="test", trust_remote_code=True)["question"][:10]
+    try:
+        dataset = load_dataset(local_data_path, "main", split="test", trust_remote_code=True)["question"][:10]
+    except Exception:
+        dataset = load_dataset("gsm8k", "main", split="test")["question"][:10]
diffulex/sampler/base.py (1)

88-102: Good defensive handling for cache lookups and fallback logic.

The string-based cache keying ensures consistency, and the fallback chain (cached block → existing cache → last logit → error) is well-structured. The error message at Line 102 provides useful debugging context.

One minor consideration: the logits[-1] fallback silently caches a potentially incorrect value if to_cache_last_token_id was expected but not available. Consider logging a warning in this fallback path to aid debugging.

♻️ Optional: Add warning for fallback path
         # Fallback: use last logit from current batch and cache it
+        import logging
+        logging.debug(f"Falling back to last logit for sequence {seq.seq_id}")
         last_logits = logits[-1] if logits.shape[0] > 0 else None
diffulex/engine/tp_worker.py (1)

165-220: Solid async implementation mirroring the synchronous version.

The generate_async correctly uses async methods throughout and includes await asyncio.sleep(0) to yield control during the inference loop. The output ordering logic is preserved correctly.

Consider adding strict=True to the zip() calls on Lines 180 and 217 to catch length mismatches early, though this is optional since the sampling_params normalization ensures consistent lengths.

♻️ Optional: Add strict=True to zip calls
-        for idx, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
+        for idx, (prompt, sp) in enumerate(zip(prompts, sampling_params, strict=True)):
diffulex/engine/dp_worker.py (1)

444-454: Use exception chaining for better traceability.

When re-raising exceptions in an except block, use raise ... from err to preserve the original exception context.

♻️ Proposed fix
-            except EOFError:
+            except EOFError as err:
                 p = self.ps[replica_idx]
                 exitcode = p.exitcode
                 raise RuntimeError(
                     f"DP child #{replica_idx} terminated unexpectedly during generate (exitcode={exitcode}). "
                     f"Enable envs: PYTHONFAULTHANDLER=1 CUDA_LAUNCH_BLOCKING=1 TORCH_SHOW_CPP_STACKTRACES=1 for more info."
-                )
+                ) from err
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9d4ad6c and 9f96381.

📒 Files selected for processing (14)
  • diffulex/engine/dp_worker.py
  • diffulex/engine/model_runner.py
  • diffulex/engine/tp_worker.py
  • diffulex/sampler/base.py
  • diffulex/strategy/block_diffusion/engine/kvcache_manager.py
  • diffulex_kernel/python/dllm_flash_attn.py
  • diffulex_legacy/layers/attention/ops/chunked_prefill_decoding_unified_kernel.py
  • diffulex_legacy/layers/attention/ops/triton_decode_attn_clm.py
  • examples/test_async_inference.py
  • examples/test_dream_dvllm_gsm8k.py
  • examples/test_fastdllmv2_diffulex_gsm8k.py
  • examples/test_llada_dvllm_human_eval.py
  • examples/test_sdar_dvllm.py
  • scripts/test_dvllm_dream_gsm8k.sh
🧰 Additional context used
🧬 Code graph analysis (4)
examples/test_async_inference.py (2)
diffulex/engine/tp_worker.py (2)
  • generate_async (165-220)
  • exit (40-57)
diffulex/engine/dp_worker.py (2)
  • generate_async (363-469)
  • exit (170-180)
diffulex/engine/tp_worker.py (1)
diffulex/engine/dp_worker.py (6)
  • add_request_async (192-201)
  • add_request (182-190)
  • step_async (231-268)
  • is_finished (270-271)
  • is_finished_async (273-277)
  • generate_async (363-469)
diffulex/strategy/block_diffusion/engine/kvcache_manager.py (2)
diffulex/engine/sequence.py (1)
  • num_blocks (61-62)
diffulex/strategy/d2f/engine/kvcache_manager.py (1)
  • may_append (20-39)
examples/test_sdar_dvllm.py (1)
diffulex/model/sdar.py (8)
  • SDARForDiffusionLM (181-200)
  • __init__ (184-189)
  • SDARModel (159-177)
  • SDARMLP (114-129)
  • forward (125-129)
  • SDARAttention (23-111)
  • __init__ (160-164)
  • __init__ (31-88)
🪛 Ruff (0.14.10)
examples/test_async_inference.py

95-95: Local variable tokenizer is assigned to but never used

Remove assignment to unused variable tokenizer

(F841)


134-134: Do not catch blind exception: Exception

(BLE001)


135-135: f-string without any placeholders

Remove extraneous f prefix

(F541)

diffulex/engine/tp_worker.py

180-180: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


217-217: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

diffulex/engine/dp_worker.py

251-251: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


423-423: Local variable pending is assigned to but never used

Remove assignment to unused variable pending

(F841)


424-424: Local variable conn_to_idx is assigned to but never used

Remove assignment to unused variable conn_to_idx

(F841)


447-450: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


447-450: Avoid specifying long messages outside the exception class

(TRY003)


454-454: Avoid specifying long messages outside the exception class

(TRY003)


461-461: Loop control variable e not used within loop body

Rename unused e to _e

(B007)

diffulex/sampler/base.py

102-102: Avoid specifying long messages outside the exception class

(TRY003)

diffulex_legacy/layers/attention/ops/chunked_prefill_decoding_unified_kernel.py

21-21: Do not catch blind exception: Exception

(BLE001)

🔇 Additional comments (16)
diffulex_kernel/python/dllm_flash_attn.py (1)

134-135: LGTM! This fixes an API capitalization bug.

The change from T.parallel to T.Parallel corrects an inconsistency with the rest of the file and aligns with the tilelang DSL API. All other parallel loops in this file use T.Parallel, so this was likely a typo that would have caused a runtime AttributeError.

diffulex_legacy/layers/attention/ops/triton_decode_attn_clm.py (1)

39-41: LGTM!

The switch to explicit Triton imports (import triton and import triton.language as tl) is cleaner and more portable than importing through vllm's wrapper. This aligns with standard Triton usage patterns and is consistent with the similar change in chunked_prefill_decoding_unified_kernel.py.

diffulex_legacy/layers/attention/ops/chunked_prefill_decoding_unified_kernel.py (1)

26-27: LGTM!

Explicit Triton imports are consistent with the changes in triton_decode_attn_clm.py and follow standard Triton usage patterns.

scripts/test_dvllm_dream_gsm8k.sh (1)

2-2: LGTM!

The reduced GPU count (from 8 to 2) aligns with the data_parallel_size=1 and tensor_parallel_size=1 configuration in the test script.

examples/test_dream_dvllm_gsm8k.py (2)

13-13: LGTM!

The import registers the Dream model with the auto-loader framework.


67-68: LGTM!

The dataset loading now uses the full GSM8K test split for comprehensive benchmarking.

diffulex/engine/model_runner.py (2)

3-4: LGTM!

Imports for async support are appropriate.


77-79: LGTM!

Proper cleanup of the executor on exit prevents resource leaks.

examples/test_llada_dvllm_human_eval.py (1)

41-41: The model name on HuggingFace is confirmed as GSAI-ML/LLaDA-8B-Instruct with this exact capitalization. No changes needed.

diffulex/strategy/block_diffusion/engine/kvcache_manager.py (1)

29-31: LGTM!

Good defensive fix. When cached_or_caching_num_tokens == caching_num_tokens, prev_end_token becomes -1, making prev_block_idx = -1. Without this lower-bound check, seq.block(-1) would incorrectly access the last block via Python's negative indexing.

However, the D2F variant in diffulex/strategy/d2f/engine/kvcache_manager.py (line 31) has the same vulnerability. It currently checks only if prev_block_idx < seq.num_blocks: without the 0 <= lower-bound check. Apply the same fix there for correctness.

examples/test_fastdllmv2_diffulex_gsm8k.py (1)

53-53: Verify intent of reducing max_num_seqs to 1.

Reducing max_num_seqs from 20 to 1 significantly limits batching capability and throughput. If this is intentional for debugging or resource constraints, consider adding a comment explaining the rationale, or making it configurable.

examples/test_async_inference.py (1)

19-22: LGTM - Clean async inference test structure.

The async wrapper pattern with asyncio.run() and proper cleanup in the finally block follows best practices. Setting TOKENIZERS_PARALLELISM=false correctly prevents tokenizer parallelism conflicts with asyncio.

Also applies to: 104-143

diffulex/engine/tp_worker.py (1)

68-72: LGTM - Async wrappers for API consistency.

Using run_in_executor(None, ...) for lightweight operations like add_request and is_finished is appropriate. The dedicated single-threaded executor in step_async correctly serializes scheduler operations.

Also applies to: 106-109

diffulex/engine/dp_worker.py (3)

145-146: Good executor lifecycle management.

The _executor is properly initialized with max_workers=self.dp_size and correctly shut down in exit(). This ensures proper resource cleanup.

Also applies to: 170-173


231-268: Well-structured parallel execution across DP replicas.

The step_async correctly parallelizes both the is_finished checks and step calls across replicas using asyncio.gather. The result aggregation logic properly remaps local IDs to global IDs.


192-201: add_request_async is dead code and the race condition concern is not applicable.

The method is never called anywhere in the codebase—generate_async uses direct RPC commands to child processes rather than add_request_async. Additionally, even in a hypothetical concurrent scenario, asyncio's single-threaded event loop serializes execution: the read-modify-write operations on _rr occur before the await point, and the _gid_counter operations execute as atomic statements after the await, preventing any race condition. No synchronization primitives are needed.

Likely an incorrect or invalid review comment.

Comment on lines +19 to +24
try:
from vllm.platforms.rocm import use_rocm_custom_paged_attention # vLLM newer
except Exception:
# vLLM older / CUDA-only env: treat as disabled
def use_rocm_custom_paged_attention() -> bool:
return False
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Fallback function signature mismatch and overly broad exception.

Two issues here:

  1. Signature mismatch (critical): The fallback function takes no parameters, but it's called at line 299 with 8 arguments: use_rocm_custom_paged_attention(query.dtype, head_size, block_size, num_queries_per_kv, max_seq_len, sliding_window, kv_cache_dtype, alibi_slopes). This will raise a TypeError in CUDA-only environments.

  2. Overly broad exception: Catching bare Exception can mask unrelated errors. Use ImportError instead.

🐛 Proposed fix
-try:
-    from vllm.platforms.rocm import use_rocm_custom_paged_attention  # vLLM newer
-except Exception:
-    # vLLM older / CUDA-only env: treat as disabled
-    def use_rocm_custom_paged_attention() -> bool:
-        return False
+try:
+    from vllm.platforms.rocm import use_rocm_custom_paged_attention  # vLLM newer
+except ImportError:
+    # vLLM older / CUDA-only env: treat as disabled
+    def use_rocm_custom_paged_attention(*args, **kwargs) -> bool:
+        return False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
from vllm.platforms.rocm import use_rocm_custom_paged_attention # vLLM newer
except Exception:
# vLLM older / CUDA-only env: treat as disabled
def use_rocm_custom_paged_attention() -> bool:
return False
try:
from vllm.platforms.rocm import use_rocm_custom_paged_attention # vLLM newer
except ImportError:
# vLLM older / CUDA-only env: treat as disabled
def use_rocm_custom_paged_attention(*args, **kwargs) -> bool:
return False
🧰 Tools
🪛 Ruff (0.14.10)

21-21: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
In
@diffulex_legacy/layers/attention/ops/chunked_prefill_decoding_unified_kernel.py
around lines 19 - 24, The fallback currently catches a broad Exception and
defines use_rocm_custom_paged_attention() with no parameters, but the call site
invokes use_rocm_custom_paged_attention(query.dtype, head_size, block_size,
num_queries_per_kv, max_seq_len, sliding_window, kv_cache_dtype, alibi_slopes);
change the except to catch ImportError only and redefine the fallback function
use_rocm_custom_paged_attention(...same 8 parameters...) -> bool to accept the
same parameter list and return False (match the type hints/signature used at
call site) so CUDA-only environments won't raise a TypeError.

Comment on lines +421 to +424
# Collect results asynchronously
collected = {}
pending = set(slices.keys())
conn_to_idx = {self.conns[i]: i for i in slices.keys()}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused variables pending and conn_to_idx.

These variables are assigned but never used in the async version of generate_async.

🔧 Proposed fix
         # Collect results asynchronously
         collected = {}
-        pending = set(slices.keys())
-        conn_to_idx = {self.conns[i]: i for i in slices.keys()}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Collect results asynchronously
collected = {}
pending = set(slices.keys())
conn_to_idx = {self.conns[i]: i for i in slices.keys()}
# Collect results asynchronously
collected = {}
🧰 Tools
🪛 Ruff (0.14.10)

423-423: Local variable pending is assigned to but never used

Remove assignment to unused variable pending

(F841)


424-424: Local variable conn_to_idx is assigned to but never used

Remove assignment to unused variable conn_to_idx

(F841)

🤖 Prompt for AI Agents
In @diffulex/engine/dp_worker.py around lines 421 - 424, In generate_async,
remove the unused local variables pending and conn_to_idx (they are assigned
from slices.keys() and self.conns but never referenced); edit the function to
delete the lines that create pending and conn_to_idx so only the used variable
collected and slices remain, and run tests/lint to ensure no other references to
pending or conn_to_idx exist.

Comment on lines +118 to +126
async def call_async(self, method_name, *args):
"""Async version of call that runs in a thread pool executor."""
loop = asyncio.get_event_loop()
# Use default executor or create one if needed
executor = getattr(self, '_executor', None)
if executor is None:
executor = ThreadPoolExecutor(max_workers=1)
self._executor = executor
return await loop.run_in_executor(executor, self.call, method_name, *args)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use asyncio.get_running_loop() and fix potential race condition in executor creation.

Two issues:

  1. asyncio.get_event_loop() is deprecated since Python 3.10 and may raise a DeprecationWarning. Use asyncio.get_running_loop() which is the recommended approach when called from within a coroutine.

  2. The executor creation is not thread-safe. If multiple coroutines call call_async concurrently before _executor is set, multiple executors could be created.

🔧 Proposed fix
     async def call_async(self, method_name, *args):
         """Async version of call that runs in a thread pool executor."""
-        loop = asyncio.get_event_loop()
-        # Use default executor or create one if needed
-        executor = getattr(self, '_executor', None)
-        if executor is None:
-            executor = ThreadPoolExecutor(max_workers=1)
-            self._executor = executor
+        loop = asyncio.get_running_loop()
+        # Lazily create executor on first use
+        if not hasattr(self, '_executor') or self._executor is None:
+            self._executor = ThreadPoolExecutor(max_workers=1)
-        return await loop.run_in_executor(executor, self.call, method_name, *args)
+        return await loop.run_in_executor(self._executor, self.call, method_name, *args)

Note: The race condition is unlikely in practice since ModelRunnerBase instances are typically used from a single async context, but worth addressing for correctness.

🤖 Prompt for AI Agents
In @diffulex/engine/model_runner.py around lines 118 - 126, Replace
asyncio.get_event_loop() with asyncio.get_running_loop() in call_async, and
serialize executor creation using an asyncio.Lock: add an asyncio.Lock instance
(e.g., self._executor_lock) during the ModelRunnerBase __init__, then in
call_async do a double-checked pattern: if getattr(self, '_executor', None) is
None: await self._executor_lock; check _executor again and only then create and
assign ThreadPoolExecutor(max_workers=1) to self._executor; finally call
loop.run_in_executor(self._executor, self.call, method_name, *args). This fixes
the deprecated API usage and prevents concurrent creation of multiple executors.

Comment on lines +84 to +101
async def step_async(self):
"""Async version of step that runs model inference in a thread pool."""
loop = asyncio.get_event_loop()
executor = getattr(self, '_step_executor', None)
if executor is None:
executor = ThreadPoolExecutor(max_workers=1)
self._step_executor = executor

def _step():
seqs, is_prefill = self.scheduler.schedule()
sample_output = self.model_runner.call("run", seqs, is_prefill)
n_diff_steps = self.scheduler.postprocess(seqs, sample_output)
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
num_tokens = sum(seq.num_tokens for seq in seqs) if is_prefill else sum(seq.new_tokens for seq in seqs)
deltas = []
return outputs, num_tokens, is_prefill, n_diff_steps, deltas

return await loop.run_in_executor(executor, _step)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Thread pool executor is not cleaned up in exit().

The _step_executor ThreadPoolExecutor is lazily created in step_async() but is never shut down when the worker exits. This can lead to resource leaks and hanging threads.

🔧 Proposed fix - add cleanup in exit()
     def exit(self):
         if getattr(self, "_exited", False):
             return
         self._exited = True
+        # Shutdown step executor if created
+        if hasattr(self, '_step_executor') and self._step_executor is not None:
+            self._step_executor.shutdown(wait=False)
         if hasattr(self, "model_runner") and self.model_runner is not None:
🤖 Prompt for AI Agents
In @diffulex/engine/tp_worker.py around lines 84 - 101, The ThreadPoolExecutor
created lazily in step_async() (stored as self._step_executor) is never shut
down causing leaked threads; update the worker's exit() method to check for
self._step_executor, call its shutdown(wait=True) (or shutdown(wait=False) if
non-blocking termination is preferred), handle any exceptions, and set
self._step_executor = None so resources are released and repeated exits are
safe; ensure exit() cleans up even if executor was never created.

master_port=2333,
)

tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True, use_fast=True)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused tokenizer variable.

The tokenizer is loaded but never used in this script. Either remove it or use it for prompt validation/token counting.

🔧 Proposed fix
-    tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True, use_fast=True)
-
     prompts = [args.prompt] * args.num_prompts
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True, use_fast=True)
prompts = [args.prompt] * args.num_prompts
🧰 Tools
🪛 Ruff (0.14.10)

95-95: Local variable tokenizer is assigned to but never used

Remove assignment to unused variable tokenizer

(F841)

🤖 Prompt for AI Agents
In @examples/test_async_inference.py at line 95, The variable `tokenizer`
created by calling `AutoTokenizer.from_pretrained(args.model,
trust_remote_code=True, use_fast=True)` in examples/test_async_inference.py is
unused; either remove that line entirely or use `tokenizer` for prompt
validation/token counting before sending requests (e.g., call tokenizer.encode
or tokenizer(model_input) to compute token length) and replace any hardcoded
token assumptions with the computed values; update imports if removed to avoid
unused-import warnings.

Comment on lines +134 to +138
except Exception as e:
print(f"\n[Error during async inference]")
print(f"Error: {e}")
import traceback
traceback.print_exc()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix f-string without placeholders.

Line 135 uses an f-string but has no placeholders. Remove the f prefix.

🔧 Proposed fix
     except Exception as e:
-        print(f"\n[Error during async inference]")
+        print("\n[Error during async inference]")
         print(f"Error: {e}")
         import traceback
         traceback.print_exc()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
except Exception as e:
print(f"\n[Error during async inference]")
print(f"Error: {e}")
import traceback
traceback.print_exc()
except Exception as e:
print("\n[Error during async inference]")
print(f"Error: {e}")
import traceback
traceback.print_exc()
🧰 Tools
🪛 Ruff (0.14.10)

134-134: Do not catch blind exception: Exception

(BLE001)


135-135: f-string without any placeholders

Remove extraneous f prefix

(F541)

🤖 Prompt for AI Agents
In @examples/test_async_inference.py around lines 134 - 138, The print statement
inside the exception handler uses an f-string with no placeholders; update the
except block (the "except Exception as e:" handler) to remove the unnecessary f
prefix from the first print (change print(f"\n[Error during async inference]")
to a normal string print) so it is not an f-string without placeholders while
keeping the other prints (print(f"Error: {e}") and traceback.print_exc())
unchanged.

@drewjin drewjin merged commit 86714d1 into SJTU-DENG-Lab:main Jan 25, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants