diff --git a/main/como/cluster_rnaseq.py b/main/como/cluster_rnaseq.py index 1078b926..214910d0 100644 --- a/main/como/cluster_rnaseq.py +++ b/main/como/cluster_rnaseq.py @@ -8,7 +8,7 @@ import numpy as np from como.data_types import LogLevel -from como.utils import _log_and_raise_error, stringlist_to_list +from como.utils import log_and_raise_error, stringlist_to_list @dataclass @@ -35,35 +35,35 @@ def __post_init__(self): # noqa: C901, ignore too complex self.seed = np.random.randint(0, 100_000) if (isdigit(self.min_active_count) and int(self.min_active_count) < 0) or self.min_active_count != "default": - _log_and_raise_error( + log_and_raise_error( "min_active_count must be either 'default' or an integer > 0", error=ValueError, level=LogLevel.ERROR, ) if (isdigit(self.quantile) and 0 > int(self.quantile) > 100) or self.quantile != "default": - _log_and_raise_error( + log_and_raise_error( "quantile must be either 'default' or an integer between 0 and 100", error=ValueError, level=LogLevel.ERROR, ) if (isdigit(self.replicate_ratio) and 0 > self.replicate_ratio > 1.0) or self.replicate_ratio != "default": - _log_and_raise_error( + log_and_raise_error( "--rep-ratio must be either 'default' or a float between 0 and 1", error=ValueError, level=LogLevel.ERROR, ) if (isdigit(self.batch_ratio) and 0 > self.batch_ratio > 1.0) or self.batch_ratio != "default": - _log_and_raise_error( + log_and_raise_error( "--batch-ratio must be either 'default' or a float between 0 and 1", error=ValueError, level=LogLevel.ERROR, ) if self.filtering_technique.lower() not in {"quantile", "tpm", "cpm", "zfpkm"}: - _log_and_raise_error( + log_and_raise_error( "--technique must be either 'quantile', 'tpm', 'cpm', 'zfpkm'", error=ValueError, level=LogLevel.ERROR, @@ -73,35 +73,35 @@ def __post_init__(self): # noqa: C901, ignore too complex self.filtering_technique = "quantile" if self.cluster_algorithm.lower() not in {"mca", "umap"}: - _log_and_raise_error( + log_and_raise_error( "--clust_algo must be either 'mca', 'umap'", error=ValueError, level=LogLevel.ERROR, ) if 0 > self.min_distance > 1.0: - _log_and_raise_error( + log_and_raise_error( "--min_dist must be a float between 0 and 1", error=ValueError, level=LogLevel.ERROR, ) if (isdigit(self.num_replicate_neighbors) and self.num_replicate_neighbors < 1) or self.num_replicate_neighbors != "default": - _log_and_raise_error( + log_and_raise_error( "--n-neighbors-rep must be either 'default' or an integer > 1", error=ValueError, level=LogLevel.ERROR, ) if (isdigit(self.num_batch_neighbors) and self.num_batch_neighbors < 1) or self.num_batch_neighbors != "default": - _log_and_raise_error( + log_and_raise_error( "--n-neighbors-batch must be either 'default' or an integer > 1", error=ValueError, level=LogLevel.ERROR, ) if (isdigit(self.num_context_neighbors) and self.num_context_neighbors < 1) or self.num_context_neighbors != "default": - _log_and_raise_error( + log_and_raise_error( "--n-neighbors-context must be either 'default' or an integer > 1", error=ValueError, level=LogLevel.ERROR, diff --git a/main/como/combine_distributions.py b/main/como/combine_distributions.py index 9b87ea47..2fb38462 100644 --- a/main/como/combine_distributions.py +++ b/main/como/combine_distributions.py @@ -18,7 +18,7 @@ _OutputCombinedSourceFilepath, _SourceWeights, ) -from como.utils import LogLevel, _log_and_raise_error, _num_columns +from como.utils import LogLevel, log_and_raise_error, _num_columns, get_missing_gene_data def _combine_z_distribution_for_batch( @@ -186,7 +186,7 @@ def _combine_z_distribution_for_context( for res in zscore_results: matrix = res.z_score_matrix.copy() if len(matrix.columns) > 1: - _log_and_raise_error( + log_and_raise_error( f"Expected a single column for combined z-score dataframe for data '{res.type.value.lower()}'. Got '{len(matrix.columns)}' columns", error=ValueError, level=LogLevel.ERROR, @@ -302,7 +302,7 @@ async def _begin_combining_distributions( else "" ) if not index_name: - _log_and_raise_error( + log_and_raise_error( f"Unable to find common gene identifier across batches for source '{source.value}' in context '{context_name}'", error=ValueError, level=LogLevel.ERROR, diff --git a/main/como/create_context_specific_model.py b/main/como/create_context_specific_model.py index cd771c4b..a0410e48 100644 --- a/main/como/create_context_specific_model.py +++ b/main/como/create_context_specific_model.py @@ -25,8 +25,23 @@ from troppo.methods.reconstruction.imat import IMAT, IMATProperties from troppo.methods.reconstruction.tINIT import tINIT, tINITProperties -from como.data_types import Algorithm, CobraCompartments, LogLevel, Solver, _BoundaryReactions, _BuildResults -from como.utils import _log_and_raise_error, _read_file, _set_up_logging, split_gene_expression_data +from como.data_types import ( + Algorithm, + BuildResults, + CobraCompartments, + LogLevel, + ModelBuildSettings, + Solver, + _BoundaryReactions, +) +from como.utils import log_and_raise_error, set_up_logging, split_gene_expression_data + + +def _reaction_indices_to_ids( + ref_model: cobra.Model, reaction_indices: Sequence[int] | npt.NDArray[np.integer] +) -> list[str]: + rxns = list(ref_model.reactions) + return [rxns[int(i)].id for i in reaction_indices] def _correct_bracket(rule: str, name: str) -> str: @@ -203,11 +218,28 @@ def _build_with_gimme( return model_reconstruction -def _build_with_fastcore(cobra_model, s_matrix, lower_bounds, upper_bounds, exp_idx_list, solver): +def _build_with_fastcore( + reference_model: cobra.Model, + lower_bounds: npt.NDArray[np.floating], + upper_bounds: npt.NDArray[np.floating], + exp_idx_list: Sequence[int], + solver: str, +): # 'Vlassis, Pacheco, Sauter (2014). Fast reconstruction of compact # context-specific metabolic network models. PLoS Comput. Biol. 10, # e1003424.' - logger.warning("Fastcore requires a flux consistant model is used as refererence, to achieve this fastcc is required which is NOT reproducible.") + model = reference_model + logger.warning( + "Fastcore requires a flux consistant model is used as refererence. " + "To achieve this, fastcc is required, which is NOT reproducible." + ) + s_matrix = cast(npt.NDArray[np.floating], cobra.util.create_stoichiometric_matrix(model=model)) + if lower_bounds.shape[0] != upper_bounds.shape[0] != s_matrix.shape[1]: + log_and_raise_error( + message="Lower bounds, upper bounds, and stoichiometric matrix must have the same number of reactions.", + error=ValueError, + level=LogLevel.ERROR, + ) logger.debug("Creating feasible model") _, cobra_model = _feasibility_test(cobra_model, "other") properties = FastcoreProperties(core=exp_idx_list, solver=solver) @@ -269,6 +301,8 @@ def _build_with_tinit( solver, idx_force, ) -> Model: + log_and_raise_error("tINIT is not yet implemented.", error=NotImplementedError, level=LogLevel.CRITICAL) + model = reference_model properties = tINITProperties( reactions_scores=expr_vector, solver=solver, @@ -284,10 +318,44 @@ def _build_with_tinit( algorithm.build_problem() _log_and_raise_error("tINIT is not yet implemented.", error=NotImplementedError, level=LogLevel.CRITICAL) +def _build_with_corda( + reference_model: cobra.Model, + neg_expression_threshold: float, + high_expression_threshold: float, + lower_bounds: npt.NDArray[np.floating], + upper_bounds: npt.NDArray[np.floating], + expression_vector: Sequence[float] | npt.NDArray[np.floating], +): + """Reconstruct a model using CORDA. + + :param neg_expression_threshold: Reactions expressed below this value will be placed in "negative" expression bin + :param high_expression_threshold: Reactions expressed above this value will be placed in the "high" expression bin + """ + log_and_raise_error("CORDA is not yet implemented", error=NotImplementedError, level=LogLevel.CRITICAL) + model = reference_model + properties = CORDAProperties( + high_conf_rx=[], + medium_conf_rx=[], + neg_conf_rx=[], + pr_to_np=2, + constraint=1, + constrainby="val", + om=1e4, + ntimes=5, + nl=1e-2, + solver="GUROBI", + threads=5, + ) + s_matrix: npt.NDArray[np.floating] = np.asarray( + cobra.util.array.create_stoichiometric_matrix(model=model), dtype=float + ) + algorithm = CORDA(S=s_matrix, lb=np.asarray(lower_bounds), ub=np.asarray(upper_bounds), properties=properties) + active_rxn_indices: npt.NDArray[np.integer] = algorithm.run() + -async def _map_expression_to_reaction( - reference_model, - gene_expression_file, +def _map_expression_to_reaction( + reference_model: cobra.Model, + gene_expression_file: Path, recon_algorithm: Algorithm, low_thresh: float, high_thresh: float, @@ -382,7 +450,7 @@ def _read_reference_model(filepath: Path) -> cobra.Model: case ".json": reference_model = cobra.io.load_json_model(filepath) case _: - _log_and_raise_error( + log_and_raise_error( f"Reference model format must be .xml, .mat, or .json; found '{filepath.suffix}'", error=ValueError, level=LogLevel.ERROR, @@ -449,12 +517,34 @@ async def _build_model( lower_bounds: list[int] = [] upper_bounds: list[int] = [] reaction_ids: list[str] = [] - for i, reaction in enumerate(reference_model.reactions): - # if reaction.id in boundary_reactions: - # lower_bounds.append() - lower_bounds.append(reaction.lower_bound) - upper_bounds.append(reaction.upper_bound) - reaction_ids.append(reaction.id) + for i, rxn in enumerate(reference_model.reactions): + rxn: cobra.Reaction + ref_lb[i] = float(rxn.lower_bound) + ref_ub[i] = float(rxn.upper_bound) + reaction_ids.append(rxn.id) + if ref_lb.shape[0] != ref_ub.shape[0] != len(reaction_ids): + log_and_raise_error( + message=( + "Lower bounds, upper bounds, and reaction IDs must have the same length.\n" + f"Number of reactions: {len(reaction_ids)}\n" + f"Number of upper bounds: {ref_ub.shape[0]}\n" + f"Number of lower bounds: {ref_lb.shape[0]}" + ), + error=ValueError, + level=LogLevel.ERROR, + ) + if np.isnan(ref_lb).any(): + log_and_raise_error( + message="Lower bounds contains unfilled values!", + error=ValueError, + level=LogLevel.ERROR, + ) + if np.isnan(ref_ub).any(): + log_and_raise_error( + message="Upper bounds contains unfilled values!", + error=ValueError, + level=LogLevel.ERROR, + ) # get expressed reactions reaction_expression: collections.OrderedDict[str, int] = await _map_expression_to_reaction( @@ -500,61 +590,60 @@ async def _build_model( expression_vector_indices = [i for (i, val) in enumerate(expression_vector) if val > 0] expression_threshold = (low_thresh, high_thresh) - match recon_algorithm: - case Algorithm.IMAT: - context_model_cobra: cobra.Model = _build_with_imat( - reference_model=reference_model, - lower_bounds=lower_bounds, - upper_bounds=upper_bounds, - expr_vector=expression_vector, - expr_thresh=expression_threshold, - force_reaction_indices=force_reaction_indices, - solver=solver, - ) - case Algorithm.GIMME: - context_model_cobra: cobra.Model = _build_with_gimme( - reference_model=reference_model, - lower_bounds=lower_bounds, - upper_bounds=upper_bounds, - idx_objective=objective_index, - expr_vector=expression_vector, - ) - case Algorithm.FASTCORE: - context_model_cobra: cobra.Model = _build_with_fastcore( - cobra_model=reference_model, - s_matrix=s_matrix, - lower_bounds=lower_bounds, - upper_bounds=upper_bounds, - exp_idx_list=expression_vector_indices, - solver=solver, - ) - context_model_cobra.objective = objective - flux_sol: cobra.Solution = context_model_cobra.optimize() - fluxes: pd.Series = flux_sol.fluxes - model_reactions: list[str] = [reaction.id for reaction in context_model_cobra.reactions] - reaction_intersections: set[str] = set(fluxes.index).intersection(model_reactions) - flux_df: pd.DataFrame = cast(pd.DataFrame, fluxes[~fluxes.index.isin(reaction_intersections)]) - flux_df.dropna(inplace=True) - flux_df.to_csv(output_flux_result_filepath) - case Algorithm.TINIT: - context_model_cobra: cobra.Model = _build_with_tinit( - reference_model=reference_model, - lower_bounds=lower_bounds, - upper_bounds=upper_bounds, - expr_vector=expression_vector, - solver=solver, - idx_force=force_reaction_indices, - ) - case _: - _log_and_raise_error( - ( - f"Reconstruction algorithm must be {Algorithm.GIMME.value}, " - f"{Algorithm.FASTCORE.value}, {Algorithm.IMAT.value}, or {Algorithm.TINIT.value}. " - f"Got: {recon_algorithm.value}" - ), - error=ValueError, - level=LogLevel.ERROR, - ) + if recon_algorithm == Algorithm.IMAT: + context_model_cobra: cobra.Model = _build_with_imat( + reference_model=reference_model, + lower_bounds=ref_lb, + upper_bounds=ref_ub, + expr_vector=expression_vector, + low_expression_threshold=updated_low_thresh, + high_expression_threshold=updated_high_thresh, + force_reaction_indices=force_reaction_indices, + solver=solver, + build_settings=build_settings, + ) + elif recon_algorithm == Algorithm.GIMME: + expressed_rxn_ids: list[str] = list(reaction_expression.keys()) + metabolite_ids: set[str] = set() + for rxn_id in expressed_rxn_ids: + cobra_rxn: cobra.Reaction = cast(cobra.Reaction, reference_model.reactions.get_by_id(rxn_id)) + metabolite_ids.update([m.id for m in cobra_rxn.metabolites]) + + context_model_cobra: cobra.Model = _build_with_gimme( + reference_model=reference_model, + expression_vector=expression_vector, + idx_objective=objective_index, + lower_bounds=ref_lb, + upper_bounds=ref_ub, + solver=solver, + ) + elif recon_algorithm == Algorithm.FASTCORE: + context_model_cobra: cobra.Model = _build_with_fastcore( + reference_model=reference_model, + lower_bounds=ref_lb, + upper_bounds=ref_ub, + exp_idx_list=expression_vector_indices, + solver=solver, + ) + elif recon_algorithm == Algorithm.TINIT: + context_model_cobra: cobra.Model = _build_with_tinit( + reference_model=reference_model, + lower_bounds=ref_lb, + upper_bounds=ref_ub, + expr_vector=expression_vector, + solver=solver, + idx_force=force_reaction_indices, + ) + else: + log_and_raise_error( + ( + f"Reconstruction algorithm must be {Algorithm.GIMME.value}, " + f"{Algorithm.FASTCORE.value}, {Algorithm.IMAT.value}, or {Algorithm.TINIT.value}. " + f"Got: {recon_algorithm.value}" + ), + error=ValueError, + level=LogLevel.ERROR, + ) inconsistent_and_infeasible_reactions: pd.DataFrame = pd.concat( [ @@ -601,7 +690,7 @@ async def _collect_boundary_reactions(path: Path) -> _BoundaryReactions: "minimum reaction rate", "maximum reaction rate", ]: - _log_and_raise_error( + log_and_raise_error( ( f"Boundary reactions file must have columns named 'Reaction', 'Abbreviation', 'Compartment', " f"'Minimum Reaction Rate', and 'Maximum Reaction Rate'. Found: {column}" @@ -618,7 +707,7 @@ async def _collect_boundary_reactions(path: Path) -> _BoundaryReactions: for i in range(len(boundary_type)): boundary: str = boundary_type[i].lower() if boundary not in boundary_map: - _log_and_raise_error( + log_and_raise_error( f"Boundary reaction type must be 'Exchange', 'Demand', or 'Sink'. Found: {boundary}", error=ValueError, level=LogLevel.ERROR, @@ -652,7 +741,7 @@ async def _write_model_to_disk( elif path.suffix in xml_suffix: tasks.add(asyncio.to_thread(cobra.io.write_sbml_model, model=model, filename=path)) else: - _log_and_raise_error( + log_and_raise_error( f"Invalid output model filetype. Should be one of .xml, .sbml, .mat, or .json. Got '{path.suffix}'", error=ValueError, level=LogLevel.ERROR, @@ -709,44 +798,50 @@ async def create_context_specific_model( # noqa: C901 Raises: ImportError: If Gurobi solver is selected but gurobipy is not installed. """ - _set_up_logging(level=log_level, location=log_location) + set_up_logging(level=log_level, location=log_location) + if low_percentile is None: + raise ValueError("low_percentile must be provided") + if high_percentile is None: + raise ValueError("high_percentile must be provided") + # TODO: set up zfpkm threshold defaults + boundary_rxns_filepath: Path | None = Path(boundary_rxns_filepath) if boundary_rxns_filepath else None output_model_filepaths = [output_model_filepaths] if isinstance(output_model_filepaths, Path) else output_model_filepaths if not reference_model.exists(): - _log_and_raise_error( + log_and_raise_error( f"Reference model not found at {reference_model}", error=FileNotFoundError, level=LogLevel.ERROR, ) if not active_genes_filepath.exists(): - _log_and_raise_error( + log_and_raise_error( f"Active genes file not found at {active_genes_filepath}", error=FileNotFoundError, level=LogLevel.ERROR, ) if algorithm == Algorithm.FASTCORE and not output_fastcore_expression_index_filepath: - _log_and_raise_error( + log_and_raise_error( "The fastcore expression index output filepath must be provided", error=ValueError, level=LogLevel.ERROR, ) if boundary_rxns_filepath and not boundary_rxns_filepath.exists(): - _log_and_raise_error( + log_and_raise_error( f"Boundary reactions file not found at {boundary_rxns_filepath}", error=FileNotFoundError, level=LogLevel.ERROR, ) if algorithm not in Algorithm: - _log_and_raise_error( + log_and_raise_error( f"Algorithm {algorithm} not supported. Use one of {', '.join(a.value for a in Algorithm)}", error=ValueError, level=LogLevel.ERROR, ) if solver not in Solver: - _log_and_raise_error( + log_and_raise_error( f"Solver '{solver}' not supported. Use one of {', '.join(s.value for s in Solver)}", error=ValueError, level=LogLevel.ERROR, @@ -754,8 +849,12 @@ async def create_context_specific_model( # noqa: C901 mat_suffix, json_suffix, xml_suffix = {".mat"}, {".json"}, {".sbml", ".xml"} if any(path.suffix not in {*mat_suffix, *json_suffix, *xml_suffix} for path in output_model_filepaths): - invalid_suffix = "\n".join(path for path in output_model_filepaths if path.suffix not in {*mat_suffix, *json_suffix, *xml_suffix}) - _log_and_raise_error( + invalid_suffix: str = "\n".join( + path.as_posix() + for path in output_model_filepaths + if path.suffix not in {*mat_suffix, *json_suffix, *xml_suffix} + ) + log_and_raise_error( f"Invalid output filetype. Should be 'xml', 'sbml', 'mat', or 'json'. Got:\n{invalid_suffix}'", error=ValueError, level=LogLevel.ERROR, @@ -770,7 +869,7 @@ async def create_context_specific_model( # noqa: C901 exclude_rxns_filepath: Path = Path(exclude_rxns_filepath) df = await _create_df(exclude_rxns_filepath) if "abbreviation" not in df.columns: - _log_and_raise_error( + log_and_raise_error( "The exclude reactions file should have a single column with a header named Abbreviation", error=ValueError, level=LogLevel.ERROR, @@ -782,7 +881,7 @@ async def create_context_specific_model( # noqa: C901 force_rxns_filepath: Path = Path(force_rxns_filepath) df = await _create_df(force_rxns_filepath, lowercase_col_names=True) if "abbreviation" not in df.columns: - _log_and_raise_error( + log_and_raise_error( "The force reactions file should have a single column with a header named Abbreviation", error=ValueError, level=LogLevel.ERROR, @@ -795,7 +894,7 @@ async def create_context_specific_model( # noqa: C901 gurobi_present = find_spec("gurobipy") if not gurobi_present: - _log_and_raise_error( + log_and_raise_error( message=( "The gurobi solver requires the gurobipy package to be installed. " "Please install gurobipy and try again. " diff --git a/main/como/merge_xomics.py b/main/como/merge_xomics.py index 87b5ba62..f1e04a81 100644 --- a/main/como/merge_xomics.py +++ b/main/como/merge_xomics.py @@ -24,7 +24,7 @@ _SourceWeights, ) from como.project import Config -from como.utils import _log_and_raise_error, _read_file, _set_up_logging, get_missing_gene_data, return_placeholder_data +from como.utils import log_and_raise_error, read_file, set_up_logging, get_missing_gene_data, return_placeholder_data class _MergedHeaderNames: @@ -72,7 +72,7 @@ def load_dummy_dict(): inquiry_full_path = Path(config.data_dir, "config_sheets", filename) if not inquiry_full_path.exists(): - _log_and_raise_error( + log_and_raise_error( f"Config file not found at {inquiry_full_path}", error=FileNotFoundError, level=LogLevel.ERROR, @@ -451,7 +451,7 @@ async def _process( elif adjust_method == AdjustmentMethod.FLAT: adjusted_expression_requirement = expression_requirement else: - _log_and_raise_error( + log_and_raise_error( message=f"Unknown `adjust_method`: {adjust_method}.", error=ValueError, level=LogLevel.ERROR, @@ -516,7 +516,7 @@ def _build_batches( for study in sorted(metadata["study"].unique()): batch_search = re.search(r"\d+", study) if not batch_search: - _log_and_raise_error( + log_and_raise_error( message=f"Unable to find batch number in study name. Expected a digit in the study value: {study}", error=ValueError, level=LogLevel.ERROR, @@ -545,7 +545,7 @@ def _validate_source_arguments( """ if any(i for i in args) and not all(i for i in args): - _log_and_raise_error( + log_and_raise_error( f"Must specify all or none of '{source.value}' arguments", error=ValueError, level=LogLevel.ERROR, @@ -614,7 +614,7 @@ async def merge_xomics( # noqa: C901 proteomic_matrix_or_filepath, ) ): - _log_and_raise_error("No data was passed!", error=ValueError, level=LogLevel.ERROR) + log_and_raise_error("No data was passed!", error=ValueError, level=LogLevel.ERROR) if expression_requirement and expression_requirement < 1: logger.warning(f"Expression requirement must be at least 1! Setting to the minimum of 1 now. Got: {expression_requirement}") @@ -651,52 +651,24 @@ async def merge_xomics( # noqa: C901 # Build trna items # `cast` helps type checkers know what types we are dealing with - costs no runtime performance - (trna_matrix, trna_boolean_matrix, trna_metadata) = cast( - typ=tuple[pd.DataFrame | None, pd.DataFrame | None, pd.DataFrame | None], - val=await asyncio.gather( - *[ - _read_file(trna_matrix_or_filepath, h5ad_as_df=True), - _read_file(trna_boolean_matrix_or_filepath, h5ad_as_df=True), - _read_file(trna_metadata_filepath_or_df, h5ad_as_df=True), - ] - ), - ) + trna_matrix = read_file(trna_matrix_or_filepath, h5ad_as_df=True) + trna_boolean_matrix = read_file(trna_boolean_matrix_or_filepath, h5ad_as_df=True) + trna_metadata = read_file(trna_metadata_filepath_or_df, h5ad_as_df=True) # Build mrna items - (mrna_matrix, mrna_boolean_matrix, mrna_metadata) = cast( - typ=tuple[pd.DataFrame | None, pd.DataFrame | None, pd.DataFrame | None], - val=await asyncio.gather( - *[ - _read_file(mrna_matrix_or_filepath), - _read_file(mrna_boolean_matrix_or_filepath), - _read_file(mrna_metadata_filepath_or_df), - ] - ), - ) + mrna_matrix = read_file(mrna_matrix_or_filepath, h5ad_as_df=True) + mrna_boolean_matrix = read_file(mrna_boolean_matrix_or_filepath, h5ad_as_df=True) + mrna_metadata = read_file(mrna_metadata_filepath_or_df, h5ad_as_df=True) # build scrna items - (scrna_matrix, scrna_boolean_matrix, scrna_metadata) = cast( - typ=tuple[pd.DataFrame | None, pd.DataFrame | None, pd.DataFrame | None], - val=await asyncio.gather( - *[ - _read_file(scrna_matrix_or_filepath), - _read_file(scrna_boolean_matrix_or_filepath), - _read_file(scrna_metadata_filepath_or_df), - ] - ), - ) + scrna_matrix = read_file(scrna_matrix_or_filepath, h5ad_as_df=True) + scrna_boolean_matrix = read_file(scrna_boolean_matrix_or_filepath, h5ad_as_df=True) + scrna_metadata = read_file(scrna_metadata_filepath_or_df, h5ad_as_df=True) # build proteomic items - (proteomic_matrix, proteomic_boolean_matrix, proteomic_metadata) = cast( - typ=tuple[pd.DataFrame | None, pd.DataFrame | None, pd.DataFrame | None], - val=await asyncio.gather( - *[ - _read_file(proteomic_matrix_or_filepath), - _read_file(proteomic_boolean_matrix_or_filepath), - _read_file(proteomic_metadata_filepath_or_df), - ] - ), - ) + proteomic_matrix = read_file(proteomic_matrix_or_filepath, h5ad_as_df=True) + proteomic_boolean_matrix = read_file(proteomic_boolean_matrix_or_filepath, h5ad_as_df=True) + proteomic_metadata = read_file(proteomic_metadata_filepath_or_df, h5ad_as_df=True) source_weights = _SourceWeights(trna=trna_weight, mrna=mrna_weight, scrna=scrna_weight, proteomics=proteomic_weight) input_matrices = _InputMatrices(trna=trna_matrix, mrna=mrna_matrix, scrna=scrna_matrix, proteomics=proteomic_matrix) diff --git a/main/como/proteomics/FTPManager.py b/main/como/proteomics/FTPManager.py index b3b96052..86b27a07 100644 --- a/main/como/proteomics/FTPManager.py +++ b/main/como/proteomics/FTPManager.py @@ -17,7 +17,7 @@ from loguru import logger from como.proteomics.FileInformation import FileInformation, clear_print -from como.utils import _log_and_raise_error +from como.utils import log_and_raise_error from como.data_types import LogLevel @@ -43,7 +43,7 @@ async def aioftp_client(host: str, username: str = "anonymous", password: str = attempt_num += 1 time.sleep(5) if not connection_successful: - _log_and_raise_error( + log_and_raise_error( "Could not connect to FTP server", error=ConnectionResetError, level=LogLevel.ERROR, @@ -97,7 +97,7 @@ async def _get_info(self) -> None: if url_parse.hostname is not None: host = url_parse.hostname else: - _log_and_raise_error( + log_and_raise_error( f"Unable to identify hostname from url: {self._root_link}", error=ValueError, level=LogLevel.ERROR, @@ -105,7 +105,7 @@ async def _get_info(self) -> None: if url_parse.path != "": folder = url_parse.path else: - _log_and_raise_error( + log_and_raise_error( f"Unable to identify folder or path from url: {self._root_link}", error=ValueError, level=LogLevel.ERROR, @@ -184,7 +184,7 @@ async def _aioftp_download_data(self, file_information: FileInformation, semapho if url_parse.hostname is not None: host = url_parse.hostname else: - _log_and_raise_error( + log_and_raise_error( f"Unable to identify hostname from url: {file_information.download_url}", error=ValueError, level=LogLevel.ERROR, @@ -192,7 +192,7 @@ async def _aioftp_download_data(self, file_information: FileInformation, semapho if url_parse.path != "": folder = url_parse.path else: - _log_and_raise_error( + log_and_raise_error( f"Unable to identify folder or path from url: {file_information.download_url}", error=ValueError, level=LogLevel.ERROR, diff --git a/main/como/proteomics/proteomics_preprocess.py b/main/como/proteomics/proteomics_preprocess.py index b2223da7..b3740ac9 100644 --- a/main/como/proteomics/proteomics_preprocess.py +++ b/main/como/proteomics/proteomics_preprocess.py @@ -9,7 +9,7 @@ from como.data_types import LogLevel from como.proteomics import Crux, FileInformation, FTPManager -from como.utils import _log_and_raise_error +from como.utils import log_and_raise_error class ArgParseFormatter(argparse.RawTextHelpFormatter, argparse.ArgumentDefaultsHelpFormatter): @@ -316,12 +316,12 @@ def parse_args() -> argparse.Namespace: # Validte the input file exists if not Path(args.input_csv).is_file(): - _log_and_raise_error(f"Input file {args.input} does not exist!", error=FileNotFoundError, level=LogLevel.ERROR) + log_and_raise_error(f"Input file {args.input} does not exist!", error=FileNotFoundError, level=LogLevel.ERROR) if args.core_count == "all": args.core_count = os.cpu_count() elif not str(args.core_count).isdigit(): - _log_and_raise_error( + log_and_raise_error( f"Invalid option '{args.core_count}' for option '--cores'. Enter an integer or 'all' to use all cores", error=ValueError, level=LogLevel.ERROR, diff --git a/main/como/proteomics_gen.py b/main/como/proteomics_gen.py index caeeb8a3..fa76c618 100644 --- a/main/como/proteomics_gen.py +++ b/main/como/proteomics_gen.py @@ -14,7 +14,7 @@ from como.data_types import LogLevel from como.project import Config from como.proteomics_preprocessing import protein_transform_main -from como.utils import _log_and_raise_error, _set_up_logging, return_placeholder_data +from como.utils import log_and_raise_error, set_up_logging, return_placeholder_data # Load Proteomics @@ -31,13 +31,13 @@ def process_proteomics_data(path: Path) -> pd.DataFrame: matrix: pd.DataFrame = pd.read_csv(path) gene_symbol_colname = [col for col in matrix.columns if "symbol" in col] if len(gene_symbol_colname) == 0: - _log_and_raise_error( + log_and_raise_error( "No gene_symbol column found in proteomics data.", error=ValueError, level=LogLevel.ERROR, ) if len(gene_symbol_colname) > 1: - _log_and_raise_error( + log_and_raise_error( "Multiple gene_symbol columns found in proteomics data.", error=ValueError, level=LogLevel.ERROR, @@ -179,7 +179,7 @@ def load_empty_dict(): inquiry_full_path = config.data_dir / "config_sheets" / filename if not inquiry_full_path.exists(): - _log_and_raise_error(f"Error: file not found {inquiry_full_path}", error=FileNotFoundError, level=LogLevel.ERROR) + log_and_raise_error(f"Error: file not found {inquiry_full_path}", error=FileNotFoundError, level=LogLevel.ERROR) filename = f"Proteomics_{context_name}.csv" full_save_filepath = config.result_dir / context_name / "proteomics" / filename @@ -211,36 +211,36 @@ async def proteomics_gen( log_location: str | TextIO = sys.stderr, ): """Generate proteomics data.""" - _set_up_logging(level=log_level, location=log_location) + set_up_logging(level=log_level, location=log_location) if not config_filepath.exists(): - _log_and_raise_error( + log_and_raise_error( f"Config file not found at {config_filepath}", error=FileNotFoundError, level=LogLevel.ERROR, ) if config_filepath.suffix not in (".xlsx", ".xls"): - _log_and_raise_error( + log_and_raise_error( f"Config file must be an xlsx or xls file at {config_filepath}", error=FileNotFoundError, level=LogLevel.ERROR, ) if not matrix_filepath.exists(): - _log_and_raise_error( + log_and_raise_error( f"Matrix file not found at {matrix_filepath}", error=FileNotFoundError, level=LogLevel.ERROR, ) if matrix_filepath.suffix != ".csv": - _log_and_raise_error( + log_and_raise_error( f"Matrix file must be a csv file at {matrix_filepath}", error=FileNotFoundError, level=LogLevel.ERROR, ) if quantile < 0 or quantile > 100: - _log_and_raise_error( + log_and_raise_error( "Quantile must be an integer from 0 to 100", error=ValueError, level=LogLevel.ERROR, diff --git a/main/como/rnaseq_gen.py b/main/como/rnaseq_gen.py index 327c6ba5..817b76d6 100644 --- a/main/como/rnaseq_gen.py +++ b/main/como/rnaseq_gen.py @@ -27,7 +27,7 @@ from como.data_types import FilteringTechnique, LogLevel, RNAType from como.migrations import gene_info_migrations from como.project import Config -from como.utils import _log_and_raise_error, _read_file, _set_up_logging +from como.utils import log_and_raise_error, read_file, set_up_logging class _FilteringOptions(NamedTuple): @@ -62,7 +62,7 @@ class _StudyMetrics: def __post_init__(self): for layout in self.layout: if layout not in LayoutMethod: - _log_and_raise_error( + log_and_raise_error( f"Layout must be 'paired-end' or 'single-end'; got: {layout}", error=ValueError, level=LogLevel.ERROR, @@ -140,7 +140,7 @@ def genefilter(data: pd.DataFrame | npt.NDArray, filter_func: Callable[[npt.NDAr A NumPy array of the filtered data. """ if not isinstance(data, pd.DataFrame | np.ndarray): - _log_and_raise_error( + log_and_raise_error( f"Unsupported data type. Must be a Pandas DataFrame or a NumPy array, got '{type(data)}'", error=TypeError, level=LogLevel.CRITICAL, @@ -176,7 +176,7 @@ async def _build_matrix_results( conversion = await gene_symbol_to_ensembl_and_gene_id(symbols=matrix.var["gene_symbol"].tolist(), taxon=taxon) else: if "ensembl_gene_id" not in matrix.columns: - _log_and_raise_error( + log_and_raise_error( message="'ensembl_gene_id' column not found in the provided DataFrame.", error=ValueError, level=LogLevel.CRITICAL, @@ -201,7 +201,7 @@ async def _build_matrix_results( set(matrix.columns if isinstance(matrix, pd.DataFrame) else matrix.var.columns) & set(conversion.columns) ) if not conversion_merge_on: - _log_and_raise_error( + log_and_raise_error( ( "No columns to merge on, unable to find at least one of `ensembl_gene_id`, `entrez_gene_id`, or `gene_symbol`. " "Please check your input files." @@ -273,7 +273,7 @@ async def _build_matrix_results( entrez_gene_ids = subset.var["entrez_gene_id"].to_numpy(dtype=int) gene_sizes = subset.var["size"].to_numpy(dtype=int) else: - _log_and_raise_error( + log_and_raise_error( message=f"Matrix must be a pandas DataFrame or scanpy AnnData object, got: '{type(matrix)}'.", error=TypeError, level=LogLevel.CRITICAL, @@ -341,7 +341,7 @@ def _calculate_fpkm(metrics: NamedMetrics, scale: float = 1e6) -> NamedMetrics: matrix_values: dict[str, npt.NDArray[np.floating]] = {} count_matrix = metrics[study].count_matrix if not isinstance(count_matrix, pd.DataFrame): - _log_and_raise_error( + log_and_raise_error( message="FPKM cannot be performed on scanpy.AnnData objects!", error=TypeError, level=LogLevel.CRITICAL, @@ -715,7 +715,7 @@ def filter_counts( perform_normalization=umi_perform_normalization, ) case _: - _log_and_raise_error( + log_and_raise_error( f"Technique must be one of {FilteringTechnique}, got '{technique.value}'", error=ValueError, level=LogLevel.ERROR, @@ -749,7 +749,7 @@ async def _process( """Save the results of the RNA-Seq tests to a CSV file.""" output_boolean_activity_filepath.parent.mkdir(parents=True, exist_ok=True) - rnaseq_matrix: pd.DataFrame | sc.AnnData = _read_file(rnaseq_matrix_filepath, h5ad_as_df=False) + rnaseq_matrix: pd.DataFrame | sc.AnnData = read_file(rnaseq_matrix_filepath, h5ad_as_df=False) filtering_options = _FilteringOptions( replicate_ratio=replicate_ratio, batch_ratio=batch_ratio, @@ -907,14 +907,14 @@ async def rnaseq_gen( # noqa: C901 :return: None """ - _set_up_logging(level=log_level, location=log_location) + set_up_logging(level=log_level, location=log_location) technique = FilteringTechnique(technique) if isinstance(technique, str) else technique match technique: case FilteringTechnique.TPM: cutoff: int | float = cutoff or 25 if cutoff < 1 or cutoff > 100: - _log_and_raise_error( + log_and_raise_error( "Quantile must be between 1 - 100", error=ValueError, level=LogLevel.ERROR, @@ -922,7 +922,7 @@ async def rnaseq_gen( # noqa: C901 case FilteringTechnique.CPM: if cutoff and cutoff < 0: - _log_and_raise_error( + log_and_raise_error( "Cutoff must be greater than or equal to 0", error=ValueError, level=LogLevel.ERROR, @@ -935,14 +935,14 @@ async def rnaseq_gen( # noqa: C901 case FilteringTechnique.UMI: cutoff: int = cutoff or 1 case _: - _log_and_raise_error( + log_and_raise_error( f"Technique must be one of {','.join(FilteringTechnique)}. Got: {technique.value}", error=ValueError, level=LogLevel.ERROR, ) if not input_rnaseq_filepath.exists(): - _log_and_raise_error( + log_and_raise_error( f"Input RNA-seq file not found! Searching for: '{input_rnaseq_filepath}'", error=FileNotFoundError, level=LogLevel.ERROR, @@ -960,13 +960,13 @@ async def rnaseq_gen( # noqa: C901 metadata_df = input_metadata_filepath_or_df elif isinstance(input_metadata_filepath_or_df, Path): if input_metadata_filepath_or_df.suffix not in {".xlsx", ".xls"}: - _log_and_raise_error( + log_and_raise_error( f"Expected an excel file with extension of '.xlsx' or '.xls', got '{input_metadata_filepath_or_df.suffix}'.", error=ValueError, level=LogLevel.ERROR, ) if not input_metadata_filepath_or_df.exists(): - _log_and_raise_error( + log_and_raise_error( f"Input metadata file not found! Searching for: '{input_metadata_filepath_or_df}'", error=FileNotFoundError, level=LogLevel.ERROR, @@ -974,7 +974,7 @@ async def rnaseq_gen( # noqa: C901 metadata_df = pd.read_excel(input_metadata_filepath_or_df) else: - _log_and_raise_error( + log_and_raise_error( f"Expected a pandas DataFrame or Path object as metadata, got '{type(input_metadata_filepath_or_df)}'", error=TypeError, level=LogLevel.ERROR, @@ -985,8 +985,8 @@ async def rnaseq_gen( # noqa: C901 context_name=context_name, rnaseq_matrix_filepath=input_rnaseq_filepath, metadata_df=metadata_df, - gene_info_df=_read_file(input_gene_info_filepath), - fragment_df=_read_file(input_fragment_lengths), + gene_info_df=read_file(input_gene_info_filepath), + fragment_df=read_file(input_fragment_lengths), prep=prep, taxon=taxon_id, replicate_ratio=replicate_ratio, diff --git a/main/como/rnaseq_preprocess.py b/main/como/rnaseq_preprocess.py index 03745f62..691d7d68 100644 --- a/main/como/rnaseq_preprocess.py +++ b/main/como/rnaseq_preprocess.py @@ -1,17 +1,14 @@ from __future__ import annotations import asyncio -import csv import functools -import io import json import re import sys -from collections.abc import Sequence -from dataclasses import asdict, dataclass, field +from dataclasses import dataclass, field from itertools import chain from pathlib import Path -from typing import Final, Literal, cast +from typing import Final, Literal, TextIO, cast import numpy as np import numpy.typing as npt @@ -20,8 +17,8 @@ from fast_bioservices.pipeline import gene_symbol_to_ensembl_and_gene_id from loguru import logger -from como.data_types import PATH_TYPE, LogLevel, RNAType -from como.utils import _listify, _log_and_raise_error, _read_file, _set_up_logging +from como.data_types import LogLevel, RNAType +from como.utils import log_and_raise_error, read_file, set_up_logging @dataclass @@ -34,20 +31,20 @@ class _QuantInformation: @classmethod def build_from_sf(cls, filepath: Path) -> _QuantInformation: if filepath.suffix != ".sf": - _log_and_raise_error( + log_and_raise_error( f"Building quantification information requires a '.sf' file; received: '{filepath}'", error=ValueError, level=LogLevel.ERROR, ) if not filepath.exists(): - _log_and_raise_error( + log_and_raise_error( f"Unable to find the .sf file: {filepath}", error=FileNotFoundError, level=LogLevel.ERROR, ) sample_name = filepath.stem.removesuffix("_quant.genes") - df = _read_file( + df = read_file( filepath, sep="\t", names=["ensembl_gene_id", "length", "effective_length", "tpm", sample_name], @@ -82,7 +79,7 @@ def __post_init__(self): self.__sample_names = [f.stem for f in self.quant_files] if len(self.quant_files) != len(self.strand_files): - _log_and_raise_error( + log_and_raise_error( ( f"Unequal number of count files and strand files for study '{self.study_name}'. " f"Found {len(self.quant_files)} count files and {len(self.strand_files)} strand files." @@ -92,7 +89,7 @@ def __post_init__(self): ) if self.num_samples != len(self.quant_files): - _log_and_raise_error( + log_and_raise_error( ( f"Unequal number of samples and count files for study '{self.study_name}'. " f"Found {self.num_samples} samples and {len(self.quant_files)} count files." @@ -102,7 +99,7 @@ def __post_init__(self): ) if self.num_samples != len(self.strand_files): - _log_and_raise_error( + log_and_raise_error( ( f"Unequal number of samples and strand files for study '{self.study_name}'. " f"Found {self.num_samples} samples and {len(self.strand_files)} strand files." @@ -112,7 +109,7 @@ def __post_init__(self): ) if self.__num_samples == 1: - _log_and_raise_error( + log_and_raise_error( f"Only one sample exists for study {self.study_name}. Provide at least two samples", error=ValueError, level=LogLevel.ERROR, @@ -134,21 +131,25 @@ class SampleConfiguration: library_prep: str def __post_init__(self): + """Validate the effective lengths dataframe to ensure it has the expected structure and content.""" if len(self.effective_lengths.columns) > 2: - _log_and_raise_error( - message=f"Effective lengths dataframe for sample '{self.sample_name}' has more than 2 columns, expected 'name' and 'effective_length'", + log_and_raise_error( + message=( + f"Effective lengths dataframe for sample '{self.sample_name}' has more than 2 columns, " + f"expected 'name' and 'effective_length'" + ), error=ValueError, level=LogLevel.ERROR, ) if "name" not in self.effective_lengths.columns: - _log_and_raise_error( + log_and_raise_error( message=f"Effective lengths dataframe for sample '{self.sample_name}' is missing 'name' column", error=ValueError, level=LogLevel.ERROR, ) if "effective_length" not in self.effective_lengths.columns: - _log_and_raise_error( - message=f"Effective lengths dataframe for sample '{self.sample_name}' is missing 'effective_length' column", + log_and_raise_error( + message=f"Sample '{self.sample_name}' is missing 'effective_length' column", error=ValueError, level=LogLevel.ERROR, ) @@ -180,23 +181,37 @@ def to_dataframe(cls, samples: list[SampleConfiguration]) -> tuple[pd.DataFrame, def _sample_name_from_filepath(file: Path) -> str: - return re.search(r".+_S\d+R\d+(r\d+)?", file.stem).group() + group = re.search(r".+_S\d+R\d+(r\d+)?", file.stem) + if not group: + log_and_raise_error( + message=( + "Filename does not match expected pattern 'contextName_SXRYrZ' where " + "X is the study number, Y is the replicate number, and Z is the optional run number" + ), + error=ValueError, + level=LogLevel.ERROR, + ) + return group.group() def _require_one( - paths: list[Path], + paths: list[Path | None], kind: Literal["layout", "strand", "preparation", "fragment"], label: str, -) -> Path | None: - if len(paths) == 1: +) -> Path: + if len(paths) == 1 and isinstance(paths[0], Path): return paths[0] - if len(paths) == 0: - return None - _log_and_raise_error( - f"Multiple matching {kind} files for {label}, make sure there is only one copy for each replicate in COMO_input", - error=ValueError, - level=LogLevel.ERROR, - ) + if len(paths) > 1: + message = ( + f"Multiple matching {kind} files for {label}, " + f"make sure there is only one copy for each replicate in COMO_input" + ) + elif len(paths) > 1 and paths[0] is None: + message = f"No {kind} file found for {label}, make sure there is one copy for each replicate in COMO_input" + else: + message = f"No {kind} file found for {label}, make sure there is one copy for each replicate in COMO_input" + + log_and_raise_error(message=message, error=ValueError, level=LogLevel.ERROR) def _organize_gene_counts_files(data_dir: Path) -> list[_StudyMetrics]: @@ -213,7 +228,7 @@ def _organize_gene_counts_files(data_dir: Path) -> list[_StudyMetrics]: strandedness_directories: list[Path] = sorted([p for p in strand_dir.glob("*") if not p.name.startswith(".")]) if len(quantification_directories) != len(strandedness_directories): - _log_and_raise_error( + log_and_raise_error( ( f"Unequal number of quantification directories and strandedness directories. " f"Found {len(quantification_directories)} quantification directories and " @@ -230,9 +245,9 @@ def _organize_gene_counts_files(data_dir: Path) -> list[_StudyMetrics]: quant_files = list(quant.glob("*_quant.genes.sf")) strand_files = list(strand_dir.glob("*.txt")) if len(quant_files) == 0: - _log_and_raise_error(f"No quant found for study '{quant.stem}'.", error=ValueError, level=LogLevel.ERROR) + log_and_raise_error(f"No quant found for study '{quant.stem}'.", error=ValueError, level=LogLevel.ERROR) if len(strand_files) == 0: - _log_and_raise_error( + log_and_raise_error( f"No strandedness files found for study '{quant.stem}'.", error=ValueError, level=LogLevel.ERROR, @@ -262,7 +277,9 @@ def _process_first_multirun_sample(strand_file: Path, all_quant_files: list[Path sample_counts = sample_counts.fillna(value=0) sample_counts["counts"] = sample_counts["counts"].astype(float) - count_avg = sample_counts.groupby("ensembl_gene_id", as_index=False)["counts"].mean() + count_avg = cast( # type checkers think `.groupby(...).mean()` returns a pd.Series, force pd.DataFrame + pd.DataFrame, cast(object, sample_counts.groupby("ensembl_gene_id", as_index=False)["counts"].mean()) + ) count_avg["counts"] = np.ceil(count_avg["counts"].astype(int)) count_avg.columns = ["ensembl_gene_id", _sample_name_from_filepath(strand_file)] return count_avg @@ -309,7 +326,7 @@ def _create_sample_counts_matrix(metrics: _StudyMetrics) -> pd.DataFrame: continue assert isinstance(counts, pd.DataFrame) # noqa: S101 - counts: pd.DataFrame = counts.merge(new_counts, on="ensembl_gene_id", how="outer") + counts = counts.merge(new_counts, on="ensembl_gene_id", how="outer") counts = counts.fillna(value=0) # Remove run number "r\d+" from multi-run names @@ -336,7 +353,8 @@ def _write_counts_matrix( """Create a counts matrix file by reading gene counts table(s). :param config_df: Configuration DataFrame containing sample information. - :param fragment_lengths: DataFrame containing effective lengths for each gene and sample, used for zFPKM normalization. + :param fragment_lengths: DataFrame containing effective lengths for each gene and sample, + used for zFPKM normalization. :param como_context_dir: Path to the COMO_input directory containing gene count files. :param output_counts_matrix_filepath: Path where the output counts matrix CSV will be saved. :param output_fragment_lengths_filepath: Path where the output fragment lengths CSV will be saved. @@ -354,7 +372,7 @@ def _write_counts_matrix( ) final_matrix.fillna(value=0, inplace=True) final_matrix.iloc[:, 1:] = final_matrix.iloc[:, 1:].astype(int) - final_matrix = cast(pd.DataFrame, final_matrix[["ensembl_gene_id", *rna_specific_sample_names]]) + final_matrix = final_matrix[["ensembl_gene_id", *rna_specific_sample_names]] output_counts_matrix_filepath.parent.mkdir(parents=True, exist_ok=True) output_fragment_lengths_filepath.parent.mkdir(parents=True, exist_ok=True) @@ -400,7 +418,7 @@ def _create_config_df( # noqa: C901 quant_files: list[Path] = list((como_context_dir / quantification_dir).rglob("*.genes.sf")) # gene_counts: list[Path] = list((como_context_dir / gene_count_dirname).rglob("*.tab")) if not quant_files: - _log_and_raise_error( + log_and_raise_error( f"No gene count files found in '{gene_count_dirname}'", error=FileNotFoundError, level=LogLevel.ERROR, @@ -421,12 +439,14 @@ def _create_config_df( # noqa: C901 m = label_regex.search(p.stem) if m: aux_lookup[kind][m.group(0)] = p + if "layout" not in aux_lookup: + raise ValueError rows: list[SampleConfiguration] = [] for quant_file in sorted(quant_files): m = label_regex.search(quant_file.as_posix()) if m is None: - _log_and_raise_error( + log_and_raise_error( f"Filename '{quant_file.name}' does not match contextName_SXRYrZ.tab pattern", error=ValueError, level=LogLevel.ERROR, @@ -444,13 +464,13 @@ def _create_config_df( # noqa: C901 strand = strand_path.read_text().rstrip() prep = prep_path.read_text().rstrip() if prep not in {"total", "mrna"}: - _log_and_raise_error( + log_and_raise_error( f"Prep method must be 'total' or 'mrna' (got '{prep}') for {label}", error=ValueError, level=LogLevel.ERROR, ) if layout == "": - _log_and_raise_error( + log_and_raise_error( message=f"No layout file found for '{label}'.", error=FileNotFoundError, level=LogLevel.WARNING, @@ -462,7 +482,7 @@ def _create_config_df( # noqa: C901 and layout in ["paired-end", "", None] and prep.lower() in [RNAType.TRNA.value.lower(), RNAType.MRNA.value.lower()] ): - _log_and_raise_error( + log_and_raise_error( message=f"No quantification file found for '{label}'; defaulting to 100 bp (needed for zFPKM).", error=FileNotFoundError, level=LogLevel.WARNING, @@ -471,7 +491,7 @@ def _create_config_df( # noqa: C901 effective_len = pd.DataFrame({"Name": [], "EffectiveLength": []}) mean_effective_len = 0.0 # cannot compute FPKM for single-ended data else: - df = _read_file(quant_file) + df = read_file(quant_file) df.columns = [c.lower() for c in df.columns] df = df.rename(columns={"effectivelength": "effective_length"}) @@ -493,171 +513,8 @@ def _create_config_df( # noqa: C901 return SampleConfiguration.to_dataframe(rows) - # 6-3-25: Intentionally left commented-out code to test its replacement - # gene_counts_dir = como_context_dir / gene_count_dirname - # layout_dir = como_context_dir / layout_dirname - # strandedness_dir = como_context_dir / strandedness_dirname - # fragment_sizes_dir = como_context_dir / fragment_sizes_dirname - # prep_method_dir = como_context_dir / prep_method_dirname - # - # gene_counts_files = list(gene_counts_dir.rglob("*.tab")) - # sample_names: list[str] = [] - # fragment_lengths: list[int | float] = [] - # layouts: list[str] = [] - # strands: list[str] = [] - # groups: list[str] = [] - # preparation_method: list[str] = [] - # - # if len(gene_counts_files) == 0: - # _log_and_raise_error(f"No gene count files found in '{gene_counts_dir}'.", error=FileNotFoundError, level=LogLevel.ERROR) - # - # for gene_count_filename in sorted(gene_counts_files): - # # Match S___R___r___ - # # \d{1,3} matches 1-3 digits - # # (?:r\d{1,3})? optionally matches a "r" followed by three digits - # label = re.findall(r"S\d{1,3}R\d{1,3}(?:r\d{1,3})?", gene_count_filename.as_posix())[0] - # if not label: - # _log_and_raise_error( - # ( - # f"\n\nFilename of '{gene_count_filename}' is not valid. " - # f"Should be 'contextName_SXRYrZ.tab', " - # f"where X is the study/batch number, Y is the replicate number, " - # f"and Z is the run number." - # "\n\nIf not a multi-run sample, exclude 'rZ' from the filename." - # ), - # error=ValueError, - # level=LogLevel.ERROR, - # ) - # - # study_number = re.findall(r"S\d{1,3}", label)[0] - # rep_number = re.findall(r"R\d{1,3}", label)[0] - # run_number = re.findall(r"r\d{1,3}", label) - # - # multi_flag = 0 - # if len(run_number) > 0: - # if run_number[0] != "r1": - # continue - # label_glob = f"{study_number}{rep_number}r*" # S__R__r* - # runs = [run for run in gene_counts_files if re.search(label_glob, run.as_posix())] - # multi_flag = 1 - # frag_files = [] - # - # for run in runs: - # run_number = re.findall(r"R\d{1,3}", run.as_posix())[0] - # replicate = re.findall(r"r\d{1,3}", run.as_posix())[0] - # frag_filename = "".join([context_name, "_", study_number, run_number, replicate, "_fragment_size.txt"]) - # frag_files.append(como_context_dir / fragment_sizes_dirname / study_number / frag_filename) - # - # layout_files: list[Path] = list(layout_dir.rglob(f"{context_name}_{label}_layout.txt")) - # strand_files: list[Path] = list(strandedness_dir.rglob(f"{context_name}_{label}_strandedness.txt")) - # frag_files: list[Path] = list(fragment_sizes_dir.rglob(f"{context_name}_{label}_fragment_size.txt")) - # prep_files: list[Path] = list(prep_method_dir.rglob(f"{context_name}_{label}_prep_method.txt")) - # - # layout = "UNKNOWN" - # if len(layout_files) == 0: - # logger.warning( - # f"No layout file found for {label}, writing as 'UNKNOWN', " - # f"this should be defined if you are using zFPKM or downstream 'rnaseq_gen.py' will not run" - # ) - # elif len(layout_files) == 1: - # with layout_files[0].open("r") as file: - # layout = file.read().strip() - # elif len(layout_files) > 1: - # _log_and_raise_error( - # f"Multiple matching layout files for {label}, make sure there is only one copy for each replicate in COMO_input", - # error=ValueError, - # level=LogLevel.ERROR, - # ) - # - # strand = "UNKNOWN" - # if len(strand_files) == 0: - # logger.warning( - # f"No strandedness file found for {label}, writing as 'UNKNOWN'. " - # f"This will not interfere with the analysis since you have already set rnaseq_preprocess.py to " - # f"infer the strandedness when writing the counts matrix" - # ) - # elif len(strand_files) == 1: - # with strand_files[0].open("r") as file: - # strand = file.read().strip() - # elif len(strand_files) > 1: - # _log_and_raise_error( - # f"Multiple matching strandedness files for {label}, make sure there is only one copy for each replicate in COMO_input", - # error=ValueError, - # level=LogLevel.ERROR, - # ) - # - # prep = "total" - # if len(prep_files) == 0: - # logger.warning(f"No prep file found for {label}, assuming 'total', as in 'Total RNA' library preparation") - # elif len(prep_files) == 1: - # with prep_files[0].open("r") as file: - # prep = file.read().strip().lower() - # if prep not in ["total", "mrna"]: - # _log_and_raise_error( - # f"Prep method must be either 'total' or 'mrna' for {label}", - # error=ValueError, - # level=LogLevel.ERROR, - # ) - # elif len(prep_files) > 1: - # _log_and_raise_error( - # f"Multiple matching prep files for {label}, make sure there is only one copy for each replicate in COMO_input", - # error=ValueError, - # level=LogLevel.ERROR, - # ) - # - # mean_fragment_size = 100 - # if len(frag_files) == 0 and prep != RNAType.TRNA.value: - # logger.warning( - # f"No fragment file found for {label}, using '100'. You should define this if you are going to use downstream zFPKM normalization" - # ) - # elif len(frag_files) == 1: - # if layout == "single-end": - # mean_fragment_size = 0 - # else: - # if not multi_flag: - # frag_df = pd.read_table(frag_files[0], low_memory=False) - # frag_df["meanxcount"] = frag_df["frag_mean"] * frag_df["frag_count"] - # mean_fragment_size = sum(frag_df["meanxcount"] / sum(frag_df["frag_count"])) - # - # else: - # mean_fragment_sizes = np.array([]) - # library_sizes = np.array([]) - # for ff in frag_files: - # frag_df = pd.read_table(ff, low_memory=False, sep="\t", on_bad_lines="skip") - # frag_df["meanxcount"] = frag_df["frag_mean"] * frag_df["frag_count"] - # mean_fragment_size = sum(frag_df["meanxcount"] / sum(frag_df["frag_count"])) - # mean_fragment_sizes = np.append(mean_fragment_sizes, mean_fragment_size) - # library_sizes = np.append(library_sizes, sum(frag_df["frag_count"])) - # - # mean_fragment_size = sum(mean_fragment_sizes * library_sizes) / sum(library_sizes) - # elif len(frag_files) > 1: - # _log_and_raise_error( - # f"Multiple matching fragment files for {label}, make sure there is only one copy for each replicate in COMO_input", - # error=ValueError, - # level=LogLevel.ERROR, - # ) - # - # sample_names.append(f"{context_name}_{study_number}{rep_number}") - # fragment_lengths.append(mean_fragment_size) - # layouts.append(layout) - # strands.append(strand) - # groups.append(study_number) - # preparation_method.append(prep) - # - # out_df = pd.DataFrame( - # { - # "sample_name": sample_names, - # "fragment_length": fragment_lengths, - # "layout": layouts, - # "strand": strands, - # "study": groups, - # "library_prep": preparation_method, - # } - # ).sort_values("sample_name") - # return out_df - - -async def _create_gene_info_file( + +async def _create_gene_info_file( # noqa: C901 *, counts_matrix_filepaths: list[Path], output_filepath: Path, @@ -670,14 +527,13 @@ async def _create_gene_info_file( """ async def read_ensembl_gene_ids(file: Path) -> list[str]: - data = _read_file(file, h5ad_as_df=False) - if isinstance(data, pd.DataFrame): - data: pd.DataFrame - return data["ensembl_gene_id"].tolist() + data_ = read_file(file, h5ad_as_df=False) + if isinstance(data_, pd.DataFrame): + return data_["ensembl_gene_id"].tolist() try: - conversion = await gene_symbol_to_ensembl_and_gene_id(symbols=data.var_names.tolist(), taxon=taxon) + conversion = await gene_symbol_to_ensembl_and_gene_id(symbols=data_.var_names.tolist(), taxon=taxon) except json.JSONDecodeError as e: - _log_and_raise_error( + log_and_raise_error( f"Got a JSON decode error for file '{counts_matrix_filepaths}' ({e})", error=ValueError, level=LogLevel.CRITICAL, @@ -688,7 +544,8 @@ async def read_ensembl_gene_ids(file: Path) -> list[str]: return conversion["ensembl_gene_id"].tolist() logger.info( - "Fetching gene info - this can take up to 5 minutes depending on the number of genes and your internet connection" + "Fetching gene info - this can take up to 5 minutes " + "depending on the number of genes and your internet connection" ) ensembl_ids: set[str] = set( @@ -715,18 +572,47 @@ def _avg_pos(value: int | list[int] | None) -> int: for i, data in enumerate(gene_data): data: dict[str, str | int | list[str] | list[int] | None] + if "genomic_pos.start" not in data: + log_and_raise_error( + message="Unexpectedly missing key 'genomic_pos.start'", error=KeyError, level=LogLevel.WARNING + ) + if "genomic_pos.end" not in data: + log_and_raise_error( + message="Unexpectedly missing key 'genomic_pos.end'", error=KeyError, level=LogLevel.WARNING + ) + if "ensembl.gene" not in data: + log_and_raise_error( + message="Unexpectedly missing key 'ensembl.gene'", error=KeyError, level=LogLevel.WARNING + ) - start = _avg_pos(data.get("genomic_pos.start", 0)) - end = _avg_pos(data.get("genomic_pos.end", 0)) - size = end - start + start = data["genomic_pos.start"] + end = data["genomic_pos.end"] + ensembl_id = data["ensembl.gene"] - ensembl_id: int = data.get("ensembl.gene", "-") - all_ensembl_ids[i] = ( - ",".join(map(str, ensembl_id)) if isinstance(ensembl_id, list) and ensembl_id else ensembl_id - ) + if not isinstance(start, int): + log_and_raise_error( + message=f"Unexpected type for 'genomic_pos.start': expected int, got {type(start)}", + error=TypeError, + level=LogLevel.WARNING, + ) + if not isinstance(end, int): + log_and_raise_error( + message=f"Unexpected type for 'genomic_pos.end': expected int, got {type(start)}", + error=TypeError, + level=LogLevel.WARNING, + ) + if not isinstance(ensembl_id, str): + log_and_raise_error( + message=f"Unexpected type for 'ensembl.gene': expected str, got {type(ensembl_id)}", + error=ValueError, + level=LogLevel.WARNING, + ) + + size = end - start + all_ensembl_ids[i] = ",".join(map(str, ensembl_id)) if isinstance(ensembl_id, list) else ensembl_id all_gene_symbols[i] = str(data.get("symbol", "-")) all_entrez_ids[i] = str(data.get("entrezgene", "-")) - all_sizes[i] = size if size > 0 else -1 + all_sizes[i] = max(size, -1) # use `size` otherwise -1 gene_info: pd.DataFrame = pd.DataFrame( { @@ -791,7 +677,7 @@ async def _process( create_gene_info_only: bool, ): rna_types: list[tuple[RNAType, Path, Path, Path]] = [] - if output_trna_config_filepath is not None and output_trna_fragment_lengths_filepath is not None: + if output_trna_config_filepath and output_trna_matrix_filepath and output_trna_fragment_lengths_filepath: rna_types.append( ( RNAType.TRNA, @@ -800,7 +686,7 @@ async def _process( output_trna_fragment_lengths_filepath, ) ) - if output_mrna_config_filepath is not None and output_mrna_fragment_lengths_filepath is not None: + if output_mrna_config_filepath and output_mrna_matrix_filepath and output_mrna_fragment_lengths_filepath: rna_types.append( ( RNAType.MRNA, @@ -813,13 +699,13 @@ async def _process( # if provided, iterate through como-input specific directories if not create_gene_info_only: if como_context_dir is None: - _log_and_raise_error( + log_and_raise_error( message="como_context_dir must be provided if create_gene_info_only is False", error=ValueError, level=LogLevel.ERROR, ) if output_trna_fragment_lengths_filepath is None: - _log_and_raise_error( + log_and_raise_error( message="output_fragment_lengths_filepath must be provided if create_gene_info_only is False", error=ValueError, level=LogLevel.ERROR, @@ -852,21 +738,21 @@ async def _process( ) -async def rnaseq_preprocess( +async def rnaseq_preprocess( # noqa: C901 context_name: str, taxon: int, - output_gene_info_filepath: Path, - como_context_dir: Path | None = None, - input_matrix_filepath: Path | list[Path] | None = None, - output_trna_fragment_lengths_filepath: Path | None = None, - output_mrna_fragment_lengths_filepath: Path | None = None, - output_trna_metadata_filepath: Path | None = None, - output_mrna_metadata_filepath: Path | None = None, - output_trna_count_matrix_filepath: Path | None = None, - output_mrna_count_matrix_filepath: Path | None = None, + output_gene_info_filepath: str | Path, + como_context_dir: str | Path | None = None, + input_matrix_filepath: str | Path | list[str] | list[Path] | list[str | Path] | None = None, + output_trna_fragment_lengths_filepath: str | Path | None = None, + output_mrna_fragment_lengths_filepath: str | Path | None = None, + output_trna_metadata_filepath: str | Path | None = None, + output_mrna_metadata_filepath: str | Path | None = None, + output_trna_count_matrix_filepath: str | Path | None = None, + output_mrna_count_matrix_filepath: str | Path | None = None, cache: bool = True, log_level: LogLevel | str = LogLevel.INFO, - log_location: str | io.TextIOWrapper = sys.stderr, + log_location: str | TextIO = sys.stderr, *, create_gene_info_only: bool = False, ) -> None: @@ -878,8 +764,8 @@ async def rnaseq_preprocess( :param context_name: The context/cell type being processed :param taxon: The NCBI taxonomy ID :param output_gene_info_filepath: Path to the output gene information CSV file - :param output_trna_fragment_lengths_filepath: Path to the output tRNA fragment lengths CSV file (if in "create" mode) - :param output_mrna_fragment_lengths_filepath: Path to the output mRNA fragment lengths CSV file (if in "create" mode) + :param output_trna_fragment_lengths_filepath: Path to the output tRNA fragment lengths CSV file + :param output_mrna_fragment_lengths_filepath: Path to the output mRNA fragment lengths CSV file :param output_trna_metadata_filepath: Path to the output tRNA config file (if in "create" mode) :param output_mrna_metadata_filepath: Path to the output mRNA config file (if in "create" mode) :param output_trna_count_matrix_filepath: The path to write total RNA count matrices @@ -892,27 +778,47 @@ async def rnaseq_preprocess( :param log_location: The logging location :param create_gene_info_only: If True, only create the gene info file and skip general preprocessing steps """ - _set_up_logging(level=log_level, location=log_location) + set_up_logging(level=log_level, location=log_location) - output_gene_info_filepath = output_gene_info_filepath.resolve() + # ruff: disable[ASYNC240] + if not output_gene_info_filepath: + log_and_raise_error( + message="output_gene_info_filepath must be provided", + error=ValueError, + level=LogLevel.ERROR, + ) - if como_context_dir: - como_context_dir = como_context_dir.resolve() - input_matrix_filepath = [i.resolve() for i in _listify(input_matrix_filepath)] if input_matrix_filepath else None - output_trna_metadata_filepath = output_trna_metadata_filepath.resolve() if output_trna_metadata_filepath else None - output_mrna_metadata_filepath = output_mrna_metadata_filepath.resolve() if output_mrna_metadata_filepath else None - output_trna_count_matrix_filepath = ( - output_trna_count_matrix_filepath.resolve() if output_trna_count_matrix_filepath else None - ) - output_mrna_count_matrix_filepath = ( - output_mrna_count_matrix_filepath.resolve() if output_mrna_count_matrix_filepath else None - ) + output_gene_info_filepath = Path(output_gene_info_filepath).resolve() + + context_dir = None + if como_context_dir is not None: + context_dir = Path(como_context_dir).resolve() + + in_matrix = None + if isinstance(input_matrix_filepath, list): + in_matrix = [Path(i).resolve() for i in input_matrix_filepath] + else: + if isinstance(input_matrix_filepath, (str, Path)): + in_matrix = [Path(input_matrix_filepath).resolve()] + + if output_trna_metadata_filepath is not None: + output_trna_metadata_filepath = Path(output_trna_metadata_filepath).resolve() + if output_mrna_metadata_filepath is not None: + output_mrna_metadata_filepath = Path(output_mrna_metadata_filepath).resolve() + if output_trna_count_matrix_filepath is not None: + output_trna_count_matrix_filepath = Path(output_trna_count_matrix_filepath).resolve() + if output_mrna_count_matrix_filepath is not None: + output_mrna_count_matrix_filepath = Path(output_mrna_count_matrix_filepath).resolve() + if output_trna_fragment_lengths_filepath is not None: + output_trna_fragment_lengths_filepath = Path(output_trna_fragment_lengths_filepath).resolve() + if output_mrna_fragment_lengths_filepath is not None: + output_mrna_fragment_lengths_filepath = Path(output_mrna_fragment_lengths_filepath).resolve() await _process( context_name=context_name, taxon=taxon, - como_context_dir=como_context_dir, - input_matrix_filepath=input_matrix_filepath, + como_context_dir=context_dir, + input_matrix_filepath=in_matrix, output_gene_info_filepath=output_gene_info_filepath, output_trna_config_filepath=output_trna_metadata_filepath, output_trna_matrix_filepath=output_trna_count_matrix_filepath, @@ -923,3 +829,4 @@ async def rnaseq_preprocess( cache=cache, create_gene_info_only=create_gene_info_only, ) + # ruff: enable[ASYNC240] diff --git a/main/como/utils.py b/main/como/utils.py index 262183e3..38db7d72 100644 --- a/main/como/utils.py +++ b/main/como/utils.py @@ -2,13 +2,11 @@ import contextlib import io -import itertools import sys from collections.abc import Iterator from pathlib import Path -from typing import Literal, NoReturn, TextIO, TypeVar, cast, overload +from typing import Any, Literal, NoReturn, TextIO, TypeVar, overload -import aiofiles import numpy.typing as npt import pandas as pd import scanpy as sc @@ -24,7 +22,14 @@ from como.data_types import LOG_FORMAT, Algorithm, LogLevel T = TypeVar("T") -__all__ = ["split_gene_expression_data", "stringlist_to_list", "suppress_stdout"] +__all__ = [ + "log_and_raise_error", + "read_file", + "set_up_logging", + "split_gene_expression_data", + "stringlist_to_list", + "suppress_stdout", +] def stringlist_to_list(stringlist: str | list[str]) -> list[str]: @@ -52,14 +57,20 @@ def stringlist_to_list(stringlist: str | list[str]) -> list[str]: new_list: list[str] = stringlist.strip("[]").replace("'", "").replace(" ", "").split(",") # Show a warning if more than one item is present in the list (this means we are using the old method) - logger.critical("DeprecationWarning: Please use the new method of providing context names, i.e. --output-filetypes 'type1 type2 type3'.") + logger.critical( + "DeprecationWarning: Please use the new method of providing context names, " + "i.e. --output-filetypes 'type1 type2 type3'." + ) logger.critical( "If you are using COMO, this can be done by setting the 'context_names' variable to a " "simple string separated by spaces. Here are a few examples!" ) logger.critical("context_names = 'cellType1 cellType2 cellType3'") logger.critical("output_filetypes = 'output1 output2 output3'") - logger.critical("\nYour current method of passing context names will be removed in the future. Update your variables above accordingly!\n\n") + logger.critical( + "\nYour current method of passing context names will be removed in the future. " + "Update your variables above accordingly!\n\n" + ) return new_list @@ -70,7 +81,7 @@ def split_gene_expression_data( recon_algorithm: Algorithm | None = None, *, ensembl_as_index: bool = True, -): +) -> pd.DataFrame: """Split the gene expression data into single-gene and multiple-gene names. Arg: @@ -83,15 +94,15 @@ def split_gene_expression_data( A pandas DataFrame with the split gene expression data """ expression_data.columns = [c.lower() for c in expression_data.columns] - if recon_algorithm in {Algorithm.IMAT, Algorithm.TINIT}: + if "combine_z" in expression_data.columns: expression_data.rename(columns={"combine_z": "active"}, inplace=True) - expression_data = cast(typ=pd.DataFrame, val=expression_data[[identifier_column, "active"]]) + expression_data = expression_data[[identifier_column, "active"]] single_gene_names = expression_data[~expression_data[identifier_column].astype(str).str.contains("//")] multiple_gene_names = expression_data[expression_data[identifier_column].astype(str).str.contains("//")] - split_gene_names = multiple_gene_names.assign(ensembl_gene_id=multiple_gene_names[identifier_column].astype(str).str.split("///")).explode( - identifier_column - ) + split_gene_names = multiple_gene_names.assign( + ensembl_gene_id=multiple_gene_names[identifier_column].astype(str).str.split("///") + ).explode(identifier_column) gene_expressions = pd.concat([single_gene_names, split_gene_names], axis=0, ignore_index=True) if ensembl_as_index: @@ -129,13 +140,19 @@ async def _format_determination( """ requested_output = [requested_output] if isinstance(requested_output, Output) else requested_output - coercion = (await biodbnet.db_find(values=input_values, output_db=requested_output, taxon=taxon)).drop(columns=["Input Type"]) + coercion = (await biodbnet.db_find(values=input_values, output_db=requested_output, taxon=taxon)).drop( + columns=["Input Type"] + ) coercion.columns = pd.Index(["input_value", *[o.value.replace(" ", "_").lower() for o in requested_output]]) return coercion -async def get_missing_gene_data(values: list[str] | pd.DataFrame, taxon_id: int | str | Taxon) -> pd.DataFrame: - if isinstance(values, list) and not isinstance(values, pd.DataFrame): # second isinstance required for static type check to be happy +async def get_missing_gene_data( # noqa: C901 + values: list[str] | pd.DataFrame | sc.AnnData, taxon_id: int | str | Taxon +) -> pd.DataFrame: + if isinstance(values, list) and not isinstance( + values, pd.DataFrame + ): # second isinstance required for static type check to be happy gene_type = await determine_gene_type(values) if all(v == "gene_symbol" for v in gene_type.values()): return await gene_symbol_to_ensembl_and_gene_id(values, taxon=taxon_id) @@ -144,7 +161,7 @@ async def get_missing_gene_data(values: list[str] | pd.DataFrame, taxon_id: int elif all(v == "entrez_gene_id" for v in gene_type.values()): return await gene_id_to_ensembl_and_gene_symbol(ids=values, taxon=taxon_id) else: - _log_and_raise_error( + log_and_raise_error( message="Gene data must be of the same type (i.e., all Ensembl, Entrez, or Gene Symbols)", error=ValueError, level=LogLevel.CRITICAL, @@ -153,34 +170,41 @@ async def get_missing_gene_data(values: list[str] | pd.DataFrame, taxon_id: int # raise error if duplicate column names exist if any(values.columns.duplicated(keep=False)): duplicate_cols = values.columns[values.columns.duplicated(keep=False)].unique().tolist() - _log_and_raise_error( - message=f"Duplicate column names exist! This will result in an error processing data. Duplicates: {','.join(duplicate_cols)}", + log_and_raise_error( + message=( + f"Duplicate column names exist! This will result in an error processing data. " + f"Duplicates: {','.join(duplicate_cols)}" + ), error=ValueError, level=LogLevel.CRITICAL, ) - if "gene_symbol" in itertools.chain(values.columns, [values.index.name]): + + names: list[str] = values.columns.tolist() + if values.index.name is not None: + names.append(str(values.index.name)) + if "gene_symbol" in names: return await get_missing_gene_data( values["gene_symbol"].tolist() if "gene_symbol" in values.columns else values.index.tolist(), taxon_id=taxon_id, ) - elif "entrez_gene_id" in itertools.chain(values.columns, [values.index.name]): + elif "entrez_gene_id" in names: return await get_missing_gene_data( values["entrez_gene_id"].tolist() if "entrez_gene_id" in values.columns else values.index.tolist(), taxon_id=taxon_id, ) - elif "ensembl_gene_id" in itertools.chain(values.columns, [values.index.name]): + elif "ensembl_gene_id" in names: return await get_missing_gene_data( values["ensembl_gene_id"].tolist() if "ensembl_gene_id" in values.columns else values.index.tolist(), taxon_id=taxon_id, ) else: - _log_and_raise_error( + log_and_raise_error( message="Unable to find 'gene_symbol', 'entrez_gene_id', or 'ensembl_gene_id' in the input matrix.", error=ValueError, level=LogLevel.CRITICAL, ) else: - _log_and_raise_error( + log_and_raise_error( message=f"Values must be a list of strings or a pandas DataFrame, got: {type(values)}", error=TypeError, level=LogLevel.CRITICAL, @@ -188,34 +212,34 @@ async def get_missing_gene_data(values: list[str] | pd.DataFrame, taxon_id: int @overload -def _read_file(path: None, h5ad_as_df: bool = True, **kwargs: Any) -> None: ... +def read_file(path: None, h5ad_as_df: bool = True, **kwargs: Any) -> None: ... @overload -def _read_file(path: pd.DataFrame, h5ad_as_df: bool = True, **kwargs: Any) -> pd.DataFrame: ... +def read_file(path: pd.DataFrame, h5ad_as_df: bool = True, **kwargs: Any) -> pd.DataFrame: ... @overload -def _read_file(path: io.StringIO, h5ad_as_df: bool = True, **kwargs: Any) -> pd.DataFrame: ... +def read_file(path: io.StringIO, h5ad_as_df: bool = True, **kwargs: Any) -> pd.DataFrame: ... @overload -def _read_file(path: sc.AnnData, h5ad_as_df: Literal[False], **kwargs: Any) -> sc.AnnData: ... +def read_file(path: sc.AnnData, h5ad_as_df: Literal[False], **kwargs: Any) -> sc.AnnData: ... @overload -def _read_file(path: sc.AnnData, h5ad_as_df: Literal[True] = True, **kwargs: Any) -> pd.DataFrame: ... +def read_file(path: sc.AnnData, h5ad_as_df: Literal[True] = True, **kwargs: Any) -> pd.DataFrame: ... @overload -def _read_file(path: Path, h5ad_as_df: Literal[False], **kwargs: Any) -> pd.DataFrame | sc.AnnData: ... +def read_file(path: Path, h5ad_as_df: Literal[False], **kwargs: Any) -> pd.DataFrame | sc.AnnData: ... @overload -def _read_file(path: Path, h5ad_as_df: Literal[True] = True, **kwargs: Any) -> pd.DataFrame: ... +def read_file(path: Path, h5ad_as_df: Literal[True] = True, **kwargs: Any) -> pd.DataFrame: ... -def _read_file( +def read_file( # noqa: C901 path: Path | io.StringIO | pd.DataFrame | sc.AnnData | None, h5ad_as_df: bool = True, **kwargs: Any, @@ -228,8 +252,8 @@ def _read_file( Args: path: The path to read from - h5ad_as_df: If True and the file is an h5ad, return a pandas DataFrame of the .X matrix instead of an AnnData object - kwargs: Additional arguments to pass to pandas.read_csv, pandas.read_excel, or scanpy.read_h5ad, depending on the filepath provided + h5ad_as_df: If True and the file is an h5ad, return a pandas DataFrame of the .X matrix + kwargs: Additional arguments to pass to pandas.read_csv, pandas.read_excel, or scanpy.read_h5ad Returns: None, or a pandas DataFrame or AnnData @@ -245,7 +269,7 @@ def _read_file( return None if isinstance(path, Path) and not path.exists(): - _log_and_raise_error(f"File {path} does not exist", error=FileNotFoundError, level=LogLevel.CRITICAL) + log_and_raise_error(f"File {path} does not exist", error=FileNotFoundError, level=LogLevel.CRITICAL) match path.suffix: case ".csv" | ".tsv" | ".txt" | ".tab" | ".sf": @@ -263,8 +287,9 @@ def _read_file( return df return adata case _: - _log_and_raise_error( - f"Unknown file extension '{path.suffix}'. Valid options are '.tsv', '.csv', '.xlsx', '.xls', or '.h5ad'", + log_and_raise_error( + f"Unknown file extension '{path.suffix}'. " + "Valid options are '.tsv', '.csv', '.xlsx', '.xls', or '.h5ad'", error=ValueError, level=LogLevel.CRITICAL, ) @@ -286,11 +311,8 @@ def _listify(value: T | list[T]) -> list[T]: Returns: A list of the provided value - """ - if isinstance(value, list): - return cast(list[T], value) # does not actually do anything; signifies to type checker that return value is of type list[T] - return [value] + return value if isinstance(value, list) else [value] def _num_rows(item: pd.DataFrame | npt.NDArray) -> int: @@ -305,11 +327,17 @@ def return_placeholder_data() -> pd.DataFrame: return pd.DataFrame(data=0, index=pd.Index(data=[0], name="entrez_gene_id"), columns=["expressed", "top"]) -def _set_up_logging( +def set_up_logging( level: LogLevel | str, location: str | TextIO, formatting: str = LOG_FORMAT, ): + """Set up logging for the application. + + :param level: The default logging level to use (e.g., LogLevel.INFO, LogLevel.DEBUG, etc.) + :param location: The location to log to (e.g., a file path or sys.stdout) + :param formatting: The log message format to use (default is LOG_FORMAT) + """ if isinstance(level, str): level = LogLevel[level.upper()] with contextlib.suppress(ValueError): @@ -317,19 +345,24 @@ def _set_up_logging( logger.add(sink=location, level=level.value, format=formatting) -def _log_and_raise_error( +def log_and_raise_error( message: str, *, error: type[BaseException], level: LogLevel, ) -> NoReturn: + """Log an error message and raise an exception. + + :param message: The error message to log and include in the raised exception + :param error: The type of exception to raise (e.g., ValueError, File NotFoundError, etc.) + :param level: The LogLevel at which to log the error message (e.g., LogLevel.ERROR, LogLevel.CRITICAL) + """ caller = logger.opt(depth=1) - match level: - case LogLevel.ERROR: - caller.error(message) - raise error(message) - case LogLevel.CRITICAL: - caller.critical(message) - raise error(message) - case _: - raise ValueError(f"When raising an error, LogLevel.ERROR or LogLevel.CRITICAL must be used. Got: {level}") + if level == LogLevel.ERROR: + caller.error(message) + raise error(message) + if level == LogLevel.CRITICAL: + caller.critical(message) + raise error(message) + + raise ValueError(f"When raising an error, LogLevel.ERROR or LogLevel.CRITICAL must be used. Got: {level}")