Move every UI related thing to ui.py

This commit is contained in:
henryruhs 2023-06-06 19:29:00 +02:00
parent 80aec1cb3a
commit 4194584854
4 changed files with 156 additions and 165 deletions

View File

@ -1,10 +1,11 @@
from typing import Any
import insightface import insightface
import roop.globals import roop.globals
FACE_ANALYSER = None FACE_ANALYSER = None
def get_face_analyser(): def get_face_analyser() -> Any:
global FACE_ANALYSER global FACE_ANALYSER
if FACE_ANALYSER is None: if FACE_ANALYSER is None:
FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=roop.globals.providers) FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=roop.globals.providers)
@ -12,16 +13,16 @@ def get_face_analyser():
return FACE_ANALYSER return FACE_ANALYSER
def get_face_single(img_data): def get_face_single(image_data) -> Any:
face = get_face_analyser().get(img_data) face = get_face_analyser().get(image_data)
try: try:
return sorted(face, key=lambda x: x.bbox[0])[0] return min(face, key=lambda x: x.bbox[0])
except IndexError: except ValueError:
return None return None
def get_face_many(img_data): def get_face_many(image_data) -> Any:
try: try:
return get_face_analyser().get(img_data) return get_face_analyser().get(image_data)
except IndexError: except IndexError:
return None return None

View File

@ -2,10 +2,11 @@
import os import os
import sys import sys
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# single thread doubles performance of gpu-mode - needs to be set before torch import # single thread doubles performance of gpu-mode - needs to be set before torch import
if any(arg.startswith('--gpu-vendor') for arg in sys.argv): if any(arg.startswith('--gpu-vendor') for arg in sys.argv):
os.environ['OMP_NUM_THREADS'] = '1' os.environ['OMP_NUM_THREADS'] = '1'
# reduce tensorflow log level
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import warnings import warnings
from typing import List from typing import List
import platform import platform
@ -20,15 +21,18 @@ from opennsfw2 import predict_video_frames, predict_image
import cv2 import cv2
import roop.globals import roop.globals
from roop.swapper import process_video, process_img, process_faces import roop.ui as ui
from roop.swapper import process_video, process_img
from roop.utilities import has_image_extention, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frames_paths, restore_audio, create_temp, move_temp, clean_temp from roop.utilities import has_image_extention, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frames_paths, restore_audio, create_temp, move_temp, clean_temp
from roop.analyser import get_face_single from roop.analyser import get_face_single
import roop.ui as ui
if 'ROCMExecutionProvider' in roop.globals.providers:
del torch
warnings.simplefilter(action='ignore', category=FutureWarning) warnings.simplefilter(action='ignore', category=FutureWarning)
def handle_parse():
global args def parse_args() -> None:
signal.signal(signal.SIGINT, lambda signal_number, frame: destroy()) signal.signal(signal.SIGINT, lambda signal_number, frame: destroy())
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-f', '--face', help='use this face', dest='source_path') parser.add_argument('-f', '--face', help='use this face', dest='source_path')
@ -45,6 +49,9 @@ def handle_parse():
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
roop.globals.source_path = args.source_path
roop.globals.target_path = args.target_path
roop.globals.output_path = args.output_path
roop.globals.headless = args.source_path or args.target_path or args.output_path roop.globals.headless = args.source_path or args.target_path or args.output_path
roop.globals.keep_fps = args.keep_fps roop.globals.keep_fps = args.keep_fps
roop.globals.keep_audio = args.keep_audio roop.globals.keep_audio = args.keep_audio
@ -76,8 +83,8 @@ def limit_resources():
gpus = tensorflow.config.experimental.list_physical_devices('GPU') gpus = tensorflow.config.experimental.list_physical_devices('GPU')
for gpu in gpus: for gpu in gpus:
tensorflow.config.experimental.set_memory_growth(gpu, True) tensorflow.config.experimental.set_memory_growth(gpu, True)
if args.max_memory: if roop.globals.max_memory:
memory = args.max_memory * 1024 * 1024 * 1024 memory = roop.globals.max_memory * 1024 * 1024 * 1024
if str(platform.system()).lower() == 'windows': if str(platform.system()).lower() == 'windows':
import ctypes import ctypes
kernel32 = ctypes.windll.kernel32 kernel32 = ctypes.windll.kernel32
@ -102,58 +109,22 @@ def pre_check():
if 'ROCMExecutionProvider' not in roop.globals.providers: if 'ROCMExecutionProvider' not in roop.globals.providers:
quit('You are using --gpu=amd flag but ROCM is not available or properly installed on your system.') quit('You are using --gpu=amd flag but ROCM is not available or properly installed on your system.')
if roop.globals.gpu_vendor == 'nvidia': if roop.globals.gpu_vendor == 'nvidia':
CUDA_VERSION = torch.version.cuda
CUDNN_VERSION = torch.backends.cudnn.version()
if not torch.cuda.is_available(): if not torch.cuda.is_available():
quit('You are using --gpu=nvidia flag but CUDA is not available or properly installed on your system.') quit('You are using --gpu=nvidia flag but CUDA is not available or properly installed on your system.')
if CUDA_VERSION > '11.8': if torch.version.cuda > '11.8':
quit(f'CUDA version {CUDA_VERSION} is not supported - please downgrade to 11.8') quit(f'CUDA version {torch.version.cuda} is not supported - please downgrade to 11.8')
if CUDA_VERSION < '11.4': if torch.version.cuda < '11.4':
quit(f'CUDA version {CUDA_VERSION} is not supported - please upgrade to 11.8') quit(f'CUDA version {torch.version.cuda} is not supported - please upgrade to 11.8')
if CUDNN_VERSION < 8220: if torch.backends.cudnn.version() < 8220:
quit(f'CUDNN version {CUDNN_VERSION} is not supported - please upgrade to 8.9.1') quit(f'CUDNN version { torch.backends.cudnn.version()} is not supported - please upgrade to 8.9.1')
if CUDNN_VERSION > 8910: if torch.backends.cudnn.version() > 8910:
quit(f'CUDNN version {CUDNN_VERSION} is not supported - please downgrade to 8.9.1') quit(f'CUDNN version { torch.backends.cudnn.version()} is not supported - please downgrade to 8.9.1')
def get_video_frame(video_path, frame_number = 1):
cap = cv2.VideoCapture(video_path)
amount_of_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
cap.set(cv2.CAP_PROP_POS_FRAMES, min(amount_of_frames, frame_number-1))
if not cap.isOpened():
status('Error opening video file')
return
ret, frame = cap.read()
if ret:
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
cap.release()
def preview_video(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
status('Error opening video file')
return 0
amount_of_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
ret, frame = cap.read()
if ret:
frame = get_video_frame(video_path)
cap.release()
return (amount_of_frames, frame)
def status(message: str):
value = 'Status: ' + message
print(value)
if not roop.globals.headless:
ui.update_status_label(value)
def conditional_process_video(source_path: str, frame_paths: List[str]) -> None: def conditional_process_video(source_path: str, frame_paths: List[str]) -> None:
pool_amount = len(frame_paths) // roop.globals.cpu_cores pool_amount = len(frame_paths) // roop.globals.cpu_cores
if pool_amount > 2 and roop.globals.cpu_cores > 1 and roop.globals.gpu_vendor is None: if pool_amount > 2 and roop.globals.cpu_cores > 1 and roop.globals.gpu_vendor is None:
status('Pool-Swapping in progress...') update_status('Pool-Swapping in progress...')
global POOL global POOL
POOL = multiprocessing.Pool(roop.globals.cpu_cores, maxtasksperchild=1) POOL = multiprocessing.Pool(roop.globals.cpu_cores, maxtasksperchild=1)
pools = [] pools = []
@ -162,129 +133,89 @@ def conditional_process_video(source_path: str, frame_paths: List[str]) -> None:
pools.append(pool) pools.append(pool)
for pool in pools: for pool in pools:
pool.get() pool.get()
POOL.join()
POOL.close() POOL.close()
POOL.join()
else: else:
status('Swapping in progress...') update_status('Swapping in progress...')
process_video(args.source_path, frame_paths) process_video(roop.globals.source_path, frame_paths)
def start(preview_callback = None) -> None: def update_status(message: str):
if not args.source_path or not os.path.isfile(args.source_path): value = 'Status: ' + message
status('Please select an image containing a face.') print(value)
if not roop.globals.headless:
ui.update_status(value)
def start() -> None:
if not roop.globals.source_path or not os.path.isfile(roop.globals.source_path):
update_status('Please select an image containing a face.')
return return
elif not args.target_path or not os.path.isfile(args.target_path): elif not roop.globals.target_path or not os.path.isfile(roop.globals.target_path):
status('Please select a video/image target!') update_status('Please select a video/image target!')
return return
test_face = get_face_single(cv2.imread(args.source_path)) test_face = get_face_single(cv2.imread(roop.globals.source_path))
if not test_face: if not test_face:
status('No face detected in source image. Please try with another one!') update_status('No face detected in source image. Please try with another one!')
return return
# process image to image # process image to image
if has_image_extention(args.target_path): if has_image_extention(roop.globals.target_path):
if predict_image(args.target_path) > 0.85: if predict_image(roop.globals.target_path) > 0.85:
destroy() destroy()
process_img(args.source_path, args.target_path, args.output_path) process_img(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
if is_image(args.target_path): if is_image(roop.globals.target_path):
status('Swapping to image succeed!') update_status('Swapping to image succeed!')
else: else:
status('Swapping to image failed!') update_status('Swapping to image failed!')
return return
# process image to videos # process image to videos
seconds, probabilities = predict_video_frames(video_path=args.target_path, frame_interval=100) seconds, probabilities = predict_video_frames(video_path=roop.globals.target_path, frame_interval=100)
if any(probability > 0.85 for probability in probabilities): if any(probability > 0.85 for probability in probabilities):
destroy() destroy()
status('Creating temp resources...') update_status('Creating temp resources...')
create_temp(args.target_path) create_temp(roop.globals.target_path)
status('Extracting frames...') update_status('Extracting frames...')
extract_frames(args.target_path) extract_frames(roop.globals.target_path)
frame_paths = get_temp_frames_paths(args.target_path) frame_paths = get_temp_frames_paths(roop.globals.target_path)
conditional_process_video(args.source_path, frame_paths) conditional_process_video(roop.globals.source_path, frame_paths)
# prevent memory leak using ffmpeg with cuda # prevent memory leak using ffmpeg with cuda
if args.gpu_vendor == 'nvidia': if roop.globals.gpu_vendor == 'nvidia':
torch.cuda.empty_cache() torch.cuda.empty_cache()
if roop.globals.keep_fps: if roop.globals.keep_fps:
status('Detecting fps...') update_status('Detecting fps...')
fps = detect_fps(args.source_path) fps = detect_fps(roop.globals.source_path)
status(f'Creating video with {fps} fps...') update_status(f'Creating video with {fps} fps...')
create_video(args.target_path, fps) create_video(roop.globals.target_path, fps)
else: else:
status('Creating video with 30 fps...') update_status('Creating video with 30 fps...')
create_video(args.target_path, 30) create_video(roop.globals.target_path, 30)
if roop.globals.keep_audio: if roop.globals.keep_audio:
if roop.globals.keep_fps: if roop.globals.keep_fps:
status('Restoring audio...') update_status('Restoring audio...')
else: else:
status('Restoring audio might cause issues as fps are not kept...') update_status('Restoring audio might cause issues as fps are not kept...')
restore_audio(args.target_path, args.output_path) restore_audio(roop.globals.target_path, roop.globals.output_path)
else: else:
move_temp(args.target_path, args.output_path) move_temp(roop.globals.target_path, roop.globals.output_path)
clean_temp(args.target_path) clean_temp(roop.globals.target_path)
if is_video(args.target_path): if is_video(roop.globals.target_path):
status('Swapping to video succeed!') update_status('Swapping to video succeed!')
else: else:
status('Swapping to video failed!') update_status('Swapping to video failed!')
def select_face_handler(path: str):
args.source_path = path
def select_target_handler(path: str):
args.target_path = path
return preview_video(args.target_path)
def toggle_all_faces_handler(value: int):
roop.globals.all_faces = True if value == 1 else False
def toggle_fps_limit_handler(value: int):
args.keep_fps = int(value != 1)
def toggle_keep_frames_handler(value: int):
args.keep_frames = value
def save_file_handler(path: str):
args.output_path = path
def create_test_preview(frame_number):
return process_faces(
get_face_single(cv2.imread(args.source_path)),
get_video_frame(args.target_path, frame_number)
)
def destroy() -> None: def destroy() -> None:
clean_temp(args.target_path) if roop.globals.target_path:
clean_temp(roop.globals.target_path)
quit() quit()
def run() -> None: def run() -> None:
global all_faces, keep_frames, limit_fps parse_args()
handle_parse()
pre_check() pre_check()
limit_resources() limit_resources()
if roop.globals.headless: if roop.globals.headless:
start() start()
else: else:
window = ui.init( window = ui.init(start)
{
'all_faces': args.all_faces,
'keep_fps': args.keep_fps,
'keep_frames': args.keep_frames
},
select_face_handler,
select_target_handler,
toggle_all_faces_handler,
toggle_fps_limit_handler,
toggle_keep_frames_handler,
save_file_handler,
start,
get_video_frame,
create_test_preview
)
window.mainloop() window.mainloop()

View File

@ -1,5 +1,8 @@
import onnxruntime import onnxruntime
source_path = None
target_path = None
output_path = None
keep_fps = None keep_fps = None
keep_audio = None keep_audio = None
keep_frames = None keep_frames = None
@ -7,6 +10,7 @@ all_faces = None
cpu_cores = None cpu_cores = None
gpu_threads = None gpu_threads = None
gpu_vendor = None gpu_vendor = None
max_memory = None
headless = None headless = None
log_level = 'error' log_level = 'error'
providers = onnxruntime.get_available_providers() providers = onnxruntime.get_available_providers()

View File

@ -1,11 +1,16 @@
import tkinter as tk import tkinter as tk
from typing import Any, Callable, Tuple from typing import Any, Callable, Tuple
import cv2
from PIL import Image, ImageTk, ImageOps from PIL import Image, ImageTk, ImageOps
import webbrowser import webbrowser
from tkinter import filedialog from tkinter import filedialog
from tkinter.filedialog import asksaveasfilename from tkinter.filedialog import asksaveasfilename
import threading import threading
import roop.globals
from roop.analyser import get_face_single
from roop.swapper import process_faces
from roop.utilities import is_image from roop.utilities import is_image
max_preview_size = 800 max_preview_size = 800
@ -213,23 +218,12 @@ def preview_target(frame):
target_label.image = photo_img target_label.image = photo_img
def update_status_label(value): def update_status(value):
status_label["text"] = value status_label["text"] = value
window.update() window.update()
def init( def init(start: Callable[[], None]):
initial_values: dict,
select_face_handler: Callable[[str], None],
select_target_handler: Callable[[str], Tuple[int, Any]],
toggle_all_faces_handler: Callable[[int], None],
toggle_fps_limit_handler: Callable[[int], None],
toggle_keep_frames_handler: Callable[[int], None],
save_file_handler: Callable[[str], None],
start: Callable[[], None],
get_video_frame: Callable[[str, int], None],
create_test_preview: Callable[[int], Any],
):
global window, preview, preview_visible, face_label, target_label, status_label global window, preview, preview_visible, face_label, target_label, status_label
window = tk.Tk() window = tk.Tk()
@ -274,22 +268,23 @@ def init(
target_button.place(x=360,y=320,width=180,height=80) target_button.place(x=360,y=320,width=180,height=80)
# All faces checkbox # All faces checkbox
all_faces = tk.IntVar(None, initial_values['all_faces']) all_faces = tk.IntVar(None, roop.globals.all_faces)
all_faces_checkbox = create_check(window, "Process all faces in frame", all_faces, toggle_all_faces(toggle_all_faces_handler, all_faces)) all_faces_checkbox = create_check(window, "Process all faces in frame", all_faces, toggle_all_faces(toggle_all_faces_handler, all_faces))
all_faces_checkbox.place(x=60,y=500,width=240,height=31) all_faces_checkbox.place(x=60,y=500,width=240,height=31)
# FPS limit checkbox # FPS limit checkbox
limit_fps = tk.IntVar(None, not initial_values['keep_fps']) limit_fps = tk.IntVar(None, not roop.globals.keep_fps)
fps_checkbox = create_check(window, "Limit FPS to 30", limit_fps, toggle_fps_limit(toggle_fps_limit_handler, limit_fps)) fps_checkbox = create_check(window, "Limit FPS to 30", limit_fps, toggle_fps_limit(toggle_fps_limit_handler, limit_fps))
fps_checkbox.place(x=60,y=475,width=240,height=31) fps_checkbox.place(x=60,y=475,width=240,height=31)
# Keep frames checkbox # Keep frames checkbox
keep_frames = tk.IntVar(None, initial_values['keep_frames']) keep_frames = tk.IntVar(None, roop.globals.keep_frames)
frames_checkbox = create_check(window, "Keep frames dir", keep_frames, toggle_keep_frames(toggle_keep_frames_handler, keep_frames)) frames_checkbox = create_check(window, "Keep frames dir", keep_frames, toggle_keep_frames(toggle_keep_frames_handler, keep_frames))
frames_checkbox.place(x=60,y=450,width=240,height=31) frames_checkbox.place(x=60,y=450,width=240,height=31)
# Start button # Start button
start_button = create_button(window, "Start", lambda: [save_file(save_file_handler, target_path.get()), preview_thread(lambda: start(update_preview))]) #start_button = create_button(window, "Start", lambda: [save_file(save_file_handler, target_path.get()), preview_thread(lambda: start(update_preview))])
start_button = create_button(window, "Start", lambda: [save_file(save_file_handler, target_path.get()), start])
start_button.place(x=170,y=560,width=120,height=49) start_button.place(x=170,y=560,width=120,height=49)
# Preview button # Preview button
@ -301,3 +296,63 @@ def init(
status_label.place(x=10,y=640,width=580,height=30) status_label.place(x=10,y=640,width=580,height=30)
return window return window
def get_video_frame(video_path, frame_number = 1):
cap = cv2.VideoCapture(video_path)
amount_of_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
cap.set(cv2.CAP_PROP_POS_FRAMES, min(amount_of_frames, frame_number-1))
if not cap.isOpened():
update_status('Error opening video file')
return
ret, frame = cap.read()
if ret:
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
cap.release()
def preview_video(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
update_status('Error opening video file')
return 0
amount_of_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
ret, frame = cap.read()
if ret:
frame = get_video_frame(video_path)
cap.release()
return (amount_of_frames, frame)
def select_face_handler(path: str):
roop.globals.source_path = path
def select_target_handler(target_path: str) -> None:
roop.globals.target_path = target_path
return preview_video(roop.globals.target_path)
def toggle_all_faces_handler(value: int):
roop.globals.all_faces = True if value == 1 else False
def toggle_fps_limit_handler(value: int):
roop.globals.keep_fps = int(value != 1)
def toggle_keep_frames_handler(value: int):
roop.globals.keep_frames = value
def save_file_handler(path: str):
roop.globals.output_path = path
def create_test_preview(frame_number):
return process_faces(
get_face_single(cv2.imread(roop.globals.source_path)),
get_video_frame(roop.globals.target_path, frame_number)
)