diff --git a/core/config.py b/core/config.py index 151ad33..8a3d908 100644 --- a/core/config.py +++ b/core/config.py @@ -1,13 +1,20 @@ import insightface import core.globals -face_analyser = insightface.app.FaceAnalysis(name='buffalo_l', providers=core.globals.providers) -face_analyser.prepare(ctx_id=0, det_size=(640, 640)) +FACE_ANALYSER = None + + +def get_face_analyser(): + global FACE_ANALYSER + if FACE_ANALYSER is None: + FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=core.globals.providers) + FACE_ANALYSER.prepare(ctx_id=0, det_size=(640, 640)) + return FACE_ANALYSER def get_face(img_data): - analysed = face_analyser.get(img_data) + face = get_face_analyser().get(img_data) try: - return sorted(analysed, key=lambda x: x.bbox[0])[0] + return sorted(face, key=lambda x: x.bbox[0])[0] except IndexError: return None diff --git a/core/processor.py b/core/processor.py index 82d61c1..f09ba57 100644 --- a/core/processor.py +++ b/core/processor.py @@ -1,11 +1,16 @@ -import os import cv2 import insightface -import core.globals from core.config import get_face from core.utils import rreplace -face_swapper = insightface.model_zoo.get_model('inswapper_128.onnx', providers=core.globals.providers) +FACE_SWAPPER = None + + +def get_face_swapper(): + global FACE_SWAPPER + if FACE_SWAPPER is None: + FACE_SWAPPER = insightface.model_zoo.get_model('inswapper_128.onnx') + return FACE_SWAPPER def process_video(source_img, frame_paths): @@ -15,12 +20,13 @@ def process_video(source_img, frame_paths): try: face = get_face(frame) if face: - result = face_swapper.get(frame, face, source_face, paste_back=True) + result = get_face_swapper().get(frame, face, source_face, paste_back=True) cv2.imwrite(frame_path, result) print('.', end='', flush=True) else: print('S', end='', flush=True) except Exception as e: + print(e, flush=True) print('E', end='', flush=True) pass @@ -30,6 +36,6 @@ def process_img(source_img, target_path): face = get_face(frame) source_face = get_face(cv2.imread(source_img)) result = face_swapper.get(frame, face, source_face, paste_back=True) - target_path = rreplace(target_path, "/", "/swapped-", 1) if "/" in target_path else "swapped-"+target_path + target_path = rreplace(target_path, "/", "/swapped-", 1) if "/" in target_path else "swapped-" + target_path print(target_path) cv2.imwrite(target_path, result)