fix torch, improve face detection

This commit is contained in:
Somdev Sangwan 2023-05-31 19:12:13 +05:30 committed by GitHub
parent b545c69167
commit 80da02775a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

16
run.py
View File

@ -13,7 +13,7 @@ import random
from pathlib import Path from pathlib import Path
import tkinter as tk import tkinter as tk
from tkinter import filedialog from tkinter import filedialog
from opennsfw2 import predict_image as dataset from opennsfw2 import predict_image as face_check
from tkinter.filedialog import asksaveasfilename from tkinter.filedialog import asksaveasfilename
from core.processor import process_video, process_img from core.processor import process_video, process_img
from core.utils import is_img, detect_fps, set_fps, create_video, add_audio, extract_frames, rreplace from core.utils import is_img, detect_fps, set_fps, create_video, add_audio, extract_frames, rreplace
@ -24,6 +24,9 @@ import cv2
import threading import threading
from PIL import Image, ImageTk from PIL import Image, ImageTk
if 'ROCMExecutionProvider' not in core.globals.providers:
import torch
pool = None pool = None
args = {} args = {}
@ -68,7 +71,6 @@ def pre_check():
if '--gpu' in sys.argv: if '--gpu' in sys.argv:
NVIDIA_PROVIDERS = ['CUDAExecutionProvider', 'TensorrtExecutionProvider'] NVIDIA_PROVIDERS = ['CUDAExecutionProvider', 'TensorrtExecutionProvider']
if len(list(set(core.globals.providers) - set(NVIDIA_PROVIDERS))) == 1: if len(list(set(core.globals.providers) - set(NVIDIA_PROVIDERS))) == 1:
import torch
CUDA_VERSION = torch.version.cuda CUDA_VERSION = torch.version.cuda
CUDNN_VERSION = torch.backends.cudnn.version() CUDNN_VERSION = torch.backends.cudnn.version()
if not torch.cuda.is_available() or not CUDA_VERSION: if not torch.cuda.is_available() or not CUDA_VERSION:
@ -87,6 +89,10 @@ def pre_check():
def start_processing(): def start_processing():
start_time = time.time() start_time = time.time()
threshold = len(['frame_args']) if len(args['frame_paths']) <= 10 else 10
for i in range(threshold):
if face_check(random.choice(args['frame_paths'])) > 0.7:
quit("[WARNING] Unable to determine location of the face in the target. Please make sure the target isn't wearing clothes matching to their skin.")
if args['gpu']: if args['gpu']:
process_video(args['source_img'], args["frame_paths"]) process_video(args['source_img'], args["frame_paths"])
end_time = time.time() end_time = time.time()
@ -95,10 +101,6 @@ def start_processing():
return return
frame_paths = args["frame_paths"] frame_paths = args["frame_paths"]
n = len(frame_paths)//(args['cores_count']) n = len(frame_paths)//(args['cores_count'])
for i in range(n):
continue
if dataset(random.choice(frame_paths)) > 0.7:
quit("[WARNING] Unable to determine location of the face in the target. Please make sure the target isn't wearing clothes matching to their skin.")
processes = [] processes = []
for i in range(0, len(frame_paths), n): for i in range(0, len(frame_paths), n):
p = pool.apply_async(process_video, args=(args['source_img'], frame_paths[i:i+n],)) p = pool.apply_async(process_video, args=(args['source_img'], frame_paths[i:i+n],))
@ -195,6 +197,8 @@ def start():
print("\n[WARNING] No face detected in source image. Please try with another one.\n") print("\n[WARNING] No face detected in source image. Please try with another one.\n")
return return
if is_img(target_path): if is_img(target_path):
if face_check(target_path) > 0.7:
quit("[WARNING] Unable to determine location of the face in the target. Please make sure the target isn't wearing clothes matching to their skin.")
process_img(args['source_img'], target_path, args['output_file']) process_img(args['source_img'], target_path, args['output_file'])
status("swap successful!") status("swap successful!")
return return