diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index a64bd382..4ad6383a 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -583,7 +583,7 @@ class FileDataSource(DataSource): def __init__( self, - read_file_fn: Callable[[tf.data.Dataset], tf.data.Dataset], + read_file_fn: Callable[[str], tf.data.Dataset], split_to_filepattern: Mapping[str, Union[str, Iterable[str]]], num_input_examples: Optional[Mapping[str, int]] = None, caching_permitted: bool = True,