diff --git a/src/calng/base_gpu.py b/src/calng/base_gpu.py index c0619d0ae253b7cb736dbc2fe7275092c7493a3c..333eb053815cd643540cca94e3b1aa0c457adc37 100644 --- a/src/calng/base_gpu.py +++ b/src/calng/base_gpu.py @@ -126,7 +126,7 @@ class BaseGpuRunner: return self.reshaped_data_gpu.get(out=out) def load_data(self, raw_data): - self.input_data_gpu.set(np.squeeze(raw_data)) + self.input_data_gpu.set(raw_data) def load_cell_table(self, cell_table): self.cell_table_gpu.set(cell_table)