From fa2a193b0d55f0d7736e1d4992d12d10b6879946 Mon Sep 17 00:00:00 2001
From: Danilo Enoque Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.eu>
Date: Tue, 20 Sep 2022 13:16:09 +0200
Subject: [PATCH] Adapt load_crystfel_geometry to read from a StringIO object
 if it is given.

---
 cfelpyutils/geometry/crystfel_utils.py | 432 +++++++++++++------------
 1 file changed, 225 insertions(+), 207 deletions(-)

diff --git a/cfelpyutils/geometry/crystfel_utils.py b/cfelpyutils/geometry/crystfel_utils.py
index eb8115e..cf6561c 100644
--- a/cfelpyutils/geometry/crystfel_utils.py
+++ b/cfelpyutils/geometry/crystfel_utils.py
@@ -25,6 +25,8 @@ import math
 import re
 import sys
 from typing import Dict, List, NamedTuple, Tuple, Union
+from io import StringIO, TextIOWrapper
+from os import PathLike
 
 from mypy_extensions import TypedDict
 
@@ -473,7 +475,7 @@ def _find_min_max_d(detector):
         )
 
 
-def load_crystfel_geometry(filename: str) -> CrystFELGeometry:
+def load_crystfel_geometry(file: Union[str, PathLike, StringIO, TextIOWrapper]) -> CrystFELGeometry:
     """
     Loads a CrystFEL geometry file.
 
@@ -492,7 +494,9 @@ def load_crystfel_geometry(filename: str) -> CrystFELGeometry:
 
     Arguments:
 
-        filename (str): the absolute or relative path to a CrystFEL geometry file.
+        file (Union[str, PathLike, StringIO, TextIOWrapper]):
+                                     Either the path to a CrystFEL
+                                     geometry file, or a text file handler object.
 
     Returns:
 
@@ -785,236 +789,250 @@ def load_crystfel_geometry(filename: str) -> CrystFELGeometry:
     default_dim = ["ss", "fs"]  # type: List[Union[int, str, None]]
     hdf5_peak_path = ""  # type: str
     try:
-        with open(filename, mode="r") as file_handle:
-            file_lines = file_handle.readlines()  # type: List[str]
-            for line in file_lines:
-                if line.startswith(";"):
-                    continue
-                line_without_comments = line.strip().split(";")[0]  # type: str
-                line_items = re.split(
-                    pattern="([ \t])", string=line_without_comments
-                )  # type: List[str]
-                line_items = [
-                    item for item in line_items if item not in ("", " ", "\t")
-                ]
-                if len(line_items) < 3:
-                    continue
-                value = "".join(line_items[2:])  # type: str
-                if line_items[1] != "=":
-                    continue
-                path = re.split("(/)", line_items[0])  # type: List[str]
-                path = [item for item in path if item not in "/"]
-                if len(path) < 2:
-                    hdf5_peak_path = _parse_toplevel(
-                        key=line_items[0],
-                        value=value,
-                        detector=detector,
-                        beam=beam,
-                        panel=default_panel,
-                        hdf5_peak_path=hdf5_peak_path,
-                    )
-                    continue
-                if path[0].startswith("bad"):
-                    if path[0] in detector["bad"]:
-                        curr_bad = detector["bad"][path[0]]
-                    else:
-                        curr_bad = copy.deepcopy(default_bad_region)
-                        detector["bad"][path[0]] = curr_bad
-                    _parse_field_bad(key=path[1], value=value, bad=curr_bad)
-                else:
-                    if path[0] in detector["panels"]:
-                        curr_panel = detector["panels"][path[0]]
-                    else:
-                        curr_panel = copy.deepcopy(default_panel)
-                        detector["panels"][path[0]] = curr_panel
-                    _parse_field_for_panel(
-                        key=path[1],
-                        value=value,
-                        panel=curr_panel,
-                        panel_name=path[0],
-                        detector=detector,
-                    )
-            if not detector["panels"]:
-                raise RuntimeError("No panel descriptions in geometry file.")
-            num_placeholders_in_panels = -1  # type: int
-            for panel in detector["panels"].values():
-                if panel["dim_structure"] is not None:
-                    curr_num_placeholders = panel["dim_structure"].count(
-                        "%"
-                    )  # type: int
-                else:
-                    curr_num_placeholders = 0
-
-                if num_placeholders_in_panels == -1:
-                    num_placeholders_in_panels = curr_num_placeholders
-                else:
-                    if curr_num_placeholders != num_placeholders_in_panels:
-                        raise RuntimeError(
-                            "All panels' data and mask entries must have the same "
-                            "number of placeholders."
-                        )
-            num_placeholders_in_masks = -1  # type: int
-            for panel in detector["panels"].values():
-                if panel["mask"] is not None:
-                    curr_num_placeholders = panel["mask"].count("%")
+        file_lines: List[str] = list()
+        if isinstance(file, (str, PathLike)):
+            with open(file, mode="r") as file_handle:
+                file_lines += file_handle.readlines()
+        elif isinstance(file, StringIO) or isinstance(file, TextIOWrapper):
+            file_handle = file
+            file_lines += file_handle.readlines()
+        else:
+            raise NotImplementedError(f"Received an input file that is neither "
+                                      f"a path, nor a text file handler. Cannot read "
+                                      f"from it. This is the object received: {file}.")
+
+        for line in file_lines:
+            if line.startswith(";"):
+                continue
+            line_without_comments = line.strip().split(";")[0]  # type: str
+            line_items = re.split(
+                pattern="([ \t])", string=line_without_comments
+            )  # type: List[str]
+            line_items = [
+                item for item in line_items if item not in ("", " ", "\t")
+            ]
+            if len(line_items) < 3:
+                continue
+            value = "".join(line_items[2:])  # type: str
+            if line_items[1] != "=":
+                continue
+            path = re.split("(/)", line_items[0])  # type: List[str]
+            path = [item for item in path if item not in "/"]
+            if len(path) < 2:
+                hdf5_peak_path = _parse_toplevel(
+                    key=line_items[0],
+                    value=value,
+                    detector=detector,
+                    beam=beam,
+                    panel=default_panel,
+                    hdf5_peak_path=hdf5_peak_path,
+                )
+                continue
+            if path[0].startswith("bad"):
+                if path[0] in detector["bad"]:
+                    curr_bad = detector["bad"][path[0]]
                 else:
-                    curr_num_placeholders = 0
-
-                if num_placeholders_in_masks == -1:
-                    num_placeholders_in_masks = curr_num_placeholders
+                    curr_bad = copy.deepcopy(default_bad_region)
+                    detector["bad"][path[0]] = curr_bad
+                _parse_field_bad(key=path[1], value=value, bad=curr_bad)
+            else:
+                if path[0] in detector["panels"]:
+                    curr_panel = detector["panels"][path[0]]
                 else:
-                    if curr_num_placeholders != num_placeholders_in_masks:
-                        raise RuntimeError(
-                            "All panels' data and mask entries must have the same "
-                            "number of placeholders."
-                        )
-            if num_placeholders_in_masks > num_placeholders_in_panels:
-                raise RuntimeError(
-                    "Number of placeholders in mask cannot be larger the number than "
-                    "for data."
+                    curr_panel = copy.deepcopy(default_panel)
+                    detector["panels"][path[0]] = curr_panel
+                _parse_field_for_panel(
+                    key=path[1],
+                    value=value,
+                    panel=curr_panel,
+                    panel_name=path[0],
+                    detector=detector,
                 )
-            dim_length = -1  # type: int
-            for panel_name, panel in detector["panels"].items():
-                if len(panel["dim_structure"]) == 0:
-                    panel["dim_structure"] = copy.deepcopy(default_dim)
-                found_ss = 0  # type: int
-                found_fs = 0  # type: int
-                found_placeholder = 0  # type: int
-                for dim_index, entry in enumerate(panel["dim_structure"]):
-                    if entry is None:
-                        raise RuntimeError(
-                            "Dimension {} for panel {} is undefined.".format(
-                                dim_index, panel_name
-                            )
-                        )
-                    if entry == "ss":
-                        found_ss += 1
-                    elif entry == "fs":
-                        found_fs += 1
-                    elif entry == "%":
-                        found_placeholder += 1
-                if found_ss != 1:
-                    raise RuntimeError(
-                        "Exactly one slow scan dim coordinate is needed (found {} for "
-                        "panel {}).".format(found_ss, panel_name)
-                    )
-                if found_fs != 1:
+        if not detector["panels"]:
+            raise RuntimeError("No panel descriptions in geometry file.")
+        num_placeholders_in_panels = -1  # type: int
+        for panel in detector["panels"].values():
+            if panel["dim_structure"] is not None:
+                curr_num_placeholders = panel["dim_structure"].count(
+                    "%"
+                )  # type: int
+            else:
+                curr_num_placeholders = 0
+
+            if num_placeholders_in_panels == -1:
+                num_placeholders_in_panels = curr_num_placeholders
+            else:
+                if curr_num_placeholders != num_placeholders_in_panels:
                     raise RuntimeError(
-                        "Exactly one fast scan dim coordinate is needed (found {} for "
-                        "panel {}).".format(found_fs, panel_name)
+                        "All panels' data and mask entries must have the same "
+                        "number of placeholders. Found {} placeholders in a previous panel, "
+                        "but panel {} has {} placeholders.".format(num_placeholders_in_panels,
+                            panel, curr_num_placeholders)
                     )
-                if found_placeholder > 1:
+        num_placeholders_in_masks = -1  # type: int
+        for panel in detector["panels"].values():
+            if panel["mask"] is not None:
+                curr_num_placeholders = panel["mask"].count("%")
+            else:
+                curr_num_placeholders = 0
+
+            if num_placeholders_in_masks == -1:
+                num_placeholders_in_masks = curr_num_placeholders
+            else:
+                if curr_num_placeholders != num_placeholders_in_masks:
                     raise RuntimeError(
-                        "Only one placeholder dim coordinate is allowed. Maximum one "
-                        "placeholder dim coordinate is allowed "
-                        "(found {} for panel {})".format(found_placeholder, panel_name)
+                        "All panels' data and mask entries must have the same "
+                        "number of placeholders. Found {} placeholders in a previous mask, "
+                        "but mask for panel {} has {} placeholders.".format(num_placeholders_in_masks,
+                            panel, curr_num_placeholders)
                     )
-                if dim_length == -1:
-                    dim_length = len(panel["dim_structure"])
-                elif dim_length != len(panel["dim_structure"]):
-                    raise RuntimeError(
-                        "Number of dim coordinates must be the same for all panels."
-                    )
-                if dim_length == 1:
-                    raise RuntimeError(
-                        "Number of dim coordinates must be at least " "two."
-                    )
-            for panel_name, panel in detector["panels"].items():
-                if panel["orig_min_fs"] < 0:
-                    raise RuntimeError(
-                        "Please specify the minimum fs coordinate for panel {}.".format(
-                            panel_name
-                        )
-                    )
-                if panel["orig_max_fs"] < 0:
+        if num_placeholders_in_masks > num_placeholders_in_panels:
+            raise RuntimeError(
+                "Number of placeholders in mask ({}) cannot be larger the number than "
+                "for data ({}).".format(num_placeholders_in_masks, num_placeholders_in_panels)
+            )
+        dim_length = -1  # type: int
+        for panel_name, panel in detector["panels"].items():
+            if len(panel["dim_structure"]) == 0:
+                panel["dim_structure"] = copy.deepcopy(default_dim)
+            found_ss = 0  # type: int
+            found_fs = 0  # type: int
+            found_placeholder = 0  # type: int
+            for dim_index, entry in enumerate(panel["dim_structure"]):
+                if entry is None:
                     raise RuntimeError(
-                        "Please specify the maximum fs coordinate for panel {}.".format(
-                            panel_name
+                        "Dimension {} for panel {} is undefined.".format(
+                            dim_index, panel_name
                         )
                     )
-                if panel["orig_min_ss"] < 0:
-                    raise RuntimeError(
-                        "Please specify the minimum ss coordinate for panel {}.".format(
-                            panel_name
-                        )
+                if entry == "ss":
+                    found_ss += 1
+                elif entry == "fs":
+                    found_fs += 1
+                elif entry == "%":
+                    found_placeholder += 1
+            if found_ss != 1:
+                raise RuntimeError(
+                    "Exactly one slow scan dim coordinate is needed (found {} for "
+                    "panel {}).".format(found_ss, panel_name)
+                )
+            if found_fs != 1:
+                raise RuntimeError(
+                    "Exactly one fast scan dim coordinate is needed (found {} for "
+                    "panel {}).".format(found_fs, panel_name)
+                )
+            if found_placeholder > 1:
+                raise RuntimeError(
+                    "Only one placeholder dim coordinate is allowed. Maximum one "
+                    "placeholder dim coordinate is allowed "
+                    "(found {} for panel {})".format(found_placeholder, panel_name)
+                )
+            if dim_length == -1:
+                dim_length = len(panel["dim_structure"])
+            elif dim_length != len(panel["dim_structure"]):
+                raise RuntimeError(
+                    "Number of dim coordinates must be the same for all panels."
+                )
+            if dim_length == 1:
+                raise RuntimeError(
+                    "Number of dim coordinates must be at least " "two."
+                )
+        for panel_name, panel in detector["panels"].items():
+            if panel["orig_min_fs"] < 0:
+                raise RuntimeError(
+                    "Please specify the minimum fs coordinate for panel {}.".format(
+                        panel_name
                     )
-                if panel["orig_max_ss"] < 0:
-                    raise RuntimeError(
-                        "Please specify the maximum ss coordinate for panel {}.".format(
-                            panel_name
-                        )
+                )
+            if panel["orig_max_fs"] < 0:
+                raise RuntimeError(
+                    "Please specify the maximum fs coordinate for panel {}.".format(
+                        panel_name
                     )
-                if panel["cnx"] is None:
-                    raise RuntimeError(
-                        "Please specify the corner X coordinate for panel {}.".format(
-                            panel_name
-                        )
+                )
+            if panel["orig_min_ss"] < 0:
+                raise RuntimeError(
+                    "Please specify the minimum ss coordinate for panel {}.".format(
+                        panel_name
                     )
-                if panel["clen"] is None and panel["clen_from"] is None:
-                    raise RuntimeError(
-                        "Please specify the camera length for panel {}.".format(
-                            panel_name
-                        )
+                )
+            if panel["orig_max_ss"] < 0:
+                raise RuntimeError(
+                    "Please specify the maximum ss coordinate for panel {}.".format(
+                        panel_name
                     )
-                if panel["res"] < 0:
-                    raise RuntimeError(
-                        "Please specify the resolution or panel {}.".format(panel_name)
+                )
+            if panel["cnx"] is None:
+                raise RuntimeError(
+                    "Please specify the corner X coordinate for panel {}.".format(
+                        panel_name
                     )
-                if panel["adu_per_eV"] is None and panel["adu_per_photon"] is None:
-                    raise RuntimeError(
-                        "Please specify either adu_per_eV or adu_per_photon for panel "
-                        "{}.".format(panel_name)
+                )
+            if panel["clen"] is None and panel["clen_from"] is None:
+                raise RuntimeError(
+                    "Please specify the camera length for panel {}.".format(
+                        panel_name
                     )
-                if panel["clen_for_centering"] is None and panel["rail_x"] is not None:
+                )
+            if panel["res"] < 0:
+                raise RuntimeError(
+                    "Please specify the resolution or panel {}.".format(panel_name)
+                )
+            if panel["adu_per_eV"] is None and panel["adu_per_photon"] is None:
+                raise RuntimeError(
+                    "Please specify either adu_per_eV or adu_per_photon for panel "
+                    "{}.".format(panel_name)
+                )
+            if panel["clen_for_centering"] is None and panel["rail_x"] is not None:
+                raise RuntimeError(
+                    "You must specify clen_for_centering if you specify the rail "
+                    "direction (panel {})".format(panel_name)
+                )
+            if panel["rail_x"] is None:
+                panel["rail_x"] = 0.0
+                panel["rail_y"] = 0.0
+                panel["rail_z"] = 1.0
+            if panel["clen_for_centering"] is None:
+                panel["clen_for_centering"] = 0.0
+            panel["w"] = panel["orig_max_fs"] - panel["orig_min_fs"] + 1
+            panel["h"] = panel["orig_max_ss"] - panel["orig_min_ss"] + 1
+        for bad_region_name, bad_region in detector["bad"].items():
+            if bad_region["is_fsss"] == 99:
+                raise RuntimeError(
+                    "Please specify the coordinate ranges for bad "
+                    "region {}.".format(bad_region_name)
+                )
+        for group in detector["rigid_groups"]:
+            for name in detector["rigid_groups"][group]:
+                if name not in detector["panels"]:
                     raise RuntimeError(
-                        "You must specify clen_for_centering if you specify the rail "
-                        "direction (panel {})".format(panel_name)
+                        "Cannot add panel to rigid_group. Panel not "
+                        "found: {}".format(name)
                     )
-                if panel["rail_x"] is None:
-                    panel["rail_x"] = 0.0
-                    panel["rail_y"] = 0.0
-                    panel["rail_z"] = 1.0
-                if panel["clen_for_centering"] is None:
-                    panel["clen_for_centering"] = 0.0
-                panel["w"] = panel["orig_max_fs"] - panel["orig_min_fs"] + 1
-                panel["h"] = panel["orig_max_ss"] - panel["orig_min_ss"] + 1
-            for bad_region_name, bad_region in detector["bad"].items():
-                if bad_region["is_fsss"] == 99:
+        for group_collection in detector["rigid_group_collections"]:
+            for name in detector["rigid_group_collections"][group_collection]:
+                if name not in detector["rigid_groups"]:
                     raise RuntimeError(
-                        "Please specify the coordinate ranges for bad "
-                        "region {}.".format(bad_region_name)
+                        "Cannot add rigid_group to collection. Rigid group not "
+                        "found: {}".format(name)
                     )
-            for group in detector["rigid_groups"]:
-                for name in detector["rigid_groups"][group]:
-                    if name not in detector["panels"]:
-                        raise RuntimeError(
-                            "Cannot add panel to rigid_group. Panel not "
-                            "found: {}".format(name)
-                        )
-            for group_collection in detector["rigid_group_collections"]:
-                for name in detector["rigid_group_collections"][group_collection]:
-                    if name not in detector["rigid_groups"]:
-                        raise RuntimeError(
-                            "Cannot add rigid_group to collection. Rigid group not "
-                            "found: {}".format(name)
-                        )
-            for panel in detector["panels"].values():
-                d__ = (
-                    panel["fsx"] * panel["ssy"] - panel["ssx"] * panel["fsy"]
-                )  # type: float
-                if d__ == 0.0:
-                    raise RuntimeError("Panel {} transformation is singular.")
-                panel["xfs"] = panel["ssy"] / d__
-                panel["yfs"] = panel["ssx"] / d__
-                panel["xss"] = panel["fsy"] / d__
-                panel["yss"] = panel["fsx"] / d__
-            _find_min_max_d(detector)
+        for panel in detector["panels"].values():
+            d__ = (
+                panel["fsx"] * panel["ssy"] - panel["ssx"] * panel["fsy"]
+            )  # type: float
+            if d__ == 0.0:
+                raise RuntimeError("Panel {} transformation is singular.".format(panel))
+            panel["xfs"] = panel["ssy"] / d__
+            panel["yfs"] = panel["ssx"] / d__
+            panel["xss"] = panel["fsy"] / d__
+            panel["yss"] = panel["fsx"] / d__
+        _find_min_max_d(detector)
     except (IOError, OSError) as exc:
         exc_type, exc_value = sys.exc_info()[:2]
         raise RuntimeError(
             "The following error occurred while reading the {0} geometry"
             "file {1}: {2}".format(
-                filename,
+                file,
                 exc_type.__name__,  # type: ignore
                 exc_value,  # type: ignore
             )
-- 
GitLab