Fix GPU for NVIDIA
This commit is contained in:
parent
38fb60efca
commit
4e2be506ce
21
run.py
Normal file → Executable file
21
run.py
Normal file → Executable file
@ -8,24 +8,24 @@ import glob
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import tkinter as tk
|
||||
from tkinter import filedialog
|
||||
from opennsfw2 import predict_image as face_check
|
||||
from tkinter.filedialog import asksaveasfilename
|
||||
import core.globals
|
||||
from core.processor import process_video, process_img
|
||||
from core.utils import is_img, detect_fps, set_fps, create_video, add_audio, extract_frames, rreplace
|
||||
from core.config import get_face
|
||||
import webbrowser
|
||||
import psutil
|
||||
import cv2
|
||||
import threading
|
||||
from PIL import Image, ImageTk
|
||||
import core.globals
|
||||
from core.processor import process_video, process_img
|
||||
from core.utils import is_img, detect_fps, set_fps, create_video, add_audio, extract_frames, rreplace
|
||||
from core.config import get_face
|
||||
|
||||
if 'ROCMExecutionProvider' not in core.globals.providers:
|
||||
import torch
|
||||
if 'ROCMExecutionProvider' in core.globals.providers:
|
||||
del torch
|
||||
|
||||
pool = None
|
||||
args = {}
|
||||
@ -69,8 +69,7 @@ def pre_check():
|
||||
if not os.path.isfile(model_path):
|
||||
quit('File "inswapper_128.onnx" does not exist!')
|
||||
if '--gpu' in sys.argv:
|
||||
NVIDIA_PROVIDERS = ['CUDAExecutionProvider', 'TensorrtExecutionProvider']
|
||||
if len(list(set(core.globals.providers) - set(NVIDIA_PROVIDERS))) == 1:
|
||||
if 'ROCMExecutionProvider' not in core.globals.providers:
|
||||
CUDA_VERSION = torch.version.cuda
|
||||
CUDNN_VERSION = torch.backends.cudnn.version()
|
||||
if not torch.cuda.is_available() or not CUDA_VERSION:
|
||||
@ -89,10 +88,6 @@ def pre_check():
|
||||
|
||||
def start_processing():
|
||||
start_time = time.time()
|
||||
threshold = len(['frame_args']) if len(args['frame_paths']) <= 10 else 10
|
||||
for i in range(threshold):
|
||||
if face_check(random.choice(args['frame_paths'])) > 0.8:
|
||||
quit("[WARNING] Unable to determine location of the face in the target. Please make sure the target isn't wearing clothes matching to their skin.")
|
||||
if args['gpu']:
|
||||
process_video(args['source_img'], args["frame_paths"])
|
||||
end_time = time.time()
|
||||
|
Loading…
x
Reference in New Issue
Block a user