diff --git a/run.py b/run.py old mode 100644 new mode 100755 index 26960dc..c74f7eb --- a/run.py +++ b/run.py @@ -8,24 +8,24 @@ import glob import argparse import multiprocessing as mp import os -import random +import torch from pathlib import Path import tkinter as tk from tkinter import filedialog from opennsfw2 import predict_image as face_check from tkinter.filedialog import asksaveasfilename -import core.globals -from core.processor import process_video, process_img -from core.utils import is_img, detect_fps, set_fps, create_video, add_audio, extract_frames, rreplace -from core.config import get_face import webbrowser import psutil import cv2 import threading from PIL import Image, ImageTk +import core.globals +from core.processor import process_video, process_img +from core.utils import is_img, detect_fps, set_fps, create_video, add_audio, extract_frames, rreplace +from core.config import get_face -if 'ROCMExecutionProvider' not in core.globals.providers: - import torch +if 'ROCMExecutionProvider' in core.globals.providers: + del torch pool = None args = {} @@ -69,8 +69,7 @@ def pre_check(): if not os.path.isfile(model_path): quit('File "inswapper_128.onnx" does not exist!') if '--gpu' in sys.argv: - NVIDIA_PROVIDERS = ['CUDAExecutionProvider', 'TensorrtExecutionProvider'] - if len(list(set(core.globals.providers) - set(NVIDIA_PROVIDERS))) == 1: + if 'ROCMExecutionProvider' not in core.globals.providers: CUDA_VERSION = torch.version.cuda CUDNN_VERSION = torch.backends.cudnn.version() if not torch.cuda.is_available() or not CUDA_VERSION: @@ -89,10 +88,6 @@ def pre_check(): def start_processing(): start_time = time.time() - threshold = len(['frame_args']) if len(args['frame_paths']) <= 10 else 10 - for i in range(threshold): - if face_check(random.choice(args['frame_paths'])) > 0.8: - quit("[WARNING] Unable to determine location of the face in the target. Please make sure the target isn't wearing clothes matching to their skin.") if args['gpu']: process_video(args['source_img'], args["frame_paths"]) end_time = time.time()