-
Notifications
You must be signed in to change notification settings - Fork 241
Description
Disclaimer: This is less an issue, and more an opportunity to discuss/document a little rabbit-hole that I went down.
I was recently prompted by the deprecation of the class interface to start transitioning to the function wrappers that are encouraged in the deprecation warnings, but in doing so lost many of my IDE (VS Code) code intelligence features (tooltips, autocomplete, syntax highlighting, etc) because of the lack of type information propagated by the wrappers. Plus, the smarter LLM agents that would ordinarily try to check call signatures by making cheap tool calls (to a type checker/LSP), instead burn tokens reading files to get that information.
The "easy" problem: define_function_from_class()
For example, in spikeinterface/extractors/neoextractors/spikeglx.py, read_spikeglx is dynamically generated using define_function_from_class, which means that static analysis tools can't infer the return types easily:
read_spikeglx = define_function_from_class(source_class=SpikeGLXRecordingExtractor, name="read_spikeglx")Once consequence is that if I call se.read_spikeglx() and hover over it in my IDE, instead of getting arguments, return types, etc, I just get a tooltip that says "Unknown" spikeinterface.extractors.extractor_classes module, but then I started thinking about what it would take to help out the users who choose to stick with the function wrappers.
I can see 4 options:
- Just assign the class directly:
read_spikeglx = SpikeGLXRecordingExtractorI am not sure why the no-op define_function_from_class() is there, but I'm sure there's a good reason. Possibly a historical reason, or for consistency with define_function_handling_dict_from_class(), basically "reserving the right" to do something like actually wrap the class, rather than just "softly" rename it.
- Explicit type annotation of the wrapper:
read_spikeglx: type[SpikeGLXRecordingExtractor] = define_function_from_class(...)But this won't propagate information about arguments. It tells the static analyzer that read_spikeglx is the class SpikeGLXRecordingExtractor, but it doesn't expose the constructor signature or docstring when hovering over the function name itself. It also requires edits in all 61 (by my count) places where define_function_from_class() is used.
Along these lines, you could of course give up define_function_from_class() and make actual wrappers:
def read_spikeglx(
folder_path: Path | str,
load_sync_channel: bool = False,
stream_id: str | None,
...
) -> type[SpikeGLXRecordingExtractor]:
return SpikeGLXRecordingExtractor(folder_path, load_sync_channel, stream_id, ...)Much like option 1 (direct assignment) and comes with most of its benefits, but actually changes the name (e.g. in tooltips), at the cost of increased maintenance (duplicates the signature).
- Use generic types from
typingto annotatedefine_function_from_class()itself:
_P = ParamSpec("_P")
_T = TypeVar("_T")
def define_function_from_class(source_class: Callable[_P, _T], name: str) -> Callable[_P, _T]:
return source_classThis has the advantage of propagating type and signature, and only has to be made in 1 place. It has the minor limitation that it doesn't help with syntax highlighting, and __name__ isn't set, but in order to do that I think you'd need a real wrapper function, and then you'd lose the actual class identity (you can't change name on the source class and have it be the same class), which maybe you care about.
- Full-on
.pyistubs
# spikeinterface/src/spikeinterface/extractors/neoextractors/spikeglx.pyi
from pathlib import Path
from spikeinterface.extractors.neoextractors.neobaseextractor import (
NeoBaseRecordingExtractor,
NeoBaseEventExtractor,
)
class SpikeGLXRecordingExtractor(NeoBaseRecordingExtractor):
NeoRawIOClass: str
def __init__(
self,
folder_path: str | Path,
load_sync_channel: bool = False,
stream_id: str | None = None,
stream_name: str | None = None,
all_annotations: bool = False,
use_names_as_ids: bool = False,
) -> None: ...
@classmethod
def map_to_neo_kwargs(
cls, folder_path: str | Path, load_sync_channel: bool = False
) -> dict[str, str | bool]: ...
def read_spikeglx(
folder_path: str | Path,
load_sync_channel: bool = False,
stream_id: str | None = None,
stream_name: str | None = None,
all_annotations: bool = False,
use_names_as_ids: bool = False,
) -> SpikeGLXRecordingExtractor:
"""
Class for reading data saved by SpikeGLX software.
See https://billkarsh.github.io/SpikeGLX/
Based on :py:class:`neo.rawio.SpikeGLXRawIO`
Contrary to older verions, this reader is folder-based.
If the folder contains several streams (e.g., "imec0.ap", "nidq" ,"imec0.lf"),
then the stream has to be specified with "stream_id" or "stream_name".
Parameters
----------
folder_path : str
The folder path to load the recordings from.
load_sync_channel : bool default: False
Whether or not to load the last channel in the stream, which is typically used for synchronization.
If True, then the probe is not loaded.
stream_id : str or None, default: None
If there are several streams, specify the stream id you want to load.
For example, "imec0.ap", "nidq", or "imec0.lf".
stream_name : str or None, default: None
If there are several streams, specify the stream name you want to load.
all_annotations : bool, default: False
Load exhaustively all annotations from neo.
use_names_as_ids : bool, default: False
Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the
names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO.
Examples
--------
>>> from spikeinterface.extractors import read_spikeglx
>>> recording = read_spikeglx(folder_path=r'path_to_folder_with_data', load_sync_channel=False)
# we can load the sync channel, but then the probe is not loaded
>>> recording = read_spikeglx(folder_path=r'pat_to_folder_with_data', load_sync_channel=True)
"""
...Full support, zero change to runtime code (.py) needed, but obnoxiously burdensome. Lots more files, and to preserve a single point of truth about function arguments, docstrings, etc. you would probably want a pre-commit step to generate these stubs automatically.
Personally, I'd favor option 3: It's like 2 additional lines of code and gets you most of the benefit.
The hard problem: define_function_handling_dict_from_class()
Clever, and evil (from a static analysis perspective). It creates wrappers that accept either:
- A single
BaseRecording, and returns a single preprocessor instance - A
dict[str, BaseRecording], and returns adict[str, preprocessor]
So it's polymorphic based on input type. We are basically overloading functions like bandpass_filter(), but refusing to use the language tools for overloading (typing.overload). The options are kind of analogous to the options from the easy problem:
-
Except there is no solution directly analogous to "direct assignment".
-
Overload the wrapper functions where they are declared. There are lots of ways to do this, but the one that I think results in the least amount of duplication is:
@overload
def bandpass_filter(
recording: BaseRecording,
freq_min: float = ...,
freq_max: float = ...,
margin_ms: float | str = ...,
) -> BandpassFilterRecording: ...
@overload
def bandpass_filter(
recording: dict[str, BaseRecording],
freq_min: float = ...,
freq_max: float = ...,
margin_ms: float | str = ...,
) -> dict[str, BandpassFilterRecording]: ...
def bandpass_filter(recording, freq_min=300.0, freq_max=6000.0, margin_ms="auto", **filter_kwargs):
"""Docstring here..."""
return _bandpass_filter_impl(recording, freq_min, freq_max, margin_ms, **filter_kwargs)Except that's still a lot of duplication/boilerplate, and now your default keyword argument values are duplicated across both in the wrapper function definition and the actual class definition (or you remove defaults from the class constructor definition, but I don't think anyone wants that).
- Use generic types from
typingto annotatedefine_function_handling_dict_from_class()itself:
_P = ParamSpec("_P")
_T = TypeVar("_T")
def define_function_handling_dict_from_class(
source_class: Callable[_P, _T], name: str
) -> Callable[..., _T | dict[str, _T]]:
"Docstring here..."
from spikeinterface.core import BaseRecording
def source_class_or_dict_of_sources_classes(*args, **kwargs) -> _T | dict[str, _T]:
# no change to inner functionUnfortunately, this is no longer a good option, because it won't give you meaningful parameter hints, and there is no way to express to the type checker that the return type depends on the input type, so you can never narrow the union type down. That latter limitation means that if you declare in your own code that, say, a function which calls bandpass_filter() returns a BaseRecording, type checkers will go nuts when they can't reconcile this with the union type.
I tried to overload define_function_handling_dict_from_class(), and combine it with this approach
_P = ParamSpec("_P")
_T = TypeVar("_T", bound="BaseRecording")
@overload
def define_function_handling_dict_from_class(
source_class: Callable[..., _T], name: str
) -> Callable[..., _T | dict[str, _T]]: ...
def define_function_handling_dict_from_class(
source_class: Callable[_P, _T], name: str
) -> Callable[..., _T | dict[str, _T]]:
...but realized that this cannot work because both overloads have the same input signature (source_class: Callable[_P, _T], name: str), so the type checker has no way to distinguish which one to use. Another way of saying this is that the polymorphism happens at the returned function's call site, not at the factory's call site. So I really don't think you can put the overloads anywhere other than on the wrapper function itself, like in option 2.
- Full-on stubs again. You can put the
@overloaddefinitions in the stubs, but not the.py, if you generate the stubs automatically. Yay, everything works, but now you've got thousands of lines of stub code.
There really is no winning with the hard problem! Of the options, I personally think that option 2 (and duplicating the default keyword argument values across the wrapper and the class constructor) is best.
Summary
The use of dynamically created function wrappers limits the ability for type checkers and LSPs to perform static analysis, which limits the ability of users to use code intelligence features. For wrappers created with define_function_from_class(), most of this can be alleviated with a few couple lines of code (I can submit a PR if you want) by using generics from typing. Unfortunately, for wrappers created with define_function_handling_dict_from_class(), I can't see any easy solution. It's either do overloading the pythonic way, or make your own type stubs. I'm not suggesting any action, because every option is disruptive to some degree. But I thought it might be helpful to have this all spelled out somewhere, in case there is appetite to tackle it in the future.