From c19f8125f28d22d9ffd59731460482929ea38397 Mon Sep 17 00:00:00 2001 From: Somdev Sangwan Date: Sun, 4 Jun 2023 17:00:09 +0530 Subject: [PATCH] add threading --- roop/swapper.py | 71 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 46 insertions(+), 25 deletions(-) diff --git a/roop/swapper.py b/roop/swapper.py index 9f25b46..3e3eca8 100644 --- a/roop/swapper.py +++ b/roop/swapper.py @@ -4,24 +4,20 @@ import torch import onnxruntime import cv2 import insightface - +import threading import roop.globals from roop.analyser import get_face_single, get_face_many +from roop.globals import gpu_vendor FACE_SWAPPER = None - +lock = threading.Lock() def get_face_swapper(): global FACE_SWAPPER - if FACE_SWAPPER is None: - session_options = onnxruntime.SessionOptions() - if roop.globals.gpu_vendor is not None: - session_options.intra_op_num_threads = roop.globals.gpu_threads - else: - session_options.intra_op_num_threads = roop.globals.cpu_threads - session_options.execution_mode = onnxruntime.ExecutionMode.ORT_PARALLEL - model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), '../inswapper_128.onnx') - FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.providers, session_options=session_options) + with lock: + if FACE_SWAPPER is None: + model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), '../inswapper_128.onnx') + FACE_SWAPPER = insightface.model_zoo.get_model(model_path, providers=roop.globals.providers) return FACE_SWAPPER @@ -31,7 +27,20 @@ def swap_face_in_frame(source_face, target_face, frame): return frame -def process_faces(source_face, target_frame, progress): +def process_frames(source_img, frame_paths, progress=None): + source_face = get_face_single(cv2.imread(source_img)) + for frame_path in frame_paths: + frame = cv2.imread(frame_path) + try: + result = process_faces(source_face, frame) + cv2.imwrite(frame_path, result) + except Exception as e: + print(">>>>", e) + pass + if progress: + progress.update(1) + +def process_faces(source_face, target_frame): if roop.globals.all_faces: many_faces = get_face_many(target_frame) if many_faces: @@ -45,22 +54,34 @@ def process_faces(source_face, target_frame, progress): def process_video(source_img, frame_paths, preview_callback): - source_face = get_face_single(cv2.imread(source_img)) progress_bar_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]' with tqdm(total=len(frame_paths), desc="Processing", unit="frame", dynamic_ncols=True, bar_format=progress_bar_format) as progress: - for frame_path in frame_paths: - if roop.globals.gpu_vendor == 'nvidia': - progress.set_postfix(cuda_utilization="{:02d}%".format(torch.cuda.utilization()), cuda_memory="{:02d}GB".format(torch.cuda.memory_usage())) - frame = cv2.imread(frame_path) - try: - result = process_faces(source_face, frame, progress) - cv2.imwrite(frame_path, result) - if preview_callback: - preview_callback(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) - except Exception: - pass - progress.update(1) + # nvidia multi-threading + if roop.globals.gpu_vendor == 'nvidia': + #progress.set_postfix(cuda_utilization="{:02d}%".format(torch.cuda.utilization()), cuda_memory="{:02d}GB".format(torch.cuda.memory_usage())) + # caculate the number of frames each threads processed + num_threads = roop.globals.gpu_threads + num_frames_per_thread = len(frame_paths) // num_threads + remaining_frames = len(frame_paths) % num_threads + # create thread list + threads = [] + start_index = 0 + # create thread and launch + for _ in range(num_threads): + end_index = start_index + num_frames_per_thread + if remaining_frames > 0: + end_index += 1 + remaining_frames -= 1 + thread_frame_paths = frame_paths[start_index:end_index] + thread = threading.Thread(target=process_frames, args=(source_img, thread_frame_paths, progress)) + threads.append(thread) + thread.start() + start_index = end_index + for thread in threads: + thread.join() + else: + process_frames(source_img, frame_paths, progress) def process_img(source_img, target_path, output_file):