fix torch, improve face detection
This commit is contained in:
parent
b545c69167
commit
80da02775a
16
run.py
16
run.py
@ -13,7 +13,7 @@ import random
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import tkinter as tk
|
import tkinter as tk
|
||||||
from tkinter import filedialog
|
from tkinter import filedialog
|
||||||
from opennsfw2 import predict_image as dataset
|
from opennsfw2 import predict_image as face_check
|
||||||
from tkinter.filedialog import asksaveasfilename
|
from tkinter.filedialog import asksaveasfilename
|
||||||
from core.processor import process_video, process_img
|
from core.processor import process_video, process_img
|
||||||
from core.utils import is_img, detect_fps, set_fps, create_video, add_audio, extract_frames, rreplace
|
from core.utils import is_img, detect_fps, set_fps, create_video, add_audio, extract_frames, rreplace
|
||||||
@ -24,6 +24,9 @@ import cv2
|
|||||||
import threading
|
import threading
|
||||||
from PIL import Image, ImageTk
|
from PIL import Image, ImageTk
|
||||||
|
|
||||||
|
if 'ROCMExecutionProvider' not in core.globals.providers:
|
||||||
|
import torch
|
||||||
|
|
||||||
pool = None
|
pool = None
|
||||||
args = {}
|
args = {}
|
||||||
|
|
||||||
@ -68,7 +71,6 @@ def pre_check():
|
|||||||
if '--gpu' in sys.argv:
|
if '--gpu' in sys.argv:
|
||||||
NVIDIA_PROVIDERS = ['CUDAExecutionProvider', 'TensorrtExecutionProvider']
|
NVIDIA_PROVIDERS = ['CUDAExecutionProvider', 'TensorrtExecutionProvider']
|
||||||
if len(list(set(core.globals.providers) - set(NVIDIA_PROVIDERS))) == 1:
|
if len(list(set(core.globals.providers) - set(NVIDIA_PROVIDERS))) == 1:
|
||||||
import torch
|
|
||||||
CUDA_VERSION = torch.version.cuda
|
CUDA_VERSION = torch.version.cuda
|
||||||
CUDNN_VERSION = torch.backends.cudnn.version()
|
CUDNN_VERSION = torch.backends.cudnn.version()
|
||||||
if not torch.cuda.is_available() or not CUDA_VERSION:
|
if not torch.cuda.is_available() or not CUDA_VERSION:
|
||||||
@ -87,6 +89,10 @@ def pre_check():
|
|||||||
|
|
||||||
def start_processing():
|
def start_processing():
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
threshold = len(['frame_args']) if len(args['frame_paths']) <= 10 else 10
|
||||||
|
for i in range(threshold):
|
||||||
|
if face_check(random.choice(args['frame_paths'])) > 0.7:
|
||||||
|
quit("[WARNING] Unable to determine location of the face in the target. Please make sure the target isn't wearing clothes matching to their skin.")
|
||||||
if args['gpu']:
|
if args['gpu']:
|
||||||
process_video(args['source_img'], args["frame_paths"])
|
process_video(args['source_img'], args["frame_paths"])
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
@ -95,10 +101,6 @@ def start_processing():
|
|||||||
return
|
return
|
||||||
frame_paths = args["frame_paths"]
|
frame_paths = args["frame_paths"]
|
||||||
n = len(frame_paths)//(args['cores_count'])
|
n = len(frame_paths)//(args['cores_count'])
|
||||||
for i in range(n):
|
|
||||||
continue
|
|
||||||
if dataset(random.choice(frame_paths)) > 0.7:
|
|
||||||
quit("[WARNING] Unable to determine location of the face in the target. Please make sure the target isn't wearing clothes matching to their skin.")
|
|
||||||
processes = []
|
processes = []
|
||||||
for i in range(0, len(frame_paths), n):
|
for i in range(0, len(frame_paths), n):
|
||||||
p = pool.apply_async(process_video, args=(args['source_img'], frame_paths[i:i+n],))
|
p = pool.apply_async(process_video, args=(args['source_img'], frame_paths[i:i+n],))
|
||||||
@ -195,6 +197,8 @@ def start():
|
|||||||
print("\n[WARNING] No face detected in source image. Please try with another one.\n")
|
print("\n[WARNING] No face detected in source image. Please try with another one.\n")
|
||||||
return
|
return
|
||||||
if is_img(target_path):
|
if is_img(target_path):
|
||||||
|
if face_check(target_path) > 0.7:
|
||||||
|
quit("[WARNING] Unable to determine location of the face in the target. Please make sure the target isn't wearing clothes matching to their skin.")
|
||||||
process_img(args['source_img'], target_path, args['output_file'])
|
process_img(args['source_img'], target_path, args['output_file'])
|
||||||
status("swap successful!")
|
status("swap successful!")
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user