Source code for batch_multi_plot_spectrogram

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Provides batch spectrogram plotting utilities.
Should work with CDFs like those from FAST (see batch_multi_plot_FAST_spectrograms.py) but should also be flexible with other data.

Assumed folder layout is::
    {CDF_DATA_DIRECTORY}/year/month

Filenames in the month folders assumed to be in the following formats::
    {??}_{??}_{??}_{instrument}_{timestamp}_{orbit}_v02.cdf      (known "instruments" are ees, eeb, ies, or ieb)
    {??}_{??}_orb_{orbit}_{??}.cdf

Examples::
    FAST_data/2000/01/fa_esa_l2_eeb_20000101001737_13312_v02.cdf
    FAST_data/2000/01/fa_k0_orb_13312_v01.cdf
"""

__authors__: list[str] = ["Ev Hansen"]
__contact__: str = "ephansen+gh@terpmail.umd.edu"

__credits__: list[list[str]] = [
    ["Ev Hansen", "Python code"],
    ["Emma Mirizio", "Co-Mentor"],
    ["Marilia Samara", "Co-Mentor"],
]

__date__: str = "2025-08-13"
__status__: str = "Development"
__version__: str = "0.0.1"
__license__: str = "GPL-3.0"

# Main imports for CDF data and plotting
import pandas as pd
import cdflib
import numpy as np
import matplotlib

matplotlib.use("Agg")  # Use non-interactive backend for batch
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.dates import date2num
import matplotlib.dates as mdates
import matplotlib.colors as mcolors
from matplotlib import _pylab_helpers
from datetime import datetime, timezone
import os
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict, deque
import json
import time as _time
import concurrent.futures

# garbage collection, parallel processing, and profiling
import gc
import concurrent.futures
import signal
import sys


# Section: Constants and Configuration
# Directory containing CDF data files
CDF_DATA_DIRECTORY = "./FAST_data/"
# List of variable names expected in CDF files
CDF_VARIABLE_NAMES = ["time_unix", "data", "energy", "pitch_angle"]
# Function to collapse 3D data arrays to 2D (e.g., sum over axis)
COLLAPSE_FUNCTION = np.nansum

# Colormaps for different axis scaling combinations (colorblind-friendly and visually distinct)
COLORMAP_LINEAR_Y_LINEAR_Z = "viridis"
COLORMAP_LINEAR_Y_LOG_Z = "cividis"
COLORMAP_LOG_Y_LINEAR_Z = "plasma"
COLORMAP_LOG_Y_LOG_Z = "inferno"

# Plot configuration
PLOT_FIGURE_WIDTH_INCHES = 6.25
PLOT_FIGURE_HEIGHT_INCHES = 2.0
TICK_LABEL_FONT_SIZE = 15
AXIS_LABEL_FONT_SIZE = 18
DEFAULT_ZOOM_WINDOW_MINUTES = 6  # Default zoom window duration in minutes
FILTERED_ORBITS_CSV_PATH = "./FAST_Cusp_Indices.csv"  # Path to filtered cusp orbits CSV
PLOTTING_PROGRESS_JSON_PATH = "./batch_multi_plot_progress.json"  # Path to JSON for tracking plotting progress across sessions
OUTPUT_BASE_DIRECTORY = "./plots/"  # Parent directory to save plots

# Logfile configuration for batch restarts
LOGFILE_DATETIME_PATH = "./batch_multi_plot_logfile_datetime.txt"
if os.path.exists(LOGFILE_DATETIME_PATH):
    with open(LOGFILE_DATETIME_PATH, "r") as f:
        LOGFILE_DATETIME_STRING = f.read().strip()
    if not LOGFILE_DATETIME_STRING:
        LOGFILE_DATETIME_STRING = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        with open(LOGFILE_DATETIME_PATH, "w") as f:
            f.write(LOGFILE_DATETIME_STRING)
else:
    LOGFILE_DATETIME_STRING = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    with open(LOGFILE_DATETIME_PATH, "w") as f:
        f.write(LOGFILE_DATETIME_STRING)
LOGFILE_PATH = f"./batch_multi_plot_log_{LOGFILE_DATETIME_STRING}.log"

# Pitch angle category definitions for spectrogram plots
PITCH_ANGLE_CATEGORY_RANGES = {
    "downgoing": [(0, 30), (330, 360)],
    "upgoing": [(150, 210)],
    "perpendicular": [(40, 140), (210, 330)],
    "all": [(0, 360)],
}

# Global caches for batch optimization
filtered_orbits_cache = {}
orbit_column_cache = {}
cdf_type_cache = {}


# Section: Functions
# Section: Utility Functions
[docs] def load_filtered_orbits(csv_path=FILTERED_ORBITS_CSV_PATH): """Load the filtered orbits CSV with a simple cache. Parameters ---------- csv_path : str, default FILTERED_ORBITS_CSV_PATH Path to the filtered orbits TSV/CSV file. Returns ------- pandas.DataFrame or None DataFrame of filtered orbits, or ``None`` if loading fails. Notes ----- A module-level dictionary caches previously loaded DataFrames keyed by absolute path string to avoid repeated disk I/O in batch routines. """ global filtered_orbits_cache if csv_path in filtered_orbits_cache: return filtered_orbits_cache[csv_path] try: dataframe = pd.read_csv(csv_path, sep="\t") filtered_orbits_cache[csv_path] = dataframe return dataframe except Exception as exc: log_error(f"Error loading CSV {csv_path}: {exc}") return None
# Section: SIGINT Handling def _terminate_all_child_processes(): """Attempt to terminate all child processes of this process. Returns ------- None Notes ----- Uses :mod:`psutil` (imported lazily) to enumerate child processes recursively and invoke ``terminate()`` on each. Exceptions during termination are suppressed because this function is used during best-effort shutdown handling. """ import psutil current_process = psutil.Process() for child in current_process.children(recursive=True): try: child.terminate() except Exception as child_termination_exception: # Suppress individual child termination issues during shutdown. # Variable name intentionally descriptive for lint clarity. _ = child_termination_exception # explicit no-op reference def _sigint_handler(signum, frame): """SIGINT handler to terminate children and exit promptly. Parameters ---------- signum : int Signal number. frame : FrameType or None Current execution frame (unused). Returns ------- None """ log_message("[INFO] SIGINT received. Terminating all child processes and exiting.") _terminate_all_child_processes() sys.exit(1) # Section: Logging _LOG_BUFFER = [] # list[tuple[str, str]]: buffered (level, message) entries _LOG_BATCH_SIZE = 10 # default batch size for buffered logging; configurable
[docs] def configure_log_batch(batch_size: int): """Configure buffered logging batch size. Parameters ---------- batch_size : int Desired number of log records to accumulate before an automatic flush. Values less than 1 are coerced to 1. """ global _LOG_BATCH_SIZE _LOG_BATCH_SIZE = max(1, int(batch_size))
def _flush_log_buffer(force: bool = False): """Flush buffered log messages to disk. Parameters ---------- force : bool, default False If True, flush even if the current buffer length is below the configured batch size threshold. """ if not _LOG_BUFFER: return if (len(_LOG_BUFFER) >= _LOG_BATCH_SIZE) or force: try: with open(LOGFILE_PATH, "a") as logfile_out: for level, msg in _LOG_BUFFER: if level == "error": logfile_out.write(f"[ERROR] {msg}\n") else: logfile_out.write(msg + "\n") except Exception as log_flush_exception: # Last-resort console output tqdm.write(f"[ERROR] Failed flushing log buffer: {log_flush_exception}") finally: _LOG_BUFFER.clear()
[docs] def log_message(message: str, force_flush: bool = False): """Queue an informational log message. Messages are appended to an in-memory buffer; a flush occurs automatically once the configured batch size is reached or ``force_flush`` is True. """ _LOG_BUFFER.append(("info", message)) _flush_log_buffer(force=force_flush)
[docs] def log_error(message: str, force_flush: bool = False): """Queue an error log message and echo to console immediately.""" tqdm.write("[ERROR] " + message) _LOG_BUFFER.append(("error", message)) _flush_log_buffer(force=force_flush)
[docs] def get_timestamps_for_orbit( filtered_orbits_dataframe, orbit_number, instrument_type, time_unix_array ): """Compute orbit boundary UNIX timestamps from filtered indices. Parameters ---------- filtered_orbits_dataframe : pandas.DataFrame DataFrame containing filtered orbits and min/max indices per instrument. orbit_number : int Orbit number to look up. instrument_type : str Instrument type identifier (e.g., ``'ees'``, ``'ies'``). time_unix_array : numpy.ndarray 1D array of UNIX timestamps for the instrument. Returns ------- list of float Boundary UNIX timestamps for the orbit (one value for a degenerate span or two values for start/end). Returns an empty list on invalid indices. """ global orbit_column_cache dataframe = filtered_orbits_dataframe cache = orbit_column_cache if dataframe is None or instrument_type is None or time_unix_array is None: return [] key = (id(dataframe), instrument_type) if key not in cache: orbit_column = next(col for col in dataframe.columns if "orbit" in col.lower()) min_index_column = next( col for col in dataframe.columns if instrument_type in col.lower() and "min index" in col.lower() ) max_index_column = next( col for col in dataframe.columns if instrument_type in col.lower() and "max index" in col.lower() ) cache[key] = (orbit_column, min_index_column, max_index_column) else: orbit_column, min_index_column, max_index_column = cache[key] row = dataframe[dataframe[orbit_column] == orbit_number] if row.empty: return [] min_index = row.iloc[0][cache[key][1]] max_index = row.iloc[0][cache[key][2]] try: min_index = int(min_index) max_index = int(max_index) except Exception as orbit_index_cast_exception: log_message("[WARN] Non-integer indices found in orbit row, using 0.") _ = orbit_index_cast_exception # explicit no-op reference return [] min_index = max(0, min(min_index, len(time_unix_array) - 1)) max_index = max(0, min(max_index, len(time_unix_array) - 1)) if min_index == max_index: return [float(time_unix_array[min_index])] return [float(time_unix_array[min_index]), float(time_unix_array[max_index])]
[docs] def get_cdf_file_type(cdf_file_path: str): """Infer instrument type from a CDF file path. Parameters ---------- cdf_file_path : str Path to the CDF file. Returns ------- str or None Instrument type string (e.g., ``'ees'``), ``'orb'`` for orbit files, or ``None`` if not recognized. """ path_lower = cdf_file_path.lower() instrument_tags = ["ees", "eeb", "ies", "ieb"] if "_orb_" in path_lower: return "orb" for tag in instrument_tags: if f"_{tag}_" in path_lower: return tag log_error(f"Unknown CDF file type for path: {cdf_file_path}") return None
[docs] def get_variable_shape(cdf_path, variable_name): """Return the shape of a variable in a CDF file. Parameters ---------- cdf_path : str Path to the CDF file. variable_name : str Variable name to inspect. Returns ------- tuple or None Variable shape tuple, or ``None`` if variable absent / not array or an error occurs. """ global cdf_type_cache instrument_type = cdf_type_cache.get(cdf_path) if instrument_type is None: instrument_type = get_cdf_file_type(cdf_path) cdf_type_cache[cdf_path] = instrument_type if instrument_type is None or instrument_type == "orb": return None try: with cdflib.CDF(cdf_path) as cdf: variable_data = cdf.varget(variable_name) return ( variable_data.shape if isinstance(variable_data, np.ndarray) else None ) except Exception as exc: log_error(f"Error reading {cdf_path} for variable {variable_name}: {exc}") return None
[docs] def get_cdf_var_shapes( cdf_folder_path=CDF_DATA_DIRECTORY, variable_names=CDF_VARIABLE_NAMES ): """Collect shapes of variables across CDF files in a folder. Parameters ---------- cdf_folder_path : str, default CDF_DATA_DIRECTORY Directory containing CDF files. variable_names : list of str, default CDF_VARIABLE_NAMES Variable names to inspect. Returns ------- dict Mapping from variable name (str) to list of shape tuples (or None) per file. """ cdf_file_paths = [str(p) for p in Path(cdf_folder_path).rglob("*.[cC][dD][fF]")] shapes_by_variable = {} for variable_name in variable_names: shapes_by_variable[variable_name] = [] for cdf_path in tqdm( cdf_file_paths, desc=f"Processing CDF files ({variable_name})", unit="file", total=len(cdf_file_paths), ): shapes_by_variable[variable_name].append( get_variable_shape(cdf_path, variable_name) ) return shapes_by_variable
[docs] def close_all_axes_and_clear(fig): """Close axes/subplots and clear a figure to free memory. Parameters ---------- fig : matplotlib.figure.Figure Figure instance to clear and dispose. Returns ------- None Notes ----- Ensures axes are deleted, the canvas is closed/detached, and removes the figure from the global Gcf registry when possible to mitigate memory growth during large batch operations. """ for axis in list(fig.axes): try: fig.delaxes(axis) except Exception as axis_close_error: log_error(f"Error closing axis: {axis_close_error}") fig.clf() if hasattr(fig, "canvas") and fig.canvas is not None: try: fig.canvas.close() except Exception as canvas_close_error: log_message(f"[WARN] Error closing canvas: {canvas_close_error}") try: fig.canvas.figure = None except Exception as canvas_figure_clear_error: log_message( f"[WARN] Error clearing canvas figure: {canvas_figure_clear_error}" ) fig.canvas = None try: if hasattr(fig, "number") and fig.number is not None: _pylab_helpers.Gcf.destroy(fig.number) except Exception as gcf_registry_error: log_error(f"Error removing figure from Gcf registry: {gcf_registry_error}")
# Section: Spectrogram Plotting
[docs] def make_spectrogram( x_axis_values, y_axis_values, data_array_3d, x_axis_min=None, x_axis_max=None, x_axis_is_unix=True, x_axis_label=None, center_timestamp=None, window_duration_seconds=None, y_axis_scale_function=None, y_axis_label=None, y_axis_min=0, y_axis_max=4000, z_axis_scale_function=None, z_axis_min=None, z_axis_max=None, z_axis_label=None, collapse_axis=1, colormap="viridis", axis_object=None, instrument_label=None, vertical_lines_unix=None, # list of unix timestamps to mark ): """Plot a spectrogram by collapsing a 3D data array along an axis. Parameters ---------- x_axis_values : array-like 1D array for x (horizontal) axis (e.g., time sequence). y_axis_values : array-like 1D array for y (vertical) axis (e.g., energy bins). data_array_3d : numpy.ndarray 3D data array, e.g. ``(time, angle/pitch, energy)``. x_axis_min, x_axis_max : float, optional Explicit x-axis clipping bounds before plotting. x_axis_is_unix : bool, default True If ``True``, x-axis treated as UNIX seconds and converted to dates. x_axis_label : str, optional Custom x-axis label (default depends on ``x_axis_is_unix``). center_timestamp : float, optional Center of requested zoom window (UNIX seconds). window_duration_seconds : float, optional Duration of zoom window; both must be provided for zoom to apply. y_axis_scale_function : {'linear', 'log'}, optional Y-axis scaling; ``None`` behaves as ``'linear'``. y_axis_label : str, optional Y-axis label text. y_axis_min, y_axis_max : float, default 0, 4000 Y-axis clipping range applied before filtering / plotting. z_axis_scale_function : {'linear', 'log'}, optional Color scale mode; ``None`` behaves as ``'linear'``. z_axis_min, z_axis_max : float, optional Optional color scale bounds (percentiles chosen if omitted). z_axis_label : str, optional Colorbar label text. collapse_axis : int, default 1 Axis index along which to collapse the 3D data array. colormap : str, default 'viridis' Matplotlib colormap name. axis_object : matplotlib.axes.Axes, optional Existing axes to draw into; if ``None`` a new figure/axes created. instrument_label : str, optional Title string applied to the axes. vertical_lines_unix : list of float, optional UNIX timestamps to annotate with vertical lines. Returns ------- axis_object : matplotlib.axes.Axes or None The axis object used for plotting (``None`` if no data plotted). x_axis_plot : numpy.ndarray or None X values actually used (possibly filtered / converted), or ``None`` if skipped. """ # Log the function call and key parameters for debugging log_message( f"[DEBUG] make_spectrogram: y_axis_scale_function={y_axis_scale_function}, z_axis_scale_function={z_axis_scale_function}, z_axis_min={z_axis_min}, z_axis_max={z_axis_max}, colormap={colormap}" ) # Convert input arrays to numpy arrays for consistency x_axis = np.asarray(x_axis_values) y_axis = np.asarray(y_axis_values) data_array = np.asarray(data_array_3d) # Collapse the 3D data array along the specified axis (e.g., sum over pitch angle) collapsed_matrix = COLLAPSE_FUNCTION(data_array, axis=collapse_axis) # Mask out columns that are all NaN and restrict to valid energy range nan_column_mask = ~np.all(np.isnan(collapsed_matrix), axis=0) valid_energy_mask = (y_axis >= y_axis_min) & (y_axis <= y_axis_max) combined_mask = nan_column_mask & valid_energy_mask collapsed_matrix = collapsed_matrix[:, combined_mask] y_axis = y_axis[combined_mask] if collapsed_matrix.size == 0 or y_axis.size == 0: log_message("[WARNING] All energy bins were filtered out. No data to plot.") return None, None # Ensure y-axis is increasing (for plotting) if y_axis[0] > y_axis[-1]: y_axis = y_axis[::-1] collapsed_matrix = collapsed_matrix[:, ::-1] # If a zoom window is specified, restrict to that window if center_timestamp is not None and window_duration_seconds is not None: half_window = window_duration_seconds / 2 left_bound = center_timestamp - half_window right_bound = center_timestamp + half_window zoom_mask = (x_axis >= left_bound) & (x_axis <= right_bound) x_axis = x_axis[zoom_mask] collapsed_matrix = collapsed_matrix[zoom_mask, :] # Restrict to specified x-axis min/max if provided if x_axis_min is not None or x_axis_max is not None: x_mask = np.ones_like(x_axis, dtype=bool) if x_axis_min is not None: x_mask &= x_axis >= x_axis_min if x_axis_max is not None: x_mask &= x_axis <= x_axis_max x_axis = x_axis[x_mask] collapsed_matrix = collapsed_matrix[x_mask, :] # Convert x-axis to matplotlib date format if using unix timestamps if x_axis_is_unix: x_axis_datetime = np.array( [datetime.fromtimestamp(x, tz=timezone.utc) for x in x_axis] ) x_axis_plot = date2num(x_axis_datetime) x_label = x_axis_label if x_axis_label is not None else "Time (UTC)" else: x_axis_plot = x_axis x_label = x_axis_label if x_axis_label is not None else "X" # Create a new figure and axis if not provided if axis_object is None: fig = Figure(figsize=(PLOT_FIGURE_WIDTH_INCHES, PLOT_FIGURE_HEIGHT_INCHES)) canvas = FigureCanvas(fig) axis_object = fig.add_subplot(1, 1, 1) else: fig = axis_object.figure # Transpose matrix for plotting (so y-axis is vertical) matrix_plot = collapsed_matrix.T # Set x-axis limits to zoom window if specified, otherwise to full range if center_timestamp is not None and window_duration_seconds is not None: if x_axis_is_unix: left_num = float( date2num( datetime.fromtimestamp( center_timestamp - window_duration_seconds / 2, tz=timezone.utc ) ) ) right_num = float( date2num( datetime.fromtimestamp( center_timestamp + window_duration_seconds / 2, tz=timezone.utc ) ) ) axis_object.set_xlim(left_num, right_num) else: axis_object.set_xlim( center_timestamp - window_duration_seconds / 2, center_timestamp + window_duration_seconds / 2, ) else: axis_object.set_xlim(x_axis_plot[0], x_axis_plot[-1]) # If no data remains after filtering, skip plotting if matrix_plot.size == 0: log_message("[WARNING] No data to plot after filtering. Skipping plot.") return None, None # Set colorbar min/max if not provided if z_axis_min is None: z_axis_min = np.nanpercentile(matrix_plot, 1) if z_axis_max is None: z_axis_max = np.nanpercentile(matrix_plot, 99) # Find the smallest positive value for safe log scaling finite_positive = matrix_plot[np.isfinite(matrix_plot) & (matrix_plot > 0)] safe_vmin = np.nanmin(finite_positive) if finite_positive.size > 0 else 1e-10 # Plot with log colorbar if requested, masking non-positive values if z_axis_scale_function == "log": if np.any(matrix_plot <= 0) or not ( np.isfinite(z_axis_min) and np.isfinite(z_axis_max) and z_axis_min > 0 and z_axis_max > 0 and z_axis_max > z_axis_min ): log_message( "[WARNING] Non-positive values found in matrix for log colorbar. Masking to z_axis_min and enforcing log scale." ) z_axis_min = float(max(z_axis_min, safe_vmin, 1e-10)) z_axis_max = float(z_axis_max) # Mask all non-positive and non-finite values for log scale matrix_plot = np.where( ~np.isfinite(matrix_plot) | (matrix_plot <= 0), z_axis_min, matrix_plot ) norm = mcolors.LogNorm(vmin=z_axis_min, vmax=z_axis_max) im = axis_object.imshow( matrix_plot, aspect="auto", origin="lower", extent=(x_axis_plot[0], x_axis_plot[-1], y_axis[0], y_axis[-1]), cmap=colormap, norm=norm, ) # Compute tick marks for every integer power of 10 in range min_exponent = int(np.floor(np.log10(z_axis_min))) max_exponent = int(np.ceil(np.log10(z_axis_max))) ticks = [ 10**i for i in range(min_exponent, max_exponent + 1) if z_axis_min <= 10**i <= z_axis_max ] log_message(f"[DEBUG] make_spectrogram: log colorbar ticks: {ticks}") def log_tick_formatter(value, position=None): if value <= 0: return "" exponent = int(np.log10(value)) if np.isclose(value, 10**exponent): return f"$10^{{{exponent}}}$" return "" # Create the colorbar with custom ticks and formatter colorbar = fig.colorbar( im, ax=axis_object, label=z_axis_label if z_axis_label is not None else "Counts", ticks=ticks, format=log_tick_formatter, ) else: # Linear colorbar: mask NaN and inf values, set vmin/vmax z_axis_min = float(z_axis_min) z_axis_max = float(z_axis_max) matrix_plot = np.where(np.isnan(matrix_plot), z_axis_min, matrix_plot) matrix_plot = np.where(np.isneginf(matrix_plot), z_axis_min, matrix_plot) matrix_plot = np.where(np.isposinf(matrix_plot), z_axis_max, matrix_plot) if not ( np.isfinite(z_axis_min) and np.isfinite(z_axis_max) and z_axis_max > z_axis_min ): z_axis_min = float(np.nanmin(matrix_plot)) z_axis_max = float(np.nanmax(matrix_plot)) im = axis_object.imshow( matrix_plot, aspect="auto", origin="lower", extent=(x_axis_plot[0], x_axis_plot[-1], y_axis[0], y_axis[-1]), cmap=colormap, vmin=z_axis_min, vmax=z_axis_max, ) # Create the colorbar for linear scale colorbar = fig.colorbar( im, ax=axis_object, label=z_axis_label if z_axis_label is not None else "Counts", ) # Set axis labels and title axis_object.set_xlabel(x_label) axis_object.set_ylabel(y_axis_label if y_axis_label is not None else "Energy (eV)") if instrument_label is not None: axis_object.set_title(instrument_label) # Configure y-axis ticks and scale if len(y_axis) >= 2: if y_axis_scale_function != "log": # For linear y-axis, set ticks at reasonable intervals y_max_str = str(y_axis_max) y_max_digits = len(y_max_str) y_first_digit = int(y_max_str[0]) y_second_digit = int(y_max_str[1]) if y_second_digit >= 5: step_size = 10**y_max_digits y_max_tick = (y_first_digit) * 10 ** (y_max_digits - 1) else: step_size = 10 ** (y_max_digits - 1) y_max_tick = (y_first_digit + 0.5) * 10 ** (y_max_digits - 1) yticks = [ i for i in range(y_axis_min, int(y_max_tick) + 1, step_size) if (i / y_max_tick) <= 1.1 ] if len(yticks) > 0: axis_object.set_yticks(yticks) axis_object.set_yticklabels([f"{int(e)}" for e in yticks]) else: # For log y-axis, set scale to log axis_object.set_yscale("log") # Format x-axis as time if using unix timestamps if x_axis_is_unix: x_limits = axis_object.get_xlim() left_datetime = mdates.num2date(x_limits[0], tz=timezone.utc) right_datetime = mdates.num2date(x_limits[1], tz=timezone.utc) displayed_time_range_seconds = (right_datetime - left_datetime).total_seconds() if displayed_time_range_seconds < 120: axis_object.xaxis.set_major_formatter( mdates.DateFormatter("%H:%M:%S", tz=timezone.utc) ) else: axis_object.xaxis.set_major_formatter( mdates.DateFormatter("%H:%M", tz=timezone.utc) ) # Draw vertical lines for orbit boundaries or other events if provided if vertical_lines_unix is not None and len(vertical_lines_unix) > 0: if x_axis_is_unix: vertical_lines_plot = date2num( [ datetime.fromtimestamp(timestamp, tz=timezone.utc) for timestamp in vertical_lines_unix ] ) x_min_plot = x_axis_plot[0] x_max_plot = x_axis_plot[-1] vertical_lines_plot = [ v for v in vertical_lines_plot if x_min_plot <= v <= x_max_plot ] else: vertical_lines_plot = [ v for v in vertical_lines_unix if x_axis_plot[0] <= v <= x_axis_plot[-1] ] for vertical_line in vertical_lines_plot: # Draw a thick black line under a thinner red line for visibility axis_object.axvline( vertical_line, color="black", linestyle="-", linewidth=4, alpha=1.0, zorder=10, ) axis_object.axvline( vertical_line, color="red", linestyle="-", linewidth=2, alpha=1.0, zorder=11, ) # Set tick parameters for better readability axis_object.tick_params( axis="both", which="major", labelsize=TICK_LABEL_FONT_SIZE, length=8, width=1 ) axis_object.tick_params( axis="both", which="minor", labelsize=TICK_LABEL_FONT_SIZE, length=5, width=1 ) colorbar.ax.tick_params(labelsize=TICK_LABEL_FONT_SIZE, length=6, width=1) colorbar.ax.tick_params( which="minor", labelsize=TICK_LABEL_FONT_SIZE, length=3, width=1 ) # Set axis label font sizes axis_object.xaxis.label.set_fontsize(AXIS_LABEL_FONT_SIZE) axis_object.yaxis.label.set_fontsize(AXIS_LABEL_FONT_SIZE) colorbar.ax.set_ylabel("Counts", fontsize=AXIS_LABEL_FONT_SIZE) # Return the axis and the x-axis values used for plotting return axis_object, x_axis_plot
[docs] def generic_plot_spectrogram_set( datasets, collapse_axis=1, zoom_center=None, zoom_window_seconds=None, vertical_lines=None, x_is_unix=True, y_scale="linear", z_scale="linear", colormap="viridis", figure_title=None, show=False, y_min=None, y_max=None, z_min=None, z_max=None, ): """Plot a vertical stack of generic spectrograms. Parameters ---------- datasets : list of dict Each dict requires keys ``'x'``, ``'y'``, ``'data'`` and may include optional keys: ``'label'``, ``'y_label'``, ``'z_label'``, ``'y_min'``, ``'y_max'``, ``'z_min'``, ``'z_max'``. collapse_axis : int, default 1 Axis index of the 3D array collapsed prior to plotting. zoom_center : float, optional Center (UNIX time) for zoom column when used. zoom_window_seconds : float, optional Duration of zoom window (seconds) when ``zoom_center`` provided. vertical_lines : list of float, optional UNIX timestamps to annotate with vertical lines. x_is_unix : bool, default True If ``True``, x values are treated as UNIX seconds and formatted. y_scale : {'linear', 'log'}, default 'linear' Y-axis scaling mode. z_scale : {'linear', 'log'}, default 'linear' Color (intensity) scale mode. colormap : str, default 'viridis' Matplotlib colormap name. figure_title : str, optional Figure-level title (sup-title). show : bool, default False If ``True``, display interactively (requires GUI backend). y_min : float, optional Global Y min fallback when per-row not supplied. Defaults to 0 if omitted and per-row missing. y_max : float, optional Global Y max fallback when per-row not supplied. If both global and per-row absent, inferred. z_min : float, optional Global colorbar lower bound fallback. z_max : float, optional Global colorbar upper bound fallback. Returns ------- tuple ``(fig, canvas)`` or ``(None, None)`` if ``datasets`` is empty. """ if not datasets: return None, None fig = Figure(figsize=(10, 3 * len(datasets))) canvas = FigureCanvas(fig) axes = [] for row_index, dataset in enumerate(datasets): axis_obj = fig.add_subplot(len(datasets), 1, row_index + 1) axes.append(axis_obj) # Resolve per-dataset ranges with global fallback (row-specific wins) dataset_y_min = dataset.get("y_min", y_min) dataset_y_max = dataset.get("y_max", y_max) dataset_z_min = dataset.get("z_min", z_min) dataset_z_max = dataset.get("z_max", z_max) # Compute fallback y max from provided y array if not given inferred_y_max = ( dataset["y"].max() if dataset_y_max is None and dataset.get("y") is not None else dataset_y_max ) make_spectrogram( x_axis_values=dataset["x"], y_axis_values=dataset["y"], data_array_3d=dataset["data"], collapse_axis=collapse_axis, center_timestamp=zoom_center, window_duration_seconds=zoom_window_seconds, x_axis_is_unix=x_is_unix, y_axis_scale_function=y_scale, z_axis_scale_function=z_scale, y_axis_min=dataset_y_min if dataset_y_min is not None else 0, y_axis_max=inferred_y_max if inferred_y_max is not None else 4000, z_axis_min=dataset_z_min, z_axis_max=dataset_z_max, colormap=colormap, y_axis_label=dataset.get("y_label", "Energy (eV)"), z_axis_label=dataset.get("z_label", "Counts"), x_axis_label="Time (UTC)" if x_is_unix else dataset.get("x_label"), vertical_lines_unix=vertical_lines, axis_object=axis_obj, ) if dataset.get("label"): axis_obj.set_title(dataset["label"]) if figure_title: fig.suptitle(figure_title) fig.tight_layout(rect=(0, 0, 1, 0.97)) if show: import matplotlib.pyplot as plt plt.show() return fig, canvas
[docs] def generic_batch_plot( items, output_dir, build_datasets_fn, zoom_center_fn=None, zoom_window_seconds=None, vertical_lines_fn=None, y_scale="linear", z_scale="linear", colormap="viridis", max_workers=2, progress_json_path: str = PLOTTING_PROGRESS_JSON_PATH, ignore_progress_json: bool = False, flush_batch_size: int = 10, log_flush_batch_size: int | None = None, install_signal_handlers: bool = True, ): """Generic batch runner for plotting datasets. Parameters ---------- items : iterable Iterable of item identifiers (any ``repr``-able objects). output_dir : str Base output directory; plots saved under ``output_dir/<item>/generic.png``. build_datasets_fn : callable Callable returning ``list[dict]`` describing datasets for an item. zoom_center_fn : callable, optional Callable mapping item -> center UNIX time (or ``None``) for zoom. zoom_window_seconds : float, optional Duration of zoom window in seconds. vertical_lines_fn : callable, optional Callable mapping item -> list[float] UNIX timestamps (or ``None``). y_scale : {'linear', 'log'}, default 'linear' Y-axis scaling for all rows. z_scale : {'linear', 'log'}, default 'linear' Color scaling for all rows. colormap : str, default 'viridis' Matplotlib colormap name. max_workers : int, default 2 Number of parallel worker processes. progress_json_path : str, default PLOTTING_PROGRESS_JSON_PATH Path to progress JSON (resumable state). Created/updated as needed. ignore_progress_json : bool, default False If ``True``, skip reading existing progress prior to execution. flush_batch_size : int, default 10 Progress/log batch size; values < 1 coerced to 1. Final partial batch flushed. log_flush_batch_size : int, optional Explicit log batch size; if ``None`` reuse ``flush_batch_size``. install_signal_handlers : bool, default True When True, a temporary SIGINT handler is installed (restored on exit) to enable graceful interruption (progress & log flush). Set False in embedded environments if altering the global handler causes side-effects. Returns ------- list of tuple Sequence of ``(item, status)`` with ``status`` in {``'ok'``, ``'no_data'``, ``'error'``}. Notes ----- * Logging is buffered and force-flushed at completion. * Progress JSON contains simple lists of completed, error, and no-data items. * Items are identified via ``repr(item)`` for data-agnostic persistence. """ os.makedirs(output_dir, exist_ok=True) previous_sigint = None if install_signal_handlers: try: previous_sigint = signal.getsignal(signal.SIGINT) signal.signal(signal.SIGINT, _sigint_handler) except Exception as _gbp_sig_setup_exc: log_message( f"[WARN] Could not install temporary SIGINT handler: {_gbp_sig_setup_exc}" ) flush_batch_size = max(1, int(flush_batch_size)) configure_log_batch(log_flush_batch_size or flush_batch_size) # Step: Load prior progress (if any) progress_state = { # Ordered list of repr(item) for successfully completed items. "completed_items": [], # Items that raised exceptions inside the worker. "errors": [], # Items that returned logically empty datasets (no plots generated). "no_data": [], # Sequential index counter for processed items (0-based). "last_index": -1, # For future schema migrations. "schema_version": 1, } if (not ignore_progress_json) and os.path.exists(progress_json_path): try: with open(progress_json_path, "r") as progress_in: loaded = json.load(progress_in) if isinstance(loaded, dict): # Merge known keys for k in progress_state.keys(): if k in loaded: progress_state[k] = loaded[k] except Exception as progress_json_read_exception: log_error( f"[PROGRESS] Failed to read existing progress JSON '{progress_json_path}': {progress_json_read_exception}" ) # Step: Determine pending items (skip already completed) item_list = list(items) completed_set = set(progress_state.get("completed_items", [])) pending_items = [it for it in item_list if repr(it) not in completed_set] total_pending = len(pending_items) log_message( f"[BATCH] Starting generic batch plot with {total_pending} pending / {len(item_list)} total items; flush_batch_size={flush_batch_size} log_flush_batch_size={log_flush_batch_size or flush_batch_size}" ) # Step: Define progress JSON flush helper pending_progress_write_count = ( 0 # number of in-memory updates since last JSON flush ) def _flush_progress(force: bool = False): """Flush progress JSON to disk using batch semantics. Parameters ---------- force : bool, default False If True, flush even if the number of pending updates is below the configured batch size. """ nonlocal pending_progress_write_count if pending_progress_write_count == 0 and not force: return if (pending_progress_write_count >= flush_batch_size) or force: try: with open(progress_json_path, "w") as progress_out: json.dump(progress_state, progress_out, indent=2) pending_progress_write_count = 0 except Exception as progress_json_write_exception: log_error( f"[PROGRESS] Failed writing progress JSON '{progress_json_path}': {progress_json_write_exception}" ) # Step: Worker wrapper def _worker(item): try: datasets = build_datasets_fn(item) if not datasets: return (item, "no_data") center = zoom_center_fn(item) if zoom_center_fn else None vlines = vertical_lines_fn(item) if vertical_lines_fn else None fig, canvas = generic_plot_spectrogram_set( datasets, zoom_center=center, zoom_window_seconds=zoom_window_seconds, vertical_lines=vlines, y_scale=y_scale, z_scale=z_scale, colormap=colormap, show=False, ) if fig is not None: od = os.path.join(output_dir, str(item)) os.makedirs(od, exist_ok=True) out_path = os.path.join(od, "generic.png") fig.savefig(out_path, dpi=150) close_all_axes_and_clear(fig) return (item, "ok") except Exception as generic_exception: log_error(f"[GENERIC-FAIL] Item {item}: {generic_exception}") return (item, "error") results = [] processed_item_count = 0 with concurrent.futures.ProcessPoolExecutor( max_workers=max_workers ) as process_pool: future_map = { process_pool.submit(_worker, item_identifier): item_identifier for item_identifier in pending_items } for finished_future in concurrent.futures.as_completed(future_map): original_item_identifier = future_map[finished_future] try: item_identifier, status = finished_future.result() except Exception as generic_batch_future_exception: status = "error" item_identifier = original_item_identifier log_error( f"[GENERIC-FAIL] Item {original_item_identifier} outer exception: {generic_batch_future_exception}" ) results.append((item_identifier, status)) # Progress classification & state update item_repr = repr(item_identifier) if status == "ok": progress_state["completed_items"].append(item_repr) elif status == "no_data": progress_state["no_data"].append(item_repr) else: progress_state["errors"].append(item_repr) processed_item_count += 1 progress_state["last_index"] = processed_item_count - 1 pending_progress_write_count += 1 _flush_progress(force=False) # Step: Final flushes _flush_progress(force=True) _flush_log_buffer(force=True) log_message( ( "[BATCH] Completed generic batch plot: " f"{processed_item_count} processed (ok={sum(1 for _, s in results if s == 'ok')} " f"no_data={sum(1 for _, s in results if s == 'no_data')} " f"error={sum(1 for _, s in results if s == 'error')})" ), force_flush=True, ) # Restore prior handler if we replaced it if install_signal_handlers and previous_sigint is not None: try: signal.signal(signal.SIGINT, previous_sigint) except Exception as _gbp_sig_restore_exc: log_message( f"[WARN] Could not restore original SIGINT handler: {_gbp_sig_restore_exc}" ) return results
[docs] def generic_plot_multirow_optional_zoom( datasets, vertical_lines=None, zoom_duration_minutes=6.25, y_scale="linear", z_scale="linear", colormap="viridis", show=False, title=None, row_label_pad=50, row_label_rotation=90, y_min=None, y_max=None, z_min=None, z_max=None, ): """Render a multi-row spectrogram grid with an optional zoom column. Parameters ---------- datasets : list of dict Each dict must contain keys: * ``'x'`` – 1D UNIX epoch seconds (float) array * ``'y'`` – 1D energy (eV) array (unfiltered, 0–4000 typical) * ``'data'`` – 3D ndarray that can be collapsed (time, pitch/angle, energy) Optional per-row keys (all honored when present): * ``'label'`` – Row label placed on the left (rotated) * ``'y_label'`` – Units label for y-axis (default: ``'Energy (eV)'``) * ``'z_label'`` – Color scale label (default: ``'Counts'``) * ``'y_min'`` / ``'y_max'`` – Energy bounds (overrides global ``y_min`` / ``y_max`` args) * ``'z_min'`` / ``'z_max'`` – Color bounds (overrides global ``z_min`` / ``z_max`` args) * ``'vmin'`` / ``'vmax'`` – Precomputed percentile (or fixed) color bounds used when ``z_min`` / ``z_max`` not provided. (``vmin``/``vmax`` are interpreted as the *row's* native bounds; global ``z_min`` / ``z_max`` will still clamp if supplied.) vertical_lines : list of float, optional UNIX timestamps defining event/selection markers and potential zoom window. zoom_duration_minutes : float, default 6.25 Desired zoom window length in minutes (may auto-expand to include full marked span). y_scale : {'linear', 'log'}, default 'linear' Y-axis scaling. z_scale : {'linear', 'log'}, default 'linear' Color (intensity) scale. colormap : str, default 'viridis' Matplotlib colormap. show : bool, default False If ``True``, display interactively. title : str, optional Figure suptitle. row_label_pad : int, default 50 Padding for row labels. row_label_rotation : int, default 90 Rotation angle (degrees) for row labels. y_min, y_max, z_min, z_max : float, optional Global override bounds applied uniformly when provided. Any per-row ``y_min`` / ``y_max`` / ``z_min`` / ``z_max`` in a dataset dict take precedence. When neither global nor per-row color bounds are supplied the function relies on each dataset's ``vmin`` / ``vmax`` (if present) else falls back to internal percentile selection in lower-level calls. Returns ------- tuple ``(fig, canvas)`` or ``(None, None)`` if ``datasets`` is empty. Notes ----- * Determines need for a zoom column dynamically: only rendered if at least one dataset contains non-NaN values inside the computed zoom window. * Y-axis and colorbar labels default to "Energy (eV)" and "Counts" when a dataset omits explicit ``y_label`` / ``z_label`` (defaults originate in ``generic_plot_spectrogram_set`` / dataset assembly). """ if not datasets: return None, None # Determine zoom window & whether needed zoom_needed = False center_value = None duration = None if vertical_lines and len(vertical_lines) > 0: if len(vertical_lines) == 1: center_value = vertical_lines[0] duration = zoom_duration_minutes * 60 else: center_value = 0.5 * (vertical_lines[0] + vertical_lines[1]) min_window = abs(vertical_lines[1] - vertical_lines[0]) * 1.5 requested_window = zoom_duration_minutes * 60 duration = max(requested_window, min_window) left = center_value - duration / 2 right = center_value + duration / 2 for ds in datasets: t = ds["x"] d = ds["data"] mask_zoom = (t >= left) & (t <= right) # Require some non-NaN data inside window if np.any(~np.isnan(d[mask_zoom])): zoom_needed = True break number_rows = len(datasets) number_columns = 2 if zoom_needed else 1 fig = Figure(figsize=(12 * number_columns, 3 * number_rows)) canvas = FigureCanvas(fig) axes = np.empty((number_rows, number_columns), dtype=object) for i in range(number_rows): for j in range(number_columns): axes[i, j] = fig.add_subplot( number_rows, number_columns, i * number_columns + j + 1 ) # Plot rows for i, ds in enumerate(datasets): times = ds["x"] energy = ds["y"] data3d = ds["data"] vmin = ds.get("vmin") vmax = ds.get("vmax") # Full make_spectrogram( x_axis_values=times, y_axis_values=energy, data_array_3d=data3d, collapse_axis=1, x_axis_min=times[0], x_axis_max=times[-1], x_axis_is_unix=True, instrument_label=None, y_axis_scale_function=y_scale, z_axis_scale_function=z_scale, vertical_lines_unix=vertical_lines, z_axis_min=vmin if z_min is None else z_min, z_axis_max=vmax if z_max is None else z_max, axis_object=axes[i, 0], colormap=colormap, ) # Zoom if number_columns == 2: make_spectrogram( x_axis_values=times, y_axis_values=energy, data_array_3d=data3d, collapse_axis=1, center_timestamp=center_value, window_duration_seconds=duration, x_axis_is_unix=True, instrument_label=None, y_axis_scale_function=y_scale, z_axis_scale_function=z_scale, vertical_lines_unix=vertical_lines, z_axis_min=vmin if z_min is None else z_min, z_axis_max=vmax if z_max is None else z_max, axis_object=axes[i, 1], colormap=colormap, ) # Row labels for i, ds in enumerate(datasets): axes[i, 0].set_ylabel( ds.get("label", ""), fontsize=AXIS_LABEL_FONT_SIZE, rotation=row_label_rotation, labelpad=row_label_pad, va="center", ) # Column headers if number_columns == 2: axes[0, 0].set_title("Full", fontsize=AXIS_LABEL_FONT_SIZE) axes[0, 1].set_title("Zoomed", fontsize=AXIS_LABEL_FONT_SIZE) else: axes[0, 0].set_title("Full", fontsize=AXIS_LABEL_FONT_SIZE) # Title if title: fig.suptitle(title, fontsize=AXIS_LABEL_FONT_SIZE + 2) # Timespan annotation (use first dataset times) base_times = datasets[0]["x"] t0 = datetime.fromtimestamp(base_times[0], tz=timezone.utc) t1 = datetime.fromtimestamp(base_times[-1], tz=timezone.utc) data_timespan_str = f"Data timespan: {t0.strftime('%Y-%m-%d %H:%M:%S')} to {t1.strftime('%Y-%m-%d %H:%M:%S')} UTC" marked_str = "" if vertical_lines and len(vertical_lines) > 0: v0 = datetime.fromtimestamp(min(vertical_lines), tz=timezone.utc) v1 = datetime.fromtimestamp(max(vertical_lines), tz=timezone.utc) marked_str = f"\nMarked range: {v0.strftime('%Y-%m-%d %H:%M:%S')} to {v1.strftime('%Y-%m-%d %H:%M:%S')} UTC" fig.subplots_adjust(bottom=0.18) fig.text(0.5, 0.01, data_timespan_str, ha="center", va="bottom", fontsize=13) if marked_str: fig.text( 0.5, 0.045, marked_str.strip(), ha="center", va="bottom", fontsize=13, color="red", ) fig.tight_layout(rect=(0, 0.08, 1, 0.95)) if show: import matplotlib.pyplot as plt plt.show() return fig, canvas