From e492f0f78d293dfe510cc04e973e62994d9ae17d Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Mon, 26 Jan 2026 12:02:56 -0500 Subject: [PATCH 01/19] Added label_thermal_quench_onset_time() method --- disruption_py/machine/cmod/physics.py | 140 ++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index e35b8081b..6828b37aa 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -2128,3 +2128,143 @@ def _is_on_blacklist(shot_id: int) -> bool: or 1150000000 < shot_id < 1150610000 or 1160000000 < shot_id < 1160303000 ) + +@physics_method(columns=["thermal_quench_time"], tokamak=Tokamak.CMOD) +def label_thermal_quench_onset_time(params: PhysicsMethodParams): + """ + Labels the onset time of the thermal quench for a given shot (NaN for non-disruptive shots) + using a vertical SXR array due to its off-axis views and robustness across shots, + as opposed to ECE. The labeling method is non-causal (i.e. post-shot processing). + The TQ is found by searching for min(dSXR/dt) in a time window prior to the CQ. + There is a tension between using longer windows to find the first TQ in a multi-stage TQ + versus using a shorter window to avoid labeling sawtooth crashes. + Thus, for shots with multi-stage thermal quenches, (see shots 1050830034 and 1120717002), + this algorithm struggles to select the first thermal quench. Based on manual testing + of 120 shots, about 5% of flattop disruptions on C-Mod feature multi-stage thermal quenches. + This algorithm has only been tested on flattop disruptions. + ---------- + params: PhysicsMethodParams + The parameters storing the requested time base, shot id, etc + Returns + ---------- + thermal_quench_time : array_like + time of thermal quench onset for the shot, identical values at each time-slice + + Last Major Update: Henry Wietfeldt (1/26/26) + """ + from scipy.signal import butter, filtfilt, resample_poly + thermal_quench_time = np.full(len(params.times), np.nan) + + # Get current data for obtaining start of current quench + ip, magtime = params.mds_conn.get_data_with_dims( + r"\ip", tree_name="magnetics" + ) + ip = np.abs(ip) + + # Get SXR chords + n_chords = 38 + array_path = r"\top.brightnesses.array_1" + try: + chord_01, t_sxr = params.mds_conn.get_data_with_dims( + array_path + ":chord_01", + tree_name="xtomo", + ) # Units: W/m^2, s + except mdsExceptions.MdsException: + print(params.shot_id) + params.logger.debug("Failed to get SXR " + array_path + " data") + return {"thermal_quench_time": np.full(len(params.times), np.nan)} + valid_times = (t_sxr > 0) & (t_sxr < 2.0) + t_sxr = t_sxr[valid_times] + sxr = np.zeros(shape=(n_chords, len(t_sxr))) + sxr[0] = chord_01[valid_times] + # Get all other SXR chords + for i in range(1, n_chords): + try: + chord, t_chord = params.mds_conn.get_data_with_dims( + array_path + ":chord_" + f"{i+1:02}", + tree_name="xtomo", + ) + except mdsExceptions.MdsException: + params.logger.debug("Failed to get SXR " + array_path + " chord " + str(i+1) + " data") + sxr[i] = 0. + continue + # Subtract constant background + chord = chord - np.mean(chord[t_chord < 0.]) + # Occasionally the time bases of a chord are of a different length + # Usually one timebase is just cut off early after shot is over + valid_times = (t_chord > 0) & (t_chord < 2.) + # Goods chords should be of the same shape + if len(chord[valid_times]) == sxr.shape[1]: + sxr[i] = chord[valid_times] + sample_time = t_sxr[1] - t_sxr[0] + sample_freq = 1 / sample_time + + # Remove bad chords by checking each chord's autocorrelation. + # Bad chords often have significant white noise, meaning low autocorrelation (< 10 ms) + # Good chords should have an autocorrelation of 100s of ms + # See shot 1050311013 as an example with some bad chords + noise_autorr_cutoff = 0.01 # [s] + for i, chord in enumerate(sxr): + # Use 300 ms prior to current quench for speed-up during autocorr O(N^2) + idx_start = np.argmin(np.abs(t_sxr - (params.disruption_time - 0.3))) + idx_end = np.argmin(np.abs(t_sxr - (params.disruption_time))) + chord = chord[idx_start:idx_end] + sample_freq_5kHz = 5000 # [Hz] + if sample_freq > sample_freq_5kHz: + # 2012-2016 has 250 kHz sampling frequency. Resample to 5 kHz frequency + # (native SXR sample frequency of earlier campaigns) for speed-up + chord = resample_poly(chord, up=1, down=sample_freq//sample_freq_5kHz) + autocorr = np.correlate(chord, chord, mode='full') + max_autocorr = np.max(autocorr) + if max_autocorr > 0: + autocorr = autocorr / np.max(autocorr) # Normalize + else: + sxr[i] = 0. + continue + index_no_lag = np.argmax(autocorr) + index_decay = np.argmax(autocorr[index_no_lag:] < 0) + if (index_decay*(1/sample_freq_5kHz) < noise_autorr_cutoff): + sxr[i] = 0. + + # Noncausal Butterworth low pass filter to smooth transient SXR spikes during TQ. + # Cutoff of 1.0 kHz and order 2 seems to filter recombination SXR spikes + # while maintaining decent resolution of TQ based on scan from 0.25 kHz - 2 kHz + # Results were fairly insensitive within these windows on the 100 shots checked + # See shot 1120913013 as example of large recombination spike + bworth_cutoff = 1000 # [Hz] + bworth_order = 2 + normalized_cutoff = bworth_cutoff / (0.5*sample_freq) + b, a = butter(bworth_order, normalized_cutoff, btype='low', analog=False) + core_sxr_raw = np.max(sxr, axis=0) + sxr = filtfilt(b, a, sxr, axis=1) + core_sxr = np.max(sxr, axis=0) + dcore_sxr_dt = np.diff(core_sxr, prepend=0)/sample_time + + # Search for the onset of the CQ so that we can search for the TQ in a small time window + # to avoid labeleing sawtooth crashes as the thermal quench + # Some current quenches can be long (see shots 1050311013, 1050802017). + # Set Ip prior to disruption as minimum in prior time window (not median due to case of ramp-down) + idx_start = np.argmin(np.abs(magtime - (params.disruption_time - 0.04))) + idx_end = np.argmin(np.abs(magtime - (params.disruption_time - 0.02))) + ip_prior = np.min(ip[idx_start:idx_end]) + # CQ onset is last moment Ip is >90% Ip prior to disruption + idx_cq_onset = np.where(ip > 0.9*ip_prior)[0][-1] + cq_onset_time = magtime[idx_cq_onset] + + # Search for TQ midpoint as min(dSXR/dt) in window of 5 ms prior to current quench onset + wndw_before_cq = 0.005 # [s] + idx_start = np.argmin(np.abs(t_sxr - (cq_onset_time - wndw_before_cq))) + idx_end = np.argmin(np.abs(t_sxr - (cq_onset_time))) + t_max_sxr_drop = t_sxr[idx_start + np.argmin(dcore_sxr_dt[idx_start:idx_end])] + + # Find onset of thermal quench in 0.5 ms window prior to midpoint of TQ + # Thermal quenches on C-Mod are almost always shorter than 1 ms + # Label onset as max of SXR signal on 0.5 ms window preceding max drop in SXR + # Use raw signal bc smoothed signal as a longer crash time. + # Note this sometimes picks up on recombination spikes + wndw_before_tq_mid = 0.0005 # [s] + idx_start = np.argmin(np.abs(t_sxr - (t_max_sxr_drop - wndw_before_tq_mid))) + idx_end = np.argmin(np.abs(t_sxr - (t_max_sxr_drop))) + tq_time_scalar = t_sxr[idx_start + np.argmax(core_sxr_raw[idx_start:idx_end])] + thermal_quench_time = tq_time_scalar * np.ones(len(params.times)) + return {"thermal_quench_time": thermal_quench_time} \ No newline at end of file From 949d0c8f7f01c6ca6162ed19611e2bf45300852c Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Mon, 26 Jan 2026 12:16:36 -0500 Subject: [PATCH 02/19] minor reformatting --- disruption_py/machine/cmod/physics.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index 6828b37aa..72141e093 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -2133,14 +2133,14 @@ def _is_on_blacklist(shot_id: int) -> bool: def label_thermal_quench_onset_time(params: PhysicsMethodParams): """ Labels the onset time of the thermal quench for a given shot (NaN for non-disruptive shots) - using a vertical SXR array due to its off-axis views and robustness across shots, - as opposed to ECE. The labeling method is non-causal (i.e. post-shot processing). + using a vertical SXR array due to its off-axis views and robustness across shots, + as opposed to ECE. The labeling method is non-causal (i.e. post-shot processing). The TQ is found by searching for min(dSXR/dt) in a time window prior to the CQ. There is a tension between using longer windows to find the first TQ in a multi-stage TQ versus using a shorter window to avoid labeling sawtooth crashes. Thus, for shots with multi-stage thermal quenches, (see shots 1050830034 and 1120717002), - this algorithm struggles to select the first thermal quench. Based on manual testing - of 120 shots, about 5% of flattop disruptions on C-Mod feature multi-stage thermal quenches. + this algorithm struggles to select the first thermal quench. Based on manual testing + of 120 shots, about 5% of flattop disruptions on C-Mod feature multi-stage thermal quenches. This algorithm has only been tested on flattop disruptions. ---------- params: PhysicsMethodParams @@ -2209,11 +2209,11 @@ def label_thermal_quench_onset_time(params: PhysicsMethodParams): idx_start = np.argmin(np.abs(t_sxr - (params.disruption_time - 0.3))) idx_end = np.argmin(np.abs(t_sxr - (params.disruption_time))) chord = chord[idx_start:idx_end] - sample_freq_5kHz = 5000 # [Hz] - if sample_freq > sample_freq_5kHz: - # 2012-2016 has 250 kHz sampling frequency. Resample to 5 kHz frequency + sample_freq_5khz = 5000 # [Hz] + if sample_freq > 5000: + # 2012-2016 has 250 kHz sampling frequency. Resample to 5 kHz frequency # (native SXR sample frequency of earlier campaigns) for speed-up - chord = resample_poly(chord, up=1, down=sample_freq//sample_freq_5kHz) + chord = resample_poly(chord, up=1, down=sample_freq//sample_freq_5khz) autocorr = np.correlate(chord, chord, mode='full') max_autocorr = np.max(autocorr) if max_autocorr > 0: @@ -2223,7 +2223,7 @@ def label_thermal_quench_onset_time(params: PhysicsMethodParams): continue index_no_lag = np.argmax(autocorr) index_decay = np.argmax(autocorr[index_no_lag:] < 0) - if (index_decay*(1/sample_freq_5kHz) < noise_autorr_cutoff): + if index_decay*(1/sample_freq_5khz) < noise_autorr_cutoff: sxr[i] = 0. # Noncausal Butterworth low pass filter to smooth transient SXR spikes during TQ. @@ -2242,8 +2242,8 @@ def label_thermal_quench_onset_time(params: PhysicsMethodParams): # Search for the onset of the CQ so that we can search for the TQ in a small time window # to avoid labeleing sawtooth crashes as the thermal quench - # Some current quenches can be long (see shots 1050311013, 1050802017). - # Set Ip prior to disruption as minimum in prior time window (not median due to case of ramp-down) + # Some current quenches can be long (see shots 1050311013, 1050802017). + # Set Ip prior to disruption as minimum in prior time window (not median for ramp-down) idx_start = np.argmin(np.abs(magtime - (params.disruption_time - 0.04))) idx_end = np.argmin(np.abs(magtime - (params.disruption_time - 0.02))) ip_prior = np.min(ip[idx_start:idx_end]) @@ -2267,4 +2267,4 @@ def label_thermal_quench_onset_time(params: PhysicsMethodParams): idx_end = np.argmin(np.abs(t_sxr - (t_max_sxr_drop))) tq_time_scalar = t_sxr[idx_start + np.argmax(core_sxr_raw[idx_start:idx_end])] thermal_quench_time = tq_time_scalar * np.ones(len(params.times)) - return {"thermal_quench_time": thermal_quench_time} \ No newline at end of file + return {"thermal_quench_time": thermal_quench_time} From dacce4f5ba7cea3a6d96b6a5acce441ccca000f7 Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Mon, 26 Jan 2026 15:18:54 -0500 Subject: [PATCH 03/19] rename and reformat --- disruption_py/machine/cmod/physics.py | 281 +++++++++++++------------- 1 file changed, 141 insertions(+), 140 deletions(-) diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index 72141e093..4d495b7f5 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -2115,6 +2115,147 @@ def get_surface_voltage(params: PhysicsMethodParams): v_surf = interp1(efit_time, v_surf, params.times) return {"v_surf": v_surf} + + @staticmethod + @physics_method(columns=["thermal_quench_time"], tokamak=Tokamak.CMOD) + def get_thermal_quench_onset_time(params: PhysicsMethodParams): + """ + Labels the onset time of the thermal quench for a given shot (NaN for non-disruptive shots) + using a vertical SXR array due to its off-axis views and robustness across shots, + as opposed to ECE. The labeling method is non-causal (i.e. post-shot processing). + The TQ is found by searching for min(dSXR/dt) in a time window prior to the CQ. + There is a tension between using longer windows to find the first TQ in a multi-stage TQ + versus using a shorter window to avoid labeling sawtooth crashes. + Thus, for shots with multi-stage thermal quenches, (see shots 1050830034 and 1120717002), + this algorithm struggles to select the first thermal quench. Based on manual testing + of 120 shots, about 5% of flattop disruptions on C-Mod feature multi-stage thermal quenches. + This algorithm has only been tested on flattop disruptions. + ---------- + params: PhysicsMethodParams + The parameters storing the requested time base, shot id, etc + Returns + ---------- + thermal_quench_time : array_like + time of thermal quench onset for the shot, identical values at each time-slice + + Last Major Update: Henry Wietfeldt (1/26/26) + """ + from scipy.signal import butter, filtfilt, resample_poly + thermal_quench_time = np.full(len(params.times), np.nan) + + # Get current data for obtaining start of current quench + ip, magtime = params.mds_conn.get_data_with_dims( + r"\ip", tree_name="magnetics" + ) + ip = np.abs(ip) + + # Get SXR chords + n_chords = 38 + array_path = r"\top.brightnesses.array_1" + try: + chord_01, t_sxr = params.mds_conn.get_data_with_dims( + array_path + ":chord_01", + tree_name="xtomo", + ) # Units: W/m^2, s + except mdsExceptions.MdsException: + print(params.shot_id) + params.logger.debug("Failed to get SXR " + array_path + " data") + return {"thermal_quench_time": np.full(len(params.times), np.nan)} + valid_times = (t_sxr > 0) & (t_sxr < 2.0) + t_sxr = t_sxr[valid_times] + sxr = np.zeros(shape=(n_chords, len(t_sxr))) + sxr[0] = chord_01[valid_times] + # Get all other SXR chords + for i in range(1, n_chords): + try: + chord, t_chord = params.mds_conn.get_data_with_dims( + array_path + ":chord_" + f"{i+1:02}", + tree_name="xtomo", + ) + except mdsExceptions.MdsException: + params.logger.debug("Failed to get SXR " + array_path + " chord " + str(i+1) + " data") + sxr[i] = 0. + continue + # Subtract constant background + chord = chord - np.mean(chord[t_chord < 0.]) + # Occasionally the time bases of a chord are of a different length + # Usually one timebase is just cut off early after shot is over + valid_times = (t_chord > 0) & (t_chord < 2.) + # Goods chords should be of the same shape + if len(chord[valid_times]) == sxr.shape[1]: + sxr[i] = chord[valid_times] + sample_time = t_sxr[1] - t_sxr[0] + sample_freq = 1 / sample_time + + # Remove bad chords by checking each chord's autocorrelation. + # Bad chords often have significant white noise, meaning low autocorrelation (< 10 ms) + # Good chords should have an autocorrelation of 100s of ms + # See shot 1050311013 as an example with some bad chords + noise_autorr_cutoff = 0.01 # [s] + for i, chord in enumerate(sxr): + # Use 300 ms prior to current quench for speed-up during autocorr O(N^2) + idx_start = np.argmin(np.abs(t_sxr - (params.disruption_time - 0.3))) + idx_end = np.argmin(np.abs(t_sxr - (params.disruption_time))) + chord = chord[idx_start:idx_end] + sample_freq_5khz = 5000 # [Hz] + if sample_freq > 5000: + # 2012-2016 has 250 kHz sampling frequency. Resample to 5 kHz frequency + # (native SXR sample frequency of earlier campaigns) for speed-up + chord = resample_poly(chord, up=1, down=sample_freq//sample_freq_5khz) + autocorr = np.correlate(chord, chord, mode='full') + max_autocorr = np.max(autocorr) + if max_autocorr > 0: + autocorr = autocorr / np.max(autocorr) # Normalize + else: + sxr[i] = 0. + continue + index_no_lag = np.argmax(autocorr) + index_decay = np.argmax(autocorr[index_no_lag:] < 0) + if index_decay*(1/sample_freq_5khz) < noise_autorr_cutoff: + sxr[i] = 0. + + # Noncausal Butterworth low pass filter to smooth transient SXR spikes during TQ. + # Cutoff of 1.0 kHz and order 2 seems to filter recombination SXR spikes + # while maintaining decent resolution of TQ based on scan from 0.25 kHz - 2 kHz + # Results were fairly insensitive within these windows on the 100 shots checked + # See shot 1120913013 as example of large recombination spike + bworth_cutoff = 1000 # [Hz] + bworth_order = 2 + normalized_cutoff = bworth_cutoff / (0.5*sample_freq) + b, a = butter(bworth_order, normalized_cutoff, btype='low', analog=False) + core_sxr_raw = np.max(sxr, axis=0) + sxr = filtfilt(b, a, sxr, axis=1) + core_sxr = np.max(sxr, axis=0) + dcore_sxr_dt = np.diff(core_sxr, prepend=0)/sample_time + + # Search for the onset of the CQ so that we can search for the TQ in a small time window + # to avoid labeleing sawtooth crashes as the thermal quench + # Some current quenches can be long (see shots 1050311013, 1050802017). + # Set Ip prior to disruption as minimum in prior time window (not median for ramp-down) + idx_start = np.argmin(np.abs(magtime - (params.disruption_time - 0.04))) + idx_end = np.argmin(np.abs(magtime - (params.disruption_time - 0.02))) + ip_prior = np.min(ip[idx_start:idx_end]) + # CQ onset is last moment Ip is >90% Ip prior to disruption + idx_cq_onset = np.where(ip > 0.9*ip_prior)[0][-1] + cq_onset_time = magtime[idx_cq_onset] + + # Search for TQ midpoint as min(dSXR/dt) in window of 5 ms prior to current quench onset + wndw_before_cq = 0.005 # [s] + idx_start = np.argmin(np.abs(t_sxr - (cq_onset_time - wndw_before_cq))) + idx_end = np.argmin(np.abs(t_sxr - (cq_onset_time))) + t_max_sxr_drop = t_sxr[idx_start + np.argmin(dcore_sxr_dt[idx_start:idx_end])] + + # Find onset of thermal quench in 0.5 ms window prior to midpoint of TQ + # Thermal quenches on C-Mod are almost always shorter than 1 ms + # Label onset as max of SXR signal on 0.5 ms window preceding max drop in SXR + # Use raw signal bc smoothed signal as a longer crash time. + # Note this sometimes picks up on recombination spikes + wndw_before_tq_mid = 0.0005 # [s] + idx_start = np.argmin(np.abs(t_sxr - (t_max_sxr_drop - wndw_before_tq_mid))) + idx_end = np.argmin(np.abs(t_sxr - (t_max_sxr_drop))) + tq_time_scalar = t_sxr[idx_start + np.argmax(core_sxr_raw[idx_start:idx_end])] + thermal_quench_time = tq_time_scalar * np.ones(len(params.times)) + return {"thermal_quench_time": thermal_quench_time} @staticmethod def _is_on_blacklist(shot_id: int) -> bool: @@ -2128,143 +2269,3 @@ def _is_on_blacklist(shot_id: int) -> bool: or 1150000000 < shot_id < 1150610000 or 1160000000 < shot_id < 1160303000 ) - -@physics_method(columns=["thermal_quench_time"], tokamak=Tokamak.CMOD) -def label_thermal_quench_onset_time(params: PhysicsMethodParams): - """ - Labels the onset time of the thermal quench for a given shot (NaN for non-disruptive shots) - using a vertical SXR array due to its off-axis views and robustness across shots, - as opposed to ECE. The labeling method is non-causal (i.e. post-shot processing). - The TQ is found by searching for min(dSXR/dt) in a time window prior to the CQ. - There is a tension between using longer windows to find the first TQ in a multi-stage TQ - versus using a shorter window to avoid labeling sawtooth crashes. - Thus, for shots with multi-stage thermal quenches, (see shots 1050830034 and 1120717002), - this algorithm struggles to select the first thermal quench. Based on manual testing - of 120 shots, about 5% of flattop disruptions on C-Mod feature multi-stage thermal quenches. - This algorithm has only been tested on flattop disruptions. - ---------- - params: PhysicsMethodParams - The parameters storing the requested time base, shot id, etc - Returns - ---------- - thermal_quench_time : array_like - time of thermal quench onset for the shot, identical values at each time-slice - - Last Major Update: Henry Wietfeldt (1/26/26) - """ - from scipy.signal import butter, filtfilt, resample_poly - thermal_quench_time = np.full(len(params.times), np.nan) - - # Get current data for obtaining start of current quench - ip, magtime = params.mds_conn.get_data_with_dims( - r"\ip", tree_name="magnetics" - ) - ip = np.abs(ip) - - # Get SXR chords - n_chords = 38 - array_path = r"\top.brightnesses.array_1" - try: - chord_01, t_sxr = params.mds_conn.get_data_with_dims( - array_path + ":chord_01", - tree_name="xtomo", - ) # Units: W/m^2, s - except mdsExceptions.MdsException: - print(params.shot_id) - params.logger.debug("Failed to get SXR " + array_path + " data") - return {"thermal_quench_time": np.full(len(params.times), np.nan)} - valid_times = (t_sxr > 0) & (t_sxr < 2.0) - t_sxr = t_sxr[valid_times] - sxr = np.zeros(shape=(n_chords, len(t_sxr))) - sxr[0] = chord_01[valid_times] - # Get all other SXR chords - for i in range(1, n_chords): - try: - chord, t_chord = params.mds_conn.get_data_with_dims( - array_path + ":chord_" + f"{i+1:02}", - tree_name="xtomo", - ) - except mdsExceptions.MdsException: - params.logger.debug("Failed to get SXR " + array_path + " chord " + str(i+1) + " data") - sxr[i] = 0. - continue - # Subtract constant background - chord = chord - np.mean(chord[t_chord < 0.]) - # Occasionally the time bases of a chord are of a different length - # Usually one timebase is just cut off early after shot is over - valid_times = (t_chord > 0) & (t_chord < 2.) - # Goods chords should be of the same shape - if len(chord[valid_times]) == sxr.shape[1]: - sxr[i] = chord[valid_times] - sample_time = t_sxr[1] - t_sxr[0] - sample_freq = 1 / sample_time - - # Remove bad chords by checking each chord's autocorrelation. - # Bad chords often have significant white noise, meaning low autocorrelation (< 10 ms) - # Good chords should have an autocorrelation of 100s of ms - # See shot 1050311013 as an example with some bad chords - noise_autorr_cutoff = 0.01 # [s] - for i, chord in enumerate(sxr): - # Use 300 ms prior to current quench for speed-up during autocorr O(N^2) - idx_start = np.argmin(np.abs(t_sxr - (params.disruption_time - 0.3))) - idx_end = np.argmin(np.abs(t_sxr - (params.disruption_time))) - chord = chord[idx_start:idx_end] - sample_freq_5khz = 5000 # [Hz] - if sample_freq > 5000: - # 2012-2016 has 250 kHz sampling frequency. Resample to 5 kHz frequency - # (native SXR sample frequency of earlier campaigns) for speed-up - chord = resample_poly(chord, up=1, down=sample_freq//sample_freq_5khz) - autocorr = np.correlate(chord, chord, mode='full') - max_autocorr = np.max(autocorr) - if max_autocorr > 0: - autocorr = autocorr / np.max(autocorr) # Normalize - else: - sxr[i] = 0. - continue - index_no_lag = np.argmax(autocorr) - index_decay = np.argmax(autocorr[index_no_lag:] < 0) - if index_decay*(1/sample_freq_5khz) < noise_autorr_cutoff: - sxr[i] = 0. - - # Noncausal Butterworth low pass filter to smooth transient SXR spikes during TQ. - # Cutoff of 1.0 kHz and order 2 seems to filter recombination SXR spikes - # while maintaining decent resolution of TQ based on scan from 0.25 kHz - 2 kHz - # Results were fairly insensitive within these windows on the 100 shots checked - # See shot 1120913013 as example of large recombination spike - bworth_cutoff = 1000 # [Hz] - bworth_order = 2 - normalized_cutoff = bworth_cutoff / (0.5*sample_freq) - b, a = butter(bworth_order, normalized_cutoff, btype='low', analog=False) - core_sxr_raw = np.max(sxr, axis=0) - sxr = filtfilt(b, a, sxr, axis=1) - core_sxr = np.max(sxr, axis=0) - dcore_sxr_dt = np.diff(core_sxr, prepend=0)/sample_time - - # Search for the onset of the CQ so that we can search for the TQ in a small time window - # to avoid labeleing sawtooth crashes as the thermal quench - # Some current quenches can be long (see shots 1050311013, 1050802017). - # Set Ip prior to disruption as minimum in prior time window (not median for ramp-down) - idx_start = np.argmin(np.abs(magtime - (params.disruption_time - 0.04))) - idx_end = np.argmin(np.abs(magtime - (params.disruption_time - 0.02))) - ip_prior = np.min(ip[idx_start:idx_end]) - # CQ onset is last moment Ip is >90% Ip prior to disruption - idx_cq_onset = np.where(ip > 0.9*ip_prior)[0][-1] - cq_onset_time = magtime[idx_cq_onset] - - # Search for TQ midpoint as min(dSXR/dt) in window of 5 ms prior to current quench onset - wndw_before_cq = 0.005 # [s] - idx_start = np.argmin(np.abs(t_sxr - (cq_onset_time - wndw_before_cq))) - idx_end = np.argmin(np.abs(t_sxr - (cq_onset_time))) - t_max_sxr_drop = t_sxr[idx_start + np.argmin(dcore_sxr_dt[idx_start:idx_end])] - - # Find onset of thermal quench in 0.5 ms window prior to midpoint of TQ - # Thermal quenches on C-Mod are almost always shorter than 1 ms - # Label onset as max of SXR signal on 0.5 ms window preceding max drop in SXR - # Use raw signal bc smoothed signal as a longer crash time. - # Note this sometimes picks up on recombination spikes - wndw_before_tq_mid = 0.0005 # [s] - idx_start = np.argmin(np.abs(t_sxr - (t_max_sxr_drop - wndw_before_tq_mid))) - idx_end = np.argmin(np.abs(t_sxr - (t_max_sxr_drop))) - tq_time_scalar = t_sxr[idx_start + np.argmax(core_sxr_raw[idx_start:idx_end])] - thermal_quench_time = tq_time_scalar * np.ones(len(params.times)) - return {"thermal_quench_time": thermal_quench_time} From 9ec25e990e0bcadee28e06e1914a6466b7d15e82 Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Tue, 27 Jan 2026 16:55:04 -0500 Subject: [PATCH 04/19] Gracefully return NaN for non-disruptive shots --- disruption_py/machine/cmod/physics.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index 4d495b7f5..4cb121f9a 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -2142,6 +2142,8 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): """ from scipy.signal import butter, filtfilt, resample_poly thermal_quench_time = np.full(len(params.times), np.nan) + if params.disruption_time is None: + return {"thermal_quench_time": thermal_quench_time} # Get current data for obtaining start of current quench ip, magtime = params.mds_conn.get_data_with_dims( From 1af34d4102302747104fda67429d2ed005301b20 Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Tue, 7 Apr 2026 11:23:08 -0400 Subject: [PATCH 05/19] mds_conn -> data_conn in tq labeling method + fix for shots w/ CQ > 2 s --- disruption_py/machine/cmod/physics.py | 11 ++++-- scripts/get_thermal_quench_times.py | 48 +++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) create mode 100644 scripts/get_thermal_quench_times.py diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index b5a530d95..cd9dbd4e8 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -2144,9 +2144,10 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): thermal_quench_time = np.full(len(params.times), np.nan) if params.disruption_time is None: return {"thermal_quench_time": thermal_quench_time} + params.logger.debug(params.disruption_time) # Get current data for obtaining start of current quench - ip, magtime = params.mds_conn.get_data_with_dims( + ip, magtime = params.data_conn.get_data_with_dims( r"\ip", tree_name="magnetics" ) ip = np.abs(ip) @@ -2155,7 +2156,7 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): n_chords = 38 array_path = r"\top.brightnesses.array_1" try: - chord_01, t_sxr = params.mds_conn.get_data_with_dims( + chord_01, t_sxr = params.data_conn.get_data_with_dims( array_path + ":chord_01", tree_name="xtomo", ) # Units: W/m^2, s @@ -2170,7 +2171,7 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): # Get all other SXR chords for i in range(1, n_chords): try: - chord, t_chord = params.mds_conn.get_data_with_dims( + chord, t_chord = params.data_conn.get_data_with_dims( array_path + ":chord_" + f"{i+1:02}", tree_name="xtomo", ) @@ -2245,6 +2246,10 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): wndw_before_cq = 0.005 # [s] idx_start = np.argmin(np.abs(t_sxr - (cq_onset_time - wndw_before_cq))) idx_end = np.argmin(np.abs(t_sxr - (cq_onset_time))) + # When params.disruption_time > 2 s, the SXR data stops prior to the labeled CQ + if idx_start == len(t_sxr) - 1: + params.logger.warning(f"No SXR data at time of CQ. params.disruption_time = {params.disruption_time:.3f}") + return {"thermal_quench_time": np.full(len(params.times), np.nan)} t_max_sxr_drop = t_sxr[idx_start + np.argmin(dcore_sxr_dt[idx_start:idx_end])] # Find onset of thermal quench in 0.5 ms window prior to midpoint of TQ diff --git a/scripts/get_thermal_quench_times.py b/scripts/get_thermal_quench_times.py new file mode 100644 index 000000000..24e1955af --- /dev/null +++ b/scripts/get_thermal_quench_times.py @@ -0,0 +1,48 @@ +""" +This script gets the times of the thermal quench onset for a specified shotlist. +Useful for testing the thermal quench labler. +We can delete this when we merge the thermal quench labeler into dev +Author: Henry Wietfeldt +""" + +import os +import logging + +import numpy as np + +from disruption_py.settings import LogSettings, RetrievalSettings +from disruption_py.workflow import get_shots_data + + +# Shotlist of all C-Mod shots 2012-2016 +SHOTLIST_FN = '/home/henrycw/projects/ufo-characterization/shotlists/cmod_shots_2012_to_2016.txt' +NUM_SUBSET = 500 # Number of shots to randomly select from Shotlist + +shotlist = np.genfromtxt(SHOTLIST_FN, dtype=int) +rng = np.random.default_rng(seed=42) +rng.shuffle(shotlist) +if len(shotlist) <= NUM_SUBSET: + shots_to_use = shotlist +else: + shots_to_use = shotlist[:NUM_SUBSET] + +# shots_to_use = [1140821020, 1140520016 ,1150710007] + +# default method for pulling disruption-py data +retrieval_settings = RetrievalSettings( + time_setting="disruption_warning", # use the set efit's timebase + efit_nickname_setting="analysis", # set the efit + run_methods=[], + run_columns=["ip", "thermal_quench_time"], + only_requested_columns=True, +) + +results = get_shots_data( + shotlist_setting=shots_to_use, + retrieval_settings=retrieval_settings, + log_settings=LogSettings(console_level=logging.WARNING), + num_processes=os.cpu_count(), +) + +# Write contents to csv for easy inspection +results.to_dataframe().to_csv('tq_labels.csv') \ No newline at end of file From e7efea823e856099514ea7c127fa0622f7dd39da Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Thu, 9 Apr 2026 12:09:12 -0400 Subject: [PATCH 06/19] Scripts for testing --- disruption_py/machine/cmod/physics.py | 29 ++++++- scripts/plot_thermal_quench_time.py | 115 ++++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 3 deletions(-) create mode 100644 scripts/plot_thermal_quench_time.py diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index cd9dbd4e8..9073f6de5 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -554,7 +554,7 @@ def get_ohmic_parameters(params: PhysicsMethodParams): ) v_loop, v_loop_time = params.data_conn.get_data_with_dims( r"\efit_aeqdsk:vloopt", tree_name="_efit_tree" - ) # [V], [s] + ) # [V], [s] if len(v_loop_time) <= 1: raise CalculationError("No data for v_loop_time") @@ -2261,8 +2261,31 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): idx_start = np.argmin(np.abs(t_sxr - (t_max_sxr_drop - wndw_before_tq_mid))) idx_end = np.argmin(np.abs(t_sxr - (t_max_sxr_drop))) tq_time_scalar = t_sxr[idx_start + np.argmax(core_sxr_raw[idx_start:idx_end])] - thermal_quench_time = tq_time_scalar * np.ones(len(params.times)) - return {"thermal_quench_time": thermal_quench_time} + + # TODO: Delete this block during clean-up + # TODO: Comment this out when running over many shots + #Write some signals for plotting + #Get magnetic axis data from EFIT for testing purposes + # z0, efit_time = params.data_conn.get_data_with_dims( + # r"\efit_aeqdsk:zmagx", tree_name="_efit_tree" + # ) # [cm], [s] + # z0 *= 0.01 # [cm] -> [m] + import pickle + plot_df = {"magtime":magtime, + "ip": ip, + "t_sxr": t_sxr, + "core_sxr_raw": core_sxr_raw, + "core_sxr": core_sxr, + "core_sxr_growth_rate": dcore_sxr_dt, + "t_disrupt": params.disruption_time, + "cq_onset_time": cq_onset_time, + "t_max_sxr_drop": t_max_sxr_drop, + "thermal_quench_time_scalar": tq_time_scalar, + } + with open('sxr.pkl', 'wb') as f: + pickle.dump(plot_df, f) + + return {"thermal_quench_time": tq_time_scalar*np.ones(len(params.times))} @staticmethod def _is_on_blacklist(shot_id: int) -> bool: diff --git a/scripts/plot_thermal_quench_time.py b/scripts/plot_thermal_quench_time.py new file mode 100644 index 000000000..0af54c85a --- /dev/null +++ b/scripts/plot_thermal_quench_time.py @@ -0,0 +1,115 @@ +""" +Program to plot various quantities used in calculating the thermal quench time +for a particular shot to compare various methods. +Author: Henry Wietfeldt +""" + +import logging +import numpy as np +import matplotlib +matplotlib.use('TkAgg') +import matplotlib.pyplot as plt +import pandas as pd +import pickle + +from disruption_py.settings import LogSettings, RetrievalSettings +from disruption_py.workflow import get_shots_data + +MAN_LABEL = False +#SHOT_ID = 1140515015 # Hot VDE +#SHOT_ID = 1140827029 +#SHOT_ID = 1120717002 +#SHOT_ID = 1051206029 +#SHOT_ID = 1160714006 +#TODO: Shot 1160714006 having issues (low SXR signal). What do we do about ramp-up? +# TODO: Shot 11405522001 has significant SXR spike when plasma hits wall, after main TQ +# TODO: Search for first time at which dSXR/dt is w/in factor of 2 from max? +SHOT_ID = 1120927010 # Doesn't have current spike, not sure if this is a hot VDE +signals = [ + "ip", + "zcur", + "thermal_quench_time" +] + +# default method for pulling disruption-py data +retrieval_settings = RetrievalSettings( + time_setting="disruption_warning", # use the set efit's timebase + efit_nickname_setting="efit21", # set the efit + run_methods=[], + run_columns=signals, + only_requested_columns=True, +) + +data = get_shots_data( + shotlist_setting=[SHOT_ID], + retrieval_settings=retrieval_settings, + log_settings=LogSettings(console_level=logging.DEBUG), + output_setting="dataframe", + num_processes=1, +) +print(data) +print("Got data") + +with open('sxr.pkl', 'rb') as f: + df = pickle.load(f) +df['ip'] = np.abs(df['ip']/1e6) +print(df['cq_onset_time']) + +plt.rcParams['font.size'] = 14 +fig, axs = plt.subplots(4, 1, sharex=True, figsize=(14,7)) +#axs[0].set_xlim(0.6, 0.64) +axs[0].scatter(df['magtime'], df['ip'], marker='.', s=10, c='k') +axs[1].scatter(df['t_sxr'], df['core_sxr_raw'], marker='.', s=5, c='k') +axs[2].scatter(df['t_sxr'], df['core_sxr'], marker='.', s=5, c='k') +axs[3].scatter(df['t_sxr'], df['core_sxr_growth_rate'], marker='.', s=5, c='k') +# axs[4].scatter(df['efit_time'], df['z0'], marker='o', s=10, c='k') +print("Plotting labeled times") +for ax in axs: + ax.axvline(df['t_disrupt'], linestyle='-', c='k', label='t_disrupt') + ax.axvline(df['cq_onset_time'], linestyle='--', c='k', label='CQ Onset') + #ax.axvline(df['t_start'], linestyle='--', c='k', label='tstart') + if not MAN_LABEL: + ax.axvline(df['thermal_quench_time_scalar'], linestyle='-', c='r', label='TQ Onset') + # ax.axvspan(df['cq_onset_time']-0.005, df['cq_onset_time'], alpha=0.15, color='tab:green', label='TQ Midpoint Search Window') + ax.axvline(df['t_max_sxr_drop'], linestyle='--', c='g', label='TQ Midpoint') + # for i, t_tq in enumerate(df['thermal_quench_times']): + # if i == 0: + # ax.axvline(t_tq, linestyle='-', c='r', label='TQ') + # else: + # ax.axvline(t_tq, linestyle='-', c='r') + # for i, t_warn in enumerate(df['thermal_quench_warnings']): + # if i == 0: + # ax.axvline(t_warn, linestyle='--', c='b', label='TQ warn') + # else: + # ax.axvline(t_warn, linestyle='--', c='b') +axs[0].set_title('C-Mod Shot: ' + str(SHOT_ID)) +axs[0].set_ylabel('Ip [MA]') +axs[1].set_ylabel('SXR raw') +axs[2].set_ylabel('SXR filt') +axs[3].set_ylabel(r"$dSXR/dt$ [Hz]") +# axs[3].set_ylim(-8e3, 2e3) +# axs[4].set_ylabel('Z0 [m]') +# axs[4].set_xlabel("Time [s]") +axs[0].legend(fontsize=12) + +# fig, axs = plt.subplots(4, 1, sharex=True, figsize=(14,7)) +# axs[0].plot(df['magtime'], df['ip'], marker='.', ms=10, c='k') +# axs[1].plot(df['t_sxr'], df['core_sxr_raw'], marker='.', ms=5, c='k') +# axs[2].plot(data['time'], data['p_rad']/1e6, marker='o', ms=10, c='k') +# axs[3].plot(df['efit_time'], df['z0'], marker='o', ms=10, c='k') + +# for ax in axs: +# ax.axvline(df['t_disrupt'], linestyle='--', c='b', label='t_disrupt (DisruptionPy)') +# ax.axvline(df['thermal_quench_time_scalar'], linestyle='--', c='r', label='TQ Onset (auto)') + +# axs[0].set_title('C-Mod Shot: ' + str(SHOT_ID)) +# axs[0].set_ylabel('Ip [MA]') +# axs[1].set_ylabel('SXR raw [a.u.]') +# axs[2].set_ylabel('Prad [MW]') +# axs[3].set_ylabel('Z0 [m]') +# axs[3].set_xlabel("Time [s]") +# # axs[-1].set_xlim(0.68, 0.725) +# axs[0].legend() +# plt.show() + +plt.show() \ No newline at end of file From c988a945aa7bf1d361f9601904fc88258075b1c6 Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Thu, 9 Apr 2026 13:28:55 -0400 Subject: [PATCH 07/19] Fix bug when autocorr[index_no_lag:] > 0 always --- disruption_py/machine/cmod/physics.py | 12 +++++++++++- scripts/plot_thermal_quench_time.py | 2 +- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index 9073f6de5..ee8bda300 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -2189,6 +2189,7 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): sxr[i] = chord[valid_times] sample_time = t_sxr[1] - t_sxr[0] sample_freq = 1 / sample_time + params.logger.warning(f"Max of SXR Raw: {np.max(sxr)}") # Remove bad chords by checking each chord's autocorrelation. # Bad chords often have significant white noise, meaning low autocorrelation (< 10 ms) @@ -2213,8 +2214,16 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): sxr[i] = 0. continue index_no_lag = np.argmax(autocorr) - index_decay = np.argmax(autocorr[index_no_lag:] < 0) + params.logger.debug(f"Chord {i+1}. Autocorr[index_no_lag]: {autocorr[index_no_lag:index_no_lag + 8]}") + crosses_zero = autocorr[index_no_lag:] < 0 + if np.any(crosses_zero): + index_decay = np.argmax(crosses_zero) + else: + # See shot 1120223007 for example of why this if-else logic is necessary + index_decay = len(crosses_zero) + params.logger.debug(f"Chord {i+1} index_decay: {index_decay}") if index_decay*(1/sample_freq_5khz) < noise_autorr_cutoff: + params.logger.debug(f"Removing chord {i+1}. Norm. Autocorr: {index_decay*(1/sample_freq_5khz)}") sxr[i] = 0. # Noncausal Butterworth low pass filter to smooth transient SXR spikes during TQ. @@ -2257,6 +2266,7 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): # Label onset as max of SXR signal on 0.5 ms window preceding max drop in SXR # Use raw signal bc smoothed signal as a longer crash time. # Note this sometimes picks up on recombination spikes + # TODO: Use np.where to take last maximum in case the SXR has saturated and there are multiple maxima wndw_before_tq_mid = 0.0005 # [s] idx_start = np.argmin(np.abs(t_sxr - (t_max_sxr_drop - wndw_before_tq_mid))) idx_end = np.argmin(np.abs(t_sxr - (t_max_sxr_drop))) diff --git a/scripts/plot_thermal_quench_time.py b/scripts/plot_thermal_quench_time.py index 0af54c85a..9bb151bf1 100644 --- a/scripts/plot_thermal_quench_time.py +++ b/scripts/plot_thermal_quench_time.py @@ -24,7 +24,7 @@ #TODO: Shot 1160714006 having issues (low SXR signal). What do we do about ramp-up? # TODO: Shot 11405522001 has significant SXR spike when plasma hits wall, after main TQ # TODO: Search for first time at which dSXR/dt is w/in factor of 2 from max? -SHOT_ID = 1120927010 # Doesn't have current spike, not sure if this is a hot VDE +SHOT_ID = 1120223007 # Doesn't have current spike, not sure if this is a hot VDE signals = [ "ip", "zcur", From bb009881edfc1fbc29bedbf602d75321476c7799 Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Thu, 9 Apr 2026 13:56:34 -0400 Subject: [PATCH 08/19] Fix bug if core sxr raw has saturated --- disruption_py/machine/cmod/physics.py | 10 ++++++---- scripts/plot_thermal_quench_time.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index ee8bda300..1b9074b0b 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -2266,11 +2266,13 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): # Label onset as max of SXR signal on 0.5 ms window preceding max drop in SXR # Use raw signal bc smoothed signal as a longer crash time. # Note this sometimes picks up on recombination spikes - # TODO: Use np.where to take last maximum in case the SXR has saturated and there are multiple maxima - wndw_before_tq_mid = 0.0005 # [s] - idx_start = np.argmin(np.abs(t_sxr - (t_max_sxr_drop - wndw_before_tq_mid))) + wndw_before_tq_midpoint = 0.0005 # [s] + idx_start = np.argmin(np.abs(t_sxr - (t_max_sxr_drop - wndw_before_tq_midpoint))) idx_end = np.argmin(np.abs(t_sxr - (t_max_sxr_drop))) - tq_time_scalar = t_sxr[idx_start + np.argmax(core_sxr_raw[idx_start:idx_end])] + window = core_sxr_raw[idx_start:idx_end] + # Want last maximum in case the SXR has saturated and there are multiple maxima + max_sxr_indx = np.nonzero(window == np.max(window))[-1] + tq_time_scalar = t_sxr[idx_start + max_sxr_indx] # TODO: Delete this block during clean-up # TODO: Comment this out when running over many shots diff --git a/scripts/plot_thermal_quench_time.py b/scripts/plot_thermal_quench_time.py index 9bb151bf1..0af54c85a 100644 --- a/scripts/plot_thermal_quench_time.py +++ b/scripts/plot_thermal_quench_time.py @@ -24,7 +24,7 @@ #TODO: Shot 1160714006 having issues (low SXR signal). What do we do about ramp-up? # TODO: Shot 11405522001 has significant SXR spike when plasma hits wall, after main TQ # TODO: Search for first time at which dSXR/dt is w/in factor of 2 from max? -SHOT_ID = 1120223007 # Doesn't have current spike, not sure if this is a hot VDE +SHOT_ID = 1120927010 # Doesn't have current spike, not sure if this is a hot VDE signals = [ "ip", "zcur", From b7966160453704fcd6abebf2d14b2f3032f70623 Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Thu, 9 Apr 2026 14:45:43 -0400 Subject: [PATCH 09/19] scripts for testing TQ labeling --- disruption_py/machine/cmod/physics.py | 3 +- scripts/plot_test_results_tq_time.py | 24 +++++++++ scripts/plot_thermal_quench_time.py | 2 +- tests/test_thermal_quench_times.py | 74 +++++++++++++++++++++++++++ 4 files changed, 100 insertions(+), 3 deletions(-) create mode 100644 scripts/plot_test_results_tq_time.py create mode 100644 tests/test_thermal_quench_times.py diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index 1b9074b0b..91c0eddf8 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -2189,7 +2189,6 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): sxr[i] = chord[valid_times] sample_time = t_sxr[1] - t_sxr[0] sample_freq = 1 / sample_time - params.logger.warning(f"Max of SXR Raw: {np.max(sxr)}") # Remove bad chords by checking each chord's autocorrelation. # Bad chords often have significant white noise, meaning low autocorrelation (< 10 ms) @@ -2271,7 +2270,7 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): idx_end = np.argmin(np.abs(t_sxr - (t_max_sxr_drop))) window = core_sxr_raw[idx_start:idx_end] # Want last maximum in case the SXR has saturated and there are multiple maxima - max_sxr_indx = np.nonzero(window == np.max(window))[-1] + max_sxr_indx = np.nonzero(window == np.max(window))[0][-1] tq_time_scalar = t_sxr[idx_start + max_sxr_indx] # TODO: Delete this block during clean-up diff --git a/scripts/plot_test_results_tq_time.py b/scripts/plot_test_results_tq_time.py new file mode 100644 index 000000000..338d1c49e --- /dev/null +++ b/scripts/plot_test_results_tq_time.py @@ -0,0 +1,24 @@ +""" +Script for plotting test results from tests/test_thermal_quench_times.py +Author: Henry Wietfeldt +""" + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +df = pd.read_csv('test_thermal_quench_results.csv') + +# Print summary statistics +error = df['onset_error_s'].to_numpy() +print(f"Mean |Error| = {1e3*np.mean(np.abs(error)):.3f} ms") +print(f"Median |Error| = {1e3*np.median(np.abs(error)):.3f} ms") +print(f"Std Dev |Error| = {1e3*np.std(np.abs(error)):.3f} ms") +print(f"Min Error = {1e3*np.min(error):.3f} ms") +print(f"Max Error = {1e3*np.max(error):.3f} ms") +print(f"Num Outliers (|error| > 1 ms) = {np.sum(~df['within_tq_range'])} out of {len(df)} shots") + +plt.hist(df['onset_error_s'], bins=50) +plt.xlabel('Error in TQ Onset Time (Auto - Manual) [s]', fontsize=16) +plt.ylabel('Count', fontsize=16) +plt.show() \ No newline at end of file diff --git a/scripts/plot_thermal_quench_time.py b/scripts/plot_thermal_quench_time.py index 0af54c85a..6e5281f79 100644 --- a/scripts/plot_thermal_quench_time.py +++ b/scripts/plot_thermal_quench_time.py @@ -24,7 +24,7 @@ #TODO: Shot 1160714006 having issues (low SXR signal). What do we do about ramp-up? # TODO: Shot 11405522001 has significant SXR spike when plasma hits wall, after main TQ # TODO: Search for first time at which dSXR/dt is w/in factor of 2 from max? -SHOT_ID = 1120927010 # Doesn't have current spike, not sure if this is a hot VDE +SHOT_ID = 1140826003 signals = [ "ip", "zcur", diff --git a/tests/test_thermal_quench_times.py b/tests/test_thermal_quench_times.py new file mode 100644 index 000000000..65faccb98 --- /dev/null +++ b/tests/test_thermal_quench_times.py @@ -0,0 +1,74 @@ +""" +Test automated TQ labels compared to manual labels +Author: Henry Wietfeldt +Source of Manual labels: Henry Wietfeldt +""" + +import logging +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from disruption_py.settings import LogSettings, RetrievalSettings +from disruption_py.workflow import get_shots_data + +db_man = pd.read_csv('manual_tq_labels.csv') +db_man['shot'] = db_man['shot'].astype(int) +shotlist = db_man['shot'].to_list() +print(f"Number of shots: len(shotlist)") + +signals = [ + "ip", + "thermal_quench_time" +] + +# default method for pulling disruption-py data +retrieval_settings = RetrievalSettings( + time_setting="disruption_warning", # use the set efit's timebase + efit_nickname_setting="analysis", # set the efit + run_methods=[], + run_columns=signals, + only_requested_columns=True, +) + +db_auto = get_shots_data( + shotlist_setting=shotlist, + retrieval_settings=retrieval_settings, + log_settings=LogSettings(console_level=logging.WARNING), + output_setting="dataframe", + num_processes=os.cpu_count(), +) + +# Using pandas because this testing script has not been updated to xarray + +# Output test results +db_auto = db_auto.drop_duplicates(subset='shot').drop(columns='time') +db_auto.to_csv('test_thermal_quench_auto_labels.csv') +db_test = db_man.copy() +db_test = db_test.sort_values(by='shot') +db_auto = db_auto.sort_values(by='shot') +db_test = pd.merge(db_test, db_auto, how='outer', on='shot') +db_test = db_test.rename(columns={'thermal_quench_time': 'tq_auto'}) +print(db_test) + +db_test['within_tq_range'] = (db_test['tq_auto'] > db_test['tq_onset_manual']-0.001) & (db_test['tq_auto'] < db_test['tq_end_manual']+0.001) +db_test['onset_error_s'] = db_test['tq_auto'] - db_test['tq_onset_manual'] +db_test = db_test.drop(columns=['notes']) +db_test.to_csv('test_thermal_quench_results.csv') + +# Print summary statistics +error = db_test['onset_error_s'].to_numpy() +print(f"Mean |Error| = {1e3*np.mean(np.abs(error)):.3f} ms") +print(f"Median |Error| = {1e3*np.median(np.abs(error)):.3f} ms") +print(f"Std Dev |Error| = {1e3*np.std(np.abs(error)):.3f} ms") +print(f"Min Error = {1e3*np.min(error):.3f} ms") +print(f"Max Error = {1e3*np.max(error):.3f} ms") +print(f"Num Outliers (|error| > 1 ms) = {np.sum(~db_test['within_tq_range'])} out of {len(shotlist)} shots") + +# Plot onset errors +plt.hist(db_test['onset_error_s'], bins=50) +plt.xlabel('Error in TQ Onset Time (Auto - Manual) [s]', fontsize=16) +plt.ylabel('Count', fontsize=16) +plt.show() \ No newline at end of file From 760eb5eeb7fba3b04af3c30558fa5fcb8d0a35f4 Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Thu, 9 Apr 2026 15:26:18 -0400 Subject: [PATCH 10/19] Label time at 90% of max SXR prior to crash --- disruption_py/machine/cmod/physics.py | 37 ++++++++++++++------------- scripts/plot_test_results_tq_time.py | 30 ++++++++++++++-------- scripts/plot_thermal_quench_time.py | 2 +- tests/test_thermal_quench_times.py | 2 +- 4 files changed, 41 insertions(+), 30 deletions(-) diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index 91c0eddf8..6c9204006 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -2261,16 +2261,17 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): t_max_sxr_drop = t_sxr[idx_start + np.argmin(dcore_sxr_dt[idx_start:idx_end])] # Find onset of thermal quench in 0.5 ms window prior to midpoint of TQ - # Thermal quenches on C-Mod are almost always shorter than 1 ms - # Label onset as max of SXR signal on 0.5 ms window preceding max drop in SXR - # Use raw signal bc smoothed signal as a longer crash time. + # Thermal quenches on C-Mod are almost always shorter than 1 ms, hence the 0.5 ms window + # Find max of SXR signal on 0.5 ms window preceding max drop in SXR and label onset as + # last timestep with SXR > 90% of that max value + # Use raw signal bc smoothed signal has a longer crash time. # Note this sometimes picks up on recombination spikes wndw_before_tq_midpoint = 0.0005 # [s] idx_start = np.argmin(np.abs(t_sxr - (t_max_sxr_drop - wndw_before_tq_midpoint))) idx_end = np.argmin(np.abs(t_sxr - (t_max_sxr_drop))) window = core_sxr_raw[idx_start:idx_end] # Want last maximum in case the SXR has saturated and there are multiple maxima - max_sxr_indx = np.nonzero(window == np.max(window))[0][-1] + max_sxr_indx = np.nonzero(window >= 0.9*np.max(window))[0][-1] tq_time_scalar = t_sxr[idx_start + max_sxr_indx] # TODO: Delete this block during clean-up @@ -2281,20 +2282,20 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): # r"\efit_aeqdsk:zmagx", tree_name="_efit_tree" # ) # [cm], [s] # z0 *= 0.01 # [cm] -> [m] - import pickle - plot_df = {"magtime":magtime, - "ip": ip, - "t_sxr": t_sxr, - "core_sxr_raw": core_sxr_raw, - "core_sxr": core_sxr, - "core_sxr_growth_rate": dcore_sxr_dt, - "t_disrupt": params.disruption_time, - "cq_onset_time": cq_onset_time, - "t_max_sxr_drop": t_max_sxr_drop, - "thermal_quench_time_scalar": tq_time_scalar, - } - with open('sxr.pkl', 'wb') as f: - pickle.dump(plot_df, f) + # import pickle + # plot_df = {"magtime":magtime, + # "ip": ip, + # "t_sxr": t_sxr, + # "core_sxr_raw": core_sxr_raw, + # "core_sxr": core_sxr, + # "core_sxr_growth_rate": dcore_sxr_dt, + # "t_disrupt": params.disruption_time, + # "cq_onset_time": cq_onset_time, + # "t_max_sxr_drop": t_max_sxr_drop, + # "thermal_quench_time_scalar": tq_time_scalar, + # } + # with open('sxr.pkl', 'wb') as f: + # pickle.dump(plot_df, f) return {"thermal_quench_time": tq_time_scalar*np.ones(len(params.times))} diff --git a/scripts/plot_test_results_tq_time.py b/scripts/plot_test_results_tq_time.py index 338d1c49e..c5edc4c46 100644 --- a/scripts/plot_test_results_tq_time.py +++ b/scripts/plot_test_results_tq_time.py @@ -7,18 +7,28 @@ import pandas as pd import matplotlib.pyplot as plt -df = pd.read_csv('test_thermal_quench_results.csv') +df_new = pd.read_csv('test_thermal_quench_results.csv') +df_old = pd.read_csv('test_thermal_quench_results_max.csv') +df_90 = pd.read_csv('test_thermal_quench_results_max90.csv') -# Print summary statistics -error = df['onset_error_s'].to_numpy() -print(f"Mean |Error| = {1e3*np.mean(np.abs(error)):.3f} ms") -print(f"Median |Error| = {1e3*np.median(np.abs(error)):.3f} ms") -print(f"Std Dev |Error| = {1e3*np.std(np.abs(error)):.3f} ms") -print(f"Min Error = {1e3*np.min(error):.3f} ms") -print(f"Max Error = {1e3*np.max(error):.3f} ms") -print(f"Num Outliers (|error| > 1 ms) = {np.sum(~df['within_tq_range'])} out of {len(df)} shots") +dfs = [df_new, df_90, df_old] +labels = ['new', '0.9*max', 'old'] +steps = np.linspace(-0.004, 0.004, 60) + +for i, df in enumerate(dfs): + # Print summary statistics + error = df['onset_error_s'].to_numpy() + print(f"Stats for {labels[i]}") + print(f"Mean |Error| = {1e3*np.mean(np.abs(error)):.3f} ms") + print(f"Median |Error| = {1e3*np.median(np.abs(error)):.3f} ms") + print(f"Std Dev |Error| = {1e3*np.std(np.abs(error)):.3f} ms") + print(f"Min Error = {1e3*np.min(error):.3f} ms") + print(f"Max Error = {1e3*np.max(error):.3f} ms") + print(f"Num Outliers (outisde TQ [start, end] by >1 ms) = {np.sum(~df['within_tq_range'])} out of {len(df)} shots\n") + + plt.hist(df['onset_error_s'], histtype='step', linewidth=2, bins=50, label=labels[i]) -plt.hist(df['onset_error_s'], bins=50) plt.xlabel('Error in TQ Onset Time (Auto - Manual) [s]', fontsize=16) plt.ylabel('Count', fontsize=16) +plt.legend() plt.show() \ No newline at end of file diff --git a/scripts/plot_thermal_quench_time.py b/scripts/plot_thermal_quench_time.py index 6e5281f79..32d4f528d 100644 --- a/scripts/plot_thermal_quench_time.py +++ b/scripts/plot_thermal_quench_time.py @@ -24,7 +24,7 @@ #TODO: Shot 1160714006 having issues (low SXR signal). What do we do about ramp-up? # TODO: Shot 11405522001 has significant SXR spike when plasma hits wall, after main TQ # TODO: Search for first time at which dSXR/dt is w/in factor of 2 from max? -SHOT_ID = 1140826003 +SHOT_ID = 1120927010 signals = [ "ip", "zcur", diff --git a/tests/test_thermal_quench_times.py b/tests/test_thermal_quench_times.py index 65faccb98..195694ea4 100644 --- a/tests/test_thermal_quench_times.py +++ b/tests/test_thermal_quench_times.py @@ -65,7 +65,7 @@ print(f"Std Dev |Error| = {1e3*np.std(np.abs(error)):.3f} ms") print(f"Min Error = {1e3*np.min(error):.3f} ms") print(f"Max Error = {1e3*np.max(error):.3f} ms") -print(f"Num Outliers (|error| > 1 ms) = {np.sum(~db_test['within_tq_range'])} out of {len(shotlist)} shots") +print(f"Num Outliers (outisde TQ [start, end] by >1 ms) = {np.sum(~db_test['within_tq_range'])} out of {len(shotlist)} shots") # Plot onset errors plt.hist(db_test['onset_error_s'], bins=50) From 3091a1a917eae6c2f52e6fdca4c5da65c7b79741 Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Wed, 3 Jun 2026 17:16:09 -0400 Subject: [PATCH 11/19] More testing of TQ labeling --- disruption_py/machine/cmod/physics.py | 11 +++-- scripts/get_thermal_quench_times.py | 1 + scripts/plot_test_results_tq_time.py | 17 ++++--- scripts/plot_thermal_quench_time.py | 22 +++++---- scripts/plot_tq_timetraces.py | 68 +++++++++++++++++++++++++++ tests/test_thermal_quench_times.py | 3 +- 6 files changed, 101 insertions(+), 21 deletions(-) create mode 100644 scripts/plot_tq_timetraces.py diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index 6c9204006..538f4b642 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -2117,7 +2117,8 @@ def get_surface_voltage(params: PhysicsMethodParams): return {"v_surf": v_surf} @staticmethod - @physics_method(columns=["thermal_quench_time"], tokamak=Tokamak.CMOD) + # TODO: Remove t_disrupt, core_sxr + @physics_method(columns=["thermal_quench_time", "t_disrupt", "core_sxr"], tokamak=Tokamak.CMOD) def get_thermal_quench_onset_time(params: PhysicsMethodParams): """ Labels the onset time of the thermal quench for a given shot (NaN for non-disruptive shots) @@ -2282,10 +2283,13 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): # r"\efit_aeqdsk:zmagx", tree_name="_efit_tree" # ) # [cm], [s] # z0 *= 0.01 # [cm] -> [m] + # te0_ece, t_ece = params.data_conn.get_data_with_dims(r"\gpc2_te0", tree_name="electrons") # import pickle # plot_df = {"magtime":magtime, # "ip": ip, # "t_sxr": t_sxr, + # "t_ece": t_ece, + # "te0_ece": te0_ece, # "core_sxr_raw": core_sxr_raw, # "core_sxr": core_sxr, # "core_sxr_growth_rate": dcore_sxr_dt, @@ -2296,8 +2300,9 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): # } # with open('sxr.pkl', 'wb') as f: # pickle.dump(plot_df, f) - - return {"thermal_quench_time": tq_time_scalar*np.ones(len(params.times))} + # TODO: Remove t_disrupt, core_sxr + core_sxr = interp1(t_sxr, core_sxr_raw, params.times) + return {"thermal_quench_time": tq_time_scalar*np.ones(len(params.times)), "t_disrupt": params.disruption_time*np.ones(len(params.times)), "core_sxr": core_sxr} @staticmethod def _is_on_blacklist(shot_id: int) -> bool: diff --git a/scripts/get_thermal_quench_times.py b/scripts/get_thermal_quench_times.py index 24e1955af..f21e2f231 100644 --- a/scripts/get_thermal_quench_times.py +++ b/scripts/get_thermal_quench_times.py @@ -9,6 +9,7 @@ import logging import numpy as np +import pandas as pd from disruption_py.settings import LogSettings, RetrievalSettings from disruption_py.workflow import get_shots_data diff --git a/scripts/plot_test_results_tq_time.py b/scripts/plot_test_results_tq_time.py index c5edc4c46..4bf8e45f6 100644 --- a/scripts/plot_test_results_tq_time.py +++ b/scripts/plot_test_results_tq_time.py @@ -7,13 +7,13 @@ import pandas as pd import matplotlib.pyplot as plt -df_new = pd.read_csv('test_thermal_quench_results.csv') -df_old = pd.read_csv('test_thermal_quench_results_max.csv') -df_90 = pd.read_csv('test_thermal_quench_results_max90.csv') +df_old = pd.read_csv('test_thermal_quench_results_cq_onset90.csv') +df_new = pd.read_csv('test_thermal_quench_results_cq_onset95.csv') +#df_90 = pd.read_csv('test_thermal_quench_results_max90.csv') -dfs = [df_new, df_90, df_old] -labels = ['new', '0.9*max', 'old'] -steps = np.linspace(-0.004, 0.004, 60) +dfs = [df_old] +labels = ['0.9*max'] +bins = np.linspace(-0.004, 0.004, 60) for i, df in enumerate(dfs): # Print summary statistics @@ -25,10 +25,13 @@ print(f"Min Error = {1e3*np.min(error):.3f} ms") print(f"Max Error = {1e3*np.max(error):.3f} ms") print(f"Num Outliers (outisde TQ [start, end] by >1 ms) = {np.sum(~df['within_tq_range'])} out of {len(df)} shots\n") + print(df[~df['within_tq_range']].sort_values(by='onset_error_s')) - plt.hist(df['onset_error_s'], histtype='step', linewidth=2, bins=50, label=labels[i]) + plt.hist(df['onset_error_s'], edgecolor='k', bins=bins, label=labels[i], alpha=0.5) +plt.title("Test vs 100 Manually Labeled Shots", fontsize=18) plt.xlabel('Error in TQ Onset Time (Auto - Manual) [s]', fontsize=16) +plt.xlim(-0.003, 0.00325) plt.ylabel('Count', fontsize=16) plt.legend() plt.show() \ No newline at end of file diff --git a/scripts/plot_thermal_quench_time.py b/scripts/plot_thermal_quench_time.py index 32d4f528d..c3ee66fec 100644 --- a/scripts/plot_thermal_quench_time.py +++ b/scripts/plot_thermal_quench_time.py @@ -24,7 +24,7 @@ #TODO: Shot 1160714006 having issues (low SXR signal). What do we do about ramp-up? # TODO: Shot 11405522001 has significant SXR spike when plasma hits wall, after main TQ # TODO: Search for first time at which dSXR/dt is w/in factor of 2 from max? -SHOT_ID = 1120927010 +SHOT_ID = 1140826003 signals = [ "ip", "zcur", @@ -56,17 +56,18 @@ print(df['cq_onset_time']) plt.rcParams['font.size'] = 14 -fig, axs = plt.subplots(4, 1, sharex=True, figsize=(14,7)) +fig, axs = plt.subplots(5, 1, sharex=True, figsize=(14,7)) #axs[0].set_xlim(0.6, 0.64) -axs[0].scatter(df['magtime'], df['ip'], marker='.', s=10, c='k') -axs[1].scatter(df['t_sxr'], df['core_sxr_raw'], marker='.', s=5, c='k') -axs[2].scatter(df['t_sxr'], df['core_sxr'], marker='.', s=5, c='k') -axs[3].scatter(df['t_sxr'], df['core_sxr_growth_rate'], marker='.', s=5, c='k') +axs[0].plot(df['magtime'], df['ip'], marker='.', ms=10, c='k', linestyle='-') +axs[1].plot(df['t_ece'], df['te0_ece'], marker='.', ms=5, c='k', linestyle='-') +axs[2].scatter(df['t_sxr'], df['core_sxr_raw'], marker='.', s=5, c='k') +axs[3].scatter(df['t_sxr'], df['core_sxr'], marker='.', s=5, c='k') +axs[4].scatter(df['t_sxr'], df['core_sxr_growth_rate'], marker='.', s=5, c='k') # axs[4].scatter(df['efit_time'], df['z0'], marker='o', s=10, c='k') print("Plotting labeled times") for ax in axs: ax.axvline(df['t_disrupt'], linestyle='-', c='k', label='t_disrupt') - ax.axvline(df['cq_onset_time'], linestyle='--', c='k', label='CQ Onset') + ax.axvline(df['cq_onset_time'], linestyle='--', c='k', label='CQ Onset', lw=2) #ax.axvline(df['t_start'], linestyle='--', c='k', label='tstart') if not MAN_LABEL: ax.axvline(df['thermal_quench_time_scalar'], linestyle='-', c='r', label='TQ Onset') @@ -84,9 +85,10 @@ # ax.axvline(t_warn, linestyle='--', c='b') axs[0].set_title('C-Mod Shot: ' + str(SHOT_ID)) axs[0].set_ylabel('Ip [MA]') -axs[1].set_ylabel('SXR raw') -axs[2].set_ylabel('SXR filt') -axs[3].set_ylabel(r"$dSXR/dt$ [Hz]") +axs[1].set_ylabel('Te0 [keV]') +axs[2].set_ylabel('SXR raw') +axs[3].set_ylabel('SXR filt') +axs[4].set_ylabel(r"$dSXR/dt$ [Hz]") # axs[3].set_ylim(-8e3, 2e3) # axs[4].set_ylabel('Z0 [m]') # axs[4].set_xlabel("Time [s]") diff --git a/scripts/plot_tq_timetraces.py b/scripts/plot_tq_timetraces.py new file mode 100644 index 000000000..55845b020 --- /dev/null +++ b/scripts/plot_tq_timetraces.py @@ -0,0 +1,68 @@ +""" +Plot timetraces of shots showing different disruption times +for testing thermal quench time labeler +""" + +import logging +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from disruption_py.settings import LogSettings, RetrievalSettings +from disruption_py.workflow import get_shots_data + +USE_PREV_DPY_RESULTS = True +N_SHOTS = 30 + +df_man = pd.read_csv('manual_tq_labels.csv') +df_man['shot'] = df_man['shot'].astype(int) +shotlist = df_man['shot'].to_list() +print(f"Number of shots: {len(shotlist)}") + +signals = [ + "ip", + "thermal_quench_time", + "t_disrupt", + "core_sxr" +] + +if USE_PREV_DPY_RESULTS: + df = pd.read_csv('tq_df.csv') +else: + # default method for pulling disruption-py data + retrieval_settings = RetrievalSettings( + time_setting="disruption_warning", # use the set efit's timebase + efit_nickname_setting="analysis", # set the efit + run_methods=[], + run_columns=signals, + only_requested_columns=True, + ) + + df = get_shots_data( + shotlist_setting=shotlist, + retrieval_settings=retrieval_settings, + log_settings=LogSettings(console_level=logging.WARNING), + output_setting="dataframe", + num_processes=os.cpu_count(), + ) + df.to_csv('tq_df.csv') + +df = df.merge(df_man[['shot', 'tq_onset_manual']], on='shot', how='left') + +# Select 30 random shots +rng = np.random.default_rng(seed=42) +rng.shuffle(shotlist) +shot_subset = shotlist[:N_SHOTS] + +fig, axs = plt.subplots(N_SHOTS, 1, figsize=(5, 14)) + +for i, s in enumerate(shot_subset): + df_s = df[df['shot']==s] + axs[i].plot(df_s['time'], df_s['ip'], c='g') + axs[i].set_yticks([]) + axs[i].set_xticks([]) + axs[i].set_xlim(0, 2) + +plt.show() \ No newline at end of file diff --git a/tests/test_thermal_quench_times.py b/tests/test_thermal_quench_times.py index 195694ea4..a173092d6 100644 --- a/tests/test_thermal_quench_times.py +++ b/tests/test_thermal_quench_times.py @@ -21,7 +21,8 @@ signals = [ "ip", - "thermal_quench_time" + "thermal_quench_time", + "t_disrupt" ] # default method for pulling disruption-py data From 5e24c22e7a60451b4f23a1a01352d6a645e457ab Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Wed, 3 Jun 2026 17:26:22 -0400 Subject: [PATCH 12/19] update get_data() syntax for tq labeling --- disruption_py/machine/cmod/physics.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index 5a2f3c9b0..e1f22be22 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -2126,7 +2126,7 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): params.logger.debug(params.disruption_time) # Get current data for obtaining start of current quench - ip, magtime = params.data_conn.get_data_with_dims( + ip, magtime = params.get_data_with_dims( r"\ip", tree_name="magnetics" ) ip = np.abs(ip) @@ -2135,7 +2135,7 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): n_chords = 38 array_path = r"\top.brightnesses.array_1" try: - chord_01, t_sxr = params.data_conn.get_data_with_dims( + chord_01, t_sxr = params.get_data_with_dims( array_path + ":chord_01", tree_name="xtomo", ) # Units: W/m^2, s @@ -2150,7 +2150,7 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): # Get all other SXR chords for i in range(1, n_chords): try: - chord, t_chord = params.data_conn.get_data_with_dims( + chord, t_chord = params.get_data_with_dims( array_path + ":chord_" + f"{i+1:02}", tree_name="xtomo", ) @@ -2257,11 +2257,11 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): # TODO: Comment this out when running over many shots #Write some signals for plotting #Get magnetic axis data from EFIT for testing purposes - # z0, efit_time = params.data_conn.get_data_with_dims( + # z0, efit_time = params.get_data_with_dims( # r"\efit_aeqdsk:zmagx", tree_name="_efit_tree" # ) # [cm], [s] # z0 *= 0.01 # [cm] -> [m] - # te0_ece, t_ece = params.data_conn.get_data_with_dims(r"\gpc2_te0", tree_name="electrons") + # te0_ece, t_ece = params.get_data_with_dims(r"\gpc2_te0", tree_name="electrons") # import pickle # plot_df = {"magtime":magtime, # "ip": ip, From e4e29018e236d39d288bc5793056761ca23a2ced Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Wed, 3 Jun 2026 17:42:11 -0400 Subject: [PATCH 13/19] Evaluate timing bottlenecks in tq labeling --- disruption_py/machine/cmod/physics.py | 18 ++++++++++++++++++ tests/test_thermal_quench_times.py | 10 +++++----- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index e1f22be22..69cdb4c98 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -2119,7 +2119,11 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): Last Major Update: Henry Wietfeldt (1/26/26) """ + import os + import pandas as pd from scipy.signal import butter, filtfilt, resample_poly + from time import perf_counter + t0 = perf_counter() thermal_quench_time = np.full(len(params.times), np.nan) if params.disruption_time is None: return {"thermal_quench_time": thermal_quench_time} @@ -2132,6 +2136,7 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): ip = np.abs(ip) # Get SXR chords + t0_read = perf_counter() n_chords = 38 array_path = r"\top.brightnesses.array_1" try: @@ -2168,11 +2173,13 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): sxr[i] = chord[valid_times] sample_time = t_sxr[1] - t_sxr[0] sample_freq = 1 / sample_time + timings = {'read_mds': perf_counter() - t0_read} # Remove bad chords by checking each chord's autocorrelation. # Bad chords often have significant white noise, meaning low autocorrelation (< 10 ms) # Good chords should have an autocorrelation of 100s of ms # See shot 1050311013 as an example with some bad chords + t0_autocorr = perf_counter() noise_autorr_cutoff = 0.01 # [s] for i, chord in enumerate(sxr): # Use 300 ms prior to current quench for speed-up during autocorr O(N^2) @@ -2203,12 +2210,14 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): if index_decay*(1/sample_freq_5khz) < noise_autorr_cutoff: params.logger.debug(f"Removing chord {i+1}. Norm. Autocorr: {index_decay*(1/sample_freq_5khz)}") sxr[i] = 0. + timings['autocorr'] = perf_counter() - t0_autocorr # Noncausal Butterworth low pass filter to smooth transient SXR spikes during TQ. # Cutoff of 1.0 kHz and order 2 seems to filter recombination SXR spikes # while maintaining decent resolution of TQ based on scan from 0.25 kHz - 2 kHz # Results were fairly insensitive within these windows on the 100 shots checked # See shot 1120913013 as example of large recombination spike + t0_bworth = perf_counter() bworth_cutoff = 1000 # [Hz] bworth_order = 2 normalized_cutoff = bworth_cutoff / (0.5*sample_freq) @@ -2217,7 +2226,9 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): sxr = filtfilt(b, a, sxr, axis=1) core_sxr = np.max(sxr, axis=0) dcore_sxr_dt = np.diff(core_sxr, prepend=0)/sample_time + timings['bworth'] = perf_counter() - t0_bworth + t0_find_tq = perf_counter() # Search for the onset of the CQ so that we can search for the TQ in a small time window # to avoid labeleing sawtooth crashes as the thermal quench # Some current quenches can be long (see shots 1050311013, 1050802017). @@ -2252,6 +2263,7 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): # Want last maximum in case the SXR has saturated and there are multiple maxima max_sxr_indx = np.nonzero(window >= 0.9*np.max(window))[0][-1] tq_time_scalar = t_sxr[idx_start + max_sxr_indx] + timings['find_tq'] = perf_counter() - t0_find_tq # TODO: Delete this block during clean-up # TODO: Comment this out when running over many shots @@ -2278,6 +2290,12 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): # } # with open('sxr.pkl', 'wb') as f: # pickle.dump(plot_df, f) + + # Output timings (TODO: Remove this eventually) + timings['total'] = perf_counter() - t0 + fn = os.path.join('/home/henrycw/projects/label-thermal-quench/disruption-py-label-thermal-quench/disruption-py/tq_timing', f'timing_{params.shot_id}.csv') + pd.DataFrame([timings]).to_csv(fn, index=False) + # TODO: Remove t_disrupt, core_sxr core_sxr = interp1(t_sxr, core_sxr_raw, params.times) return {"thermal_quench_time": tq_time_scalar*np.ones(len(params.times)), "t_disrupt": params.disruption_time*np.ones(len(params.times)), "core_sxr": core_sxr} diff --git a/tests/test_thermal_quench_times.py b/tests/test_thermal_quench_times.py index a173092d6..b0777f13f 100644 --- a/tests/test_thermal_quench_times.py +++ b/tests/test_thermal_quench_times.py @@ -61,11 +61,11 @@ # Print summary statistics error = db_test['onset_error_s'].to_numpy() -print(f"Mean |Error| = {1e3*np.mean(np.abs(error)):.3f} ms") -print(f"Median |Error| = {1e3*np.median(np.abs(error)):.3f} ms") -print(f"Std Dev |Error| = {1e3*np.std(np.abs(error)):.3f} ms") -print(f"Min Error = {1e3*np.min(error):.3f} ms") -print(f"Max Error = {1e3*np.max(error):.3f} ms") +print(f"Mean |Error| = {1e3*np.nanmean(np.abs(error)):.3f} ms") +print(f"Median |Error| = {1e3*np.nanmedian(np.abs(error)):.3f} ms") +print(f"Std Dev |Error| = {1e3*np.nanstd(np.abs(error)):.3f} ms") +print(f"Min Error = {1e3*np.nanmin(error):.3f} ms") +print(f"Max Error = {1e3*np.nanmax(error):.3f} ms") print(f"Num Outliers (outisde TQ [start, end] by >1 ms) = {np.sum(~db_test['within_tq_range'])} out of {len(shotlist)} shots") # Plot onset errors From 1877a1bbedba42abac29e795b98518ee0fc9a42e Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Wed, 3 Jun 2026 18:47:06 -0400 Subject: [PATCH 14/19] Get SXR chords with MDSplus getMany() for potential speedup --- disruption_py/machine/cmod/physics.py | 75 +++++++++++++++++++++------ 1 file changed, 60 insertions(+), 15 deletions(-) diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index 69cdb4c98..88f4801ba 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -2139,29 +2139,73 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): t0_read = perf_counter() n_chords = 38 array_path = r"\top.brightnesses.array_1" + # try: + # chord_01, t_sxr = params.get_data_with_dims( + # array_path + ":chord_01", + # tree_name="xtomo", + # ) # Units: W/m^2, s + # except mdsExceptions.MdsException: + # print(params.shot_id) + # params.logger.debug("Failed to get SXR " + array_path + " data") + # return {"thermal_quench_time": np.full(len(params.times), np.nan)} + # valid_times = (t_sxr > 0) & (t_sxr < 2.0) + # t_sxr = t_sxr[valid_times] + # sxr = np.zeros(shape=(n_chords, len(t_sxr))) + # sxr[0] = chord_01[valid_times] + # # Get all other SXR chords + # for i in range(1, n_chords): + # try: + # chord, t_chord = params.get_data_with_dims( + # array_path + ":chord_" + f"{i+1:02}", + # tree_name="xtomo", + # ) + # except mdsExceptions.MdsException: + # params.logger.debug("Failed to get SXR " + array_path + " chord " + str(i+1) + " data") + # sxr[i] = 0. + # continue + # # Subtract constant background + # chord = chord - np.mean(chord[t_chord < 0.]) + # # Occasionally the time bases of a chord are of a different length + # # Usually one timebase is just cut off early after shot is over + # valid_times = (t_chord > 0) & (t_chord < 2.) + # # Goods chords should be of the same shape + # if len(chord[valid_times]) == sxr.shape[1]: + # sxr[i] = chord[valid_times] + + + # Use getMany to read 38 SXR chords and their time arrays in one batch + params.data_conn.conn.openTree("xtomo", params.shot_id) + gm = params.data_conn.conn.getMany() + for i in range(n_chords): + chord_key = f"chord_{i+1:02}" + node_path = f"{array_path}:{chord_key}" + # Append data expression + gm.append(f"data_{chord_key}", node_path) + gm.append(f"time_{chord_key}", f"data(dim_of({node_path}))") + sxr_results = gm.execute() + + # Unpack the first SXR chord to get the shape of the SXR data + chord_key = "chord_01" try: - chord_01, t_sxr = params.get_data_with_dims( - array_path + ":chord_01", - tree_name="xtomo", - ) # Units: W/m^2, s - except mdsExceptions.MdsException: - print(params.shot_id) + chord = np.array(sxr_results.get(f"data_{chord_key}")["value"].data()) + t_sxr = np.array(sxr_results.get(f"time_{chord_key}")["value"].data()) + except mdsExceptions.MdsException as e: params.logger.debug("Failed to get SXR " + array_path + " data") return {"thermal_quench_time": np.full(len(params.times), np.nan)} + chord = chord - np.mean(chord[t_sxr < 0.]) valid_times = (t_sxr > 0) & (t_sxr < 2.0) t_sxr = t_sxr[valid_times] sxr = np.zeros(shape=(n_chords, len(t_sxr))) - sxr[0] = chord_01[valid_times] - # Get all other SXR chords + sxr[0] = chord[valid_times] + # Unpack the rest of the chords for i in range(1, n_chords): + chord_key = f"chord_{i+1:02}" try: - chord, t_chord = params.get_data_with_dims( - array_path + ":chord_" + f"{i+1:02}", - tree_name="xtomo", - ) - except mdsExceptions.MdsException: - params.logger.debug("Failed to get SXR " + array_path + " chord " + str(i+1) + " data") - sxr[i] = 0. + # TODO: Remove np.array? + chord = np.array(sxr_results.get(f"data_{chord_key}")["value"].data()) + t_chord = np.array(sxr_results.get(f"time_{chord_key}")["value"].data()) + except mdsExceptions.MdsException as e: + params.logger.debug("Failed to get SXR " + array_path + " data") continue # Subtract constant background chord = chord - np.mean(chord[t_chord < 0.]) @@ -2171,6 +2215,7 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): # Goods chords should be of the same shape if len(chord[valid_times]) == sxr.shape[1]: sxr[i] = chord[valid_times] + sample_time = t_sxr[1] - t_sxr[0] sample_freq = 1 / sample_time timings = {'read_mds': perf_counter() - t0_read} From b21916a7fb1b091e8486311e74aa66d40ece093f Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Wed, 3 Jun 2026 18:49:10 -0400 Subject: [PATCH 15/19] Back to for loop for reading SXR data, which is faster than getMany() --- disruption_py/machine/cmod/physics.py | 74 ++++++--------------------- 1 file changed, 15 insertions(+), 59 deletions(-) diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index 88f4801ba..dca09594e 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -2139,73 +2139,29 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): t0_read = perf_counter() n_chords = 38 array_path = r"\top.brightnesses.array_1" - # try: - # chord_01, t_sxr = params.get_data_with_dims( - # array_path + ":chord_01", - # tree_name="xtomo", - # ) # Units: W/m^2, s - # except mdsExceptions.MdsException: - # print(params.shot_id) - # params.logger.debug("Failed to get SXR " + array_path + " data") - # return {"thermal_quench_time": np.full(len(params.times), np.nan)} - # valid_times = (t_sxr > 0) & (t_sxr < 2.0) - # t_sxr = t_sxr[valid_times] - # sxr = np.zeros(shape=(n_chords, len(t_sxr))) - # sxr[0] = chord_01[valid_times] - # # Get all other SXR chords - # for i in range(1, n_chords): - # try: - # chord, t_chord = params.get_data_with_dims( - # array_path + ":chord_" + f"{i+1:02}", - # tree_name="xtomo", - # ) - # except mdsExceptions.MdsException: - # params.logger.debug("Failed to get SXR " + array_path + " chord " + str(i+1) + " data") - # sxr[i] = 0. - # continue - # # Subtract constant background - # chord = chord - np.mean(chord[t_chord < 0.]) - # # Occasionally the time bases of a chord are of a different length - # # Usually one timebase is just cut off early after shot is over - # valid_times = (t_chord > 0) & (t_chord < 2.) - # # Goods chords should be of the same shape - # if len(chord[valid_times]) == sxr.shape[1]: - # sxr[i] = chord[valid_times] - - - # Use getMany to read 38 SXR chords and their time arrays in one batch - params.data_conn.conn.openTree("xtomo", params.shot_id) - gm = params.data_conn.conn.getMany() - for i in range(n_chords): - chord_key = f"chord_{i+1:02}" - node_path = f"{array_path}:{chord_key}" - # Append data expression - gm.append(f"data_{chord_key}", node_path) - gm.append(f"time_{chord_key}", f"data(dim_of({node_path}))") - sxr_results = gm.execute() - - # Unpack the first SXR chord to get the shape of the SXR data - chord_key = "chord_01" try: - chord = np.array(sxr_results.get(f"data_{chord_key}")["value"].data()) - t_sxr = np.array(sxr_results.get(f"time_{chord_key}")["value"].data()) - except mdsExceptions.MdsException as e: + chord_01, t_sxr = params.get_data_with_dims( + array_path + ":chord_01", + tree_name="xtomo", + ) # Units: W/m^2, s + except mdsExceptions.MdsException: + print(params.shot_id) params.logger.debug("Failed to get SXR " + array_path + " data") return {"thermal_quench_time": np.full(len(params.times), np.nan)} - chord = chord - np.mean(chord[t_sxr < 0.]) valid_times = (t_sxr > 0) & (t_sxr < 2.0) t_sxr = t_sxr[valid_times] sxr = np.zeros(shape=(n_chords, len(t_sxr))) - sxr[0] = chord[valid_times] - # Unpack the rest of the chords + sxr[0] = chord_01[valid_times] + # Get all other SXR chords for i in range(1, n_chords): - chord_key = f"chord_{i+1:02}" try: - # TODO: Remove np.array? - chord = np.array(sxr_results.get(f"data_{chord_key}")["value"].data()) - t_chord = np.array(sxr_results.get(f"time_{chord_key}")["value"].data()) - except mdsExceptions.MdsException as e: - params.logger.debug("Failed to get SXR " + array_path + " data") + chord, t_chord = params.get_data_with_dims( + array_path + ":chord_" + f"{i+1:02}", + tree_name="xtomo", + ) + except mdsExceptions.MdsException: + params.logger.debug("Failed to get SXR " + array_path + " chord " + str(i+1) + " data") + sxr[i] = 0. continue # Subtract constant background chord = chord - np.mean(chord[t_chord < 0.]) From 1ec7fbea95d76eb27f5e591863f35c1e1f1899f8 Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Thu, 4 Jun 2026 10:57:59 -0400 Subject: [PATCH 16/19] Script to calculate timing results --- scripts/get_tq_timing_results.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 scripts/get_tq_timing_results.py diff --git a/scripts/get_tq_timing_results.py b/scripts/get_tq_timing_results.py new file mode 100644 index 000000000..0606f41b9 --- /dev/null +++ b/scripts/get_tq_timing_results.py @@ -0,0 +1,29 @@ +import os +from glob import glob +import numpy as np +import pandas as pd + +def read_csv_files(dir, file_pattern): + csv_files = glob(os.path.join(dir, file_pattern)) + dfs = [pd.read_csv(file) for file in csv_files] + return pd.concat(dfs, ignore_index=True) + +def write_timing_results(df): + print(f"Timing Avgs:\n---------------") + print(f"Reading MDSplus: {np.nanmean(df['read_mds']):.6} s") + print(f"Calculating autocorrelation: {np.nanmean(df['autocorr']):.6} s") + print(f"Butterworth Filter: {np.nanmean(df['bworth']):.6} s") + print(f"Finding TQ: {np.nanmean(df['find_tq']):.6} s") + print(f"Total: {np.nanmean(df['total']):.6} s") + print("---------------") + + +if __name__=='__main__': + timing_dir = '/home/henrycw/projects/label-thermal-quench/disruption-py-label-thermal-quench/disruption-py/tq_timing' + df_05 = read_csv_files(timing_dir, 'timing_105*.csv') + print(f"\nTiming results for 2005 shots") + write_timing_results(df_05) + df_12_16 = read_csv_files(timing_dir, 'timing_11*.csv') + print(f"\nTiming results for 2012-2016 shots") + write_timing_results(df_12_16) + From 9c18fe7e9e91f4f1100cdec5ddb06fca73ae84bb Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Thu, 4 Jun 2026 16:13:16 -0400 Subject: [PATCH 17/19] Bug fix chord 1 background subtraction; initialize sxr lazily --- disruption_py/machine/cmod/physics.py | 37 +++++++++---------- scripts/get_thermal_quench_times.py | 2 +- scripts/plot_test_results_tq_time.py | 51 ++++++++++++++++++--------- scripts/plot_thermal_quench_time.py | 5 ++- 4 files changed, 55 insertions(+), 40 deletions(-) diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index dca09594e..bfc530271 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -2127,7 +2127,6 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): thermal_quench_time = np.full(len(params.times), np.nan) if params.disruption_time is None: return {"thermal_quench_time": thermal_quench_time} - params.logger.debug(params.disruption_time) # Get current data for obtaining start of current quench ip, magtime = params.get_data_with_dims( @@ -2139,21 +2138,11 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): t0_read = perf_counter() n_chords = 38 array_path = r"\top.brightnesses.array_1" - try: - chord_01, t_sxr = params.get_data_with_dims( - array_path + ":chord_01", - tree_name="xtomo", - ) # Units: W/m^2, s - except mdsExceptions.MdsException: - print(params.shot_id) - params.logger.debug("Failed to get SXR " + array_path + " data") - return {"thermal_quench_time": np.full(len(params.times), np.nan)} - valid_times = (t_sxr > 0) & (t_sxr < 2.0) - t_sxr = t_sxr[valid_times] - sxr = np.zeros(shape=(n_chords, len(t_sxr))) - sxr[0] = chord_01[valid_times] - # Get all other SXR chords - for i in range(1, n_chords): + # Initialize sxr, etc after first successful read of a chord so we know the size + sxr = None + t_sxr = None + valid_times = None + for i in range(n_chords): try: chord, t_chord = params.get_data_with_dims( array_path + ":chord_" + f"{i+1:02}", @@ -2161,15 +2150,22 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): ) except mdsExceptions.MdsException: params.logger.debug("Failed to get SXR " + array_path + " chord " + str(i+1) + " data") - sxr[i] = 0. + if sxr is not None: + sxr[i] = 0. continue # Subtract constant background chord = chord - np.mean(chord[t_chord < 0.]) + if sxr is None: + valid_times = (t_chord > 0) & (t_chord < params.disruption_time + 0.05) + t_sxr = t_chord[valid_times] + sxr = np.zeros(shape=(n_chords, len(t_sxr))) + sxr[i] = chord[valid_times] + continue # Occasionally the time bases of a chord are of a different length # Usually one timebase is just cut off early after shot is over - valid_times = (t_chord > 0) & (t_chord < 2.) + valid_times = (t_chord > 0) & (t_chord < params.disruption_time + 0.05) # Goods chords should be of the same shape - if len(chord[valid_times]) == sxr.shape[1]: + if np.sum(valid_times) == sxr.shape[1]: sxr[i] = chord[valid_times] sample_time = t_sxr[1] - t_sxr[0] @@ -2245,9 +2241,8 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): wndw_before_cq = 0.005 # [s] idx_start = np.argmin(np.abs(t_sxr - (cq_onset_time - wndw_before_cq))) idx_end = np.argmin(np.abs(t_sxr - (cq_onset_time))) - # When params.disruption_time > 2 s, the SXR data stops prior to the labeled CQ if idx_start == len(t_sxr) - 1: - params.logger.warning(f"No SXR data at time of CQ. params.disruption_time = {params.disruption_time:.3f}") + params.logger.debug(f"No SXR data at time of CQ. params.disruption_time = {params.disruption_time:.3f}") return {"thermal_quench_time": np.full(len(params.times), np.nan)} t_max_sxr_drop = t_sxr[idx_start + np.argmin(dcore_sxr_dt[idx_start:idx_end])] diff --git a/scripts/get_thermal_quench_times.py b/scripts/get_thermal_quench_times.py index f21e2f231..d49b92773 100644 --- a/scripts/get_thermal_quench_times.py +++ b/scripts/get_thermal_quench_times.py @@ -17,7 +17,7 @@ # Shotlist of all C-Mod shots 2012-2016 SHOTLIST_FN = '/home/henrycw/projects/ufo-characterization/shotlists/cmod_shots_2012_to_2016.txt' -NUM_SUBSET = 500 # Number of shots to randomly select from Shotlist +NUM_SUBSET = 1000 # Number of shots to randomly select from Shotlist shotlist = np.genfromtxt(SHOTLIST_FN, dtype=int) rng = np.random.default_rng(seed=42) diff --git a/scripts/plot_test_results_tq_time.py b/scripts/plot_test_results_tq_time.py index 4bf8e45f6..28161328d 100644 --- a/scripts/plot_test_results_tq_time.py +++ b/scripts/plot_test_results_tq_time.py @@ -7,6 +7,8 @@ import pandas as pd import matplotlib.pyplot as plt +plt.rcParams['font.size'] = 14 + df_old = pd.read_csv('test_thermal_quench_results_cq_onset90.csv') df_new = pd.read_csv('test_thermal_quench_results_cq_onset95.csv') #df_90 = pd.read_csv('test_thermal_quench_results_max90.csv') @@ -15,23 +17,38 @@ labels = ['0.9*max'] bins = np.linspace(-0.004, 0.004, 60) -for i, df in enumerate(dfs): - # Print summary statistics - error = df['onset_error_s'].to_numpy() - print(f"Stats for {labels[i]}") - print(f"Mean |Error| = {1e3*np.mean(np.abs(error)):.3f} ms") - print(f"Median |Error| = {1e3*np.median(np.abs(error)):.3f} ms") - print(f"Std Dev |Error| = {1e3*np.std(np.abs(error)):.3f} ms") - print(f"Min Error = {1e3*np.min(error):.3f} ms") - print(f"Max Error = {1e3*np.max(error):.3f} ms") - print(f"Num Outliers (outisde TQ [start, end] by >1 ms) = {np.sum(~df['within_tq_range'])} out of {len(df)} shots\n") - print(df[~df['within_tq_range']].sort_values(by='onset_error_s')) +# for i, df in enumerate(dfs): +# # Print summary statistics +# error = df['onset_error_s'].to_numpy() +# print(f"Stats for {labels[i]}") +# print(f"Mean |Error| = {1e3*np.mean(np.abs(error)):.3f} ms") +# print(f"Median |Error| = {1e3*np.median(np.abs(error)):.3f} ms") +# print(f"Std Dev |Error| = {1e3*np.std(np.abs(error)):.3f} ms") +# print(f"Min Error = {1e3*np.min(error):.3f} ms") +# print(f"Max Error = {1e3*np.max(error):.3f} ms") +# print(f"Num Outliers (outisde TQ [start, end] by >1 ms) = {np.sum(~df['within_tq_range'])} out of {len(df)} shots\n") +# print(df[~df['within_tq_range']].sort_values(by='onset_error_s')) - plt.hist(df['onset_error_s'], edgecolor='k', bins=bins, label=labels[i], alpha=0.5) +# plt.hist(df['onset_error_s'], edgecolor='k', bins=bins, label=labels[i], alpha=0.5) -plt.title("Test vs 100 Manually Labeled Shots", fontsize=18) -plt.xlabel('Error in TQ Onset Time (Auto - Manual) [s]', fontsize=16) -plt.xlim(-0.003, 0.00325) -plt.ylabel('Count', fontsize=16) -plt.legend() +# plt.title("Test vs 100 Manually Labeled Shots", fontsize=18) +# plt.xlabel('Error in TQ Onset Time (Auto - Manual) [s]', fontsize=16) +# plt.xlim(-0.003, 0.00325) +# plt.ylabel('Count', fontsize=16) +# plt.show() +df = dfs[0] +plt.figure(figsize=(8,5)) +df['t_disrupt_err'] = df['t_disrupt'] - df['tq_onset_manual'] +print(df.sort_values(by='t_disrupt_err', ascending=False)) +plt.scatter(df['t_disrupt_err']*1e3, df['onset_error_s']*1e3, marker='o', alpha=0.5, edgecolors='k', label='shot') +plt.xlabel('$t_{disrupt} - t_{TQ, man}$ [ms]', fontsize=18) +plt.ylabel('$t_{TQ,auto} - t_{TQ,man}$ [ms]', fontsize=18) +x_sample = np.linspace(0, 20, 10) +plt.plot(x_sample, x_sample, c='r') +plt.ylim(-3.5, 3.5) +plt.xlim(0, 16) +plt.text(0.1, 0.7, '$t_{TQ,auto} > t_{disrupt}$', color='r', rotation=52 ) +plt.title('Test Results on 100 Manually Labeled Shots') +plt.legend(fontsize=18) +plt.tight_layout() plt.show() \ No newline at end of file diff --git a/scripts/plot_thermal_quench_time.py b/scripts/plot_thermal_quench_time.py index c3ee66fec..7ced50a6b 100644 --- a/scripts/plot_thermal_quench_time.py +++ b/scripts/plot_thermal_quench_time.py @@ -24,7 +24,8 @@ #TODO: Shot 1160714006 having issues (low SXR signal). What do we do about ramp-up? # TODO: Shot 11405522001 has significant SXR spike when plasma hits wall, after main TQ # TODO: Search for first time at which dSXR/dt is w/in factor of 2 from max? -SHOT_ID = 1140826003 +# 1050421007 +SHOT_ID = 1050311013 signals = [ "ip", "zcur", @@ -89,6 +90,8 @@ axs[2].set_ylabel('SXR raw') axs[3].set_ylabel('SXR filt') axs[4].set_ylabel(r"$dSXR/dt$ [Hz]") +axs[-1].set_xlabel('Time [s]') +#axs[-1].set_xlim(0.49, 0.508) # axs[3].set_ylim(-8e3, 2e3) # axs[4].set_ylabel('Z0 [m]') # axs[4].set_xlabel("Time [s]") From 4174397076225d2e2a7343890655a199a10a1921 Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Mon, 8 Jun 2026 17:22:29 -0400 Subject: [PATCH 18/19] Clean up --- disruption_py/machine/cmod/physics.py | 121 +++++++++----------------- scripts/get_thermal_quench_times.py | 49 ----------- scripts/get_tq_timing_results.py | 29 ------ scripts/plot_test_results_tq_time.py | 54 ------------ scripts/plot_thermal_quench_time.py | 120 ------------------------- scripts/plot_tq_timetraces.py | 68 --------------- tests/test_thermal_quench_times.py | 75 ---------------- 7 files changed, 42 insertions(+), 474 deletions(-) delete mode 100644 scripts/get_thermal_quench_times.py delete mode 100644 scripts/get_tq_timing_results.py delete mode 100644 scripts/plot_test_results_tq_time.py delete mode 100644 scripts/plot_thermal_quench_time.py delete mode 100644 scripts/plot_tq_timetraces.py delete mode 100644 tests/test_thermal_quench_times.py diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index bfc530271..0f9690b11 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -8,6 +8,7 @@ import numpy as np import scipy.constants as const +from scipy.signal import butter, filtfilt, resample_poly from disruption_py.core.physics_method.caching import cache_method from disruption_py.core.physics_method.decorator import physics_method @@ -2093,16 +2094,16 @@ def get_surface_voltage(params: PhysicsMethodParams): v_surf = interp1(efit_time, v_surf, params.times) return {"v_surf": v_surf} - + @staticmethod - # TODO: Remove t_disrupt, core_sxr - @physics_method(columns=["thermal_quench_time", "t_disrupt", "core_sxr"], tokamak=Tokamak.CMOD) + @physics_method(columns=["thermal_quench_time"], tokamak=Tokamak.CMOD) def get_thermal_quench_onset_time(params: PhysicsMethodParams): """ Labels the onset time of the thermal quench for a given shot (NaN for non-disruptive shots) using a vertical SXR array due to its off-axis views and robustness across shots, as opposed to ECE. The labeling method is non-causal (i.e. post-shot processing). - The TQ is found by searching for min(dSXR/dt) in a time window prior to the CQ. + The TQ is found by finding min(dSXR/dt) in a time window prior to the CQ + and then searching backwards for the onset of the TQ. There is a tension between using longer windows to find the first TQ in a multi-stage TQ versus using a shorter window to avoid labeling sawtooth crashes. Thus, for shots with multi-stage thermal quenches, (see shots 1050830034 and 1120717002), @@ -2117,25 +2118,18 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): thermal_quench_time : array_like time of thermal quench onset for the shot, identical values at each time-slice - Last Major Update: Henry Wietfeldt (1/26/26) + Last Major Update: Henry Wietfeldt (06/08/26) """ - import os - import pandas as pd - from scipy.signal import butter, filtfilt, resample_poly - from time import perf_counter - t0 = perf_counter() + thermal_quench_time = np.full(len(params.times), np.nan) if params.disruption_time is None: return {"thermal_quench_time": thermal_quench_time} # Get current data for obtaining start of current quench - ip, magtime = params.get_data_with_dims( - r"\ip", tree_name="magnetics" - ) + ip, magtime = params.get_data_with_dims(r"\ip", tree_name="magnetics") ip = np.abs(ip) # Get SXR chords - t0_read = perf_counter() n_chords = 38 array_path = r"\top.brightnesses.array_1" # Initialize sxr, etc after first successful read of a chord so we know the size @@ -2149,12 +2143,14 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): tree_name="xtomo", ) except mdsExceptions.MdsException: - params.logger.debug("Failed to get SXR " + array_path + " chord " + str(i+1) + " data") + params.logger.debug( + "Failed to get SXR " + array_path + " chord " + str(i + 1) + " data" + ) if sxr is not None: - sxr[i] = 0. + sxr[i] = 0.0 continue # Subtract constant background - chord = chord - np.mean(chord[t_chord < 0.]) + chord = chord - np.mean(chord[t_chord < 0.0]) if sxr is None: valid_times = (t_chord > 0) & (t_chord < params.disruption_time + 0.05) t_sxr = t_chord[valid_times] @@ -2170,62 +2166,56 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): sample_time = t_sxr[1] - t_sxr[0] sample_freq = 1 / sample_time - timings = {'read_mds': perf_counter() - t0_read} # Remove bad chords by checking each chord's autocorrelation. # Bad chords often have significant white noise, meaning low autocorrelation (< 10 ms) # Good chords should have an autocorrelation of 100s of ms # See shot 1050311013 as an example with some bad chords - t0_autocorr = perf_counter() - noise_autorr_cutoff = 0.01 # [s] + noise_autorr_cutoff = 0.01 # [s] for i, chord in enumerate(sxr): # Use 300 ms prior to current quench for speed-up during autocorr O(N^2) idx_start = np.argmin(np.abs(t_sxr - (params.disruption_time - 0.3))) idx_end = np.argmin(np.abs(t_sxr - (params.disruption_time))) chord = chord[idx_start:idx_end] - sample_freq_5khz = 5000 # [Hz] + sample_freq_5khz = 5000 # [Hz] if sample_freq > 5000: # 2012-2016 has 250 kHz sampling frequency. Resample to 5 kHz frequency # (native SXR sample frequency of earlier campaigns) for speed-up - chord = resample_poly(chord, up=1, down=sample_freq//sample_freq_5khz) - autocorr = np.correlate(chord, chord, mode='full') + chord = resample_poly(chord, up=1, down=sample_freq // sample_freq_5khz) + autocorr = np.correlate(chord, chord, mode="full") max_autocorr = np.max(autocorr) if max_autocorr > 0: autocorr = autocorr / np.max(autocorr) # Normalize else: - sxr[i] = 0. + sxr[i] = 0.0 continue index_no_lag = np.argmax(autocorr) - params.logger.debug(f"Chord {i+1}. Autocorr[index_no_lag]: {autocorr[index_no_lag:index_no_lag + 8]}") crosses_zero = autocorr[index_no_lag:] < 0 if np.any(crosses_zero): index_decay = np.argmax(crosses_zero) else: # See shot 1120223007 for example of why this if-else logic is necessary index_decay = len(crosses_zero) - params.logger.debug(f"Chord {i+1} index_decay: {index_decay}") - if index_decay*(1/sample_freq_5khz) < noise_autorr_cutoff: - params.logger.debug(f"Removing chord {i+1}. Norm. Autocorr: {index_decay*(1/sample_freq_5khz)}") - sxr[i] = 0. - timings['autocorr'] = perf_counter() - t0_autocorr + if index_decay * (1 / sample_freq_5khz) < noise_autorr_cutoff: + params.logger.debug( + f"Removing chord {i+1}. Norm. Autocorr: {index_decay*(1/sample_freq_5khz)}" + ) + sxr[i] = 0.0 # Noncausal Butterworth low pass filter to smooth transient SXR spikes during TQ. # Cutoff of 1.0 kHz and order 2 seems to filter recombination SXR spikes # while maintaining decent resolution of TQ based on scan from 0.25 kHz - 2 kHz # Results were fairly insensitive within these windows on the 100 shots checked # See shot 1120913013 as example of large recombination spike - t0_bworth = perf_counter() - bworth_cutoff = 1000 # [Hz] + bworth_cutoff = 1000 # [Hz] bworth_order = 2 - normalized_cutoff = bworth_cutoff / (0.5*sample_freq) - b, a = butter(bworth_order, normalized_cutoff, btype='low', analog=False) + normalized_cutoff = bworth_cutoff / (0.5 * sample_freq) + b, a = butter(bworth_order, normalized_cutoff, btype="low", analog=False) core_sxr_raw = np.max(sxr, axis=0) sxr = filtfilt(b, a, sxr, axis=1) core_sxr = np.max(sxr, axis=0) - dcore_sxr_dt = np.diff(core_sxr, prepend=0)/sample_time - timings['bworth'] = perf_counter() - t0_bworth + dcore_sxr_dt = np.diff(core_sxr, prepend=0) / sample_time - t0_find_tq = perf_counter() # Search for the onset of the CQ so that we can search for the TQ in a small time window # to avoid labeleing sawtooth crashes as the thermal quench # Some current quenches can be long (see shots 1050311013, 1050802017). @@ -2234,15 +2224,17 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): idx_end = np.argmin(np.abs(magtime - (params.disruption_time - 0.02))) ip_prior = np.min(ip[idx_start:idx_end]) # CQ onset is last moment Ip is >90% Ip prior to disruption - idx_cq_onset = np.where(ip > 0.9*ip_prior)[0][-1] + idx_cq_onset = np.where(ip > 0.9 * ip_prior)[0][-1] cq_onset_time = magtime[idx_cq_onset] # Search for TQ midpoint as min(dSXR/dt) in window of 5 ms prior to current quench onset - wndw_before_cq = 0.005 # [s] + wndw_before_cq = 0.005 # [s] idx_start = np.argmin(np.abs(t_sxr - (cq_onset_time - wndw_before_cq))) idx_end = np.argmin(np.abs(t_sxr - (cq_onset_time))) if idx_start == len(t_sxr) - 1: - params.logger.debug(f"No SXR data at time of CQ. params.disruption_time = {params.disruption_time:.3f}") + params.logger.debug( + f"No SXR data at time of CQ. params.disruption_time = {params.disruption_time:.3f}" + ) return {"thermal_quench_time": np.full(len(params.times), np.nan)} t_max_sxr_drop = t_sxr[idx_start + np.argmin(dcore_sxr_dt[idx_start:idx_end])] @@ -2252,49 +2244,20 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): # last timestep with SXR > 90% of that max value # Use raw signal bc smoothed signal has a longer crash time. # Note this sometimes picks up on recombination spikes - wndw_before_tq_midpoint = 0.0005 # [s] - idx_start = np.argmin(np.abs(t_sxr - (t_max_sxr_drop - wndw_before_tq_midpoint))) + wndw_before_tq_midpoint = 0.0005 # [s] + idx_start = np.argmin( + np.abs(t_sxr - (t_max_sxr_drop - wndw_before_tq_midpoint)) + ) idx_end = np.argmin(np.abs(t_sxr - (t_max_sxr_drop))) window = core_sxr_raw[idx_start:idx_end] # Want last maximum in case the SXR has saturated and there are multiple maxima - max_sxr_indx = np.nonzero(window >= 0.9*np.max(window))[0][-1] + max_sxr_indx = np.nonzero(window >= 0.9 * np.max(window))[0][-1] tq_time_scalar = t_sxr[idx_start + max_sxr_indx] - timings['find_tq'] = perf_counter() - t0_find_tq - - # TODO: Delete this block during clean-up - # TODO: Comment this out when running over many shots - #Write some signals for plotting - #Get magnetic axis data from EFIT for testing purposes - # z0, efit_time = params.get_data_with_dims( - # r"\efit_aeqdsk:zmagx", tree_name="_efit_tree" - # ) # [cm], [s] - # z0 *= 0.01 # [cm] -> [m] - # te0_ece, t_ece = params.get_data_with_dims(r"\gpc2_te0", tree_name="electrons") - # import pickle - # plot_df = {"magtime":magtime, - # "ip": ip, - # "t_sxr": t_sxr, - # "t_ece": t_ece, - # "te0_ece": te0_ece, - # "core_sxr_raw": core_sxr_raw, - # "core_sxr": core_sxr, - # "core_sxr_growth_rate": dcore_sxr_dt, - # "t_disrupt": params.disruption_time, - # "cq_onset_time": cq_onset_time, - # "t_max_sxr_drop": t_max_sxr_drop, - # "thermal_quench_time_scalar": tq_time_scalar, - # } - # with open('sxr.pkl', 'wb') as f: - # pickle.dump(plot_df, f) - - # Output timings (TODO: Remove this eventually) - timings['total'] = perf_counter() - t0 - fn = os.path.join('/home/henrycw/projects/label-thermal-quench/disruption-py-label-thermal-quench/disruption-py/tq_timing', f'timing_{params.shot_id}.csv') - pd.DataFrame([timings]).to_csv(fn, index=False) - - # TODO: Remove t_disrupt, core_sxr - core_sxr = interp1(t_sxr, core_sxr_raw, params.times) - return {"thermal_quench_time": tq_time_scalar*np.ones(len(params.times)), "t_disrupt": params.disruption_time*np.ones(len(params.times)), "core_sxr": core_sxr} + return { + "thermal_quench_time": tq_time_scalar * np.ones(len(params.times)), + "t_disrupt": params.disruption_time * np.ones(len(params.times)), + "core_sxr": core_sxr, + } @staticmethod def _is_on_blacklist(shot_id: int) -> bool: diff --git a/scripts/get_thermal_quench_times.py b/scripts/get_thermal_quench_times.py deleted file mode 100644 index d49b92773..000000000 --- a/scripts/get_thermal_quench_times.py +++ /dev/null @@ -1,49 +0,0 @@ -""" -This script gets the times of the thermal quench onset for a specified shotlist. -Useful for testing the thermal quench labler. -We can delete this when we merge the thermal quench labeler into dev -Author: Henry Wietfeldt -""" - -import os -import logging - -import numpy as np -import pandas as pd - -from disruption_py.settings import LogSettings, RetrievalSettings -from disruption_py.workflow import get_shots_data - - -# Shotlist of all C-Mod shots 2012-2016 -SHOTLIST_FN = '/home/henrycw/projects/ufo-characterization/shotlists/cmod_shots_2012_to_2016.txt' -NUM_SUBSET = 1000 # Number of shots to randomly select from Shotlist - -shotlist = np.genfromtxt(SHOTLIST_FN, dtype=int) -rng = np.random.default_rng(seed=42) -rng.shuffle(shotlist) -if len(shotlist) <= NUM_SUBSET: - shots_to_use = shotlist -else: - shots_to_use = shotlist[:NUM_SUBSET] - -# shots_to_use = [1140821020, 1140520016 ,1150710007] - -# default method for pulling disruption-py data -retrieval_settings = RetrievalSettings( - time_setting="disruption_warning", # use the set efit's timebase - efit_nickname_setting="analysis", # set the efit - run_methods=[], - run_columns=["ip", "thermal_quench_time"], - only_requested_columns=True, -) - -results = get_shots_data( - shotlist_setting=shots_to_use, - retrieval_settings=retrieval_settings, - log_settings=LogSettings(console_level=logging.WARNING), - num_processes=os.cpu_count(), -) - -# Write contents to csv for easy inspection -results.to_dataframe().to_csv('tq_labels.csv') \ No newline at end of file diff --git a/scripts/get_tq_timing_results.py b/scripts/get_tq_timing_results.py deleted file mode 100644 index 0606f41b9..000000000 --- a/scripts/get_tq_timing_results.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -from glob import glob -import numpy as np -import pandas as pd - -def read_csv_files(dir, file_pattern): - csv_files = glob(os.path.join(dir, file_pattern)) - dfs = [pd.read_csv(file) for file in csv_files] - return pd.concat(dfs, ignore_index=True) - -def write_timing_results(df): - print(f"Timing Avgs:\n---------------") - print(f"Reading MDSplus: {np.nanmean(df['read_mds']):.6} s") - print(f"Calculating autocorrelation: {np.nanmean(df['autocorr']):.6} s") - print(f"Butterworth Filter: {np.nanmean(df['bworth']):.6} s") - print(f"Finding TQ: {np.nanmean(df['find_tq']):.6} s") - print(f"Total: {np.nanmean(df['total']):.6} s") - print("---------------") - - -if __name__=='__main__': - timing_dir = '/home/henrycw/projects/label-thermal-quench/disruption-py-label-thermal-quench/disruption-py/tq_timing' - df_05 = read_csv_files(timing_dir, 'timing_105*.csv') - print(f"\nTiming results for 2005 shots") - write_timing_results(df_05) - df_12_16 = read_csv_files(timing_dir, 'timing_11*.csv') - print(f"\nTiming results for 2012-2016 shots") - write_timing_results(df_12_16) - diff --git a/scripts/plot_test_results_tq_time.py b/scripts/plot_test_results_tq_time.py deleted file mode 100644 index 28161328d..000000000 --- a/scripts/plot_test_results_tq_time.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Script for plotting test results from tests/test_thermal_quench_times.py -Author: Henry Wietfeldt -""" - -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt - -plt.rcParams['font.size'] = 14 - -df_old = pd.read_csv('test_thermal_quench_results_cq_onset90.csv') -df_new = pd.read_csv('test_thermal_quench_results_cq_onset95.csv') -#df_90 = pd.read_csv('test_thermal_quench_results_max90.csv') - -dfs = [df_old] -labels = ['0.9*max'] -bins = np.linspace(-0.004, 0.004, 60) - -# for i, df in enumerate(dfs): -# # Print summary statistics -# error = df['onset_error_s'].to_numpy() -# print(f"Stats for {labels[i]}") -# print(f"Mean |Error| = {1e3*np.mean(np.abs(error)):.3f} ms") -# print(f"Median |Error| = {1e3*np.median(np.abs(error)):.3f} ms") -# print(f"Std Dev |Error| = {1e3*np.std(np.abs(error)):.3f} ms") -# print(f"Min Error = {1e3*np.min(error):.3f} ms") -# print(f"Max Error = {1e3*np.max(error):.3f} ms") -# print(f"Num Outliers (outisde TQ [start, end] by >1 ms) = {np.sum(~df['within_tq_range'])} out of {len(df)} shots\n") -# print(df[~df['within_tq_range']].sort_values(by='onset_error_s')) - -# plt.hist(df['onset_error_s'], edgecolor='k', bins=bins, label=labels[i], alpha=0.5) - -# plt.title("Test vs 100 Manually Labeled Shots", fontsize=18) -# plt.xlabel('Error in TQ Onset Time (Auto - Manual) [s]', fontsize=16) -# plt.xlim(-0.003, 0.00325) -# plt.ylabel('Count', fontsize=16) -# plt.show() -df = dfs[0] -plt.figure(figsize=(8,5)) -df['t_disrupt_err'] = df['t_disrupt'] - df['tq_onset_manual'] -print(df.sort_values(by='t_disrupt_err', ascending=False)) -plt.scatter(df['t_disrupt_err']*1e3, df['onset_error_s']*1e3, marker='o', alpha=0.5, edgecolors='k', label='shot') -plt.xlabel('$t_{disrupt} - t_{TQ, man}$ [ms]', fontsize=18) -plt.ylabel('$t_{TQ,auto} - t_{TQ,man}$ [ms]', fontsize=18) -x_sample = np.linspace(0, 20, 10) -plt.plot(x_sample, x_sample, c='r') -plt.ylim(-3.5, 3.5) -plt.xlim(0, 16) -plt.text(0.1, 0.7, '$t_{TQ,auto} > t_{disrupt}$', color='r', rotation=52 ) -plt.title('Test Results on 100 Manually Labeled Shots') -plt.legend(fontsize=18) -plt.tight_layout() -plt.show() \ No newline at end of file diff --git a/scripts/plot_thermal_quench_time.py b/scripts/plot_thermal_quench_time.py deleted file mode 100644 index 7ced50a6b..000000000 --- a/scripts/plot_thermal_quench_time.py +++ /dev/null @@ -1,120 +0,0 @@ -""" -Program to plot various quantities used in calculating the thermal quench time -for a particular shot to compare various methods. -Author: Henry Wietfeldt -""" - -import logging -import numpy as np -import matplotlib -matplotlib.use('TkAgg') -import matplotlib.pyplot as plt -import pandas as pd -import pickle - -from disruption_py.settings import LogSettings, RetrievalSettings -from disruption_py.workflow import get_shots_data - -MAN_LABEL = False -#SHOT_ID = 1140515015 # Hot VDE -#SHOT_ID = 1140827029 -#SHOT_ID = 1120717002 -#SHOT_ID = 1051206029 -#SHOT_ID = 1160714006 -#TODO: Shot 1160714006 having issues (low SXR signal). What do we do about ramp-up? -# TODO: Shot 11405522001 has significant SXR spike when plasma hits wall, after main TQ -# TODO: Search for first time at which dSXR/dt is w/in factor of 2 from max? -# 1050421007 -SHOT_ID = 1050311013 -signals = [ - "ip", - "zcur", - "thermal_quench_time" -] - -# default method for pulling disruption-py data -retrieval_settings = RetrievalSettings( - time_setting="disruption_warning", # use the set efit's timebase - efit_nickname_setting="efit21", # set the efit - run_methods=[], - run_columns=signals, - only_requested_columns=True, -) - -data = get_shots_data( - shotlist_setting=[SHOT_ID], - retrieval_settings=retrieval_settings, - log_settings=LogSettings(console_level=logging.DEBUG), - output_setting="dataframe", - num_processes=1, -) -print(data) -print("Got data") - -with open('sxr.pkl', 'rb') as f: - df = pickle.load(f) -df['ip'] = np.abs(df['ip']/1e6) -print(df['cq_onset_time']) - -plt.rcParams['font.size'] = 14 -fig, axs = plt.subplots(5, 1, sharex=True, figsize=(14,7)) -#axs[0].set_xlim(0.6, 0.64) -axs[0].plot(df['magtime'], df['ip'], marker='.', ms=10, c='k', linestyle='-') -axs[1].plot(df['t_ece'], df['te0_ece'], marker='.', ms=5, c='k', linestyle='-') -axs[2].scatter(df['t_sxr'], df['core_sxr_raw'], marker='.', s=5, c='k') -axs[3].scatter(df['t_sxr'], df['core_sxr'], marker='.', s=5, c='k') -axs[4].scatter(df['t_sxr'], df['core_sxr_growth_rate'], marker='.', s=5, c='k') -# axs[4].scatter(df['efit_time'], df['z0'], marker='o', s=10, c='k') -print("Plotting labeled times") -for ax in axs: - ax.axvline(df['t_disrupt'], linestyle='-', c='k', label='t_disrupt') - ax.axvline(df['cq_onset_time'], linestyle='--', c='k', label='CQ Onset', lw=2) - #ax.axvline(df['t_start'], linestyle='--', c='k', label='tstart') - if not MAN_LABEL: - ax.axvline(df['thermal_quench_time_scalar'], linestyle='-', c='r', label='TQ Onset') - # ax.axvspan(df['cq_onset_time']-0.005, df['cq_onset_time'], alpha=0.15, color='tab:green', label='TQ Midpoint Search Window') - ax.axvline(df['t_max_sxr_drop'], linestyle='--', c='g', label='TQ Midpoint') - # for i, t_tq in enumerate(df['thermal_quench_times']): - # if i == 0: - # ax.axvline(t_tq, linestyle='-', c='r', label='TQ') - # else: - # ax.axvline(t_tq, linestyle='-', c='r') - # for i, t_warn in enumerate(df['thermal_quench_warnings']): - # if i == 0: - # ax.axvline(t_warn, linestyle='--', c='b', label='TQ warn') - # else: - # ax.axvline(t_warn, linestyle='--', c='b') -axs[0].set_title('C-Mod Shot: ' + str(SHOT_ID)) -axs[0].set_ylabel('Ip [MA]') -axs[1].set_ylabel('Te0 [keV]') -axs[2].set_ylabel('SXR raw') -axs[3].set_ylabel('SXR filt') -axs[4].set_ylabel(r"$dSXR/dt$ [Hz]") -axs[-1].set_xlabel('Time [s]') -#axs[-1].set_xlim(0.49, 0.508) -# axs[3].set_ylim(-8e3, 2e3) -# axs[4].set_ylabel('Z0 [m]') -# axs[4].set_xlabel("Time [s]") -axs[0].legend(fontsize=12) - -# fig, axs = plt.subplots(4, 1, sharex=True, figsize=(14,7)) -# axs[0].plot(df['magtime'], df['ip'], marker='.', ms=10, c='k') -# axs[1].plot(df['t_sxr'], df['core_sxr_raw'], marker='.', ms=5, c='k') -# axs[2].plot(data['time'], data['p_rad']/1e6, marker='o', ms=10, c='k') -# axs[3].plot(df['efit_time'], df['z0'], marker='o', ms=10, c='k') - -# for ax in axs: -# ax.axvline(df['t_disrupt'], linestyle='--', c='b', label='t_disrupt (DisruptionPy)') -# ax.axvline(df['thermal_quench_time_scalar'], linestyle='--', c='r', label='TQ Onset (auto)') - -# axs[0].set_title('C-Mod Shot: ' + str(SHOT_ID)) -# axs[0].set_ylabel('Ip [MA]') -# axs[1].set_ylabel('SXR raw [a.u.]') -# axs[2].set_ylabel('Prad [MW]') -# axs[3].set_ylabel('Z0 [m]') -# axs[3].set_xlabel("Time [s]") -# # axs[-1].set_xlim(0.68, 0.725) -# axs[0].legend() -# plt.show() - -plt.show() \ No newline at end of file diff --git a/scripts/plot_tq_timetraces.py b/scripts/plot_tq_timetraces.py deleted file mode 100644 index 55845b020..000000000 --- a/scripts/plot_tq_timetraces.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Plot timetraces of shots showing different disruption times -for testing thermal quench time labeler -""" - -import logging -import os - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -from disruption_py.settings import LogSettings, RetrievalSettings -from disruption_py.workflow import get_shots_data - -USE_PREV_DPY_RESULTS = True -N_SHOTS = 30 - -df_man = pd.read_csv('manual_tq_labels.csv') -df_man['shot'] = df_man['shot'].astype(int) -shotlist = df_man['shot'].to_list() -print(f"Number of shots: {len(shotlist)}") - -signals = [ - "ip", - "thermal_quench_time", - "t_disrupt", - "core_sxr" -] - -if USE_PREV_DPY_RESULTS: - df = pd.read_csv('tq_df.csv') -else: - # default method for pulling disruption-py data - retrieval_settings = RetrievalSettings( - time_setting="disruption_warning", # use the set efit's timebase - efit_nickname_setting="analysis", # set the efit - run_methods=[], - run_columns=signals, - only_requested_columns=True, - ) - - df = get_shots_data( - shotlist_setting=shotlist, - retrieval_settings=retrieval_settings, - log_settings=LogSettings(console_level=logging.WARNING), - output_setting="dataframe", - num_processes=os.cpu_count(), - ) - df.to_csv('tq_df.csv') - -df = df.merge(df_man[['shot', 'tq_onset_manual']], on='shot', how='left') - -# Select 30 random shots -rng = np.random.default_rng(seed=42) -rng.shuffle(shotlist) -shot_subset = shotlist[:N_SHOTS] - -fig, axs = plt.subplots(N_SHOTS, 1, figsize=(5, 14)) - -for i, s in enumerate(shot_subset): - df_s = df[df['shot']==s] - axs[i].plot(df_s['time'], df_s['ip'], c='g') - axs[i].set_yticks([]) - axs[i].set_xticks([]) - axs[i].set_xlim(0, 2) - -plt.show() \ No newline at end of file diff --git a/tests/test_thermal_quench_times.py b/tests/test_thermal_quench_times.py deleted file mode 100644 index b0777f13f..000000000 --- a/tests/test_thermal_quench_times.py +++ /dev/null @@ -1,75 +0,0 @@ -""" -Test automated TQ labels compared to manual labels -Author: Henry Wietfeldt -Source of Manual labels: Henry Wietfeldt -""" - -import logging -import os - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -from disruption_py.settings import LogSettings, RetrievalSettings -from disruption_py.workflow import get_shots_data - -db_man = pd.read_csv('manual_tq_labels.csv') -db_man['shot'] = db_man['shot'].astype(int) -shotlist = db_man['shot'].to_list() -print(f"Number of shots: len(shotlist)") - -signals = [ - "ip", - "thermal_quench_time", - "t_disrupt" -] - -# default method for pulling disruption-py data -retrieval_settings = RetrievalSettings( - time_setting="disruption_warning", # use the set efit's timebase - efit_nickname_setting="analysis", # set the efit - run_methods=[], - run_columns=signals, - only_requested_columns=True, -) - -db_auto = get_shots_data( - shotlist_setting=shotlist, - retrieval_settings=retrieval_settings, - log_settings=LogSettings(console_level=logging.WARNING), - output_setting="dataframe", - num_processes=os.cpu_count(), -) - -# Using pandas because this testing script has not been updated to xarray - -# Output test results -db_auto = db_auto.drop_duplicates(subset='shot').drop(columns='time') -db_auto.to_csv('test_thermal_quench_auto_labels.csv') -db_test = db_man.copy() -db_test = db_test.sort_values(by='shot') -db_auto = db_auto.sort_values(by='shot') -db_test = pd.merge(db_test, db_auto, how='outer', on='shot') -db_test = db_test.rename(columns={'thermal_quench_time': 'tq_auto'}) -print(db_test) - -db_test['within_tq_range'] = (db_test['tq_auto'] > db_test['tq_onset_manual']-0.001) & (db_test['tq_auto'] < db_test['tq_end_manual']+0.001) -db_test['onset_error_s'] = db_test['tq_auto'] - db_test['tq_onset_manual'] -db_test = db_test.drop(columns=['notes']) -db_test.to_csv('test_thermal_quench_results.csv') - -# Print summary statistics -error = db_test['onset_error_s'].to_numpy() -print(f"Mean |Error| = {1e3*np.nanmean(np.abs(error)):.3f} ms") -print(f"Median |Error| = {1e3*np.nanmedian(np.abs(error)):.3f} ms") -print(f"Std Dev |Error| = {1e3*np.nanstd(np.abs(error)):.3f} ms") -print(f"Min Error = {1e3*np.nanmin(error):.3f} ms") -print(f"Max Error = {1e3*np.nanmax(error):.3f} ms") -print(f"Num Outliers (outisde TQ [start, end] by >1 ms) = {np.sum(~db_test['within_tq_range'])} out of {len(shotlist)} shots") - -# Plot onset errors -plt.hist(db_test['onset_error_s'], bins=50) -plt.xlabel('Error in TQ Onset Time (Auto - Manual) [s]', fontsize=16) -plt.ylabel('Count', fontsize=16) -plt.show() \ No newline at end of file From ae698bea46f2aad341389bfd4235ae5ed48c884a Mon Sep 17 00:00:00 2001 From: hwietfeldt Date: Mon, 8 Jun 2026 17:25:42 -0400 Subject: [PATCH 19/19] Fix return dict --- disruption_py/machine/cmod/physics.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/disruption_py/machine/cmod/physics.py b/disruption_py/machine/cmod/physics.py index 0f9690b11..e46031fb7 100644 --- a/disruption_py/machine/cmod/physics.py +++ b/disruption_py/machine/cmod/physics.py @@ -2253,11 +2253,7 @@ def get_thermal_quench_onset_time(params: PhysicsMethodParams): # Want last maximum in case the SXR has saturated and there are multiple maxima max_sxr_indx = np.nonzero(window >= 0.9 * np.max(window))[0][-1] tq_time_scalar = t_sxr[idx_start + max_sxr_indx] - return { - "thermal_quench_time": tq_time_scalar * np.ones(len(params.times)), - "t_disrupt": params.disruption_time * np.ones(len(params.times)), - "core_sxr": core_sxr, - } + return {"thermal_quench_time": tq_time_scalar * np.ones(len(params.times))} @staticmethod def _is_on_blacklist(shot_id: int) -> bool: