diff --git a/dlclive/__init__.py b/dlclive/__init__.py index 3209cbe..2a7a3a9 100644 --- a/dlclive/__init__.py +++ b/dlclive/__init__.py @@ -17,6 +17,21 @@ from dlclive.processor.processor import Processor from dlclive.version import VERSION, __version__ + +def benchmark_videos(*args, **kwargs): + """Lazy import of benchmark_videos from dlclive.benchmark""" + from dlclive.benchmark import benchmark_videos as _benchmark_videos + + return _benchmark_videos(*args, **kwargs) + + +def download_benchmarking_data(*args, **kwargs): + """Lazy import of benchmark_videos from dlclive.benchmark""" + from dlclive.benchmark import download_benchmarking_data as _download_benchmarking_data + + return _download_benchmarking_data(*args, **kwargs) + + __all__ = [ "DLCLive", "Display", diff --git a/dlclive/benchmark.py b/dlclive/benchmark.py index 8c71b1b..0227f20 100644 --- a/dlclive/benchmark.py +++ b/dlclive/benchmark.py @@ -19,14 +19,13 @@ import colorcet as cc import cv2 import numpy as np -import torch from PIL import ImageColor from pip._internal.operations import freeze from tqdm import tqdm from dlclive import VERSION, DLCLive from dlclive.engine import Engine -from dlclive.utils import decode_fourcc +from dlclive.utils import decode_fourcc, get_torch if TYPE_CHECKING: import tensorflow # type: ignore @@ -266,8 +265,9 @@ def get_system_info() -> dict: # Not installed from git repo, e.g., pypi pass - # Get device info (GPU or CPU) - if torch.cuda.is_available(): + # Get device info (GPU or CPU). Torch is optional. + torch = get_torch(required=False) + if torch is not None and torch.cuda.is_available(): dev_type = "GPU" dev = [torch.cuda.get_device_name(torch.cuda.current_device())] else: diff --git a/dlclive/factory.py b/dlclive/factory.py index 9cf6462..ed43ba7 100644 --- a/dlclive/factory.py +++ b/dlclive/factory.py @@ -7,6 +7,7 @@ from dlclive.core.runner import BaseRunner from dlclive.engine import Engine +from dlclive.utils import get_tensorflow, get_torch def build_runner( @@ -36,12 +37,14 @@ def build_runner( """ if Engine.from_model_type(model_type) == Engine.PYTORCH: + get_torch(required=True, feature="PyTorch inference") from dlclive.pose_estimation_pytorch.runner import PyTorchRunner valid = {"device", "precision", "single_animal", "dynamic", "top_down_config"} return PyTorchRunner(model_path, **filter_keys(valid, kwargs)) elif Engine.from_model_type(model_type) == Engine.TENSORFLOW: + get_tensorflow(required=True, feature="TensorFlow inference") from dlclive.pose_estimation_tensorflow.runner import TensorFlowRunner if model_type.lower() == "tensorflow": diff --git a/dlclive/live_inference.py b/dlclive/live_inference.py index 7a4ffd1..fac0805 100644 --- a/dlclive/live_inference.py +++ b/dlclive/live_inference.py @@ -15,11 +15,11 @@ import colorcet as cc import cv2 import h5py -import torch from PIL import ImageColor from pip._internal.operations import freeze from dlclive import VERSION, DLCLive +from dlclive.utils import get_torch def get_system_info() -> dict: @@ -60,8 +60,9 @@ def get_system_info() -> dict: # Not installed from git repo, e.g., pypi pass - # Get device info (GPU or CPU) - if torch.cuda.is_available(): + # Get device info (GPU or CPU). Torch is optional. + torch = get_torch(required=False) + if torch is not None and torch.cuda.is_available(): dev_type = "GPU" dev = [torch.cuda.get_device_name(torch.cuda.current_device())] else: @@ -309,6 +310,9 @@ def save_poses_to_files(experiment_name, save_dir, bodyparts, poses, timestamp): csv_save_path = os.path.join(save_dir, f"{base_filename}_poses_{timestamp}.csv") h5_save_path = os.path.join(save_dir, f"{base_filename}_poses_{timestamp}.h5") + torch = get_torch(required=False) + tensor_type = torch.Tensor if torch is not None else () + # Save to CSV with open(csv_save_path, mode="w", newline="") as file: writer = csv.writer(file) @@ -319,7 +323,7 @@ def save_poses_to_files(experiment_name, save_dir, bodyparts, poses, timestamp): pose_data = entry["pose"]["poses"][0][0] # Convert tensor data to numeric values row = [frame_num] + [ - item.item() if isinstance(item, torch.Tensor) else item for kp in pose_data for item in kp + item.item() if isinstance(item, tensor_type) else item for kp in pose_data for item in kp ] writer.writerow(row) @@ -332,7 +336,7 @@ def save_poses_to_files(experiment_name, save_dir, bodyparts, poses, timestamp): data=[ ( entry["pose"]["poses"][0][0][i, 0].item() - if isinstance(entry["pose"]["poses"][0][0][i, 0], torch.Tensor) + if isinstance(entry["pose"]["poses"][0][0][i, 0], tensor_type) else entry["pose"]["poses"][0][0][i, 0] ) for entry in poses @@ -343,7 +347,7 @@ def save_poses_to_files(experiment_name, save_dir, bodyparts, poses, timestamp): data=[ ( entry["pose"]["poses"][0][0][i, 1].item() - if isinstance(entry["pose"]["poses"][0][0][i, 1], torch.Tensor) + if isinstance(entry["pose"]["poses"][0][0][i, 1], tensor_type) else entry["pose"]["poses"][0][0][i, 1] ) for entry in poses @@ -354,7 +358,7 @@ def save_poses_to_files(experiment_name, save_dir, bodyparts, poses, timestamp): data=[ ( entry["pose"]["poses"][0][0][i, 2].item() - if isinstance(entry["pose"]["poses"][0][0][i, 2], torch.Tensor) + if isinstance(entry["pose"]["poses"][0][0][i, 2], tensor_type) else entry["pose"]["poses"][0][0][i, 2] ) for entry in poses diff --git a/dlclive/utils.py b/dlclive/utils.py index 8a6acde..98fe45f 100644 --- a/dlclive/utils.py +++ b/dlclive/utils.py @@ -18,6 +18,54 @@ from dlclive.exceptions import DLCLiveWarning +def get_torch(required: bool = False, feature: str | None = None): + """Lazily import torch. + + Args: + required: If True, raise a clear error when torch is unavailable. + feature: Optional feature name to include in error messages. + + Returns: + The imported torch module, or None when unavailable and not required. + """ + try: + import torch + + return torch + except (ImportError, ModuleNotFoundError) as exc: + if required: + context = f" for {feature}" if feature else "" + raise ModuleNotFoundError( + f"PyTorch is required{context} but is not installed. " + "Install it with: pip install deeplabcut-live[pytorch]" + ) from exc + return None + + +def get_tensorflow(required: bool = False, feature: str | None = None): + """Lazily import tensorflow. + + Args: + required: If True, raise a clear error when tensorflow is unavailable. + feature: Optional feature name to include in error messages. + + Returns: + The imported tensorflow module, or None when unavailable and not required. + """ + try: + import tensorflow as tf + + return tf + except (ImportError, ModuleNotFoundError) as exc: + if required: + context = f" for {feature}" if feature else "" + raise ModuleNotFoundError( + f"TensorFlow is required{context} but is not installed. " + "Install it with: pip install deeplabcut-live[tf]" + ) from exc + return None + + def convert_to_ubyte(frame: np.ndarray) -> np.ndarray: """Converts an image to unsigned 8-bit integer numpy array. If scikit-image is installed, uses skimage.img_as_ubyte, otherwise, uses a similar custom function. @@ -180,19 +228,11 @@ def get_available_backends() -> list[Engine]: """ backends = [] - try: - import tensorflow - + if get_tensorflow(required=False) is not None: backends.append(Engine.TENSORFLOW) - except (ImportError, ModuleNotFoundError): - pass - - try: - import torch + if get_torch(required=False) is not None: backends.append(Engine.PYTORCH) - except (ImportError, ModuleNotFoundError): - pass if not backends: warnings.warn(