diff --git a/src/BOZcalc/BOZcalc.py b/src/BOZcalc/BOZcalc.py
index 775cd761aecf4b184264d34d834e8eb3016ffb5f..631b0945a74bdd8a1b64d35aeded8ee3ae01fe02 100644
--- a/src/BOZcalc/BOZcalc.py
+++ b/src/BOZcalc/BOZcalc.py
@@ -154,10 +154,10 @@ BOZ_CHEM_db = {
 }
 
 class BOZcalc():
-    def __init__(self):
+    def __init__(self, fig=None):
         self.geo_beams = GeoBeams()
         self.initWidgets()
-        self.initFig()
+        self.initFig(fig)
         self.init_beam_transport()
 
         # spot sizes of all beams
@@ -187,19 +187,27 @@ class BOZcalc():
                 scale = 1e6
             self.widgets[v].value = scale*temp.elems[v]
 
-    def initFig(self):
+    def initFig(self, fig=None):
         "Creates a figure for the sample plane and detector plane images."
 
-        plt.close('BOZcalc')
-        fig, (self.ax_sam, self.ax_det) = plt.subplots(
-            1, 2, num='BOZcalc', figsize=(6, 3))
+        if fig is None:
+            plt.close('BOZcalc')
+            self.fig, (self.ax_sam, self.ax_det) = plt.subplots(
+                1, 2, num='BOZcalc', figsize=(6, 3))
+        else:
+            self.fig = fig
+            self.ax_sam, self.ax_det = self.fig.subplots(1, 2)
 
         # display scale
         self.scale = 1e3  # displayed distances in [mm]
 
         self.ax_sam.set_title('Sample plane')
+        self.ax_sam.set_xlabel('x (mm)')
+        self.ax_sam.set_ylabel('y (mm)')
         self.ax_det.set_title('Detector plane')
-
+        self.ax_det.set_xlabel('x (mm)')
+        self.ax_det.set_ylabel('y (mm)')
+ 
         self.ax_sam.set_aspect('equal')
         self.ax_det.set_aspect('equal')
         self.ax_sam.set_xlim([-2, 2])
diff --git a/src/BOZcalc/GeoBeams.py b/src/BOZcalc/GeoBeams.py
index a78183994eab9f142ed0c68075dc72159bf28143..337857b80be0e2374e5bdc9e2b6bf55d8664b205 100644
--- a/src/BOZcalc/GeoBeams.py
+++ b/src/BOZcalc/GeoBeams.py
@@ -157,11 +157,14 @@ class GeoBeams:
 
         return z_x, x, z_y, y
 
-    def plot(self, conf1=None, conf2=None):
-        fig, ax = plt.subplots(2, 1, figsize=(6, 4), sharex=True)
+    def plot(self, conf1=None, conf2=None, fig=None):
+        if fig is None:
+            fig, ax = plt.subplots(2, 1, figsize=(6, 4), sharex=True)
+        else:
+            ax = fig.subplots(2, 1, sharex=True)
 
-        ax0in = ax[0].inset_axes([0.1, 0.65, 0.2, 0.25])
-        ax1in = ax[1].inset_axes([0.1, 0.65, 0.2, 0.25])
+        ax0in = ax[0].inset_axes([0.2, 0.55, 0.2, 0.35])
+        ax1in = ax[1].inset_axes([0.2, 0.55, 0.2, 0.35])
         axs = [ax, ax0in, ax1in]
 
         if conf1 is not None: