From e2ec1ca72ddeb50e6a9dc6e2bc1308566e495e7e Mon Sep 17 00:00:00 2001
From: Egor Sobolev <egor.sobolev@xfel.eu>
Date: Sat, 17 Jun 2023 00:07:48 +0200
Subject: [PATCH] Use the minimal inter-Bragg distance to look for reflexes
 mathching peaks

---
 src/geomtools/sfx/crystfelio.py | 57 +++++++++++++++++++++++----------
 src/geomtools/sfx/draw.py       |  8 ++---
 src/geomtools/sfx/misc.py       |  2 +-
 src/geomtools/sfx/report.ipynb  |  4 +--
 4 files changed, 47 insertions(+), 24 deletions(-)

diff --git a/src/geomtools/sfx/crystfelio.py b/src/geomtools/sfx/crystfelio.py
index 54a8613..a635b73 100644
--- a/src/geomtools/sfx/crystfelio.py
+++ b/src/geomtools/sfx/crystfelio.py
@@ -5,7 +5,7 @@ import numpy as np
 import pandas as pd
 import multiprocessing as mp
 
-from .lattice import spacing
+from .lattice import spacing, ph_en_to_lambda, get_min_bragg_dist
 from .misc import get_peak_position
 from cfelpyutils.crystfel_stream import (
     parse_chunk, CHUNK_START_MARKER, CHUNK_END_MARKER
@@ -42,14 +42,12 @@ def _buffer_to_line(f, marker):
     return s
 
 
-def match_reflexes_to_peaks(reflexes, peaks, lattices, panels, rmax):
-    
+def match_reflexes_to_peaks(reflexes, peaks, lattices, panels):
+    """Looks for matches of reflexes to peaks"""
     # 1. distance
     # 2. one reflex to one peak
     # 3. one peak to one reflex in the same crystal
-    #    but to can also match refelex in other crystals    
-    r2max = rmax * rmax
-
+    #    but to can also match refelex in other crystals
     re = reflexes[['fs', 'ss', 'panel', 'frame', 'cryst']].copy()
     re['reflno'] = re.index
     re = re.join(get_peak_position(re, panels))
@@ -61,13 +59,14 @@ def match_reflexes_to_peaks(reflexes, peaks, lattices, panels, rmax):
     match = pe.join(re.set_index(['frame', 'panel']), on=['frame', 'panel'],
                     lsuffix='_p', rsuffix='_r', how='inner')
 
-    match = match.join(lattices[['xc', 'yc']], on='cryst')
+    match = match.join(lattices[['xc', 'yc', 'rmin']], on='cryst')
 
     match['dx'] = match.x_p - match.x_r - match.xc
     match['dy'] = match.y_p - match.y_r - match.yc
 
     match['r2'] = match.dx * match.dx + match.dy * match.dy
 
+    r2max = 0.25 * match.rmin * match.rmin
     match = match[match.r2 < r2max]
     match = match.loc[match.groupby(['frame', 'cryst', 'reflno']).r2.idxmin()]
 
@@ -85,9 +84,13 @@ def parse_crystfel_streamfile(stream_filename, panels, begin=0, end=None):
     peaks = []
     lattices = []
     reflexes = []
+    frames = []
     frame_ix = 0
     cryst_ix = 0
 
+    clen_avg = panels.clen.mean()
+    clen_str = f"{clen_avg} m"
+    res_avg = panels.res.mean()
     with open(stream_filename, 'r') as f:
         if end is None:
             end = f.seek(0, 2)
@@ -101,13 +104,23 @@ def parse_crystfel_streamfile(stream_filename, panels, begin=0, end=None):
 
             chunk = parse_chunk(buffer, peak_tbl=True, refl_tbl=True)
 
+            # camera length
+            ln = chunk.get("average_camera_length", clen_str)
+            tk = ln.partition(' ')
+            clen = float(tk[0]) / (1000 if tk[2] == 'mm' else 1)
+            chunk["average_camera_length"] = clen
+
+            # photon energy
+            lmd = ph_en_to_lambda(float(chunk["photon_energy_eV"]))
+
             # append peaks
-            pe = chunk['peaks']
-            pe['frame'] = frame_ix
-            peaks.append(pe)
+            pe = chunk.pop('peaks')
+            if pe is not None:
+                pe['frame'] = frame_ix
+                peaks.append(pe)
 
             # append lattice
-            for crystal in chunk['crystals']:
+            for crystal in chunk.pop('crystals', []):
                 la = {'frame': frame_ix}
                 la.update(dict(
                     [(param, crystal[param]) for param in lattice_params]))
@@ -115,18 +128,25 @@ def parse_crystfel_streamfile(stream_filename, panels, begin=0, end=None):
                     la.update(dict(zip(col_names, crystal[arr_name].tolist())))
                 for param in ['a', 'b', 'c']:
                     la[param] *= 10
+
+                cell = [la[param] * 1e-10 for param in ['a', 'b', 'c']]
+                la['rmin'] = get_min_bragg_dist(
+                    1. / res_avg, clen_avg, lmd, cell)
+
                 lattices.append(la)
                 la_kwargs = dict((name, la[name]) for name in cell_columns)
 
                 re = crystal['reflections']
                 re['res'] = np.sqrt(spacing(
                     re.h.values, re.k.values, re.l.values, **la_kwargs))
+
                 re['frame'] = frame_ix
                 re['cryst'] = cryst_ix
                 reflexes.append(re)
 
                 cryst_ix += 1
 
+            frames.append(chunk)
             frame_ix += 1
             _read_to_line(f, CHUNK_START_MARKER, end)
 
@@ -139,11 +159,11 @@ def parse_crystfel_streamfile(stream_filename, panels, begin=0, end=None):
     reflexes['reflno'] = reflexes.index
 
     lattices = pd.DataFrame(lattices)
+    frames = pd.DataFrame(frames)
 
-    match = match_reflexes_to_peaks(
-        reflexes, peaks, lattices, panels, 8)
+    match = match_reflexes_to_peaks(reflexes, peaks, lattices, panels)
 
-    return peaks, lattices, reflexes, match, frame_ix
+    return frames, peaks, lattices, reflexes, match
 
 
 def extract_geometry(stream_filename):
@@ -227,13 +247,14 @@ def read_crystfel_streamfile(stream_filename, panels, disp=False):
         peak = []
         reflex = []
         lattice = []
+        frame = []
         match = []
 
         nframe = 0
         ncryst = 0
         npeak = 0
         nrefl = 0
-        for i, (pe, la, re, ma, n) in enumerate(res):
+        for i, (fr, pe, la, re, ma) in enumerate(res):
             pe.frame += nframe
             pe.peakno += npeak
             peak.append(pe)
@@ -250,8 +271,9 @@ def read_crystfel_streamfile(stream_filename, panels, disp=False):
             match.append(ma)
 
             lattice.append(la)
+            frame.append(fr)
 
-            nframe += n
+            nframe += len(fr)
             ncryst += len(la)
             npeak += len(pe)
             nrefl += len(re)
@@ -272,8 +294,9 @@ def read_crystfel_streamfile(stream_filename, panels, disp=False):
         bar = '=' * 40
         print(f"[{bar:<40s}] 100.0% elapsed:{elapsed:6.0f}s, remained:     0s")
 
+    frame = pd.concat(frame, ignore_index=True)
     peak = pd.concat(peak, ignore_index=True)
     lattice = pd.concat(lattice, ignore_index=True)
     reflex = pd.concat(reflex, ignore_index=True)
     match = pd.concat(match, ignore_index=True)
-    return peak, lattice, reflex, match, nframe
+    return frame, peak, lattice, reflex, match
diff --git a/src/geomtools/sfx/draw.py b/src/geomtools/sfx/draw.py
index 632936c..9d67ac9 100644
--- a/src/geomtools/sfx/draw.py
+++ b/src/geomtools/sfx/draw.py
@@ -5,7 +5,7 @@ from matplotlib import patches as mpatch
 from mpl_toolkits.axes_grid1 import make_axes_locatable
 from matplotlib.colors import LogNorm
 
-from .misc import gauss2d_fit, ellipse, avg_pixel_displacement
+from .misc import gauss2d_fit, ellipse
 from ..detector import plot_data_on_detector, get_pixel_positions
 
 
@@ -129,16 +129,16 @@ def plot_powder(peaks, counts=True, figwidth=14, frameon=False, **kwargs):
     return fig, ax
 
 
-def plot_geoptimiser_errormap(pxdispl, panels, figwidth=16, **kwargs):
+def plot_geoptimiser_errormap(pxdispl, panels, figwidth=14, **kwargs):
     if isinstance(panels, np.ndarray):
         pos = panels
     else:
         pos = get_pixel_positions(panels)
-        
+
     shape = pos.shape[:-1]
     data = np.zeros(shape, float)
     data[pxdispl.modno, pxdispl.ssi, pxdispl.fsi] = pxdispl.r_avg
-    
+
     if 'cmap' not in kwargs:
         kwargs['cmap'] = plt.cm.magma
 
diff --git a/src/geomtools/sfx/misc.py b/src/geomtools/sfx/misc.py
index 7b2cc2b..8d5e9c3 100644
--- a/src/geomtools/sfx/misc.py
+++ b/src/geomtools/sfx/misc.py
@@ -65,7 +65,7 @@ def avg_pixel_displacement(match, panels):
     pxdispl['r_avg'] = np.sqrt(pxdispl.dx_avg * pxdispl.dx_avg +
                                pxdispl.dy_avg * pxdispl.dy_avg)
     pxdispl = pxdispl.join(panels[['modno']], on='panel')
-    
+
     return pxdispl.reset_index()
 
 
diff --git a/src/geomtools/sfx/report.ipynb b/src/geomtools/sfx/report.ipynb
index c26b95b..9724be3 100644
--- a/src/geomtools/sfx/report.ipynb
+++ b/src/geomtools/sfx/report.ipynb
@@ -57,7 +57,7 @@
     "if len(clen) == 1:\n",
     "    clen = clen[0]\n",
     "\n",
-    "pe, la, re, ma, nfrm = read_crystfel_streamfile(stream_file, panels, disp=False)"
+    "fr, pe, la, re, ma = read_crystfel_streamfile(stream_file, panels, disp=False)"
    ]
   },
   {
@@ -84,7 +84,7 @@
     "print(summary)\n",
     "print()\n",
     "print(\"From Crystfel stream file\")\n",
-    "print(f\"read: {nfrm:9d} chunks, {len(la):9d} crystalls\")"
+    "print(f\"read: {len(fr):9d} chunks, {len(la):9d} crystalls\")"
    ]
   },
   {
-- 
GitLab