-
Notifications
You must be signed in to change notification settings - Fork 10
Async support for inference engine #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the Diffulex project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughThe PR introduces asynchronous execution capabilities to multiple engine components (dp_worker, tp_worker, model_runner) using ThreadPoolExecutor and asyncio, adds an async inference example script, fixes cache key handling from numeric to string-based, adds ROCm import fallback for compatibility, updates Triton imports, and modifies example scripts with new configurations. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Worker
participant Executor as ThreadPoolExecutor
participant Replica
User->>Worker: generate_async(prompts, params)
activate Worker
loop For each prompt
Worker->>Worker: add_request_async(prompt)
Worker->>Executor: _ask_async(replica, "add_request")
activate Executor
Executor->>Replica: add_request (blocking)
Executor-->>Worker: ✓
deactivate Executor
end
loop Until all finished
Worker->>Executor: step_async()
activate Executor
par Parallel replicas
Executor->>Replica: step() [replica 1]
Executor->>Replica: step() [replica 2]
Executor->>Replica: step() [replica N]
and
end
Executor-->>Worker: outputs, tokens, metrics
deactivate Executor
Worker->>Executor: is_finished_async()
activate Executor
Executor->>Replica: is_finished() [all]
Executor-->>Worker: done status
deactivate Executor
Worker->>Worker: remap local IDs → global IDs
Worker->>Worker: aggregate results
end
Worker-->>User: results (text, token_ids, n_diff_steps)
deactivate Worker
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
examples/test_dream_dvllm_gsm8k.py (1)
19-20: Fix incorrect method name:csv.dictReadershould becsv.DictReader.The
csvmodule usesDictReader(with capital D and R). This will raise anAttributeErrorat runtime.🐛 Proposed fix
- reader = csv.dictReader(f) + reader = csv.DictReader(f)examples/test_llada_dvllm_human_eval.py (1)
18-19: Fix incorrect method name:csv.dictReadershould becsv.DictReader.Same issue as in
test_dream_dvllm_gsm8k.py. This will raise anAttributeErrorat runtime.🐛 Proposed fix
- reader = csv.dictReader(f) + reader = csv.DictReader(f)
🤖 Fix all issues with AI agents
In
@diffulex_legacy/layers/attention/ops/chunked_prefill_decoding_unified_kernel.py:
- Around line 19-24: The fallback currently catches a broad Exception and
defines use_rocm_custom_paged_attention() with no parameters, but the call site
invokes use_rocm_custom_paged_attention(query.dtype, head_size, block_size,
num_queries_per_kv, max_seq_len, sliding_window, kv_cache_dtype, alibi_slopes);
change the except to catch ImportError only and redefine the fallback function
use_rocm_custom_paged_attention(...same 8 parameters...) -> bool to accept the
same parameter list and return False (match the type hints/signature used at
call site) so CUDA-only environments won't raise a TypeError.
In @diffulex/engine/dp_worker.py:
- Around line 421-424: In generate_async, remove the unused local variables
pending and conn_to_idx (they are assigned from slices.keys() and self.conns but
never referenced); edit the function to delete the lines that create pending and
conn_to_idx so only the used variable collected and slices remain, and run
tests/lint to ensure no other references to pending or conn_to_idx exist.
In @diffulex/engine/model_runner.py:
- Around line 118-126: Replace asyncio.get_event_loop() with
asyncio.get_running_loop() in call_async, and serialize executor creation using
an asyncio.Lock: add an asyncio.Lock instance (e.g., self._executor_lock) during
the ModelRunnerBase __init__, then in call_async do a double-checked pattern: if
getattr(self, '_executor', None) is None: await self._executor_lock; check
_executor again and only then create and assign
ThreadPoolExecutor(max_workers=1) to self._executor; finally call
loop.run_in_executor(self._executor, self.call, method_name, *args). This fixes
the deprecated API usage and prevents concurrent creation of multiple executors.
In @diffulex/engine/tp_worker.py:
- Around line 84-101: The ThreadPoolExecutor created lazily in step_async()
(stored as self._step_executor) is never shut down causing leaked threads;
update the worker's exit() method to check for self._step_executor, call its
shutdown(wait=True) (or shutdown(wait=False) if non-blocking termination is
preferred), handle any exceptions, and set self._step_executor = None so
resources are released and repeated exits are safe; ensure exit() cleans up even
if executor was never created.
In @examples/test_async_inference.py:
- Around line 134-138: The print statement inside the exception handler uses an
f-string with no placeholders; update the except block (the "except Exception as
e:" handler) to remove the unnecessary f prefix from the first print (change
print(f"\n[Error during async inference]") to a normal string print) so it is
not an f-string without placeholders while keeping the other prints
(print(f"Error: {e}") and traceback.print_exc()) unchanged.
- Line 95: The variable `tokenizer` created by calling
`AutoTokenizer.from_pretrained(args.model, trust_remote_code=True,
use_fast=True)` in examples/test_async_inference.py is unused; either remove
that line entirely or use `tokenizer` for prompt validation/token counting
before sending requests (e.g., call tokenizer.encode or tokenizer(model_input)
to compute token length) and replace any hardcoded token assumptions with the
computed values; update imports if removed to avoid unused-import warnings.
🧹 Nitpick comments (7)
examples/test_sdar_dvllm.py (2)
97-109: Consider using portable default paths for example scripts.The default paths reference machine-specific locations (
/data1/ckpts/...) and a user-specific home directory (/home/ljp/...). This makes the example less portable for other users.Consider using environment variables or relative paths:
Suggested improvement
parser.add_argument( "--model", type=str, - default="/data1/ckpts/SDAR/SDAR-1.7B-Chat", + default=os.environ.get("SDAR_MODEL_PATH", "./SDAR-1.7B-Chat"), help="SDAR HF model directory (contains config.json + model.safetensors).", ) parser.add_argument("--device", type=int, default=0) parser.add_argument( "--converted-dir", type=str, - default="/home/ljp/tmp/diffulex_sdar_converted", + default=os.environ.get("SDAR_CONVERTED_DIR", "/tmp/diffulex_sdar_converted"), help="Output directory for converted checkpoint keys (Diffulex-native).", )
139-139: Consider adding a comment explaining the side-effect import.This import registers the SDAR model class with
AutoModelForDiffusionLM, enablingfrom_config()to instantiate it whenmodel_name="sdar". A brief comment would clarify the intent for future readers.+ # Register SDAR model class with AutoModelForDiffusionLM import diffulex.model.sdarexamples/test_fastdllmv2_diffulex_gsm8k.py (2)
42-43: Hardcoded local paths reduce portability.The model path and
local_data_pathare hardcoded to/data1/...which won't work for other developers or CI environments. Consider using environment variables or command-line arguments for configurability.♻️ Suggested approach
+import argparse +# ... in main + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default=os.environ.get("MODEL_PATH", "/data1/ckpts/Efficient-Large-Model/Fast_dLLM_v2_7B")) + parser.add_argument("--data-path", type=str, default=os.environ.get("DATA_PATH", "/data1/LargeData/gsm8k")) + args = parser.parse_args() - model = "/data1/ckpts/Efficient-Large-Model/Fast_dLLM_v2_7B" - local_data_path = "/data1/LargeData/gsm8k" + model = args.model + local_data_path = args.data_path
62-63: Clean up commented code and consider fallback logic.The commented-out line for hosted dataset loading should either be removed or converted to a fallback mechanism. The current approach requires local data to exist.
♻️ Suggested fallback pattern
- # dataset = load_dataset("gsm8k", "main", split="test")["question"][:10] - dataset = load_dataset(local_data_path, "main", split="test", trust_remote_code=True)["question"][:10] + try: + dataset = load_dataset(local_data_path, "main", split="test", trust_remote_code=True)["question"][:10] + except Exception: + dataset = load_dataset("gsm8k", "main", split="test")["question"][:10]diffulex/sampler/base.py (1)
88-102: Good defensive handling for cache lookups and fallback logic.The string-based cache keying ensures consistency, and the fallback chain (cached block → existing cache → last logit → error) is well-structured. The error message at Line 102 provides useful debugging context.
One minor consideration: the
logits[-1]fallback silently caches a potentially incorrect value ifto_cache_last_token_idwas expected but not available. Consider logging a warning in this fallback path to aid debugging.♻️ Optional: Add warning for fallback path
# Fallback: use last logit from current batch and cache it + import logging + logging.debug(f"Falling back to last logit for sequence {seq.seq_id}") last_logits = logits[-1] if logits.shape[0] > 0 else Nonediffulex/engine/tp_worker.py (1)
165-220: Solid async implementation mirroring the synchronous version.The
generate_asynccorrectly uses async methods throughout and includesawait asyncio.sleep(0)to yield control during the inference loop. The output ordering logic is preserved correctly.Consider adding
strict=Trueto thezip()calls on Lines 180 and 217 to catch length mismatches early, though this is optional since the sampling_params normalization ensures consistent lengths.♻️ Optional: Add strict=True to zip calls
- for idx, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + for idx, (prompt, sp) in enumerate(zip(prompts, sampling_params, strict=True)):diffulex/engine/dp_worker.py (1)
444-454: Use exception chaining for better traceability.When re-raising exceptions in an
exceptblock, useraise ... from errto preserve the original exception context.♻️ Proposed fix
- except EOFError: + except EOFError as err: p = self.ps[replica_idx] exitcode = p.exitcode raise RuntimeError( f"DP child #{replica_idx} terminated unexpectedly during generate (exitcode={exitcode}). " f"Enable envs: PYTHONFAULTHANDLER=1 CUDA_LAUNCH_BLOCKING=1 TORCH_SHOW_CPP_STACKTRACES=1 for more info." - ) + ) from err
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (14)
diffulex/engine/dp_worker.pydiffulex/engine/model_runner.pydiffulex/engine/tp_worker.pydiffulex/sampler/base.pydiffulex/strategy/block_diffusion/engine/kvcache_manager.pydiffulex_kernel/python/dllm_flash_attn.pydiffulex_legacy/layers/attention/ops/chunked_prefill_decoding_unified_kernel.pydiffulex_legacy/layers/attention/ops/triton_decode_attn_clm.pyexamples/test_async_inference.pyexamples/test_dream_dvllm_gsm8k.pyexamples/test_fastdllmv2_diffulex_gsm8k.pyexamples/test_llada_dvllm_human_eval.pyexamples/test_sdar_dvllm.pyscripts/test_dvllm_dream_gsm8k.sh
🧰 Additional context used
🧬 Code graph analysis (4)
examples/test_async_inference.py (2)
diffulex/engine/tp_worker.py (2)
generate_async(165-220)exit(40-57)diffulex/engine/dp_worker.py (2)
generate_async(363-469)exit(170-180)
diffulex/engine/tp_worker.py (1)
diffulex/engine/dp_worker.py (6)
add_request_async(192-201)add_request(182-190)step_async(231-268)is_finished(270-271)is_finished_async(273-277)generate_async(363-469)
diffulex/strategy/block_diffusion/engine/kvcache_manager.py (2)
diffulex/engine/sequence.py (1)
num_blocks(61-62)diffulex/strategy/d2f/engine/kvcache_manager.py (1)
may_append(20-39)
examples/test_sdar_dvllm.py (1)
diffulex/model/sdar.py (8)
SDARForDiffusionLM(181-200)__init__(184-189)SDARModel(159-177)SDARMLP(114-129)forward(125-129)SDARAttention(23-111)__init__(160-164)__init__(31-88)
🪛 Ruff (0.14.10)
examples/test_async_inference.py
95-95: Local variable tokenizer is assigned to but never used
Remove assignment to unused variable tokenizer
(F841)
134-134: Do not catch blind exception: Exception
(BLE001)
135-135: f-string without any placeholders
Remove extraneous f prefix
(F541)
diffulex/engine/tp_worker.py
180-180: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
217-217: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
diffulex/engine/dp_worker.py
251-251: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
423-423: Local variable pending is assigned to but never used
Remove assignment to unused variable pending
(F841)
424-424: Local variable conn_to_idx is assigned to but never used
Remove assignment to unused variable conn_to_idx
(F841)
447-450: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
447-450: Avoid specifying long messages outside the exception class
(TRY003)
454-454: Avoid specifying long messages outside the exception class
(TRY003)
461-461: Loop control variable e not used within loop body
Rename unused e to _e
(B007)
diffulex/sampler/base.py
102-102: Avoid specifying long messages outside the exception class
(TRY003)
diffulex_legacy/layers/attention/ops/chunked_prefill_decoding_unified_kernel.py
21-21: Do not catch blind exception: Exception
(BLE001)
🔇 Additional comments (16)
diffulex_kernel/python/dllm_flash_attn.py (1)
134-135: LGTM! This fixes an API capitalization bug.The change from
T.paralleltoT.Parallelcorrects an inconsistency with the rest of the file and aligns with the tilelang DSL API. All other parallel loops in this file useT.Parallel, so this was likely a typo that would have caused a runtimeAttributeError.diffulex_legacy/layers/attention/ops/triton_decode_attn_clm.py (1)
39-41: LGTM!The switch to explicit Triton imports (
import tritonandimport triton.language as tl) is cleaner and more portable than importing through vllm's wrapper. This aligns with standard Triton usage patterns and is consistent with the similar change inchunked_prefill_decoding_unified_kernel.py.diffulex_legacy/layers/attention/ops/chunked_prefill_decoding_unified_kernel.py (1)
26-27: LGTM!Explicit Triton imports are consistent with the changes in
triton_decode_attn_clm.pyand follow standard Triton usage patterns.scripts/test_dvllm_dream_gsm8k.sh (1)
2-2: LGTM!The reduced GPU count (from 8 to 2) aligns with the
data_parallel_size=1andtensor_parallel_size=1configuration in the test script.examples/test_dream_dvllm_gsm8k.py (2)
13-13: LGTM!The import registers the Dream model with the auto-loader framework.
67-68: LGTM!The dataset loading now uses the full GSM8K test split for comprehensive benchmarking.
diffulex/engine/model_runner.py (2)
3-4: LGTM!Imports for async support are appropriate.
77-79: LGTM!Proper cleanup of the executor on exit prevents resource leaks.
examples/test_llada_dvllm_human_eval.py (1)
41-41: The model name on HuggingFace is confirmed asGSAI-ML/LLaDA-8B-Instructwith this exact capitalization. No changes needed.diffulex/strategy/block_diffusion/engine/kvcache_manager.py (1)
29-31: LGTM!Good defensive fix. When
cached_or_caching_num_tokens == caching_num_tokens,prev_end_tokenbecomes-1, makingprev_block_idx = -1. Without this lower-bound check,seq.block(-1)would incorrectly access the last block via Python's negative indexing.However, the D2F variant in
diffulex/strategy/d2f/engine/kvcache_manager.py(line 31) has the same vulnerability. It currently checks onlyif prev_block_idx < seq.num_blocks:without the0 <=lower-bound check. Apply the same fix there for correctness.examples/test_fastdllmv2_diffulex_gsm8k.py (1)
53-53: Verify intent of reducingmax_num_seqsto 1.Reducing
max_num_seqsfrom 20 to 1 significantly limits batching capability and throughput. If this is intentional for debugging or resource constraints, consider adding a comment explaining the rationale, or making it configurable.examples/test_async_inference.py (1)
19-22: LGTM - Clean async inference test structure.The async wrapper pattern with
asyncio.run()and proper cleanup in thefinallyblock follows best practices. SettingTOKENIZERS_PARALLELISM=falsecorrectly prevents tokenizer parallelism conflicts with asyncio.Also applies to: 104-143
diffulex/engine/tp_worker.py (1)
68-72: LGTM - Async wrappers for API consistency.Using
run_in_executor(None, ...)for lightweight operations likeadd_requestandis_finishedis appropriate. The dedicated single-threaded executor instep_asynccorrectly serializes scheduler operations.Also applies to: 106-109
diffulex/engine/dp_worker.py (3)
145-146: Good executor lifecycle management.The
_executoris properly initialized withmax_workers=self.dp_sizeand correctly shut down inexit(). This ensures proper resource cleanup.Also applies to: 170-173
231-268: Well-structured parallel execution across DP replicas.The
step_asynccorrectly parallelizes both theis_finishedchecks andstepcalls across replicas usingasyncio.gather. The result aggregation logic properly remaps local IDs to global IDs.
192-201:add_request_asyncis dead code and the race condition concern is not applicable.The method is never called anywhere in the codebase—
generate_asyncuses direct RPC commands to child processes rather thanadd_request_async. Additionally, even in a hypothetical concurrent scenario, asyncio's single-threaded event loop serializes execution: the read-modify-write operations on_rroccur before theawaitpoint, and the_gid_counteroperations execute as atomic statements after theawait, preventing any race condition. No synchronization primitives are needed.Likely an incorrect or invalid review comment.
| try: | ||
| from vllm.platforms.rocm import use_rocm_custom_paged_attention # vLLM newer | ||
| except Exception: | ||
| # vLLM older / CUDA-only env: treat as disabled | ||
| def use_rocm_custom_paged_attention() -> bool: | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Fallback function signature mismatch and overly broad exception.
Two issues here:
-
Signature mismatch (critical): The fallback function takes no parameters, but it's called at line 299 with 8 arguments:
use_rocm_custom_paged_attention(query.dtype, head_size, block_size, num_queries_per_kv, max_seq_len, sliding_window, kv_cache_dtype, alibi_slopes). This will raise aTypeErrorin CUDA-only environments. -
Overly broad exception: Catching bare
Exceptioncan mask unrelated errors. UseImportErrorinstead.
🐛 Proposed fix
-try:
- from vllm.platforms.rocm import use_rocm_custom_paged_attention # vLLM newer
-except Exception:
- # vLLM older / CUDA-only env: treat as disabled
- def use_rocm_custom_paged_attention() -> bool:
- return False
+try:
+ from vllm.platforms.rocm import use_rocm_custom_paged_attention # vLLM newer
+except ImportError:
+ # vLLM older / CUDA-only env: treat as disabled
+ def use_rocm_custom_paged_attention(*args, **kwargs) -> bool:
+ return False📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| try: | |
| from vllm.platforms.rocm import use_rocm_custom_paged_attention # vLLM newer | |
| except Exception: | |
| # vLLM older / CUDA-only env: treat as disabled | |
| def use_rocm_custom_paged_attention() -> bool: | |
| return False | |
| try: | |
| from vllm.platforms.rocm import use_rocm_custom_paged_attention # vLLM newer | |
| except ImportError: | |
| # vLLM older / CUDA-only env: treat as disabled | |
| def use_rocm_custom_paged_attention(*args, **kwargs) -> bool: | |
| return False |
🧰 Tools
🪛 Ruff (0.14.10)
21-21: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
In
@diffulex_legacy/layers/attention/ops/chunked_prefill_decoding_unified_kernel.py
around lines 19 - 24, The fallback currently catches a broad Exception and
defines use_rocm_custom_paged_attention() with no parameters, but the call site
invokes use_rocm_custom_paged_attention(query.dtype, head_size, block_size,
num_queries_per_kv, max_seq_len, sliding_window, kv_cache_dtype, alibi_slopes);
change the except to catch ImportError only and redefine the fallback function
use_rocm_custom_paged_attention(...same 8 parameters...) -> bool to accept the
same parameter list and return False (match the type hints/signature used at
call site) so CUDA-only environments won't raise a TypeError.
| # Collect results asynchronously | ||
| collected = {} | ||
| pending = set(slices.keys()) | ||
| conn_to_idx = {self.conns[i]: i for i in slices.keys()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused variables pending and conn_to_idx.
These variables are assigned but never used in the async version of generate_async.
🔧 Proposed fix
# Collect results asynchronously
collected = {}
- pending = set(slices.keys())
- conn_to_idx = {self.conns[i]: i for i in slices.keys()}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Collect results asynchronously | |
| collected = {} | |
| pending = set(slices.keys()) | |
| conn_to_idx = {self.conns[i]: i for i in slices.keys()} | |
| # Collect results asynchronously | |
| collected = {} |
🧰 Tools
🪛 Ruff (0.14.10)
423-423: Local variable pending is assigned to but never used
Remove assignment to unused variable pending
(F841)
424-424: Local variable conn_to_idx is assigned to but never used
Remove assignment to unused variable conn_to_idx
(F841)
🤖 Prompt for AI Agents
In @diffulex/engine/dp_worker.py around lines 421 - 424, In generate_async,
remove the unused local variables pending and conn_to_idx (they are assigned
from slices.keys() and self.conns but never referenced); edit the function to
delete the lines that create pending and conn_to_idx so only the used variable
collected and slices remain, and run tests/lint to ensure no other references to
pending or conn_to_idx exist.
| async def call_async(self, method_name, *args): | ||
| """Async version of call that runs in a thread pool executor.""" | ||
| loop = asyncio.get_event_loop() | ||
| # Use default executor or create one if needed | ||
| executor = getattr(self, '_executor', None) | ||
| if executor is None: | ||
| executor = ThreadPoolExecutor(max_workers=1) | ||
| self._executor = executor | ||
| return await loop.run_in_executor(executor, self.call, method_name, *args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use asyncio.get_running_loop() and fix potential race condition in executor creation.
Two issues:
-
asyncio.get_event_loop()is deprecated since Python 3.10 and may raise aDeprecationWarning. Useasyncio.get_running_loop()which is the recommended approach when called from within a coroutine. -
The executor creation is not thread-safe. If multiple coroutines call
call_asyncconcurrently before_executoris set, multiple executors could be created.
🔧 Proposed fix
async def call_async(self, method_name, *args):
"""Async version of call that runs in a thread pool executor."""
- loop = asyncio.get_event_loop()
- # Use default executor or create one if needed
- executor = getattr(self, '_executor', None)
- if executor is None:
- executor = ThreadPoolExecutor(max_workers=1)
- self._executor = executor
+ loop = asyncio.get_running_loop()
+ # Lazily create executor on first use
+ if not hasattr(self, '_executor') or self._executor is None:
+ self._executor = ThreadPoolExecutor(max_workers=1)
- return await loop.run_in_executor(executor, self.call, method_name, *args)
+ return await loop.run_in_executor(self._executor, self.call, method_name, *args)Note: The race condition is unlikely in practice since ModelRunnerBase instances are typically used from a single async context, but worth addressing for correctness.
🤖 Prompt for AI Agents
In @diffulex/engine/model_runner.py around lines 118 - 126, Replace
asyncio.get_event_loop() with asyncio.get_running_loop() in call_async, and
serialize executor creation using an asyncio.Lock: add an asyncio.Lock instance
(e.g., self._executor_lock) during the ModelRunnerBase __init__, then in
call_async do a double-checked pattern: if getattr(self, '_executor', None) is
None: await self._executor_lock; check _executor again and only then create and
assign ThreadPoolExecutor(max_workers=1) to self._executor; finally call
loop.run_in_executor(self._executor, self.call, method_name, *args). This fixes
the deprecated API usage and prevents concurrent creation of multiple executors.
| async def step_async(self): | ||
| """Async version of step that runs model inference in a thread pool.""" | ||
| loop = asyncio.get_event_loop() | ||
| executor = getattr(self, '_step_executor', None) | ||
| if executor is None: | ||
| executor = ThreadPoolExecutor(max_workers=1) | ||
| self._step_executor = executor | ||
|
|
||
| def _step(): | ||
| seqs, is_prefill = self.scheduler.schedule() | ||
| sample_output = self.model_runner.call("run", seqs, is_prefill) | ||
| n_diff_steps = self.scheduler.postprocess(seqs, sample_output) | ||
| outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] | ||
| num_tokens = sum(seq.num_tokens for seq in seqs) if is_prefill else sum(seq.new_tokens for seq in seqs) | ||
| deltas = [] | ||
| return outputs, num_tokens, is_prefill, n_diff_steps, deltas | ||
|
|
||
| return await loop.run_in_executor(executor, _step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thread pool executor is not cleaned up in exit().
The _step_executor ThreadPoolExecutor is lazily created in step_async() but is never shut down when the worker exits. This can lead to resource leaks and hanging threads.
🔧 Proposed fix - add cleanup in exit()
def exit(self):
if getattr(self, "_exited", False):
return
self._exited = True
+ # Shutdown step executor if created
+ if hasattr(self, '_step_executor') and self._step_executor is not None:
+ self._step_executor.shutdown(wait=False)
if hasattr(self, "model_runner") and self.model_runner is not None:🤖 Prompt for AI Agents
In @diffulex/engine/tp_worker.py around lines 84 - 101, The ThreadPoolExecutor
created lazily in step_async() (stored as self._step_executor) is never shut
down causing leaked threads; update the worker's exit() method to check for
self._step_executor, call its shutdown(wait=True) (or shutdown(wait=False) if
non-blocking termination is preferred), handle any exceptions, and set
self._step_executor = None so resources are released and repeated exits are
safe; ensure exit() cleans up even if executor was never created.
| master_port=2333, | ||
| ) | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True, use_fast=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused tokenizer variable.
The tokenizer is loaded but never used in this script. Either remove it or use it for prompt validation/token counting.
🔧 Proposed fix
- tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True, use_fast=True)
-
prompts = [args.prompt] * args.num_prompts📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True, use_fast=True) | |
| prompts = [args.prompt] * args.num_prompts |
🧰 Tools
🪛 Ruff (0.14.10)
95-95: Local variable tokenizer is assigned to but never used
Remove assignment to unused variable tokenizer
(F841)
🤖 Prompt for AI Agents
In @examples/test_async_inference.py at line 95, The variable `tokenizer`
created by calling `AutoTokenizer.from_pretrained(args.model,
trust_remote_code=True, use_fast=True)` in examples/test_async_inference.py is
unused; either remove that line entirely or use `tokenizer` for prompt
validation/token counting before sending requests (e.g., call tokenizer.encode
or tokenizer(model_input) to compute token length) and replace any hardcoded
token assumptions with the computed values; update imports if removed to avoid
unused-import warnings.
| except Exception as e: | ||
| print(f"\n[Error during async inference]") | ||
| print(f"Error: {e}") | ||
| import traceback | ||
| traceback.print_exc() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix f-string without placeholders.
Line 135 uses an f-string but has no placeholders. Remove the f prefix.
🔧 Proposed fix
except Exception as e:
- print(f"\n[Error during async inference]")
+ print("\n[Error during async inference]")
print(f"Error: {e}")
import traceback
traceback.print_exc()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| except Exception as e: | |
| print(f"\n[Error during async inference]") | |
| print(f"Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| except Exception as e: | |
| print("\n[Error during async inference]") | |
| print(f"Error: {e}") | |
| import traceback | |
| traceback.print_exc() |
🧰 Tools
🪛 Ruff (0.14.10)
134-134: Do not catch blind exception: Exception
(BLE001)
135-135: f-string without any placeholders
Remove extraneous f prefix
(F541)
🤖 Prompt for AI Agents
In @examples/test_async_inference.py around lines 134 - 138, The print statement
inside the exception handler uses an f-string with no placeholders; update the
except block (the "except Exception as e:" handler) to remove the unnecessary f
prefix from the first print (change print(f"\n[Error during async inference]")
to a normal string print) so it is not an f-string without placeholders while
keeping the other prints (print(f"Error: {e}") and traceback.print_exc())
unchanged.
Async implementation
Async Implementation (tp_worker.py)
Improved generate_async() to use async methods consistently
Made step_async() thread-safe by running scheduler operations in a single-threaded executor
Added proper async control flow with await asyncio.sleep(0)
New Test File (test_async_inference.py)
Test suite for async inference with fast_dllm_v2 on gsm8k
Configurable parameters and performance metrics
Error handling and cleanup
Fix some small bugs
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Chores
✏️ Tip: You can customize this high-level summary in your review settings.