From 157fe57d4921ce39d88236eec9813d77cf64fb7d Mon Sep 17 00:00:00 2001
From: Egor Sobolev <egor.sobolev@xfel.eu>
Date: Sat, 17 Jun 2023 03:45:24 +0200
Subject: [PATCH] Look for the closest reflex over connected group

---
 src/geomtools/sfx/crystfelio.py | 41 ++++++++++++++++++++++-----------
 src/geomtools/sfx/misc.py       |  6 ++++-
 src/geomtools/sfx/report.ipynb  | 11 +++++++--
 src/geomtools/sfx/report.py     |  4 ++++
 4 files changed, 45 insertions(+), 17 deletions(-)

diff --git a/src/geomtools/sfx/crystfelio.py b/src/geomtools/sfx/crystfelio.py
index a635b73..2adc6fe 100644
--- a/src/geomtools/sfx/crystfelio.py
+++ b/src/geomtools/sfx/crystfelio.py
@@ -42,22 +42,27 @@ def _buffer_to_line(f, marker):
     return s
 
 
-def match_reflexes_to_peaks(reflexes, peaks, lattices, panels):
+def match_reflexes_to_peaks(reflexes, peaks, lattices,
+                            panels, group_name='panel'):
     """Looks for matches of reflexes to peaks"""
     # 1. distance
     # 2. one reflex to one peak
     # 3. one peak to one reflex in the same crystal
     #    but to can also match refelex in other crystals
-    re = reflexes[['fs', 'ss', 'panel', 'frame', 'cryst']].copy()
-    re['reflno'] = re.index
+    re = reflexes[['fs', 'ss', 'panel', 'frame', 'cryst', 'reflno']].copy()
+    if group_name != 'panel':
+        re = re.join(panels[[group_name]], on='panel')
     re = re.join(get_peak_position(re, panels))
 
-    pe = peaks[['fs', 'ss', 'panel', 'frame']].copy()
-    pe['peakno'] = pe.index
+    pe = peaks[['fs', 'ss', 'panel', 'frame', 'peakno']]
+    if group_name != 'panel':
+        pe = pe.join(panels[[group_name]], on='panel')
     pe = pe.join(get_peak_position(pe, panels))
 
-    match = pe.join(re.set_index(['frame', 'panel']), on=['frame', 'panel'],
-                    lsuffix='_p', rsuffix='_r', how='inner')
+    match = pe.join(re.set_index(['frame', group_name]),
+                    on=['frame', group_name],
+                    lsuffix='_p', rsuffix='_r',
+                    how='inner')
 
     match = match.join(lattices[['xc', 'yc', 'rmin']], on='cryst')
 
@@ -68,12 +73,16 @@ def match_reflexes_to_peaks(reflexes, peaks, lattices, panels):
 
     r2max = 0.25 * match.rmin * match.rmin
     match = match[match.r2 < r2max]
-    match = match.loc[match.groupby(['frame', 'cryst', 'reflno']).r2.idxmin()]
+
+    # this line applies condition 3, remove it to be like geoptimiser
+    # match = match.loc[match[['reflno', 'r2']].groupby('reflno').r2.idxmin()]
+    match = match.sort_values("r2").groupby("reflno", as_index=False).first()
 
     return match
 
 
-def parse_crystfel_streamfile(stream_filename, panels, begin=0, end=None):
+def parse_crystfel_streamfile(stream_filename, panels, connected_groups,
+                              begin=0, end=None):
     lattice_params = ['lattice_type', 'centering', 'unique_axis']
     lattice_arrays = {
         'Cell parameters/lengths': ['a', 'b', 'c'],
@@ -161,7 +170,8 @@ def parse_crystfel_streamfile(stream_filename, panels, begin=0, end=None):
     lattices = pd.DataFrame(lattices)
     frames = pd.DataFrame(frames)
 
-    match = match_reflexes_to_peaks(reflexes, peaks, lattices, panels)
+    match = match_reflexes_to_peaks(reflexes, peaks, lattices,
+                                    panels, connected_groups)
 
     return frames, peaks, lattices, reflexes, match
 
@@ -208,17 +218,20 @@ def _split_file(filename, nproc=20, partsize=None, nbytes=None):
 
 
 class StreamReader:
-    def __init__(self, filename, panels):
+    def __init__(self, filename, panels, connected_groups):
         self.filename = filename
         self.panels = panels
+        self.connected_groups = connected_groups
 
     def read_part(self, args):
         begin, end = args
         return parse_crystfel_streamfile(
-            self.filename, self.panels, begin=begin, end=end)
+            self.filename, self.panels, self.connected_groups,
+            begin=begin, end=end)
 
 
-def read_crystfel_streamfile(stream_filename, panels, disp=False):
+def read_crystfel_streamfile(stream_filename, panels,
+                             connected_groups, disp=False):
     """Read Crystfel stream file in parallel.
 
     Input
@@ -240,7 +253,7 @@ def read_crystfel_streamfile(stream_filename, panels, disp=False):
     if disp:
         print(f"nproc: {nproc} nslice: {nslice}")
 
-    rdr = StreamReader(stream_filename, panels)
+    rdr = StreamReader(stream_filename, panels, connected_groups)
     with mp.Pool(min(nproc, nslice)) as pool:
         res = pool.imap(rdr.read_part, slices)
 
diff --git a/src/geomtools/sfx/misc.py b/src/geomtools/sfx/misc.py
index 8d5e9c3..39f968f 100644
--- a/src/geomtools/sfx/misc.py
+++ b/src/geomtools/sfx/misc.py
@@ -53,7 +53,11 @@ def get_peak_position(peaks, panels):
 
 
 def avg_pixel_displacement(match, panels):
-    pxdispl = match[['dx', 'dy', 'peakno', 'panel']].copy()
+    if 'panel_p' in match.columns:
+        pxdispl = match[['dx', 'dy', 'peakno', 'panel_p']].copy()
+        pxdispl = pxdispl.rename(columns={'panel_p': 'panel'})
+    else:
+        pxdispl = match[['dx', 'dy', 'peakno', 'panel']].copy()
     pxdispl['fsi'] = np.floor(match['fs_p']).astype(int)
     pxdispl['ssi'] = np.floor(match['ss_p']).astype(int)
 
diff --git a/src/geomtools/sfx/report.ipynb b/src/geomtools/sfx/report.ipynb
index 9724be3..078cbe7 100644
--- a/src/geomtools/sfx/report.ipynb
+++ b/src/geomtools/sfx/report.ipynb
@@ -16,7 +16,8 @@
     "summary_file = \"\"\n",
     "prefix=prefix = \"\"\n",
     "output_dir = \"\"\n",
-    "geometry_file = \"\""
+    "geometry_file = \"\"\n",
+    "connected_groups = \"\""
    ]
   },
   {
@@ -52,12 +53,18 @@
     "\n",
     "panels, beam = read_crystfel_geom(geometry_file, indexes={'modno': 1})\n",
     "\n",
+    "panel_columns = panels.columns.tolist() + ['panel']\n",
+    "if connected_groups not in panel_columns:\n",
+    "    raise ValueError(\n",
+    "        f\"Connected groups '{connected_groups}' are not defined\")\n",
+    "\n",
     "photon_energy = beam['photon_energy']\n",
     "clen = panels.clen.unique()\n",
     "if len(clen) == 1:\n",
     "    clen = clen[0]\n",
     "\n",
-    "fr, pe, la, re, ma = read_crystfel_streamfile(stream_file, panels, disp=False)"
+    "fr, pe, la, re, ma = read_crystfel_streamfile(\n",
+    "    stream_file, panels, connected_groups, disp=False)"
    ]
   },
   {
diff --git a/src/geomtools/sfx/report.py b/src/geomtools/sfx/report.py
index 14e35bb..3e8a35d 100644
--- a/src/geomtools/sfx/report.py
+++ b/src/geomtools/sfx/report.py
@@ -53,6 +53,7 @@ def push_geometry():
                         help="EXtra-xwiz summary file")
     parser.add_argument('-i', '--stream', type=pathlib.Path,
                         help="Crystfel stream file")
+    parser.add_argument('-c', '--connected', default="modules")
     parser.add_argument('-r', '--report-only', action="store_true")
     parser.add_argument('detector', choices=DET.keys())
     parser.add_argument('tag')
@@ -105,6 +106,7 @@ def push_geometry():
     date = dc.run_metadata()['creationDate']
     day, _, tm = date.partition('T')
     print("Data: ", day)
+    print("Connected groups:", args.connected)
 
     geom_id = '_'.join([args.detector, day, args.tag])
     if not args.report_only:
@@ -130,6 +132,7 @@ def push_geometry():
             'runs': runs,
             'method': 'sfx',
             'tool': 'geoptimiser',
+            'connected_groups': args.connected,
             'sample': args.sample,
             'date': date,
             'motors': motor_pos,
@@ -158,6 +161,7 @@ def push_geometry():
         stream_file=str(stream_file.absolute()),
         summary_file=str(summary_file.absolute()),
         geometry_file=str(geom_file.absolute()),
+        connected_groups=args.connected,
         prefix=prefix,
         output_dir=str(output_dir.absolute()),
     )
-- 
GitLab