diff --git a/__init__.py b/__init__.py index b9858b2..13711d0 100644 --- a/__init__.py +++ b/__init__.py @@ -26,12 +26,14 @@ def INPUT_TYPES(s): def main(self, images, engine): # setup tensorrt engine - engine = Engine(os.path.join(ENGINE_DIR,engine)) - engine.load() - engine.activate() - engine.allocate_buffers() - cudaStream = torch.cuda.current_stream().cuda_stream + if (not hasattr(self, 'engine') or self.engine_label != engine): + self.engine = Engine(os.path.join(ENGINE_DIR,engine)) + self.engine.load() + self.engine.activate() + self.engine.allocate_buffers() + self.engine_label = engine + cudaStream = torch.cuda.current_stream().cuda_stream pbar = ProgressBar(images.shape[0]) images = images.permute(0, 3, 1, 2) images_resized = F.interpolate(images, size=(518,518), mode='bilinear', align_corners=False) @@ -40,7 +42,7 @@ def main(self, images, engine): depth_frames = [] for img in images_list: - result = engine.infer({"input": img},cudaStream) + result = self.engine.infer({"input": img},cudaStream) depth = result['output'] # Process the depth output @@ -56,7 +58,6 @@ def main(self, images, engine): depth_frames_np = np.array(depth_frames).astype(np.float32) / 255.0 return (torch.from_numpy(depth_frames_np),) - NODE_CLASS_MAPPINGS = { "DepthAnythingTensorrt" : DepthAnythingTensorrt, }