From 80da02775a80cedf04fd09b8d12cde206198c346 Mon Sep 17 00:00:00 2001 From: Somdev Sangwan Date: Wed, 31 May 2023 19:12:13 +0530 Subject: [PATCH] fix torch, improve face detection --- run.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/run.py b/run.py index 322675d..e809494 100644 --- a/run.py +++ b/run.py @@ -13,7 +13,7 @@ import random from pathlib import Path import tkinter as tk from tkinter import filedialog -from opennsfw2 import predict_image as dataset +from opennsfw2 import predict_image as face_check from tkinter.filedialog import asksaveasfilename from core.processor import process_video, process_img from core.utils import is_img, detect_fps, set_fps, create_video, add_audio, extract_frames, rreplace @@ -24,6 +24,9 @@ import cv2 import threading from PIL import Image, ImageTk +if 'ROCMExecutionProvider' not in core.globals.providers: + import torch + pool = None args = {} @@ -68,7 +71,6 @@ def pre_check(): if '--gpu' in sys.argv: NVIDIA_PROVIDERS = ['CUDAExecutionProvider', 'TensorrtExecutionProvider'] if len(list(set(core.globals.providers) - set(NVIDIA_PROVIDERS))) == 1: - import torch CUDA_VERSION = torch.version.cuda CUDNN_VERSION = torch.backends.cudnn.version() if not torch.cuda.is_available() or not CUDA_VERSION: @@ -87,6 +89,10 @@ def pre_check(): def start_processing(): start_time = time.time() + threshold = len(['frame_args']) if len(args['frame_paths']) <= 10 else 10 + for i in range(threshold): + if face_check(random.choice(args['frame_paths'])) > 0.7: + quit("[WARNING] Unable to determine location of the face in the target. Please make sure the target isn't wearing clothes matching to their skin.") if args['gpu']: process_video(args['source_img'], args["frame_paths"]) end_time = time.time() @@ -95,10 +101,6 @@ def start_processing(): return frame_paths = args["frame_paths"] n = len(frame_paths)//(args['cores_count']) - for i in range(n): - continue - if dataset(random.choice(frame_paths)) > 0.7: - quit("[WARNING] Unable to determine location of the face in the target. Please make sure the target isn't wearing clothes matching to their skin.") processes = [] for i in range(0, len(frame_paths), n): p = pool.apply_async(process_video, args=(args['source_img'], frame_paths[i:i+n],)) @@ -195,6 +197,8 @@ def start(): print("\n[WARNING] No face detected in source image. Please try with another one.\n") return if is_img(target_path): + if face_check(target_path) > 0.7: + quit("[WARNING] Unable to determine location of the face in the target. Please make sure the target isn't wearing clothes matching to their skin.") process_img(args['source_img'], target_path, args['output_file']) status("swap successful!") return