diff --git a/src/petab_gui/controllers/mother_controller.py b/src/petab_gui/controllers/mother_controller.py index 1e97de9..04b98e5 100644 --- a/src/petab_gui/controllers/mother_controller.py +++ b/src/petab_gui/controllers/mother_controller.py @@ -151,6 +151,9 @@ def __init__(self, view, model: PEtabModel): } self.sbml_checkbox_states = {"sbml": False, "antimony": False} self.unsaved_changes = False + # Selection synchronization flags to prevent redundant updates + self._updating_from_plot = False + self._updating_from_table = False # Next Steps Panel self.next_steps_panel = NextStepsPanel(self.view) self.next_steps_panel.dont_show_again_changed.connect( @@ -1411,8 +1414,25 @@ def init_plotter(self): self.plotter = self.view.plot_dock self.plotter.highlighter.click_callback = self._on_plot_point_clicked + def _floats_match(self, a, b, epsilon=1e-9): + """Check if two floats match within epsilon tolerance.""" + return abs(a - b) < epsilon + def _on_plot_point_clicked(self, x, y, label, data_type): - # Extract observable ID from label, if formatted like 'obsId (label)' + """Handle plot point clicks and select corresponding table row. + + Uses epsilon tolerance for floating-point comparison to avoid + precision issues. + """ + # Check for None label + if label is None: + self.logger.log_message( + "Cannot select table row: plot point has no label.", + color="orange", + ) + return + + # Extract observable ID from label proxy = self.measurement_controller.proxy_model view = self.measurement_controller.view.table_view if data_type == "simulation": @@ -1424,16 +1444,26 @@ def _on_plot_point_clicked(self, x, y, label, data_type): y_axis_col = data_type observable_col = "observableId" + # Get column indices with error handling def column_index(name): for col in range(proxy.columnCount()): if proxy.headerData(col, Qt.Horizontal) == name: return col raise ValueError(f"Column '{name}' not found.") - x_col = column_index(x_axis_col) - y_col = column_index(y_axis_col) - obs_col = column_index(observable_col) + try: + x_col = column_index(x_axis_col) + y_col = column_index(y_axis_col) + obs_col = column_index(observable_col) + except ValueError as e: + self.logger.log_message( + f"Table selection failed: {e}", + color="red", + ) + return + # Search for matching row using epsilon tolerance for floats + matched = False for row in range(proxy.rowCount()): row_obs = proxy.index(row, obs_col).data() row_x = proxy.index(row, x_col).data() @@ -1442,23 +1472,80 @@ def column_index(name): row_x, row_y = float(row_x), float(row_y) except ValueError: continue - if row_obs == obs and row_x == x and row_y == y: - view.selectRow(row) + + # Use epsilon tolerance for float comparison + if ( + row_obs == obs + and self._floats_match(row_x, x) + and self._floats_match(row_y, y) + ): + # Manually update highlight BEFORE selecting row + # This ensures the circle appears even though we skip the signal handler + if data_type == "measurement": + self.plotter.highlight_from_selection([row]) + else: + self.plotter.highlight_from_selection( + [row], + proxy=self.simulation_controller.proxy_model, + y_axis_col="simulation", + ) + + # Set flag to prevent redundant highlight update from signal + self._updating_from_plot = True + try: + view.selectRow(row) + matched = True + finally: + self._updating_from_plot = False break + # Provide feedback if no match found + if not matched: + self.logger.log_message( + f"No matching row found for plot point (obs={obs}, x={x:.4g}, y={y:.4g})", + color="orange", + ) + + def _handle_table_selection_changed( + self, table_view, proxy=None, y_axis_col="measurement" + ): + """Common handler for table selection changes. + + Skips update if selection was triggered by plot click to prevent + redundant highlight updates. + + Args: + table_view: The table view with selection to highlight + proxy: Optional proxy model for simulation data + y_axis_col: Column name for y-axis data (default: "measurement") + """ + # Skip if selection was triggered by plot point click + if self._updating_from_plot: + return + + # Set flag to prevent infinite loop if highlight triggers selection + self._updating_from_table = True + try: + selected_rows = get_selected(table_view) + if proxy: + self.plotter.highlight_from_selection( + selected_rows, proxy=proxy, y_axis_col=y_axis_col + ) + else: + self.plotter.highlight_from_selection(selected_rows) + finally: + self._updating_from_table = False + def _on_table_selection_changed(self, selected, deselected): """Highlight the cells selected in measurement table.""" - selected_rows = get_selected( + self._handle_table_selection_changed( self.measurement_controller.view.table_view ) - self.plotter.highlight_from_selection(selected_rows) def _on_simulation_selection_changed(self, selected, deselected): - selected_rows = get_selected( - self.simulation_controller.view.table_view - ) - self.plotter.highlight_from_selection( - selected_rows, + """Highlight the cells selected in simulation table.""" + self._handle_table_selection_changed( + self.simulation_controller.view.table_view, proxy=self.simulation_controller.proxy_model, y_axis_col="simulation", ) diff --git a/src/petab_gui/views/simple_plot_view.py b/src/petab_gui/views/simple_plot_view.py index 1e41a27..28a23d9 100644 --- a/src/petab_gui/views/simple_plot_view.py +++ b/src/petab_gui/views/simple_plot_view.py @@ -204,8 +204,25 @@ def _update_tabs(self, fig: plt.Figure): self.tab_widget.addTab(tab, "All Plots") return - # Full figure tab - create_plot_tab(fig, self, plot_title="All Plots") + # Full figure tab - capture canvas and connect picking for all axes + main_canvas = create_plot_tab(fig, self, plot_title="All Plots") + + # Enable picker on all lines and containers in the original figure + for ax in fig.axes: + # Handle regular lines (simulations, etc.) + for line in ax.get_lines(): + line.set_picker(True) + line.set_pickradius(5) # 5 pixels tolerance for clicking + + # Handle error bar containers (measurements, etc.) + for container in ax.containers: + if isinstance(container, ErrorbarContainer) and ( + len(container.lines) > 0 and container.lines[0] is not None + ): + container.lines[0].set_picker(True) + container.lines[0].set_pickradius(5) + + self.highlighter.connect_picking(main_canvas) # One tab per Axes for idx, ax in enumerate(fig.axes): @@ -219,7 +236,7 @@ def _update_tabs(self, fig: plt.Figure): line = handle else: continue - sub_ax.plot( + new_line = sub_ax.plot( line.get_xdata(), line.get_ydata(), label=label, @@ -228,7 +245,8 @@ def _update_tabs(self, fig: plt.Figure): color=line.get_color(), alpha=line.get_alpha(), picker=True, - ) + )[0] + new_line.set_pickradius(5) # 5 pixels tolerance for clicking sub_ax.set_title(ax.get_title()) sub_ax.set_xlabel(ax.get_xlabel()) sub_ax.set_ylabel(ax.get_ylabel()) @@ -241,15 +259,34 @@ def _update_tabs(self, fig: plt.Figure): plot_title=f"Subplot {idx + 1}", ) - if ax.get_title(): - obs_id = ax.get_title() - elif ax.get_legend_handles_labels()[1]: - obs_id = ax.get_legend_handles_labels()[1][0] - obs_id = obs_id.split(" ")[-1] + # Map subplot to observable IDs + # When grouped by condition/dataset, one subplot can have multiple observables + # Extract all observable IDs from legend labels + subplot_title = ( + ax.get_title() if ax.get_title() else f"subplot_{idx}" + ) + _, legend_labels = ax.get_legend_handles_labels() + + if legend_labels: + # Extract observable ID from each legend label + for legend_label in legend_labels: + label_parts = legend_label.split() + if len(label_parts) == 0: + continue + # Extract observable ID (last part before "simulation" if present) + if label_parts[-1] == "simulation": + obs_id = ( + label_parts[-2] + if len(label_parts) >= 2 + else label_parts[0] + ) + else: + obs_id = label_parts[-1] + # Map this observable to this subplot index + self.observable_to_subplot[obs_id] = idx else: - obs_id = f"subplot_{idx}" - - self.observable_to_subplot[obs_id] = idx + # No legend, use title as fallback + self.observable_to_subplot[subplot_title] = idx self.highlighter.register_subplot(ax, idx) # Register subplot canvas self.highlighter.register_subplot(sub_ax, idx) @@ -393,17 +430,34 @@ def _on_pick(self, event): # Try to recover the label from the legend (handle → label mapping) handles, labels = ax.get_legend_handles_labels() label = None + data_type = "measurement" # Default to measurement + for h, l in zip(handles, labels, strict=False): if h is artist: + # Extract observable ID and data type from legend label + # Format can be: "observableId", "datasetId observableId", or "datasetId observableId simulation" label_parts = l.split() + if len(label_parts) == 0: + continue + if label_parts[-1] == "simulation": data_type = "simulation" - label = label_parts[-2] + # Label is second-to-last: "cond obs simulation" -> "obs" + label = ( + label_parts[-2] + if len(label_parts) >= 2 + else label_parts[0] + ) else: data_type = "measurement" + # Label is last: "dataset obs" -> "obs" or just "obs" -> "obs" label = label_parts[-1] break + # If no label found, skip this click + if label is None: + return + for i in ind: x = xdata[i] y = ydata[i]