diff --git a/roop/analyser.py b/roop/analyser.py index 804f7a8..c2899e7 100644 --- a/roop/analyser.py +++ b/roop/analyser.py @@ -1,10 +1,11 @@ +from typing import Any import insightface import roop.globals FACE_ANALYSER = None -def get_face_analyser(): +def get_face_analyser() -> Any: global FACE_ANALYSER if FACE_ANALYSER is None: FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=roop.globals.providers) @@ -12,16 +13,16 @@ def get_face_analyser(): return FACE_ANALYSER -def get_face_single(img_data): - face = get_face_analyser().get(img_data) +def get_face_single(image_data) -> Any: + face = get_face_analyser().get(image_data) try: - return sorted(face, key=lambda x: x.bbox[0])[0] - except IndexError: + return min(face, key=lambda x: x.bbox[0]) + except ValueError: return None -def get_face_many(img_data): +def get_face_many(image_data) -> Any: try: - return get_face_analyser().get(img_data) + return get_face_analyser().get(image_data) except IndexError: return None diff --git a/roop/core.py b/roop/core.py index fb86d54..48bd2d7 100755 --- a/roop/core.py +++ b/roop/core.py @@ -2,10 +2,11 @@ import os import sys -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # single thread doubles performance of gpu-mode - needs to be set before torch import if any(arg.startswith('--gpu-vendor') for arg in sys.argv): os.environ['OMP_NUM_THREADS'] = '1' +# reduce tensorflow log level +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import warnings from typing import List import platform @@ -20,15 +21,18 @@ from opennsfw2 import predict_video_frames, predict_image import cv2 import roop.globals -from roop.swapper import process_video, process_img, process_faces +import roop.ui as ui +from roop.swapper import process_video, process_img from roop.utilities import has_image_extention, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frames_paths, restore_audio, create_temp, move_temp, clean_temp from roop.analyser import get_face_single -import roop.ui as ui + +if 'ROCMExecutionProvider' in roop.globals.providers: + del torch warnings.simplefilter(action='ignore', category=FutureWarning) -def handle_parse(): - global args + +def parse_args() -> None: signal.signal(signal.SIGINT, lambda signal_number, frame: destroy()) parser = argparse.ArgumentParser() parser.add_argument('-f', '--face', help='use this face', dest='source_path') @@ -45,6 +49,9 @@ def handle_parse(): args = parser.parse_known_args()[0] + roop.globals.source_path = args.source_path + roop.globals.target_path = args.target_path + roop.globals.output_path = args.output_path roop.globals.headless = args.source_path or args.target_path or args.output_path roop.globals.keep_fps = args.keep_fps roop.globals.keep_audio = args.keep_audio @@ -76,8 +83,8 @@ def limit_resources(): gpus = tensorflow.config.experimental.list_physical_devices('GPU') for gpu in gpus: tensorflow.config.experimental.set_memory_growth(gpu, True) - if args.max_memory: - memory = args.max_memory * 1024 * 1024 * 1024 + if roop.globals.max_memory: + memory = roop.globals.max_memory * 1024 * 1024 * 1024 if str(platform.system()).lower() == 'windows': import ctypes kernel32 = ctypes.windll.kernel32 @@ -102,58 +109,22 @@ def pre_check(): if 'ROCMExecutionProvider' not in roop.globals.providers: quit('You are using --gpu=amd flag but ROCM is not available or properly installed on your system.') if roop.globals.gpu_vendor == 'nvidia': - CUDA_VERSION = torch.version.cuda - CUDNN_VERSION = torch.backends.cudnn.version() if not torch.cuda.is_available(): quit('You are using --gpu=nvidia flag but CUDA is not available or properly installed on your system.') - if CUDA_VERSION > '11.8': - quit(f'CUDA version {CUDA_VERSION} is not supported - please downgrade to 11.8') - if CUDA_VERSION < '11.4': - quit(f'CUDA version {CUDA_VERSION} is not supported - please upgrade to 11.8') - if CUDNN_VERSION < 8220: - quit(f'CUDNN version {CUDNN_VERSION} is not supported - please upgrade to 8.9.1') - if CUDNN_VERSION > 8910: - quit(f'CUDNN version {CUDNN_VERSION} is not supported - please downgrade to 8.9.1') - - -def get_video_frame(video_path, frame_number = 1): - cap = cv2.VideoCapture(video_path) - amount_of_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) - cap.set(cv2.CAP_PROP_POS_FRAMES, min(amount_of_frames, frame_number-1)) - if not cap.isOpened(): - status('Error opening video file') - return - ret, frame = cap.read() - if ret: - return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - cap.release() - - -def preview_video(video_path): - cap = cv2.VideoCapture(video_path) - if not cap.isOpened(): - status('Error opening video file') - return 0 - amount_of_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) - ret, frame = cap.read() - if ret: - frame = get_video_frame(video_path) - - cap.release() - return (amount_of_frames, frame) - - -def status(message: str): - value = 'Status: ' + message - print(value) - if not roop.globals.headless: - ui.update_status_label(value) + if torch.version.cuda > '11.8': + quit(f'CUDA version {torch.version.cuda} is not supported - please downgrade to 11.8') + if torch.version.cuda < '11.4': + quit(f'CUDA version {torch.version.cuda} is not supported - please upgrade to 11.8') + if torch.backends.cudnn.version() < 8220: + quit(f'CUDNN version { torch.backends.cudnn.version()} is not supported - please upgrade to 8.9.1') + if torch.backends.cudnn.version() > 8910: + quit(f'CUDNN version { torch.backends.cudnn.version()} is not supported - please downgrade to 8.9.1') def conditional_process_video(source_path: str, frame_paths: List[str]) -> None: pool_amount = len(frame_paths) // roop.globals.cpu_cores if pool_amount > 2 and roop.globals.cpu_cores > 1 and roop.globals.gpu_vendor is None: - status('Pool-Swapping in progress...') + update_status('Pool-Swapping in progress...') global POOL POOL = multiprocessing.Pool(roop.globals.cpu_cores, maxtasksperchild=1) pools = [] @@ -162,129 +133,89 @@ def conditional_process_video(source_path: str, frame_paths: List[str]) -> None: pools.append(pool) for pool in pools: pool.get() - POOL.join() POOL.close() + POOL.join() else: - status('Swapping in progress...') - process_video(args.source_path, frame_paths) + update_status('Swapping in progress...') + process_video(roop.globals.source_path, frame_paths) -def start(preview_callback = None) -> None: - if not args.source_path or not os.path.isfile(args.source_path): - status('Please select an image containing a face.') +def update_status(message: str): + value = 'Status: ' + message + print(value) + if not roop.globals.headless: + ui.update_status(value) + + +def start() -> None: + if not roop.globals.source_path or not os.path.isfile(roop.globals.source_path): + update_status('Please select an image containing a face.') return - elif not args.target_path or not os.path.isfile(args.target_path): - status('Please select a video/image target!') + elif not roop.globals.target_path or not os.path.isfile(roop.globals.target_path): + update_status('Please select a video/image target!') return - test_face = get_face_single(cv2.imread(args.source_path)) + test_face = get_face_single(cv2.imread(roop.globals.source_path)) if not test_face: - status('No face detected in source image. Please try with another one!') + update_status('No face detected in source image. Please try with another one!') return # process image to image - if has_image_extention(args.target_path): - if predict_image(args.target_path) > 0.85: + if has_image_extention(roop.globals.target_path): + if predict_image(roop.globals.target_path) > 0.85: destroy() - process_img(args.source_path, args.target_path, args.output_path) - if is_image(args.target_path): - status('Swapping to image succeed!') + process_img(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path) + if is_image(roop.globals.target_path): + update_status('Swapping to image succeed!') else: - status('Swapping to image failed!') + update_status('Swapping to image failed!') return # process image to videos - seconds, probabilities = predict_video_frames(video_path=args.target_path, frame_interval=100) + seconds, probabilities = predict_video_frames(video_path=roop.globals.target_path, frame_interval=100) if any(probability > 0.85 for probability in probabilities): destroy() - status('Creating temp resources...') - create_temp(args.target_path) - status('Extracting frames...') - extract_frames(args.target_path) - frame_paths = get_temp_frames_paths(args.target_path) - conditional_process_video(args.source_path, frame_paths) + update_status('Creating temp resources...') + create_temp(roop.globals.target_path) + update_status('Extracting frames...') + extract_frames(roop.globals.target_path) + frame_paths = get_temp_frames_paths(roop.globals.target_path) + conditional_process_video(roop.globals.source_path, frame_paths) # prevent memory leak using ffmpeg with cuda - if args.gpu_vendor == 'nvidia': + if roop.globals.gpu_vendor == 'nvidia': torch.cuda.empty_cache() if roop.globals.keep_fps: - status('Detecting fps...') - fps = detect_fps(args.source_path) - status(f'Creating video with {fps} fps...') - create_video(args.target_path, fps) + update_status('Detecting fps...') + fps = detect_fps(roop.globals.source_path) + update_status(f'Creating video with {fps} fps...') + create_video(roop.globals.target_path, fps) else: - status('Creating video with 30 fps...') - create_video(args.target_path, 30) + update_status('Creating video with 30 fps...') + create_video(roop.globals.target_path, 30) if roop.globals.keep_audio: if roop.globals.keep_fps: - status('Restoring audio...') + update_status('Restoring audio...') else: - status('Restoring audio might cause issues as fps are not kept...') - restore_audio(args.target_path, args.output_path) + update_status('Restoring audio might cause issues as fps are not kept...') + restore_audio(roop.globals.target_path, roop.globals.output_path) else: - move_temp(args.target_path, args.output_path) - clean_temp(args.target_path) - if is_video(args.target_path): - status('Swapping to video succeed!') + move_temp(roop.globals.target_path, roop.globals.output_path) + clean_temp(roop.globals.target_path) + if is_video(roop.globals.target_path): + update_status('Swapping to video succeed!') else: - status('Swapping to video failed!') - - -def select_face_handler(path: str): - args.source_path = path - - -def select_target_handler(path: str): - args.target_path = path - return preview_video(args.target_path) - - -def toggle_all_faces_handler(value: int): - roop.globals.all_faces = True if value == 1 else False - - -def toggle_fps_limit_handler(value: int): - args.keep_fps = int(value != 1) - - -def toggle_keep_frames_handler(value: int): - args.keep_frames = value - - -def save_file_handler(path: str): - args.output_path = path - - -def create_test_preview(frame_number): - return process_faces( - get_face_single(cv2.imread(args.source_path)), - get_video_frame(args.target_path, frame_number) - ) + update_status('Swapping to video failed!') def destroy() -> None: - clean_temp(args.target_path) + if roop.globals.target_path: + clean_temp(roop.globals.target_path) quit() def run() -> None: - global all_faces, keep_frames, limit_fps - handle_parse() + parse_args() pre_check() limit_resources() if roop.globals.headless: start() else: - window = ui.init( - { - 'all_faces': args.all_faces, - 'keep_fps': args.keep_fps, - 'keep_frames': args.keep_frames - }, - select_face_handler, - select_target_handler, - toggle_all_faces_handler, - toggle_fps_limit_handler, - toggle_keep_frames_handler, - save_file_handler, - start, - get_video_frame, - create_test_preview - ) + window = ui.init(start) window.mainloop() diff --git a/roop/globals.py b/roop/globals.py index 100c193..c872571 100644 --- a/roop/globals.py +++ b/roop/globals.py @@ -1,5 +1,8 @@ import onnxruntime +source_path = None +target_path = None +output_path = None keep_fps = None keep_audio = None keep_frames = None @@ -7,6 +10,7 @@ all_faces = None cpu_cores = None gpu_threads = None gpu_vendor = None +max_memory = None headless = None log_level = 'error' providers = onnxruntime.get_available_providers() diff --git a/roop/ui.py b/roop/ui.py index bbca8bf..b83678e 100644 --- a/roop/ui.py +++ b/roop/ui.py @@ -1,11 +1,16 @@ import tkinter as tk from typing import Any, Callable, Tuple + +import cv2 from PIL import Image, ImageTk, ImageOps import webbrowser from tkinter import filedialog from tkinter.filedialog import asksaveasfilename import threading +import roop.globals +from roop.analyser import get_face_single +from roop.swapper import process_faces from roop.utilities import is_image max_preview_size = 800 @@ -213,23 +218,12 @@ def preview_target(frame): target_label.image = photo_img -def update_status_label(value): +def update_status(value): status_label["text"] = value window.update() -def init( - initial_values: dict, - select_face_handler: Callable[[str], None], - select_target_handler: Callable[[str], Tuple[int, Any]], - toggle_all_faces_handler: Callable[[int], None], - toggle_fps_limit_handler: Callable[[int], None], - toggle_keep_frames_handler: Callable[[int], None], - save_file_handler: Callable[[str], None], - start: Callable[[], None], - get_video_frame: Callable[[str, int], None], - create_test_preview: Callable[[int], Any], -): +def init(start: Callable[[], None]): global window, preview, preview_visible, face_label, target_label, status_label window = tk.Tk() @@ -274,22 +268,23 @@ def init( target_button.place(x=360,y=320,width=180,height=80) # All faces checkbox - all_faces = tk.IntVar(None, initial_values['all_faces']) + all_faces = tk.IntVar(None, roop.globals.all_faces) all_faces_checkbox = create_check(window, "Process all faces in frame", all_faces, toggle_all_faces(toggle_all_faces_handler, all_faces)) all_faces_checkbox.place(x=60,y=500,width=240,height=31) # FPS limit checkbox - limit_fps = tk.IntVar(None, not initial_values['keep_fps']) + limit_fps = tk.IntVar(None, not roop.globals.keep_fps) fps_checkbox = create_check(window, "Limit FPS to 30", limit_fps, toggle_fps_limit(toggle_fps_limit_handler, limit_fps)) fps_checkbox.place(x=60,y=475,width=240,height=31) # Keep frames checkbox - keep_frames = tk.IntVar(None, initial_values['keep_frames']) + keep_frames = tk.IntVar(None, roop.globals.keep_frames) frames_checkbox = create_check(window, "Keep frames dir", keep_frames, toggle_keep_frames(toggle_keep_frames_handler, keep_frames)) frames_checkbox.place(x=60,y=450,width=240,height=31) # Start button - start_button = create_button(window, "Start", lambda: [save_file(save_file_handler, target_path.get()), preview_thread(lambda: start(update_preview))]) + #start_button = create_button(window, "Start", lambda: [save_file(save_file_handler, target_path.get()), preview_thread(lambda: start(update_preview))]) + start_button = create_button(window, "Start", lambda: [save_file(save_file_handler, target_path.get()), start]) start_button.place(x=170,y=560,width=120,height=49) # Preview button @@ -301,3 +296,63 @@ def init( status_label.place(x=10,y=640,width=580,height=30) return window + + +def get_video_frame(video_path, frame_number = 1): + cap = cv2.VideoCapture(video_path) + amount_of_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) + cap.set(cv2.CAP_PROP_POS_FRAMES, min(amount_of_frames, frame_number-1)) + if not cap.isOpened(): + update_status('Error opening video file') + return + ret, frame = cap.read() + if ret: + return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + cap.release() + + +def preview_video(video_path): + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + update_status('Error opening video file') + return 0 + amount_of_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) + ret, frame = cap.read() + if ret: + frame = get_video_frame(video_path) + + cap.release() + return (amount_of_frames, frame) + + +def select_face_handler(path: str): + roop.globals.source_path = path + + +def select_target_handler(target_path: str) -> None: + roop.globals.target_path = target_path + return preview_video(roop.globals.target_path) + + +def toggle_all_faces_handler(value: int): + roop.globals.all_faces = True if value == 1 else False + + +def toggle_fps_limit_handler(value: int): + roop.globals.keep_fps = int(value != 1) + + +def toggle_keep_frames_handler(value: int): + roop.globals.keep_frames = value + + +def save_file_handler(path: str): + roop.globals.output_path = path + + +def create_test_preview(frame_number): + return process_faces( + get_face_single(cv2.imread(roop.globals.source_path)), + get_video_frame(roop.globals.target_path, frame_number) + ) +