From 97870d3e4b3cde534dd4018621d8e601b309f2de Mon Sep 17 00:00:00 2001 From: MaYuhang <2902139028@qq.com> Date: Wed, 14 Jan 2026 23:41:43 +0800 Subject: [PATCH] issue/189: add inference server support to InfiniLM --- README.md | 22 + python/infinilm/__init__.py | 22 +- python/infinilm/llm/__init__.py | 43 ++ python/infinilm/llm/cache_manager.py | 268 +++++++++ python/infinilm/llm/llm.py | 646 +++++++++++++++++++++ python/infinilm/llm/request.py | 231 ++++++++ python/infinilm/llm/sampling_params.py | 35 ++ python/infinilm/llm/scheduler.py | 248 ++++++++ python/infinilm/server/inference_server.py | 467 +++++++++++++++ 9 files changed, 1981 insertions(+), 1 deletion(-) create mode 100644 python/infinilm/llm/__init__.py create mode 100644 python/infinilm/llm/cache_manager.py create mode 100644 python/infinilm/llm/llm.py create mode 100644 python/infinilm/llm/request.py create mode 100644 python/infinilm/llm/sampling_params.py create mode 100644 python/infinilm/llm/scheduler.py create mode 100644 python/infinilm/server/inference_server.py diff --git a/README.md b/README.md index 350d2d9e..db68fc96 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,28 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA python examples/jiuge.py --nvidia --model_path=/models/9G7B_MHA/ --backend=cpp --tp=4 --batch_size=16 ``` + + - 推理服务测试 + - 启动推理服务 + ```bash + python python/infinilm/server/inference_server.py [--cpu | --nvidia | --metax | --moore | --iluvatar | --cambricon] --model_path= --max_tokens=MAX_TOKENS --max_batch_size=MAX_BATCH --tp=NDEV --temperature=TEMP --top_p=TOP_P --top_k=TOP_K --host=HOST --port=PORT + ``` + + - 单卡示例: + ```bash + CUDA_VISIBLE_DEVICES=0 python python/infinilm/server/inference_server.py --nvidia --model_path=/models/9G7B_MHA/ --max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1 + ``` + + - 多卡分布式示例: + ```bash + CUDA_VISIBLE_DEVICES=0,1,2,3 python python/infinilm/server/inference_server.py --nvidia --model_path=/models/9G7B_MHA/ --max_tokens=100 --max_batch_size=32 --tp=4 --temperature=1.0 --top_p=0.8 --top_k=1 + ``` + + - 测试推理服务性能: + ```bash + python scripts/test_perf.py --verbose + ``` + - 运行推理基准测试(C-Eval/MMLU) ```bash diff --git a/python/infinilm/__init__.py b/python/infinilm/__init__.py index 0fbee2ca..e34514a7 100644 --- a/python/infinilm/__init__.py +++ b/python/infinilm/__init__.py @@ -1,5 +1,25 @@ from .models import AutoLlamaModel from . import distributed from . import cache +from . import llm -__all__ = ["AutoLlamaModel", "distributed", "cache"] +from .llm import ( + LLM, + AsyncLLMEngine, + SamplingParams, + RequestOutput, + TokenOutput, +) + +__all__ = [ + "AutoLlamaModel", + "distributed", + "cache", + "llm", + # LLM classes + "LLM", + "AsyncLLMEngine", + "SamplingParams", + "RequestOutput", + "TokenOutput", +] diff --git a/python/infinilm/llm/__init__.py b/python/infinilm/llm/__init__.py new file mode 100644 index 00000000..6af8a5a3 --- /dev/null +++ b/python/infinilm/llm/__init__.py @@ -0,0 +1,43 @@ +""" +InfiniLM Engine - High-performance llm inference engine with batch generation and streaming support. +""" + +from infinilm.llm.sampling_params import SamplingParams +from infinilm.llm.request import ( + RequestStatus, + FinishReason, + RequestOutput, + CompletionOutput, + TokenOutput, + InferenceRequest, +) +from infinilm.llm.llm import ( + LLM, + LLMEngine, + AsyncLLMEngine, + EngineConfig, +) +from infinilm.llm.scheduler import Scheduler, SchedulerOutput +from infinilm.llm.cache_manager import BlockManager, Block + +__all__ = [ + # Main classes + "LLM", + "AsyncLLMEngine", + "LLMEngine", + "EngineConfig", + # Parameters + "SamplingParams", + # Request and Output + "InferenceRequest", + "RequestOutput", + "CompletionOutput", + "TokenOutput", + "RequestStatus", + "FinishReason", + # Internal (for advanced use) + "Scheduler", + "SchedulerOutput", + "BlockManager", + "Block", +] diff --git a/python/infinilm/llm/cache_manager.py b/python/infinilm/llm/cache_manager.py new file mode 100644 index 00000000..1f1e48c8 --- /dev/null +++ b/python/infinilm/llm/cache_manager.py @@ -0,0 +1,268 @@ +""" +KV Cache Manager - Paged Attention block-based cache allocation and management. +""" + +from collections import deque +from typing import List, Dict, Set +import xxhash +import numpy as np + + +class Block: + """KV Cache Block with reference counting and hash-based reuse support.""" + + def __init__(self, block_id: int): + self.block_id = block_id + self.ref_count = 0 + self.hash = -1 + self.token_ids: List[int] = [] + + def update(self, hash_value: int, token_ids: List[int]) -> None: + self.hash = hash_value + self.token_ids = token_ids.copy() + + def reset(self) -> None: + self.ref_count = 1 + self.hash = -1 + self.token_ids = [] + + def free(self) -> None: + self.ref_count = 0 + self.hash = -1 + self.token_ids = [] + + def __repr__(self) -> str: + return f"Block(id={self.block_id}, ref={self.ref_count}, hash={self.hash})" + + +class BlockManager: + """Manages Paged KV Cache allocation with prefix caching support. + + Features: + - Block allocation/deallocation with reference counting + - Hash-based prefix caching for token sequence reuse + - Slot mapping generation for physical-to-logical position mapping + """ + + def __init__(self, num_blocks: int, block_size: int): + assert ( + num_blocks > 0 and block_size > 0 + ), "num_blocks and block_size must be positive" + self.num_blocks = num_blocks + self.block_size = block_size + + self.blocks: List[Block] = [Block(i) for i in range(num_blocks)] + self.hash_to_block_id: Dict[int, int] = {} + self.free_block_ids: deque = deque(range(num_blocks)) + self.used_block_ids: Set[int] = set() + self.req_block_ids: Set[int] = set() + + def reset_req_blocks(self) -> None: + """Move blocks from prefill stage to used blocks and update hash mappings.""" + for block_id in self.req_block_ids: + self.used_block_ids.add(block_id) + block = self.blocks[block_id] + prefix_hash = block.hash + self.hash_to_block_id[prefix_hash] = block_id + self.req_block_ids.clear() + + @classmethod + def compute_hash(cls, token_ids: List[int], prefix_hash: int = -1) -> int: + """Compute hash for token sequence with optional prefix chaining.""" + h = xxhash.xxh64() + if prefix_hash != -1: + h.update(prefix_hash.to_bytes(8, "little")) + h.update(np.array(token_ids, dtype=np.int32).tobytes()) + return h.intdigest() + + def _allocate_partial_block(self, block_id: int) -> Block: + """Allocate an incomplete block and add to used blocks.""" + assert block_id in self.free_block_ids, f"Block {block_id} not in free list" + block = self.blocks[block_id] + assert block.ref_count == 0, f"Block {block_id} ref_count not zero" + + block.reset() + self.free_block_ids.remove(block_id) + self.used_block_ids.add(block_id) + return block + + def _allocate_full_block(self, block_id: int) -> Block: + """Allocate a complete block and add to request blocks.""" + assert block_id in self.free_block_ids, f"Block {block_id} not in free list" + block = self.blocks[block_id] + assert block.ref_count == 0, f"Block {block_id} ref_count not zero" + + block.reset() + self.free_block_ids.remove(block_id) + self.req_block_ids.add(block_id) + return block + + def _deallocate_block(self, block_id: int): + """Deallocate a block and return it to free list.""" + block = self.blocks[block_id] + assert ( + block.ref_count == 0 + ), f"Block {block_id} ref_count not zero, cannot deallocate" + + if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id: + del self.hash_to_block_id[block.hash] + + block.free() + self.used_block_ids.remove(block_id) + self.free_block_ids.append(block_id) + + def can_allocate(self, num_required_blocks: int) -> bool: + return len(self.free_block_ids) >= num_required_blocks + + def allocate_blocks( + self, token_ids: List[int], block_table: List[int] = None + ) -> tuple[List[int], List[int], int]: + """Allocate cache blocks for new request with prefix caching support. + + Args: + token_ids: Input token sequence + block_table: Existing block_table (for decode phase) + + Returns: + Tuple of (block_table, slot_mapping, num_cached_tokens) + """ + if block_table is None: + block_table = [] + + num_tokens = len(token_ids) + num_blocks = (num_tokens + self.block_size - 1) // self.block_size + slot_mapping = [] + num_cached_tokens = 0 + prefix_hash = -1 + cache_miss = False + + for block_idx in range(num_blocks): + start_idx = block_idx * self.block_size + end_idx = min(start_idx + self.block_size, num_tokens) + block_tokens = token_ids[start_idx:end_idx] + + # Only full blocks can be hashed for reuse + if len(block_tokens) == self.block_size: + prefix_hash = self.compute_hash(block_tokens, prefix_hash) + + # Try to reuse existing block + if not cache_miss: + cached_block_id = self.hash_to_block_id.get(prefix_hash, -1) + if ( + cached_block_id != -1 + and self.blocks[cached_block_id].token_ids == block_tokens + ): + # Check if all tokens are cached + if num_cached_tokens + self.block_size == len(token_ids): + cache_miss = True + else: + # Reuse successful + block = self.blocks[cached_block_id] + block.ref_count += 1 + block_table.append(cached_block_id) + num_cached_tokens += self.block_size + continue + else: + cache_miss = True + else: + prefix_hash = -1 + + # Cannot reuse, allocate new block + if not self.free_block_ids: + raise RuntimeError("No available cache blocks") + + new_block_id = self.free_block_ids[0] + if prefix_hash != -1: + block = self._allocate_full_block(new_block_id) + block.update(prefix_hash, block_tokens) + else: + block = self._allocate_partial_block(new_block_id) + block_table.append(new_block_id) + + # Generate slot_mapping + for i in range(len(block_tokens)): + slot_mapping.append(new_block_id * self.block_size + i) + + return block_table, slot_mapping, num_cached_tokens + + def append_slot( + self, block_table: List[int], num_tokens: int, total_token_ids: List[int] = None + ) -> tuple[List[int], int]: + """Append slot for decode phase (generate one new token). + + Args: + block_table: Current block_table + num_tokens: Current total token count (including newly generated token) + total_token_ids: All token sequence (for updating block hash) + + Returns: + Tuple of (block_table, slot_id) + """ + assert len(block_table) > 0, "block_table cannot be empty" + assert num_tokens > 0, "num_tokens must be greater than 0" + + if num_tokens % self.block_size == 1: + # Previous block is full, update its hash for future prefix caching + last_block_id = block_table[-1] + last_block = self.blocks[last_block_id] + + # Only update if block's token_ids is empty (avoid duplicate updates) + if len(last_block.token_ids) == 0: + block_start_idx = num_tokens - self.block_size - 1 + block_end_idx = num_tokens - 1 + block_tokens = total_token_ids[block_start_idx:block_end_idx] + + # Compute prefix_hash using previous block's hash if available + if len(block_table) > 1: + prev_block = self.blocks[block_table[-2]] + prefix_hash = prev_block.hash + else: + prefix_hash = -1 + + current_hash = self.compute_hash(block_tokens, prefix_hash) + last_block.update(current_hash, block_tokens) + self.hash_to_block_id[current_hash] = last_block_id + + # Need new block + if not self.free_block_ids: + if not self.try_free_blocks(1): + raise RuntimeError("No available cache blocks") + new_block_id = self.free_block_ids[0] + self._allocate_partial_block(new_block_id) + block_table.append(new_block_id) + + # Calculate slot + last_block_id = block_table[-1] + offset = (num_tokens - 1) % self.block_size + slot_id = last_block_id * self.block_size + offset + + return block_table, slot_id + + def free_blocks(self, block_table: List[int]): + """Decrease reference count for all blocks. Blocks with ref_count=0 are not + immediately freed to allow reuse.""" + for block_id in reversed(block_table): + block = self.blocks[block_id] + block.ref_count -= 1 + + def try_free_blocks(self, num_required: int) -> bool: + """Try to free blocks with ref_count=0.""" + to_free = [ + bid for bid in self.used_block_ids if self.blocks[bid].ref_count == 0 + ] + + for block_id in to_free: + self._deallocate_block(block_id) + if self.can_allocate(num_required): + return True + + return self.can_allocate(num_required) + + def get_num_free_blocks(self) -> int: + return len(self.free_block_ids) + + def __repr__(self): + return ( + f"BlockManager(blocks={self.num_blocks}, block_size={self.block_size}, " + f"free={len(self.free_block_ids)}, used={len(self.used_block_ids)})" + ) diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py new file mode 100644 index 00000000..c152d6e4 --- /dev/null +++ b/python/infinilm/llm/llm.py @@ -0,0 +1,646 @@ +""" +LLM Engine - Main interface for LLM inference. + +This module provides: +- LLM class for batch generation (offline use) +- AsyncLLM class for asynchronous streaming (server use) +""" + +import time +import uuid +import logging +import threading +from typing import List, Optional, Union, AsyncIterator +from dataclasses import dataclass + +import infinicore + +from infinilm.llm.request import ( + InferenceRequest, + RequestOutput, + TokenOutput, + FinishReason, +) +from infinilm.llm.sampling_params import SamplingParams +from infinilm.llm.scheduler import Scheduler + +from infinilm.distributed import DistConfig +from infinilm.infer_engine import InferEngine +from infinilm.cache.cache import PagedKVCacheConfig +from infinilm.modeling_utils import load_model_state_dict_by_file +from transformers import AutoTokenizer +from tokenizers import decoders as _dec + +logger = logging.getLogger(__name__) + + +@dataclass +class EngineConfig: + """Configuration for LLM Engine. + + Attributes: + model_path: Path to the model directory. + device: Device type string ('cpu', 'cuda', 'mlu', etc.). + dtype: Data type string ('float16', 'bfloat16', 'float32'). + tensor_parallel_size: Number of devices for tensor parallelism. + max_batch_size: Maximum batch size for inference. + max_tokens: Default maximum tokens to generate. + num_blocks: Number of KV cache blocks. + block_size: Size of each KV cache block. + temperature: Default sampling temperature. + top_p: Default top-p sampling parameter. + top_k: Default top-k sampling parameter. + """ + + model_path: str + device: str = "cuda" + dtype: str = "float16" + tensor_parallel_size: int = 1 + max_batch_size: int = 16 + max_tokens: int = 4096 + num_blocks: int = 8 * 1024 + block_size: int = 16 + temperature: float = 1.0 + top_p: float = 0.8 + top_k: int = 1 + + +class LLMEngine: + """Low-level LLM engine that handles inference execution.""" + + def __init__(self, config: EngineConfig): + self.config = config + + # Initialize device and dtype + self._init_device() + + # Initialize model engine + self.model_engine = InferEngine( + model_path=config.model_path, + device=self.device, + distributed_config=DistConfig(config.tensor_parallel_size), + ) + + # Load model weights + load_model_state_dict_by_file( + self.model_engine, config.model_path, dtype=self.model_engine.config.dtype + ) + + # Initialize tokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + config.model_path, trust_remote_code=True + ) + self._fix_tokenizer_decoder() + + # Initialize KV cache + cache_config = PagedKVCacheConfig( + num_blocks=config.num_blocks, block_size=config.block_size + ) + self.model_engine.reset_cache(cache_config) + + # Initialize scheduler + self.scheduler = Scheduler( + max_batch_size=config.max_batch_size, + num_blocks=config.num_blocks, + block_size=config.block_size, + ) + + # Get EOS token IDs from model config + self.eos_token_ids = self.model_engine.config.eos_token_id or [] + if isinstance(self.eos_token_ids, int): + self.eos_token_ids = [self.eos_token_ids] + + logger.info( + f"LLMEngine initialized with model at {config.model_path} " + f"on device {config.device}" + ) + + def _init_device(self): + """Initialize infinicore device and dtype.""" + supported_devices = ["cpu", "cuda", "mlu", "moore"] + device_str = self.config.device + if device_str not in supported_devices: + raise ValueError( + f"Unsupported device: '{device_str}'. " + f"Supported devices: {supported_devices}" + ) + self.device = infinicore.device(device_str, 0) + + dtype_map = { + "float32": infinicore.float32, + "float16": infinicore.float16, + "bfloat16": infinicore.bfloat16, + } + + if self.config.dtype not in dtype_map: + raise ValueError( + f"Unsupported dtype: '{self.config.dtype}'. " + f"Supported dtypes: {list(dtype_map.keys())}" + ) + + self.dtype = dtype_map[self.config.dtype] + + def _fix_tokenizer_decoder(self): + """Fix tokenizer decoder for llama models.""" + if "llama" in self.model_engine.config.model_type.lower(): + backend = getattr(self.tokenizer, "backend_tokenizer", None) + target = getattr(backend, "_tokenizer", backend) + norm = getattr(target, "normalizer", None) + dec = getattr(target, "decoder", None) + sn = repr(norm)[:800] if norm is not None else "" + sd = repr(dec)[:800] if dec is not None else "" + has_prepend = "Prepend" in sn + has_strip = "Strip" in sd + if has_prepend and has_strip: + target.decoder = _dec.Sequence( + [ + _dec.Replace("▁", " "), + _dec.ByteFallback(), + _dec.Fuse(), + ] + ) + + def add_request(self, request: InferenceRequest): + """Add a request to the scheduler.""" + self.scheduler.add_request(request) + + def step(self) -> List[InferenceRequest]: + """Run one inference step. + + Returns: + List of requests that were processed in this step. + """ + # Schedule requests + scheduler_output = self.scheduler.schedule() + if scheduler_output is None or not scheduler_output.scheduled_requests: + return [] + + # Build model inputs + model_input_dict = scheduler_output.build_model_inputs( + self.config.temperature, self.config.top_p, self.config.top_k + ) + model_input = self._prepare_model_input(model_input_dict) + + # Run inference + sampled_tokens = self.model_engine.forward(**model_input) + sampled_tokens_list = sampled_tokens.to_numpy().tolist() + + # Update request status + self._update_requests( + scheduler_output.is_prefill, + scheduler_output.scheduled_requests, + sampled_tokens_list, + ) + + return scheduler_output.scheduled_requests + + def _prepare_model_input(self, model_input_dict: dict) -> dict: + """Convert model input dict to infinicore tensors.""" + model_input = {} + for key, value in model_input_dict.items(): + if key == "input_ids": + model_input[key] = infinicore.from_list([value], dtype=infinicore.int64) + elif key in [ + "position_ids", + "past_kv_lengths", + "total_kv_lengths", + "input_offsets", + "slot_mapping", + ]: + model_input[key] = infinicore.from_list(value, dtype=infinicore.int64) + elif key == "block_tables": + model_input[key] = infinicore.from_list(value, dtype=infinicore.int64) + else: + model_input[key] = value + return model_input + + def _update_requests( + self, + is_prefill: bool, + requests: List[InferenceRequest], + sampled_tokens: List[int], + ): + """Update request status after inference step.""" + if is_prefill: + self.scheduler.cache_manager.reset_req_blocks() + + for req, token_id in zip(requests, sampled_tokens): + req.generated_token_ids.append(token_id) + if req.is_prefill: + req.is_prefill = False + + token_text = self.tokenizer.decode(token_id) + req.generated_text += token_text + + if self._check_request_finished(req, token_id): + req.mark_finished(req.finish_reason) + + # Put output in queue if it exists (for async streaming) + if req._output_queue is not None: + output = TokenOutput( + request_id=req.request_id, + token_id=token_id, + token_text=token_text, + finished=req.is_finished(), + finish_reason=req.finish_reason, + generated_text=req.generated_text, + ) + req.output_queue.sync_q.put(output) + + self.scheduler.complete_requests(requests) + + def _check_request_finished(self, req: InferenceRequest, token_id: int) -> bool: + """Check if request generation is finished.""" + max_tokens = req.sampling_params.max_tokens + if max_tokens and req.get_num_generated_tokens() >= max_tokens: + req.finish_reason = FinishReason.LENGTH + return True + + # Check EOS token + eos_ids = req.eos_token_ids or self.eos_token_ids + if eos_ids and token_id in eos_ids: + req.finish_reason = FinishReason.EOS_TOKEN + return True + + # Check stop strings + stop_strings = req.sampling_params.stop or [] + for stop_str in stop_strings: + if req.generated_text.endswith(stop_str): + req.finish_reason = FinishReason.STOP_STRING + return True + + return False + + def tokenize(self, text: str) -> List[int]: + """Tokenize text to token IDs.""" + return self.tokenizer.encode(text) + + def detokenize(self, token_ids: List[int]) -> str: + """Detokenize token IDs to text.""" + return self.tokenizer.decode(token_ids) + + def apply_chat_template( + self, + messages: List[dict], + add_generation_prompt: bool = True, + ) -> str: + """Apply chat template to messages.""" + return self.tokenizer.apply_chat_template( + conversation=messages, + add_generation_prompt=add_generation_prompt, + tokenize=False, + ) + + +class LLM: + """High-level LLM interface for batch generation.""" + + def __init__( + self, + model_path: str, + device: str = "cuda", + dtype: str = "float16", + tensor_parallel_size: int = 1, + max_batch_size: int = 16, + max_tokens: int = 4096, + num_blocks: int = 8 * 1024, + block_size: int = 16, + temperature: float = 1.0, + top_p: float = 0.8, + top_k: int = 1, + ): + """Initialize LLM. + + Args: + model_path: Path to the model directory. + device: Device type ('cpu', 'cuda', 'mlu', 'moore'). + dtype: Data type ('float16', 'bfloat16', 'float32'). + tensor_parallel_size: Number of devices for tensor parallelism. + max_batch_size: Maximum batch size for inference. + max_tokens: Default maximum tokens to generate. + num_blocks: Number of KV cache blocks. + block_size: Size of each KV cache block. + temperature: Default sampling temperature. + top_p: Default top-p sampling parameter. + top_k: Default top-k sampling parameter. + """ + config = EngineConfig( + model_path=model_path, + device=device, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + max_batch_size=max_batch_size, + max_tokens=max_tokens, + num_blocks=num_blocks, + block_size=block_size, + temperature=temperature, + top_p=top_p, + top_k=top_k, + ) + self.engine = LLMEngine(config) + self.config = config + + def generate( + self, + prompts: Union[str, List[str]], + sampling_params: Optional[SamplingParams] = None, + use_tqdm: bool = True, + ) -> List[RequestOutput]: + """Generate completions for the given prompts. + + Args: + prompts: A single prompt string or list of prompt strings. + sampling_params: Sampling parameters for generation. + use_tqdm: Whether to show progress bar. + + Returns: + List of RequestOutput objects containing generated text. + """ + if isinstance(prompts, str): + prompts = [prompts] + + if sampling_params is None: + sampling_params = SamplingParams(max_tokens=self.config.max_tokens) + elif sampling_params.max_tokens is None: + sampling_params = sampling_params.clone() + sampling_params.max_tokens = self.config.max_tokens + + requests = [] + for prompt in prompts: + request_id = f"cmpl-{uuid.uuid4().hex}" + token_ids = self.engine.tokenize(prompt) + req = InferenceRequest( + request_id=request_id, + prompt=prompt, + prompt_token_ids=token_ids, + sampling_params=sampling_params, + eos_token_ids=self.engine.eos_token_ids, + ) + requests.append(req) + self.engine.add_request(req) + + # Run inference until all requests are finished + if use_tqdm: + try: + from tqdm import tqdm + + pbar = tqdm(total=len(requests), desc="Generating") + except ImportError: + pbar = None + use_tqdm = False + else: + pbar = None + + finished_count = 0 + while finished_count < len(requests): + self.engine.step() + + new_finished = sum(1 for req in requests if req.is_finished()) + if use_tqdm and pbar and new_finished > finished_count: + pbar.update(new_finished - finished_count) + finished_count = new_finished + + if pbar: + pbar.close() + + outputs = [req.to_request_output() for req in requests] + return outputs + + def chat( + self, + messages: Union[List[dict], List[List[dict]]], + sampling_params: Optional[SamplingParams] = None, + use_tqdm: bool = True, + ) -> List[RequestOutput]: + """Generate chat completions for the given messages. + + Args: + messages: A single conversation (list of message dicts) or + a list of conversations. + sampling_params: Sampling parameters for generation. + use_tqdm: Whether to show progress bar. + + Returns: + List of RequestOutput objects containing generated responses. + """ + if messages and isinstance(messages[0], dict): + messages = [messages] + + prompts = [] + for conversation in messages: + prompt = self.engine.apply_chat_template( + conversation, add_generation_prompt=True + ) + prompts.append(prompt) + + return self.generate(prompts, sampling_params, use_tqdm) + + +class AsyncLLMEngine: + """Asynchronous LLM engine for server use with streaming support.""" + + def __init__( + self, + model_path: str, + device: str = "cuda", + dtype: str = "float16", + tensor_parallel_size: int = 1, + max_batch_size: int = 16, + max_tokens: int = 512, + num_blocks: int = 8 * 1024, + block_size: int = 16, + temperature: float = 1.0, + top_p: float = 0.8, + top_k: int = 1, + ): + """Initialize AsyncLLMEngine. + + Args: + model_path: Path to the model directory. + device: Device type ('cpu', 'cuda', 'mlu', 'moore'). + dtype: Data type ('float16', 'bfloat16', 'float32'). + tensor_parallel_size: Number of devices for tensor parallelism. + max_batch_size: Maximum batch size for inference. + max_tokens: Default maximum tokens to generate. + num_blocks: Number of KV cache blocks. + block_size: Size of each KV cache block. + temperature: Default sampling temperature. + top_p: Default top-p sampling parameter. + top_k: Default top-k sampling parameter. + """ + config = EngineConfig( + model_path=model_path, + device=device, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + max_batch_size=max_batch_size, + max_tokens=max_tokens, + num_blocks=num_blocks, + block_size=block_size, + temperature=temperature, + top_p=top_p, + top_k=top_k, + ) + self.engine = LLMEngine(config) + self.config = config + + self._running = False + self._step_thread: Optional[threading.Thread] = None + + def start(self): + """Start the background inference loop.""" + if self._running: + logger.warning("AsyncLLMEngine is already running") + return + + self._running = True + self._step_thread = threading.Thread( + target=self._step_loop, daemon=True, name="AsyncLLMEngineStepThread" + ) + self._step_thread.start() + logger.info("AsyncLLMEngine started") + + def stop(self): + """Stop the background inference loop.""" + if not self._running: + logger.warning("AsyncLLMEngine is not running") + return + + self._running = False + if self._step_thread: + self._step_thread.join(timeout=5) + logger.info("AsyncLLMEngine stopped") + + def _step_loop(self): + """Background loop that runs inference steps.""" + while self._running: + try: + requests = self.engine.step() + if not requests: + time.sleep(0.01) + except Exception as e: + logger.error(f"Error in step loop: {e}", exc_info=True) + self._running = False + break + + def add_request( + self, + prompt: Optional[str] = None, + prompt_token_ids: Optional[List[int]] = None, + sampling_params: Optional[SamplingParams] = None, + request_id: Optional[str] = None, + # For server use + request_data: Optional[dict] = None, + http_request: Optional[any] = None, + ) -> InferenceRequest: + """Add a request to the engine. + + Args: + prompt: Text prompt for generation. + prompt_token_ids: Pre-tokenized prompt. + sampling_params: Sampling parameters. + request_id: Optional request ID. + request_data: Optional request data dict (for server use). + http_request: Optional HTTP request object (for server use). + + Returns: + The created InferenceRequest object. + """ + if request_id is None: + request_id = f"cmpl-{uuid.uuid4().hex}" + + if prompt_token_ids is None and prompt is not None: + prompt_token_ids = self.engine.tokenize(prompt) + + if sampling_params is None: + sampling_params = SamplingParams(max_tokens=self.config.max_tokens) + elif sampling_params.max_tokens is None: + sampling_params = sampling_params.clone() + sampling_params.max_tokens = self.config.max_tokens + + request = InferenceRequest( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + eos_token_ids=self.engine.eos_token_ids, + request_data=request_data, + http_request=http_request, + ) + + # Initialize output queue for streaming + _ = request.output_queue + + self.engine.add_request(request) + return request + + def add_chat_request( + self, + messages: List[dict], + sampling_params: Optional[SamplingParams] = None, + request_id: Optional[str] = None, + request_data: Optional[dict] = None, + http_request: Optional[any] = None, + ) -> InferenceRequest: + """Add a chat request to the engine. + + Args: + messages: List of message dicts (chat conversation). + sampling_params: Sampling parameters. + request_id: Optional request ID. + request_data: Optional request data dict. + http_request: Optional HTTP request object. + + Returns: + The created InferenceRequest object. + """ + prompt = self.engine.apply_chat_template(messages, add_generation_prompt=True) + return self.add_request( + prompt=prompt, + sampling_params=sampling_params, + request_id=request_id, + request_data=request_data, + http_request=http_request, + ) + + async def stream_request( + self, + request: InferenceRequest, + timeout: float = 100.0, + ) -> AsyncIterator[TokenOutput]: + """Stream tokens from a request. + + Args: + request: The inference request to stream from. + timeout: Timeout for waiting on each token. + + Yields: + TokenOutput objects for each generated token. + """ + import asyncio + + while True: + if request.is_finished() and request.output_queue.async_q.empty(): + break + + try: + token_output = await asyncio.wait_for( + request.output_queue.async_q.get(), timeout=timeout + ) + + request.output_queue.async_q.task_done() + + yield token_output + + if token_output.finished: + break + except asyncio.TimeoutError: + if request.is_finished(): + break + continue + except asyncio.CancelledError: + request.mark_canceled() + break + except Exception as e: + logger.error(f"Error streaming request {request.request_id}: {e}") + await asyncio.sleep(0.01) diff --git a/python/infinilm/llm/request.py b/python/infinilm/llm/request.py new file mode 100644 index 00000000..d6e08aef --- /dev/null +++ b/python/infinilm/llm/request.py @@ -0,0 +1,231 @@ +""" +Request and Output - Data structures for inference requests and outputs. +""" + +from enum import Enum +from dataclasses import dataclass, field +from typing import List, Optional, Any +import time +import janus + +from infinilm.llm.sampling_params import SamplingParams + + +class RequestStatus(Enum): + """Status of an inference request.""" + + WAITING = "waiting" + RUNNING = "running" + FINISHED = "finished" + CANCELED = "canceled" + FAILED = "failed" + TIMEOUT = "timeout" + + +class FinishReason(Enum): + """Reason for finishing generation.""" + + STOP = "stop" + LENGTH = "length" + EOS_TOKEN = "eos_token" + STOP_STRING = "stop_string" + TIMEOUT = "timeout" + CANCELED = "canceled" + ERROR = "error" + + +@dataclass +class RequestOutput: + """Output from a single generation request. + + Attributes: + request_id: Unique identifier for the request. + prompt: Original prompt text. + prompt_token_ids: Token IDs of the prompt. + outputs: List of generated outputs (for beam search, multiple outputs possible). + finished: Whether generation is complete. + finish_reason: Reason for finishing. + """ + + request_id: str + prompt: Optional[str] = None + prompt_token_ids: Optional[List[int]] = None + outputs: List["CompletionOutput"] = field(default_factory=list) + finished: bool = False + finish_reason: Optional[FinishReason] = None + + +@dataclass +class CompletionOutput: + """Single completion output. + + Attributes: + index: Index of this output (for beam search). + text: Generated text. + token_ids: Generated token IDs. + finish_reason: Reason for finishing. + """ + + index: int = 0 + text: str = "" + token_ids: List[int] = field(default_factory=list) + finish_reason: Optional[FinishReason] = None + + +@dataclass +class TokenOutput: + """Output for a single generated token. + + Attributes: + request_id: Unique identifier for the request. + token_id: Generated token ID. + token_text: Decoded text of the token. + finished: Whether generation is complete. + finish_reason: Reason for finishing. + generated_text: Full generated text so far. + """ + + request_id: str + token_id: int + token_text: str + finished: bool = False + finish_reason: Optional[FinishReason] = None + generated_text: str = "" + + +class InferenceRequest: + """Internal inference request object for managing generation state and resources.""" + + def __init__( + self, + request_id: str, + prompt: Optional[str] = None, + prompt_token_ids: Optional[List[int]] = None, + sampling_params: Optional[SamplingParams] = None, + eos_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + # For server use + request_data: Optional[dict] = None, + http_request: Optional[Any] = None, + ): + # Request metadata + self.request_id: str = request_id + self.prompt: Optional[str] = prompt + self.prompt_token_ids: List[int] = prompt_token_ids or [] + self.prompt_length: int = len(self.prompt_token_ids) + self.arrival_time: float = arrival_time or time.time() + self.finished_time: Optional[float] = None + + # Sampling parameters + self.sampling_params: SamplingParams = sampling_params or SamplingParams() + + # EOS token IDs (from model config) + self.eos_token_ids: List[int] = eos_token_ids or [] + + # Generation state + self.generated_token_ids: List[int] = [] + self.generated_text: str = "" + self.is_prefill: bool = True + self.status: RequestStatus = RequestStatus.WAITING + self.finish_reason: Optional[FinishReason] = None + self.priority: int = 0 + + # KV cache management + self.cache_id: Optional[int] = None + self.block_table: List[int] = [] + self.slot_mapping: List[int] = [] + self.num_cached_tokens: int = 0 + self.num_blocks: int = 0 + + # For server use + self.request_data: Optional[dict] = request_data + self.http_request: Optional[Any] = http_request + + # Output management (for async streaming) + self._output_queue: Optional[janus.Queue] = None + + @property + def output_queue(self) -> janus.Queue: + """Lazy initialization of output queue.""" + if self._output_queue is None: + self._output_queue = janus.Queue() + return self._output_queue + + def get_prompt_length(self) -> int: + return self.prompt_length + + def get_input_tokens(self) -> List[int]: + return self.prompt_token_ids + + def get_num_generated_tokens(self) -> int: + return len(self.generated_token_ids) + + def get_total_length(self) -> int: + return self.prompt_length + len(self.generated_token_ids) + + def get_all_token_ids(self) -> List[int]: + return self.prompt_token_ids + self.generated_token_ids + + def get_num_blocks_required(self, block_size: int) -> int: + total_tokens = self.get_total_length() + return (total_tokens + block_size - 1) // block_size + + def get_max_tokens(self) -> Optional[int]: + return self.sampling_params.max_tokens + + def is_finished(self) -> bool: + return self.status in [ + RequestStatus.FINISHED, + RequestStatus.CANCELED, + RequestStatus.FAILED, + RequestStatus.TIMEOUT, + ] + + def mark_finished(self, reason: FinishReason): + """Mark the request as finished with the given reason.""" + self.status = RequestStatus.FINISHED + self.finish_reason = reason + self.finished_time = time.time() + + def mark_failed(self, reason: FinishReason = FinishReason.ERROR): + """Mark the request as failed.""" + self.status = RequestStatus.FAILED + self.finish_reason = reason + self.finished_time = time.time() + + def mark_canceled(self): + """Mark the request as canceled.""" + self.status = RequestStatus.CANCELED + self.finish_reason = FinishReason.CANCELED + self.finished_time = time.time() + + def mark_timeout(self): + """Mark the request as timed out.""" + self.status = RequestStatus.TIMEOUT + self.finish_reason = FinishReason.TIMEOUT + self.finished_time = time.time() + + async def close(self): + """Close the output queue and clean up resources.""" + if self._output_queue is not None: + await self._output_queue.async_q.join() + self._output_queue.close() + await self._output_queue.wait_closed() + + def to_request_output(self) -> RequestOutput: + """Convert to RequestOutput for external use.""" + return RequestOutput( + request_id=self.request_id, + prompt=self.prompt, + prompt_token_ids=self.prompt_token_ids, + outputs=[ + CompletionOutput( + index=0, + text=self.generated_text, + token_ids=self.generated_token_ids.copy(), + finish_reason=self.finish_reason, + ) + ], + finished=self.is_finished(), + finish_reason=self.finish_reason, + ) diff --git a/python/infinilm/llm/sampling_params.py b/python/infinilm/llm/sampling_params.py new file mode 100644 index 00000000..6a0aed58 --- /dev/null +++ b/python/infinilm/llm/sampling_params.py @@ -0,0 +1,35 @@ +""" +Sampling Parameters - Configuration for text generation sampling. +""" + +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class SamplingParams: + """Sampling parameters for text generation.""" + + temperature: float = 1.0 + top_p: float = 0.8 + top_k: int = 1 + max_tokens: Optional[int] = None + stop: Optional[List[str]] = None + stop_token_ids: Optional[List[int]] = None + + def __post_init__(self): + if self.stop is None: + self.stop = [] + if self.stop_token_ids is None: + self.stop_token_ids = [] + + def clone(self) -> "SamplingParams": + """Create a copy of this SamplingParams instance.""" + return SamplingParams( + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + max_tokens=self.max_tokens, + stop=self.stop.copy() if self.stop else None, + stop_token_ids=self.stop_token_ids.copy() if self.stop_token_ids else None, + ) diff --git a/python/infinilm/llm/scheduler.py b/python/infinilm/llm/scheduler.py new file mode 100644 index 00000000..b1853292 --- /dev/null +++ b/python/infinilm/llm/scheduler.py @@ -0,0 +1,248 @@ +""" +Scheduler - Request scheduling and batch management with Paged Attention KV Cache. +""" + +import queue +import janus +import logging +from typing import List, Optional +from infinilm.llm.request import RequestStatus, InferenceRequest +from infinilm.llm.cache_manager import BlockManager + +logger = logging.getLogger(__name__) + + +class SchedulerOutput: + """Scheduler output containing scheduled requests and execution phase info.""" + + def __init__( + self, + scheduled_requests: List[InferenceRequest], + is_prefill: bool = False, + ): + self.scheduled_requests = scheduled_requests + self.num_requests = len(scheduled_requests) + self.is_prefill = is_prefill + + def build_model_inputs( + self, temperature: float = 1.0, top_p: float = 0.8, top_k: int = 1 + ): + """Construct model inputs for prefill or decode phase. + + Prefill phase: + - input_ids: Flattened token list (excluding cached tokens) + - position_ids: Position IDs for new tokens in complete sequence + - past_kv_lengths: Number of cached tokens per request + - total_kv_lengths: Total tokens (cached + new) per request + - input_offsets: Start position of each request in flattened array + - block_tables: Padded block_table for each request + - slot_mapping: Token to slot mappings + + Decode phase: + - input_ids: Only last generated token per request + - position_ids: Position of last token in complete sequence + - past_kv_lengths: Number of cached tokens per request + - total_kv_lengths: Total sequence length per request + - input_offsets: Offsets for each request + - block_tables: Padded block_table for each request + - slot_mapping: Single slot per request + """ + if not self.scheduled_requests: + raise RuntimeError( + "build_model_inputs called with empty scheduled_requests" + ) + + tokens = [] + seq_lens = [] + seq_offsets = [0] + block_tables = [] + slot_mapping = [] + cached_lens = [] + position_ids = [] + + max_block_table_len = max( + len(req.block_table) for req in self.scheduled_requests + ) + current_offset = 0 + + for req in self.scheduled_requests: + num_cached = req.num_cached_tokens + if self.is_prefill: + # Prefill phase + req_tokens = req.get_input_tokens() + tokens_to_compute = req_tokens[num_cached:] + tokens.extend(tokens_to_compute) + + seq_len = len(tokens_to_compute) + seq_lens.append(len(req_tokens)) + + current_offset += seq_len + seq_offsets.append(current_offset) + + slot_mapping.extend(req.slot_mapping) + cached_lens.append(num_cached) + position_ids.extend(range(num_cached, num_cached + seq_len)) + + else: + # Decode phase + last_token = req.generated_token_ids[-1] + tokens.append(last_token) + seq_lens.append(req.get_total_length()) + + current_offset += 1 + seq_offsets.append(current_offset) + + slot_mapping.extend(req.slot_mapping) + cached_lens.append(num_cached) + position_ids.append(req.get_total_length() - 1) + + # Pad block_table to same length + padded_block_table = req.block_table + [-1] * ( + max_block_table_len - len(req.block_table) + ) + block_tables.append(padded_block_table) + + return { + "input_ids": tokens, + "position_ids": position_ids, + "past_kv_lengths": cached_lens, + "total_kv_lengths": seq_lens, + "input_offsets": seq_offsets, + "block_tables": block_tables, + "slot_mapping": slot_mapping, + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + } + + +class Scheduler: + """Request scheduler with integrated BlockManager for KV cache management. + + Scheduling logic: + 1. Running queue: Check for new blocks needed, update slot_mapping + 2. Waiting queue: Try block reuse (prefix caching), allocate new blocks + 3. Reference counting: Free blocks when requests complete + """ + + def __init__( + self, + max_batch_size: int = 16, + num_blocks: int = 8 * 1024, + block_size: int = 16, + ): + self.waiting_queue = janus.Queue() + self.running_queue = janus.Queue() + self.max_batch_size = max_batch_size + + self.cache_manager = BlockManager(num_blocks=num_blocks, block_size=block_size) + self.block_size = block_size + + def add_request(self, request: InferenceRequest): + if request is not None: + request.status = RequestStatus.WAITING + self.waiting_queue.sync_q.put(request) + + def schedule(self) -> Optional[SchedulerOutput]: + """Schedule and return batch of requests to execute.""" + scheduled_requests = [] + is_prefill = False + + # Process Waiting queue (prefill phase) + while len(scheduled_requests) < self.max_batch_size: + try: + req = self.waiting_queue.sync_q.get_nowait() + except queue.Empty: + break + + req_tokens = req.get_input_tokens() + num_required_blocks = req.get_num_blocks_required(self.block_size) + + if not self.cache_manager.can_allocate(num_required_blocks): + if not self.cache_manager.try_free_blocks(num_required_blocks): + raise RuntimeError("No available cache blocks") + + # Allocate blocks with automatic prefix caching support + req.block_table, req.slot_mapping, req.num_cached_tokens = ( + self.cache_manager.allocate_blocks(req_tokens, req.block_table) + ) + + req.num_blocks = len(req.block_table) + req.status = RequestStatus.RUNNING + scheduled_requests.append(req) + + # Return prefill batch if any waiting requests were scheduled + if scheduled_requests: + is_prefill = True + return SchedulerOutput( + scheduled_requests=scheduled_requests, + is_prefill=is_prefill, + ) + + # Process Running queue (decode phase) + while len(scheduled_requests) < self.max_batch_size: + try: + req = self.running_queue.sync_q.get_nowait() + except queue.Empty: + break + + # Decode phase: allocate slot for newly generated token + try: + req.block_table, new_slot = self.cache_manager.append_slot( + req.block_table, req.get_total_length(), req.get_all_token_ids() + ) + req.slot_mapping = [new_slot] + req.num_blocks = len(req.block_table) + req.num_cached_tokens = req.get_total_length() - 1 + scheduled_requests.append(req) + + except RuntimeError as e: + raise RuntimeError("No available cache blocks") from e + + # Return decode batch if any running requests were scheduled + if scheduled_requests: + is_prefill = False + return SchedulerOutput( + scheduled_requests=scheduled_requests, + is_prefill=is_prefill, + ) + + return None + + def complete_requests(self, requests: List[InferenceRequest]): + """Handle completed requests and free their blocks.""" + for req in requests: + if req.status in [ + RequestStatus.FINISHED, + RequestStatus.CANCELED, + RequestStatus.FAILED, + RequestStatus.TIMEOUT, + ]: + if req.block_table: + self.cache_manager.free_blocks(req.block_table) + + if req.status == RequestStatus.CANCELED: + logger.info( + f"Request {req.request_id[:8]}... canceled: {req.finish_reason}" + ) + elif req.status == RequestStatus.FAILED: + logger.error( + f"Request {req.request_id[:8]}... failed: {req.finish_reason}" + ) + elif req.status == RequestStatus.TIMEOUT: + logger.error( + f"Request {req.request_id[:8]}... timed out: {req.finish_reason}" + ) + else: + # Still running, put back in running queue + self.running_queue.sync_q.put(req) + + def get_cache_stats(self) -> dict: + """Get cache statistics.""" + return { + "num_blocks": self.cache_manager.num_blocks, + "block_size": self.cache_manager.block_size, + "num_free_blocks": self.cache_manager.get_num_free_blocks(), + "num_req_blocks": len(self.cache_manager.req_block_ids), + "num_used_blocks": len(self.cache_manager.used_block_ids), + } diff --git a/python/infinilm/server/inference_server.py b/python/infinilm/server/inference_server.py new file mode 100644 index 00000000..99e1988d --- /dev/null +++ b/python/infinilm/server/inference_server.py @@ -0,0 +1,467 @@ +""" +Inference Server - HTTP API server for LLM inference. +""" + +from contextlib import asynccontextmanager +import sys +import time +import json +import uuid +import argparse +import uvicorn +import logging + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from infinilm.llm import AsyncLLMEngine, SamplingParams, FinishReason + +logger = logging.getLogger(__name__) + +DEFAULT_STREAM_TIMEOUT = 100.0 +DEFAULT_REQUEST_TIMEOUT = 1000.0 + + +def chunk_json(id_, content=None, role=None, finish_reason=None): + """Generate JSON chunk for streaming response.""" + delta = {} + if content: + delta["content"] = content + if role: + delta["role"] = role + return { + "id": id_, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "jiuge", + "system_fingerprint": None, + "choices": [ + { + "index": 0, + "text": content, + "delta": delta, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + } + + +class InferenceServer: + """HTTP server for LLM inference.""" + + def __init__( + self, + model_path: str, + device: str = "cuda", + dtype: str = "float16", + tensor_parallel_size: int = 1, + max_tokens: int = 4096, + max_batch_size: int = 16, + num_blocks: int = 8 * 1024, + block_size: int = 16, + temperature: float = 1.0, + top_p: float = 0.8, + top_k: int = 1, + host: str = "0.0.0.0", + port: int = 8000, + ): + """Initialize inference server. + + Args: + model_path: Path to the model directory. + device: Device type ('cpu', 'cuda', 'mlu', 'moore'). + dtype: Data type ('float16', 'bfloat16', 'float32'). + tensor_parallel_size: Number of devices for tensor parallelism. + max_tokens: Default maximum tokens to generate. + max_batch_size: Maximum batch size for inference. + num_blocks: Number of KV cache blocks. + block_size: Size of each KV cache block. + temperature: Default sampling temperature. + top_p: Default top-p sampling parameter. + top_k: Default top-k sampling parameter. + host: Server host address. + port: Server port number. + """ + self.model_path = model_path + self.device = device + self.dtype = dtype + self.tensor_parallel_size = tensor_parallel_size + self.max_tokens = max_tokens + self.max_batch_size = max_batch_size + self.num_blocks = num_blocks + self.block_size = block_size + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.host = host + self.port = port + + self.engine: AsyncLLMEngine = None + + def start(self): + """Start the HTTP server.""" + app = self._create_app() + logger.info(f"Starting API Server at {self.host}:{self.port}...") + uvicorn.run(app, host=self.host, port=self.port) + logger.info("Inference Server stopped") + + def _create_app(self): + """Create FastAPI application.""" + + @asynccontextmanager + async def lifespan(app: FastAPI): + self.engine = AsyncLLMEngine( + model_path=self.model_path, + device=self.device, + dtype=self.dtype, + tensor_parallel_size=self.tensor_parallel_size, + max_batch_size=self.max_batch_size, + max_tokens=self.max_tokens, + num_blocks=self.num_blocks, + block_size=self.block_size, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + ) + self.engine.start() + logger.info(f"Engine initialized with model at {self.model_path}") + yield + self.engine.stop() + + app = FastAPI(lifespan=lifespan) + self._register_routes(app) + return app + + def _register_routes(self, app: FastAPI): + """Register API routes.""" + + @app.post("/chat/completions") + async def chat_completions(request: Request): + try: + data = await request.json() + logger.debug(f"Received request data: {data}") + except Exception as e: + logger.error(f"Failed to parse request JSON: {e}") + return JSONResponse(content={"error": "Invalid JSON"}, status_code=400) + + if not data.get("messages"): + if not data.get("prompt"): + return JSONResponse( + content={"error": "No message provided"}, status_code=400 + ) + else: + data["messages"] = [{"role": "user", "content": data.get("prompt")}] + + stream = data.get("stream", False) + request_id = f"cmpl-{uuid.uuid4().hex}" + + if stream: + return StreamingResponse( + self._stream_chat(request_id, data, request), + media_type="text/event-stream", + ) + else: + response = await self._chat(request_id, data, request) + if isinstance(response, JSONResponse): + return response + return JSONResponse(content=response) + + @app.get("/health") + async def health(): + return {"status": "healthy"} + + @app.get("/v1/models") + async def list_models(): + return { + "object": "list", + "data": [ + { + "id": "jiuge", + "object": "model", + "created": int(time.time()), + "owned_by": "infinilm", + } + ], + } + + def _build_sampling_params(self, data: dict) -> SamplingParams: + """Build SamplingParams from request data.""" + return SamplingParams( + temperature=data.get("temperature", self.temperature), + top_p=data.get("top_p", self.top_p), + top_k=data.get("top_k", self.top_k), + max_tokens=data.get("max_tokens", self.max_tokens), + stop=data.get("stop"), + ) + + async def _stream_chat(self, request_id: str, data: dict, http_request: Request): + """Handle streaming chat request.""" + req = None + start_time = time.time() + + try: + messages = data.get("messages", []) + sampling_params = self._build_sampling_params(data) + + req = self.engine.add_chat_request( + messages=messages, + sampling_params=sampling_params, + request_id=request_id, + request_data=data, + http_request=http_request, + ) + + async for token_output in self.engine.stream_request( + req, timeout=DEFAULT_STREAM_TIMEOUT + ): + # Check timeout + if time.time() - start_time > DEFAULT_REQUEST_TIMEOUT: + logger.warning( + f"Request {request_id} timed out after {DEFAULT_REQUEST_TIMEOUT}s" + ) + req.mark_timeout() + error_chunk = json.dumps( + chunk_json( + request_id, + content="[Request timeout]", + finish_reason="timeout", + ), + ensure_ascii=False, + ) + yield f"data: {error_chunk}\n\n" + break + + # Check client disconnect + if await http_request.is_disconnected(): + logger.info(f"Client disconnected for request {request_id}") + req.mark_canceled() + break + + # Send token + chunk = json.dumps( + chunk_json(request_id, content=token_output.token_text), + ensure_ascii=False, + ) + yield f"data: {chunk}\n\n" + + if token_output.finished: + finish_reason = self._convert_finish_reason( + token_output.finish_reason + ) + chunk = json.dumps( + chunk_json(request_id, finish_reason=finish_reason), + ensure_ascii=False, + ) + yield f"data: {chunk}\n\n" + break + + except Exception as e: + logger.error(f"Stream error for {request_id}: {e}", exc_info=True) + if req: + req.mark_failed() + error_chunk = json.dumps( + chunk_json( + request_id, content=f"[Error: {str(e)}]", finish_reason="error" + ), + ensure_ascii=False, + ) + yield f"data: {error_chunk}\n\n" + + finally: + if req and not req.is_finished(): + req.mark_canceled() + if req: + await req.close() + yield "data: [DONE]\n\n" + + async def _chat(self, request_id: str, data: dict, http_request: Request): + """Handle non-streaming chat request.""" + req = None + start_time = time.time() + + try: + messages = data.get("messages", []) + sampling_params = self._build_sampling_params(data) + + req = self.engine.add_chat_request( + messages=messages, + sampling_params=sampling_params, + request_id=request_id, + request_data=data, + http_request=http_request, + ) + + # Collect all generated tokens + output_text = "" + async for token_output in self.engine.stream_request( + req, timeout=DEFAULT_STREAM_TIMEOUT + ): + # Check timeout + if time.time() - start_time > DEFAULT_REQUEST_TIMEOUT: + logger.warning(f"Request {request_id} timed out") + req.mark_timeout() + break + + # Check client disconnect + if await http_request.is_disconnected(): + logger.info(f"Client disconnected for request {request_id}") + req.mark_canceled() + break + + output_text += token_output.token_text + + if token_output.finished: + break + + output_text = output_text.strip() + finish_reason = self._convert_finish_reason(req.finish_reason) + + response = chunk_json( + request_id, + content=output_text, + role="assistant", + finish_reason=finish_reason or "stop", + ) + return response + + except Exception as e: + logger.error(f"Chat error for {request_id}: {e}", exc_info=True) + if req: + req.mark_failed() + return JSONResponse(content={"error": str(e)}, status_code=500) + + finally: + if req and not req.is_finished(): + req.mark_canceled() + if req: + await req.close() + + def _convert_finish_reason(self, reason: FinishReason) -> str: + """Convert FinishReason enum to string.""" + if reason is None: + return None + if reason in (FinishReason.EOS_TOKEN, FinishReason.STOP_STRING): + return "stop" + + return reason.value + + +def setup_logging(log_level: str = "INFO"): + """Configure logging system with proper formatting and handlers.""" + log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + date_format = "%Y-%m-%d %H:%M:%S" + + logging.basicConfig( + level=getattr(logging, log_level.upper(), logging.INFO), + format=log_format, + datefmt=date_format, + handlers=[ + logging.StreamHandler(sys.stdout), + ], + force=True, + ) + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="InfiniLM Inference Server") + parser.add_argument( + "--model_path", type=str, required=True, help="Path to model directory" + ) + parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism degree") + parser.add_argument( + "--max_tokens", + type=int, + default=512, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--max_batch_size", type=int, default=8, help="Maximum batch size" + ) + parser.add_argument( + "--num_blocks", type=int, default=8 * 1024, help="Number of blocks for KV cache" + ) + parser.add_argument( + "--block_size", type=int, default=16, help="Block size for KV cache" + ) + parser.add_argument( + "--dtype", + type=str, + default="float16", + choices=["float32", "float16", "bfloat16"], + help="Data type", + ) + parser.add_argument( + "--temperature", type=float, default=1.0, help="Sampling temperature" + ) + parser.add_argument( + "--top_p", type=float, default=0.8, help="Top-p sampling parameter" + ) + parser.add_argument("--top_k", type=int, default=1, help="Top-k sampling parameter") + parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host") + parser.add_argument("--port", type=int, default=8000, help="Server port") + parser.add_argument("--cpu", action="store_true", help="Use CPU") + parser.add_argument("--nvidia", action="store_true", help="Use NVIDIA GPU") + parser.add_argument("--metax", action="store_true", help="Use MetaX device") + parser.add_argument("--moore", action="store_true", help="Use Moore device") + parser.add_argument("--iluvatar", action="store_true", help="Use Iluvatar device") + parser.add_argument("--cambricon", action="store_true", help="Use Cambricon device") + parser.add_argument( + "--log_level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Logging level", + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + + setup_logging(args.log_level) + + if args.cpu: + device = "cpu" + elif args.nvidia: + device = "cuda" + elif args.metax: + device = "cuda" + elif args.moore: + device = "moore" + elif args.iluvatar: + device = "cuda" + elif args.cambricon: + device = "mlu" + else: + print( + "Usage: python infinilm.server.inference_server [--cpu | --nvidia | --metax | --moore | --iluvatar | --cambricon] " + "--model_path= --max_tokens=MAX_TOKENS --max_batch_size=MAX_BATCH_SIZE" + "\n" + "Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ " + "--max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1" + ) + sys.exit(1) + + server = InferenceServer( + model_path=args.model_path, + device=device, + dtype=args.dtype, + tensor_parallel_size=args.tp, + max_tokens=args.max_tokens, + max_batch_size=args.max_batch_size, + num_blocks=args.num_blocks, + block_size=args.block_size, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + host=args.host, + port=args.port, + ) + server.start() + + +if __name__ == "__main__": + main()