diff --git a/pes_to_spec/test/prepare_plots.py b/pes_to_spec/test/prepare_plots.py
index 12ca07cd4d98a448665ee6a1d876d0adc2d294ed..75ba648acf981744ba1d4e7033d07e6cde185cb5 100755
--- a/pes_to_spec/test/prepare_plots.py
+++ b/pes_to_spec/test/prepare_plots.py
@@ -3,6 +3,8 @@
 import os
 import re
 
+from typing import Optional, Tuple, Dict
+
 import matplotlib
 matplotlib.use('Agg')
 import pandas as pd
@@ -238,7 +240,7 @@ def plot_wiener(df: pd.DataFrame, filename: str):
     fig.savefig(filename)
     plt.close(fig)
 
-def plot_pes(df: pd.DataFrame, channel:str, filename: str):
+def plot_pes(df: pd.DataFrame, channel: Dict[str, int], filename: str, fast_range: Optional[Tuple[int, int]]=None, Ne1s: Optional[Tuple[int, int]]=None, label: Optional[Dict[str, str]]=None, refs: Optional[Dict[str, Dict[int, float]]]=None, counts_to_mv: Optional[float]=None):
     """
     Plot low-resolution spectrum.
 
@@ -255,31 +257,92 @@ def plot_pes(df: pd.DataFrame, channel:str, filename: str):
     last = last-270
     print("Range:", first, last)
     sel = (df.bin >= first) & (df.bin < last)
-    x = df.loc[sel, "bin"]
-    if channel == "sum":
-        y = df.loc[sel, [k for k in df.columns if "channel_" in k]].sum(axis=1)
-        ax.plot(x, y, c='b', lw=5)
-    elif isinstance(channel, list):
-        for ch in channel:
-            sch = ch.replace('_', ' ')
-            y = df.loc[sel, ch]
-            ax.plot(x, y, lw=5, label=sch)
-    else:
-        y = df.loc[sel, channel]
-        ax.plot(x, y, c='b', lw=5)
-    ax.legend(frameon=False)
+    x = df.loc[sel, "bin"].to_numpy()
+    col = dict()
+    colors = ["tab:red", "tab:blue"]
+    p = list()
+    # plot each channel
+    for ich, ch in enumerate(channel.keys()):
+        if label is None:
+            sch = ch.replace('_', '')[-2:]
+        else:
+            sch = label[ch]
+        y = df.loc[sel, ch].to_numpy().astype(np.float32)
+        if counts_to_mv is not None:
+            y *= counts_to_mv
+        c = colors[ich]
+        col[ch] = c
+        p += [ax.plot(x, y, lw=2, c=c, label=sch)]
     ax.set(title=f"",
-           xlabel="Time-of-flight index",
-           ylabel="Counts [a.u.]")
+           ylim=(0, None),
+           #xlabel="Time-of-flight index",
+           xlabel="Samples",
+           ylabel="Counts [a.u.]" if counts_to_mv is None else "Digitizer reading [mV]")
     ax.spines['top'].set_visible(False)
     ax.spines['right'].set_visible(False)
+    minY, maxY = ax.get_ylim()
+    # show reference energy lines
+    if refs is not None:
+        for ich, ch in enumerate(channel.keys()):
+            for tof, energy in refs[ch].items():
+                ax.axvline(tof, 0, 0.5 + ich*0.17, ls='-.', lw=1, c=col[ch])
+                ax.text(tof-1, (0.51 + ich*0.18)*maxY, f"{energy} eV", fontsize=14, rotation="vertical", color=col[ch])
+    # show prompt line
+    for ch, prompt in channel.items():
+        ax.axvline(x=prompt, ls='--', lw=1, c=col[ch])
+        ax.text(prompt-3, 0.5*maxY, "Prompt", fontsize=16, rotation="vertical", color=col[ch])
+    # show the fast electrons range
+    if fast_range is not None:
+        x1, x2 = fast_range
+        xtext = int(x1 + (x2 - x1)*0.3)
+        ytext = 0.9*maxY
+        ax.fill_between([x1, x2], minY, maxY, alpha=0.2, facecolor="tab:olive")
+        ax.text(xtext, ytext, "Valence", fontsize=18, fontweight='bold')
+        ax.text(xtext, ytext-0.05*maxY, "Auger", fontsize=18, fontweight='bold')
+    # show the Ne 1s range
+    if Ne1s is not None:
+        x1, x2 = Ne1s
+        xtext = int(x1 + (x2 - x1)*0.3)
+        ytext = 0.9*maxY
+        ax.fill_between([x1, x2], minY, maxY, alpha=0.2, facecolor="tab:cyan")
+        ax.text(xtext, ytext, "Ne 1s", fontsize=22, fontweight='bold')
+    ns_per_sample = 0.5
+    cax = dict()
+    def f_(ch):
+        return (lambda kk: (np.array(kk) - int(channel[ch]))*ns_per_sample)
+    def i_(ch):
+        return (lambda kk: np.array(kk)/ns_per_sample + int(channel[ch]))
+    forward_ = {ch: f_(ch) for ch in channel}
+    inverse_ = {ch: i_(ch) for ch in channel}
+    for ich, (ch, prompt) in enumerate(channel.items()):
+        cax[ch] = ax.secondary_xaxis(1.0+0.07*ich, functions=(forward_[ch], inverse_[ch]))
+        #cax[ch].spines['left'].set_visible(False)
+        cax[ch].spines['top'].set_position(('outward', 10))
+        cax[ch].spines['top'].set_color(col[ch])
+        cax[ch].tick_params(axis='x', colors=col[ch], labelsize=16)
+        if ich == len(channel)-1:
+            cax[ch].set_xlabel('Time-of-flight [ns]', fontsize=16)
+            #cax[ch].xaxis.label.set_color(col[ch])
+            #cax[ch].title.set_color(col[ch])
+    ax.legend(frameon=False, loc='center')
     plt.tight_layout()
     fig.savefig(filename)
     plt.close(fig)
 
 if __name__ == '__main__':
     indir = 'p900331r69t70'
-    channel = ['channel_1_A', 'channel_4_A', 'channel_3_B']
+    channel = {'channel_4_A': 2639,
+               'channel_3_B': 2646,
+              }
+    label = {'channel_4_A': r'22.5$^\circ$',
+               'channel_3_B': r'225$^\circ$',
+              }
+    Ne1s = (2710, 2742)
+    fast_range = (2650, 2670)
+    refs={'channel_4_A': {2716:1002.5, 2722:997.5},
+          'channel_3_B': {2723:1002.5, 2729:997.5}
+          }
+    counts_to_mv = 40.0/100.0
     #channel = 'sum'
     #for fname in os.listdir(indir):
     #    if re.match(r'test_q100_[0-9]*\.csv', fname):
@@ -290,7 +353,9 @@ if __name__ == '__main__':
 
     for fname in ('test_q100_1724098413', 'test_q100_1724098596', 'test_q50_1724099445'):
         plot_final(pd.read_csv(f'{indir}/{fname}.csv'), f'{fname}.pdf')
-        plot_pes(pd.read_csv(f'{indir}/{fname}_pes.csv'), channel, f'{fname}_pes.pdf')
+        plot_pes(pd.read_csv(f'{indir}/{fname}_pes.csv'), channel, f'{fname}_pes.pdf',
+                 fast_range=fast_range, Ne1s=Ne1s, label=label, refs=refs,
+                 counts_to_mv=counts_to_mv)
 
     plot_chi2(pd.read_csv(f'{indir}/quality.csv'), f'chi2_prepca.pdf')
     plot_chi2_intensity(pd.read_csv(f'{indir}/quality.csv'), f'intensity_vs_chi2_prepca.pdf')