diff --git a/src/toolbox_scs/detectors/digitizers.py b/src/toolbox_scs/detectors/digitizers.py index 4790ac1fdab03af5b4a72461d0878b31dcdc79ba..544344544f41b42fe361fdd2938776141cc6bdf4 100644 --- a/src/toolbox_scs/detectors/digitizers.py +++ b/src/toolbox_scs/detectors/digitizers.py @@ -661,8 +661,183 @@ def get_peak_params(run, mnemonic, raw_trace=None, ntrains=200): return params -def check_peak_params(run, mnemonic, raw_trace=None, ntrains=200, params=None, - plot=True, show_all=False, bunchPattern='sase3'): +def find_peaks_in_raw_trace(trace, height=None, width=1, distance=24): + """ + Find integration parameters for peak integration of a raw + digitizer trace. Based on scipy find_peaks(). + + Parameters + ---------- + trace: numpy array or xarray DataArray + The digitier raw trace used to find peaks + height: int + minimum height for peak determination + width: int + minimum width of peak + distance: int + minimum distance between two peaks + + Returns + ------- + dict with keys 'pulseStart', 'pulseStop', 'baseStart', 'baseStop', + 'period', 'npulses' and values in number of samples. + """ + if isinstance(trace, xr.DataArray): + trace = trace.values + trace_norm = trace - np.median(trace) + trace_norm = -trace_norm if np.mean(trace_norm) < 0 else trace_norm + SNR = np.max(np.abs(trace_norm)) / np.std(trace_norm[:100]) + if SNR < 10: + log.warning('signal-to-noise ratio too low: cannot ' + 'automatically find peaks.') + return {'pulseStart': 2, 'pulseStop': 3, + 'baseStart': 0, 'baseStop': 1, + 'period': 0, 'npulses': 1} + min_height = min(3 * SNR, np.max(np.abs(trace_norm)) / 3) + peaks, prop = find_peaks(trace_norm, distance=distance, + height=min_height, + width=2) + params = {} + start = int(prop['left_ips'][0]) + if len(peaks) == 1: + params['pulseStart'] = start + params['period'] = 0 + else: + pulse_period = int(np.median(np.diff(peaks))) + pulse_ids = np.digitize(peaks, + np.arange(peaks[0] - pulse_period/2, + peaks[-1] + pulse_period/2, + pulse_period)) - 1 + if len(np.unique(np.diff(pulse_ids))) == 1: + # Regular pattern + params['pulseStart'] = start + params['period'] = pulse_period + else: + # Irregular pattern + params['pulseStart'] = start + pulse_ids * pulse_period + params['period'] = 0 + params['pulseStop'] = int(prop['right_ips'][0]) + params['baseStop'] = (3 * start - peaks[0]) / 2 + params['baseStop'] = np.min([params['baseStop'], + int(prop['right_ips'][0])]).astype(int) + params['baseStart'] = params['baseStop'] - np.mean(prop['widths'])/2 + params['baseStart'] = np.max([params['baseStart'], 0]).astype(int) + params['npulses'] = len(peaks) + return params + + +def find_peak_integration_parameters(run, mnemonic, raw_trace=None, + integParams=None, pattern=None, + ntrains=None): + ''' + Finds peak integration parameters. + ''' + run_mnemonics = mnemonics_for_run(run) + digitizer = digitizer_type(run, run_mnemonics[mnemonic]['source']) + pulse_period = 24 if digitizer == "FastADC" else 440 + + if integParams is None: + # load raw trace and find peaks + autoFind = True + if raw_trace is None: + raw_trace = get_dig_avg_trace(run, mnemonic, ntrains) + params = find_peaks_in_raw_trace(raw_trace, distance=pulse_period) + else: + # inspect provided parameters + autoFind = False + required_keys = ['pulseStart', 'pulseStop', 'baseStart', + 'baseStop'] + add_text = '' + if pattern is None and not hasattr(integParams['pulseStart'], + '__len__'): + required_keys += ['period', 'npulses'] + add_text = 'Bunch pattern not provided. ' + if not all(name in integParams for name in required_keys): + raise ValueError(add_text + 'All keys of integParams argument ' + f'{required_keys} are required.') + params = integParams.copy() + + # extract pulse ids from the parameters (starting at 0) + if hasattr(params['pulseStart'], '__len__'): + if params.get('npulses') is not None and ( + params.get('npulses') != len(params['pulseStart'])): + log.warning('The number of pulses does not match the length ' + 'of pulseStart. Using length of pulseStart as ' + 'the number of pulses.') + params['npulses'] = len(params['pulseStart']) + pulse_ids_params = ((np.array(params['pulseStart']) - + params['pulseStart'][0]) / pulse_period).astype(int) + else: + pulse_ids_params = np.arange(0, + params['npulses'] * params['period'] / pulse_period, + params['period'] / pulse_period).astype(int) + + # Extract pulse_ids, period and npulses from bunch pattern + pulse_ids_bp, npulses_bp, period_bp = None, None, 0 + regular = True + if pattern is not None: + bunchPattern = 'sase3' if hasattr(pattern, 'sase') else 'scs_ppl' + if pattern.is_constant_pattern() is False: + log.warning('The number of pulses changed during the run.') + pulse_ids_bp = np.unique(pattern.pulse_ids(labelled=False, + copy=False)) + npulses_bp, period_bp = None, 0 + regular = False + else: + pulse_ids_bp = pattern.peek_pulse_ids(labelled=False) + npulses_bp = len(pulse_ids_bp) + if npulses_bp > 1: + periods = np.diff(pulse_ids_bp) + if len(np.unique(periods)) > 1: + regular = False + else: + period_bp = np.unique(periods)[0] * pulse_period + # Compare parameters with bunch pattern + if len(pulse_ids_params) == len(pulse_ids_bp): + if not (pulse_ids_params == pulse_ids_bp - pulse_ids_bp[0]).all(): + log.warning('The provided pulseStart parameters do not match ' + f'the {bunchPattern} bunch pattern pulse ids. ' + 'Using bunch pattern parameters.') + pulse_ids_params = pulse_ids_bp + + if (npulses_bp != params.get('npulses') or + period_bp != params.get('period')): + if autoFind: + add_text = 'Automatically found ' + else: + add_text = 'Provided ' + log.warning(add_text + 'integration parameters ' + f'(npulses={params.get("npulses")}, ' + + f'period={params.get("period")}) do not match the ' + f'{bunchPattern} bunch pattern (npulses=' + f'{npulses_bp}, period={period_bp}). Using bunch ' + 'pattern parameters.') + pulse_ids_params = pulse_ids_bp + params['npulses'] = npulses_bp + params['period'] = period_bp + + if regular == False: + # Irregular pattern + if hasattr(params['pulseStart'], '__len__'): + start = params['pulseStart'][0] + else: + start = params['pulseStart'] + params['pulseStart'] = np.array( + [int(start + (pid - pulse_ids_params[0]) * pulse_period) + for pid in pulse_ids_params]) + params['period'] = 0 + else: + # Regular pattern + if hasattr(params['pulseStart'], '__len__'): + params['pulseStart'] = params['pulseStart'][0] + if len(pulse_ids_params) == 1: + params['period'] = 0 + return params, pulse_ids_params, regular, raw_trace + + +def check_peak_params(proposal, runNB, mnemonic, raw_trace=None, + ntrains=200, integParams=None, bunchPattern='sase3', + plot=True, show_all=False): """ Checks and plots the peak parameters (pulse window and baseline window of a raw digitizer trace) used to compute the peak integration. These @@ -674,8 +849,10 @@ def check_peak_params(run, mnemonic, raw_trace=None, ntrains=200, params=None, Parameters ---------- - run: extra_data.DataCollection - DataCollection containing the digitizer data. + proposal: int + the proposal number + runNB: int + the run number mnemonic: str ToolBox mnemonic of the digitizer data, e.g. 'MCP2apd'. raw_trace: optional, 1D numpy array or xarray DataArray @@ -699,121 +876,120 @@ def check_peak_params(run, mnemonic, raw_trace=None, ntrains=200, params=None, ------- dictionnary of peak integration parameters """ + run = open_run(proposal, runNB) run_mnemonics = mnemonics_for_run(run) - if "raw" in mnemonic: - mnemo_raw = mnemonic - title = 'Auto-find peak params' - else: - mnemo_raw = mnemonic.replace('peaks', 'raw').replace('apd', 'raw') - title = 'Digitizer peak params' - if raw_trace is None: - raw_trace = get_dig_avg_trace(run, mnemonic, ntrains) - if params is None: - params = get_peak_params(run, mnemonic, raw_trace) - if 'enable' in params and params['enable'] == 0: - log.warning('The digitizer did not record peak-integrated data.') - if not plot: - return params digitizer = digitizer_type(run, run_mnemonics[mnemonic]['source']) - min_distance = 24 if digitizer == "FastADC" else 440 - if 'bunchPatternTable' in run_mnemonics and bunchPattern != 'None': - sel = run.select_trains(np.s_[:ntrains]) - bp_params = {} - m = run_mnemonics['bunchPatternTable'] - bpt = sel.get_array(m['source'], m['key'], m['dim']) - mask = is_pulse_at(bpt, bunchPattern) - pid = np.sort(np.unique(np.where(mask)[1])) - bp_params['npulses'] = len(pid) - if bp_params['npulses'] == 1: - bp_params['period'] = 0 + pulse_period = 24 if digitizer == "FastADC" else 440 + pattern = None + try: + if bunchPattern == 'sase3': + pattern = XrayPulses(run) + if bunchPattern == 'scs_ppl': + pattern = OpticalLaserPulses(run) + except Exception as e: + log.warning(e) + bunchPattern = None + params, pulse_ids, regular, raw_trace = find_peak_integration_parameters( + run, mnemonic, raw_trace, integParams, pattern) + if bunchPattern: + if regular: + add_text = '' + if len(pulse_ids) > 1: + add_text = f's, {(pulse_ids[1]-pulse_ids[0]) * pulse_period}' +\ + ' samples between two pulses' + print(f'Bunch pattern {bunchPattern}: {len(pulse_ids)} pulse' + + add_text) else: - bp_params['period'] = np.diff(pid)[0] * min_distance - print(f'bunch pattern {bunchPattern}: {bp_params["npulses"]} pulses,' - f' {bp_params["period"]} samples between two pulses') + print(f'Bunch pattern {bunchPattern}: Not a regular pattern. ' + f'{len(pulse_ids)} pulses, pulse_ids=' + f'{pulse_ids}.') + if integParams is None: + title = 'Auto-find peak parameters' else: - bp_params = None - print(f'{title}: {params["npulses"]} pulses, {params["period"]}' - ' samples between two pulses') - fig, ax = plotPeakIntegrationWindow(raw_trace, params, bp_params, show_all) - ax[0].set_ylabel(mnemo_raw) - fig.suptitle(title, size=12) + title = 'Provided peak parameters' + no_change = True + for k, v in integParams.items(): + no_change = no_change & (v == params[k]) + if hasattr(no_change, '__len__'): + no_change = no_change.all() + if no_change == False: + print('The provided parameters did not match the bunch ' + 'pattern and were adjusted.') + title += ' (adjusted)' + if regular: + add_text = '' + if params['npulses'] > 1: + add_text = f's, {params["period"]} samples between two pulses' + print(title + ': ' + f' {params["npulses"]} pulse' + add_text) + else: + print(f'{title}: Not a regular pattern. ' + f'{len(pulse_ids)} pulses, pulse_ids=' + f'{pulse_ids}.') + if plot: + if raw_trace is None: + raw_trace = get_dig_avg_trace(run, mnemonic) + fig, ax = plotPeakIntegrationWindow(raw_trace, params, + show_all=show_all) + fig.suptitle(f'p{proposal} r{runNB} '+ title, size=12) return params -def plotPeakIntegrationWindow(raw_trace, params, bp_params=None, - show_all=False): - if show_all: - fig, ax = plt.subplots(figsize=(6, 3), constrained_layout=True) +def plotPeakIntegrationWindow(raw_trace, params, show_all=False): + if hasattr(params['pulseStart'], '__len__'): + starts = np.array(params['pulseStart']) + stops = params['pulseStop'] + starts - params['pulseStart'][0] + baseStarts = params['baseStart'] + starts - params['pulseStart'][0] + baseStops = params['baseStop'] + starts - params['pulseStart'][0] + else: n = params['npulses'] p = params['period'] - for i in range(n): + starts = [params['pulseStart'] + i*p for i in range(n)] + stops = [params['pulseStop'] + i*p for i in range(n)] + baseStarts = [params['baseStart'] + i*p for i in range(n)] + baseStops = [params['baseStop'] + i*p for i in range(n)] + + if show_all: + fig, ax = plt.subplots(figsize=(6, 3), constrained_layout=True) + for i in range(len(starts)): lbl = 'baseline' if i == 0 else None lp = 'peak' if i == 0 else None - ax.axvline(params['baseStart'] + i*p, ls='--', color='k') - ax.axvline(params['baseStop'] + i*p, ls='--', color='k') - ax.axvspan(params['baseStart'] + i*p, params['baseStop'] + i*p, + ax.axvline(baseStarts[i], ls='--', color='k') + ax.axvline(baseStops[i], ls='--', color='k') + ax.axvspan(baseStarts[i], baseStops[i], alpha=0.5, color='grey', label=lbl) - ax.axvline(params['pulseStart'] + i*p, ls='--', color='r') - ax.axvline(params['pulseStop'] + i*p, ls='--', color='r') - ax.axvspan(params['pulseStart'] + i*p, params['pulseStop'] + i*p, + ax.axvline(starts[i], ls='--', color='r') + ax.axvline(stops[i], ls='--', color='r') + ax.axvspan(starts[i], stops[i], alpha=0.2, color='r', label=lp) ax.plot(raw_trace, color='C0', label='raw trace') ax.legend(fontsize=8) return fig, [ax] - if bp_params is not None: - npulses = bp_params['npulses'] - period = bp_params['period'] - else: - npulses = params['npulses'] - period = params['period'] - xmin = np.max([0, params['baseStart']-100]) - xmax = np.min([params['pulseStop']+100, raw_trace.size]) fig, ax = plt.subplots(1, 2, figsize=(6, 3), constrained_layout=True) - ax[0].axvline(params['baseStart'], ls='--', color='k') - ax[0].axvline(params['baseStop'], ls='--', color='k') - ax[0].axvspan(params['baseStart'], params['baseStop'], - alpha=0.5, color='grey', label='baseline') - ax[0].axvline(params['pulseStart'], ls='--', color='r') - ax[0].axvline(params['pulseStop'], ls='--', color='r') - ax[0].axvspan(params['pulseStart'], params['pulseStop'], - alpha=0.2, color='r', label='peak') - ax[0].plot(np.arange(xmin, xmax), raw_trace[xmin:xmax], color='C0', - label='1st pulse') - ax[0].legend(fontsize=8) - ax[0].set_xlim(xmin, xmax) - ax[0].set_xlabel('digitizer samples') - ax[0].set_title('First pulse', size=10) - - xmin2 = xmin + (npulses-1) * period - xmax2 = xmax + (npulses-1) * period - p = params['period'] - lbl = 'baseline' - lp = 'peak' - for i in range(params['npulses']): - mi = params['baseStart'] + i*p - if not xmin2 < mi < xmax2: - continue - ax[1].axvline(params['baseStart'] + i*p, ls='--', color='k') - ax[1].axvline(params['baseStop'] + i*p, ls='--', color='k') - ax[1].axvspan(params['baseStart'] + i*p, params['baseStop'] + i*p, - alpha=0.5, color='grey', label=lbl) - ax[1].axvline(params['pulseStart'] + i*p, ls='--', color='r') - ax[1].axvline(params['pulseStop'] + i*p, ls='--', color='r') - ax[1].axvspan(params['pulseStart'] + i*p, params['pulseStop'] + i*p, - alpha=0.2, color='r', label=lp) - lbl = None - lp = None - if xmax2 < raw_trace.size: - ax[1].plot(np.arange(xmin2, xmax2), raw_trace[xmin2:xmax2], color='C0', - label='last pulse') - else: - log.warning('The digitizer raw trace is too short to contain ' + - 'all the pulses.') - ax[1].legend(fontsize=8) - ax[1].set_xlabel('digitizer samples') - ax[1].set_xlim(xmin2, xmax2) - ax[1].set_title('Last pulse', size=10) + for plot in range(2): + title = 'First pulse' if plot == 0 else 'Last pulse' + i = 0 if plot == 0 else -1 + ax[plot].axvline(baseStarts[i], ls='--', color='k') + ax[plot].axvline(baseStops[i], ls='--', color='k') + ax[plot].axvspan(baseStarts[i], baseStops[i], + alpha=0.5, color='grey', label='baseline') + ax[plot].axvline(starts[i], ls='--', color='r') + ax[plot].axvline(stops[i], ls='--', color='r') + ax[plot].axvspan(starts[i], stops[i], + alpha=0.2, color='r', label='peak') + if len(starts) > 1: + period = starts[1] - starts[0] + xmin = np.max([0, baseStarts[i] - int(1.5*period)]) + xmax = np.min([stops[i] + int(1.5*period), raw_trace.size]) + else: + xmin = np.max([0, baseStarts[i] - 200]) + xmax = np.min([stops[i] + 200, raw_trace.size]) + ax[plot].plot(np.arange(xmin, xmax), + raw_trace[xmin:xmax], color='C0', label=title) + ax[plot].legend(fontsize=8) + ax[plot].set_xlim(xmin, xmax) + ax[plot].set_xlabel('digitizer samples') + ax[plot].set_title(title, size=10) return fig, ax