"""
Matplotlib-based visualization renderer for HBAT.
This module implements the VisualizationRenderer protocol using NetworkX
and matplotlib for backward compatibility with existing visualizations.
"""
import itertools as it
import logging
import math
import tkinter as tk
from typing import Any, Dict, List, Optional
import networkx as nx
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
from matplotlib.patches import Ellipse
from hbat.core.app_config import HBATConfig
from hbat.gui.visualization_renderer import BaseVisualizationRenderer
# Set up logging
logger = logging.getLogger(__name__)
# Check matplotlib availability
try:
import matplotlib.pyplot as plt
MATPLOTLIB_AVAILABLE = True
except ImportError:
MATPLOTLIB_AVAILABLE = False
[docs]
class MatplotlibRenderer(BaseVisualizationRenderer):
"""Matplotlib-based visualization renderer.
Provides NetworkX/matplotlib rendering with existing functionality
and styling, refactored to use the VisualizationRenderer interface.
"""
[docs]
def __init__(self, parent_widget: tk.Widget, config: HBATConfig) -> None:
"""Initialize matplotlib renderer.
:param parent_widget: Parent tkinter widget
:type parent_widget: tk.Widget
:param config: HBAT configuration instance
:type config: HBATConfig
"""
super().__init__(parent_widget, config)
self.fig: Optional[Figure] = None
self.ax: Optional[Any] = None
self.canvas: Optional[FigureCanvasTkAgg] = None
# Create matplotlib figure and canvas
self._create_figure()
[docs]
def render(self, graph: nx.Graph, layout_type: str) -> None:
"""Render the graph using matplotlib.
:param graph: NetworkX graph to render
:type graph: nx.Graph
:param layout_type: Layout algorithm name
:type layout_type: str
"""
if not self.is_available():
logger.error("Matplotlib renderer is not available")
return
try:
# Prepare graph data
self.prepare_graph_data(graph)
self.current_layout = layout_type
# Draw the graph
self._draw_graph(layout_type)
# Update canvas
if self.canvas:
self.canvas.draw()
logger.debug(f"Successfully rendered graph with {layout_type} layout")
except Exception as e:
logger.error(f"Failed to render graph with matplotlib: {e}")
raise
[docs]
def export(self, format: str, filename: str) -> bool:
"""Export visualization to file.
:param format: Export format (png, svg, pdf)
:type format: str
:param filename: Output filename
:type filename: str
:returns: True if export successful
:rtype: bool
"""
if not self.fig:
logger.error("No figure to export")
return False
if format.lower() not in self.get_supported_formats():
logger.error(f"Unsupported export format: {format}")
return False
try:
# Set DPI for high-quality export
dpi = self.config.get_preference("export_dpi", 300)
# Export figure
self.fig.savefig(
filename,
format=format.lower(),
dpi=dpi,
bbox_inches="tight",
pad_inches=0.1,
)
self.set_last_export_path(filename)
logger.info(f"Successfully exported to {filename}")
return True
except Exception as e:
logger.error(f"Failed to export visualization: {e}")
return False
[docs]
def is_available(self) -> bool:
"""Check if matplotlib renderer is available.
:returns: True if renderer can be used
:rtype: bool
"""
return MATPLOTLIB_AVAILABLE
[docs]
def get_renderer_name(self) -> str:
"""Get human-readable name of the renderer.
:returns: Renderer name
:rtype: str
"""
return "NetworkX/Matplotlib"
[docs]
def get_canvas(self) -> Optional[FigureCanvasTkAgg]:
"""Get the matplotlib canvas widget.
:returns: Canvas widget for embedding in GUI
:rtype: Optional[FigureCanvasTkAgg]
"""
return self.canvas
def _create_figure(self) -> None:
"""Create matplotlib figure and canvas."""
if not MATPLOTLIB_AVAILABLE:
return
# Create figure
self.fig = Figure(figsize=(10, 8), dpi=100)
self.ax = self.fig.add_subplot(111)
# Create canvas if parent widget is available
if self.parent:
self.canvas = FigureCanvasTkAgg(self.fig, self.parent)
def _draw_graph(self, layout_type: str = "circular") -> None:
"""Draw the graph with the specified layout.
:param layout_type: Layout algorithm name
:type layout_type: str
"""
if not self.ax or not self.graph:
return
self.ax.clear()
if not self.graph.nodes():
self.ax.text(
0.5,
0.5,
"No interactions to display",
ha="center",
va="center",
transform=self.ax.transAxes,
)
return
# Get layout positions
pos = self._get_layout(layout_type)
# Get prepared data
node_labels = self.node_data.get("labels", {})
node_colors = self.node_data.get("colors", [])
node_sizes = self.node_data.get("sizes", [])
edge_labels = self.edge_data.get("labels", {})
# Draw components
self._draw_ellipse_nodes(pos, node_colors, node_sizes)
self._draw_edges(pos)
self._draw_labels(pos, node_labels, edge_labels)
# Set title and clean up axes
chain_length = len(self.graph.edges()) if self.graph else 0
self.ax.set_title(
f"Cooperativity Chain\n"
f"Length: {chain_length} interactions ({layout_type.title()} Layout)"
)
self.ax.axis("off")
def _get_layout(self, layout_type: str = "circular") -> Dict[Any, Any]:
"""Get node positions using the specified layout algorithm.
:param layout_type: Layout algorithm name
:type layout_type: str
:returns: Dictionary mapping nodes to positions
:rtype: Dict[Any, Any]
"""
try:
if layout_type == "circular":
return nx.circular_layout(self.graph)
elif layout_type == "shell":
return nx.shell_layout(self.graph)
elif layout_type == "kamada_kawai":
return nx.kamada_kawai_layout(self.graph)
elif layout_type == "planar":
if nx.is_planar(self.graph):
return nx.planar_layout(self.graph)
else:
# Fallback to circular if not planar
return nx.circular_layout(self.graph)
elif layout_type == "spring":
return nx.spring_layout(self.graph)
else:
return nx.circular_layout(self.graph)
except Exception:
# Fallback to circular layout if anything fails
return nx.circular_layout(self.graph)
def _draw_ellipse_nodes(
self, pos: Dict[Any, Any], node_colors: List[str], node_sizes: List[int]
) -> None:
"""Draw ellipse-shaped nodes.
:param pos: Node positions dictionary
:type pos: Dict[Any, Any]
:param node_colors: List of node colors
:type node_colors: List[str]
:param node_sizes: List of node sizes
:type node_sizes: List[int]
"""
if not self.graph:
return
for i, node in enumerate(self.graph.nodes()):
if node not in pos:
continue
x, y = pos[node]
# Get colors and sizes with bounds checking
color = node_colors[i] if i < len(node_colors) else "lightgray"
size = node_sizes[i] if i < len(node_sizes) else 1000
# Calculate ellipse dimensions based on node size
width = (size / 3000) * 1.8
height = (size / 3000) * 1.0
# Determine node style based on node type
if "(" in str(node):
# Atom-specific node - more elongated ellipse
width *= 1.2
edge_style = "dotted"
linewidth = 2.0
else:
# Residue node - more circular ellipse
width *= 1.2
edge_style = "dashed"
linewidth = 2.0
# Create ellipse patch with enhanced styling
ellipse = Ellipse(
(x, y),
width,
height,
facecolor=color,
edgecolor="black",
linewidth=linewidth,
linestyle=edge_style,
alpha=0.85,
)
# Add ellipse to the axes
self.ax.add_patch(ellipse)
def _draw_edges(self, pos: Dict[Any, Any]) -> None:
"""Draw edges with connectionstyles.
:param pos: Node positions dictionary
:type pos: Dict[Any, Any]
"""
if not self.graph:
return
# Create connectionstyles for curved edges
connectionstyle = [f"arc3,rad={r}" for r in it.accumulate([0.15] * 6)]
# Draw edges with connectionstyles to handle multiple edges
nx.draw_networkx_edges(
self.graph,
pos,
edge_color="black",
style="dashed",
connectionstyle=connectionstyle,
arrows=True,
arrowsize=10,
ax=self.ax,
)
def _draw_labels(
self,
pos: Dict[Any, Any],
node_labels: Dict[Any, str],
edge_labels: Dict[Any, str],
) -> None:
"""Draw node and edge labels.
:param pos: Node positions dictionary
:type pos: Dict[Any, Any]
:param node_labels: Node labels dictionary
:type node_labels: Dict[Any, str]
:param edge_labels: Edge labels dictionary
:type edge_labels: Dict[Any, str]
"""
if not self.graph:
return
# Draw node labels
nx.draw_networkx_labels(self.graph, pos, node_labels, font_size=8, ax=self.ax)
# Draw edge labels with connectionstyles
connectionstyle = [f"arc3,rad={r}" for r in it.accumulate([0.15] * 6)]
# Convert edge_labels format for NetworkX compatibility
formatted_labels = {}
for edge_key, label in edge_labels.items():
if isinstance(edge_key, tuple) and len(edge_key) >= 2:
# Handle both (u,v) and (u,v,key) formats
if len(edge_key) == 2:
formatted_labels[edge_key] = label
else:
# Convert (u,v,key) to (u,v)
edge_tuple = (edge_key[0], edge_key[1])
formatted_labels[edge_tuple] = label
if formatted_labels:
nx.draw_networkx_edge_labels(
self.graph,
pos,
formatted_labels,
connectionstyle=connectionstyle,
label_pos=0.5,
font_size=8,
bbox={"boxstyle": "round,pad=0.2", "facecolor": "white", "alpha": 0.8},
ax=self.ax,
)
[docs]
def update_layout(self, layout_type: str) -> None:
"""Update visualization with new layout.
:param layout_type: New layout algorithm name
:type layout_type: str
"""
if self.graph is not None:
self.current_layout = layout_type
self._draw_graph(layout_type)
if self.canvas:
self.canvas.draw()
[docs]
def clear(self) -> None:
"""Clear the current visualization."""
if self.ax:
self.ax.clear()
if self.canvas:
self.canvas.draw()
[docs]
def set_title(self, title: str) -> None:
"""Set the plot title.
:param title: Title text
:type title: str
"""
if self.ax:
self.ax.set_title(title)
if self.canvas:
self.canvas.draw()