Move every UI related thing to ui.py

This commit is contained in:
henryruhs 2023-06-06 19:29:00 +02:00
parent 80aec1cb3a
commit 4194584854
4 changed files with 156 additions and 165 deletions

View File

@ -1,10 +1,11 @@
from typing import Any
import insightface
import roop.globals
FACE_ANALYSER = None
def get_face_analyser():
def get_face_analyser() -> Any:
global FACE_ANALYSER
if FACE_ANALYSER is None:
FACE_ANALYSER = insightface.app.FaceAnalysis(name='buffalo_l', providers=roop.globals.providers)
@ -12,16 +13,16 @@ def get_face_analyser():
return FACE_ANALYSER
def get_face_single(img_data):
face = get_face_analyser().get(img_data)
def get_face_single(image_data) -> Any:
face = get_face_analyser().get(image_data)
try:
return sorted(face, key=lambda x: x.bbox[0])[0]
except IndexError:
return min(face, key=lambda x: x.bbox[0])
except ValueError:
return None
def get_face_many(img_data):
def get_face_many(image_data) -> Any:
try:
return get_face_analyser().get(img_data)
return get_face_analyser().get(image_data)
except IndexError:
return None

View File

@ -2,10 +2,11 @@
import os
import sys
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# single thread doubles performance of gpu-mode - needs to be set before torch import
if any(arg.startswith('--gpu-vendor') for arg in sys.argv):
os.environ['OMP_NUM_THREADS'] = '1'
# reduce tensorflow log level
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import warnings
from typing import List
import platform
@ -20,15 +21,18 @@ from opennsfw2 import predict_video_frames, predict_image
import cv2
import roop.globals
from roop.swapper import process_video, process_img, process_faces
import roop.ui as ui
from roop.swapper import process_video, process_img
from roop.utilities import has_image_extention, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frames_paths, restore_audio, create_temp, move_temp, clean_temp
from roop.analyser import get_face_single
import roop.ui as ui
if 'ROCMExecutionProvider' in roop.globals.providers:
del torch
warnings.simplefilter(action='ignore', category=FutureWarning)
def handle_parse():
global args
def parse_args() -> None:
signal.signal(signal.SIGINT, lambda signal_number, frame: destroy())
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--face', help='use this face', dest='source_path')
@ -45,6 +49,9 @@ def handle_parse():
args = parser.parse_known_args()[0]
roop.globals.source_path = args.source_path
roop.globals.target_path = args.target_path
roop.globals.output_path = args.output_path
roop.globals.headless = args.source_path or args.target_path or args.output_path
roop.globals.keep_fps = args.keep_fps
roop.globals.keep_audio = args.keep_audio
@ -76,8 +83,8 @@ def limit_resources():
gpus = tensorflow.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tensorflow.config.experimental.set_memory_growth(gpu, True)
if args.max_memory:
memory = args.max_memory * 1024 * 1024 * 1024
if roop.globals.max_memory:
memory = roop.globals.max_memory * 1024 * 1024 * 1024
if str(platform.system()).lower() == 'windows':
import ctypes
kernel32 = ctypes.windll.kernel32
@ -102,58 +109,22 @@ def pre_check():
if 'ROCMExecutionProvider' not in roop.globals.providers:
quit('You are using --gpu=amd flag but ROCM is not available or properly installed on your system.')
if roop.globals.gpu_vendor == 'nvidia':
CUDA_VERSION = torch.version.cuda
CUDNN_VERSION = torch.backends.cudnn.version()
if not torch.cuda.is_available():
quit('You are using --gpu=nvidia flag but CUDA is not available or properly installed on your system.')
if CUDA_VERSION > '11.8':
quit(f'CUDA version {CUDA_VERSION} is not supported - please downgrade to 11.8')
if CUDA_VERSION < '11.4':
quit(f'CUDA version {CUDA_VERSION} is not supported - please upgrade to 11.8')
if CUDNN_VERSION < 8220:
quit(f'CUDNN version {CUDNN_VERSION} is not supported - please upgrade to 8.9.1')
if CUDNN_VERSION > 8910:
quit(f'CUDNN version {CUDNN_VERSION} is not supported - please downgrade to 8.9.1')
def get_video_frame(video_path, frame_number = 1):
cap = cv2.VideoCapture(video_path)
amount_of_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
cap.set(cv2.CAP_PROP_POS_FRAMES, min(amount_of_frames, frame_number-1))
if not cap.isOpened():
status('Error opening video file')
return
ret, frame = cap.read()
if ret:
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
cap.release()
def preview_video(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
status('Error opening video file')
return 0
amount_of_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
ret, frame = cap.read()
if ret:
frame = get_video_frame(video_path)
cap.release()
return (amount_of_frames, frame)
def status(message: str):
value = 'Status: ' + message
print(value)
if not roop.globals.headless:
ui.update_status_label(value)
if torch.version.cuda > '11.8':
quit(f'CUDA version {torch.version.cuda} is not supported - please downgrade to 11.8')
if torch.version.cuda < '11.4':
quit(f'CUDA version {torch.version.cuda} is not supported - please upgrade to 11.8')
if torch.backends.cudnn.version() < 8220:
quit(f'CUDNN version { torch.backends.cudnn.version()} is not supported - please upgrade to 8.9.1')
if torch.backends.cudnn.version() > 8910:
quit(f'CUDNN version { torch.backends.cudnn.version()} is not supported - please downgrade to 8.9.1')
def conditional_process_video(source_path: str, frame_paths: List[str]) -> None:
pool_amount = len(frame_paths) // roop.globals.cpu_cores
if pool_amount > 2 and roop.globals.cpu_cores > 1 and roop.globals.gpu_vendor is None:
status('Pool-Swapping in progress...')
update_status('Pool-Swapping in progress...')
global POOL
POOL = multiprocessing.Pool(roop.globals.cpu_cores, maxtasksperchild=1)
pools = []
@ -162,129 +133,89 @@ def conditional_process_video(source_path: str, frame_paths: List[str]) -> None:
pools.append(pool)
for pool in pools:
pool.get()
POOL.join()
POOL.close()
POOL.join()
else:
status('Swapping in progress...')
process_video(args.source_path, frame_paths)
update_status('Swapping in progress...')
process_video(roop.globals.source_path, frame_paths)
def start(preview_callback = None) -> None:
if not args.source_path or not os.path.isfile(args.source_path):
status('Please select an image containing a face.')
def update_status(message: str):
value = 'Status: ' + message
print(value)
if not roop.globals.headless:
ui.update_status(value)
def start() -> None:
if not roop.globals.source_path or not os.path.isfile(roop.globals.source_path):
update_status('Please select an image containing a face.')
return
elif not args.target_path or not os.path.isfile(args.target_path):
status('Please select a video/image target!')
elif not roop.globals.target_path or not os.path.isfile(roop.globals.target_path):
update_status('Please select a video/image target!')
return
test_face = get_face_single(cv2.imread(args.source_path))
test_face = get_face_single(cv2.imread(roop.globals.source_path))
if not test_face:
status('No face detected in source image. Please try with another one!')
update_status('No face detected in source image. Please try with another one!')
return
# process image to image
if has_image_extention(args.target_path):
if predict_image(args.target_path) > 0.85:
if has_image_extention(roop.globals.target_path):
if predict_image(roop.globals.target_path) > 0.85:
destroy()
process_img(args.source_path, args.target_path, args.output_path)
if is_image(args.target_path):
status('Swapping to image succeed!')
process_img(roop.globals.source_path, roop.globals.target_path, roop.globals.output_path)
if is_image(roop.globals.target_path):
update_status('Swapping to image succeed!')
else:
status('Swapping to image failed!')
update_status('Swapping to image failed!')
return
# process image to videos
seconds, probabilities = predict_video_frames(video_path=args.target_path, frame_interval=100)
seconds, probabilities = predict_video_frames(video_path=roop.globals.target_path, frame_interval=100)
if any(probability > 0.85 for probability in probabilities):
destroy()
status('Creating temp resources...')
create_temp(args.target_path)
status('Extracting frames...')
extract_frames(args.target_path)
frame_paths = get_temp_frames_paths(args.target_path)
conditional_process_video(args.source_path, frame_paths)
update_status('Creating temp resources...')
create_temp(roop.globals.target_path)
update_status('Extracting frames...')
extract_frames(roop.globals.target_path)
frame_paths = get_temp_frames_paths(roop.globals.target_path)
conditional_process_video(roop.globals.source_path, frame_paths)
# prevent memory leak using ffmpeg with cuda
if args.gpu_vendor == 'nvidia':
if roop.globals.gpu_vendor == 'nvidia':
torch.cuda.empty_cache()
if roop.globals.keep_fps:
status('Detecting fps...')
fps = detect_fps(args.source_path)
status(f'Creating video with {fps} fps...')
create_video(args.target_path, fps)
update_status('Detecting fps...')
fps = detect_fps(roop.globals.source_path)
update_status(f'Creating video with {fps} fps...')
create_video(roop.globals.target_path, fps)
else:
status('Creating video with 30 fps...')
create_video(args.target_path, 30)
update_status('Creating video with 30 fps...')
create_video(roop.globals.target_path, 30)
if roop.globals.keep_audio:
if roop.globals.keep_fps:
status('Restoring audio...')
update_status('Restoring audio...')
else:
status('Restoring audio might cause issues as fps are not kept...')
restore_audio(args.target_path, args.output_path)
update_status('Restoring audio might cause issues as fps are not kept...')
restore_audio(roop.globals.target_path, roop.globals.output_path)
else:
move_temp(args.target_path, args.output_path)
clean_temp(args.target_path)
if is_video(args.target_path):
status('Swapping to video succeed!')
move_temp(roop.globals.target_path, roop.globals.output_path)
clean_temp(roop.globals.target_path)
if is_video(roop.globals.target_path):
update_status('Swapping to video succeed!')
else:
status('Swapping to video failed!')
def select_face_handler(path: str):
args.source_path = path
def select_target_handler(path: str):
args.target_path = path
return preview_video(args.target_path)
def toggle_all_faces_handler(value: int):
roop.globals.all_faces = True if value == 1 else False
def toggle_fps_limit_handler(value: int):
args.keep_fps = int(value != 1)
def toggle_keep_frames_handler(value: int):
args.keep_frames = value
def save_file_handler(path: str):
args.output_path = path
def create_test_preview(frame_number):
return process_faces(
get_face_single(cv2.imread(args.source_path)),
get_video_frame(args.target_path, frame_number)
)
update_status('Swapping to video failed!')
def destroy() -> None:
clean_temp(args.target_path)
if roop.globals.target_path:
clean_temp(roop.globals.target_path)
quit()
def run() -> None:
global all_faces, keep_frames, limit_fps
handle_parse()
parse_args()
pre_check()
limit_resources()
if roop.globals.headless:
start()
else:
window = ui.init(
{
'all_faces': args.all_faces,
'keep_fps': args.keep_fps,
'keep_frames': args.keep_frames
},
select_face_handler,
select_target_handler,
toggle_all_faces_handler,
toggle_fps_limit_handler,
toggle_keep_frames_handler,
save_file_handler,
start,
get_video_frame,
create_test_preview
)
window = ui.init(start)
window.mainloop()

View File

@ -1,5 +1,8 @@
import onnxruntime
source_path = None
target_path = None
output_path = None
keep_fps = None
keep_audio = None
keep_frames = None
@ -7,6 +10,7 @@ all_faces = None
cpu_cores = None
gpu_threads = None
gpu_vendor = None
max_memory = None
headless = None
log_level = 'error'
providers = onnxruntime.get_available_providers()

View File

@ -1,11 +1,16 @@
import tkinter as tk
from typing import Any, Callable, Tuple
import cv2
from PIL import Image, ImageTk, ImageOps
import webbrowser
from tkinter import filedialog
from tkinter.filedialog import asksaveasfilename
import threading
import roop.globals
from roop.analyser import get_face_single
from roop.swapper import process_faces
from roop.utilities import is_image
max_preview_size = 800
@ -213,23 +218,12 @@ def preview_target(frame):
target_label.image = photo_img
def update_status_label(value):
def update_status(value):
status_label["text"] = value
window.update()
def init(
initial_values: dict,
select_face_handler: Callable[[str], None],
select_target_handler: Callable[[str], Tuple[int, Any]],
toggle_all_faces_handler: Callable[[int], None],
toggle_fps_limit_handler: Callable[[int], None],
toggle_keep_frames_handler: Callable[[int], None],
save_file_handler: Callable[[str], None],
start: Callable[[], None],
get_video_frame: Callable[[str, int], None],
create_test_preview: Callable[[int], Any],
):
def init(start: Callable[[], None]):
global window, preview, preview_visible, face_label, target_label, status_label
window = tk.Tk()
@ -274,22 +268,23 @@ def init(
target_button.place(x=360,y=320,width=180,height=80)
# All faces checkbox
all_faces = tk.IntVar(None, initial_values['all_faces'])
all_faces = tk.IntVar(None, roop.globals.all_faces)
all_faces_checkbox = create_check(window, "Process all faces in frame", all_faces, toggle_all_faces(toggle_all_faces_handler, all_faces))
all_faces_checkbox.place(x=60,y=500,width=240,height=31)
# FPS limit checkbox
limit_fps = tk.IntVar(None, not initial_values['keep_fps'])
limit_fps = tk.IntVar(None, not roop.globals.keep_fps)
fps_checkbox = create_check(window, "Limit FPS to 30", limit_fps, toggle_fps_limit(toggle_fps_limit_handler, limit_fps))
fps_checkbox.place(x=60,y=475,width=240,height=31)
# Keep frames checkbox
keep_frames = tk.IntVar(None, initial_values['keep_frames'])
keep_frames = tk.IntVar(None, roop.globals.keep_frames)
frames_checkbox = create_check(window, "Keep frames dir", keep_frames, toggle_keep_frames(toggle_keep_frames_handler, keep_frames))
frames_checkbox.place(x=60,y=450,width=240,height=31)
# Start button
start_button = create_button(window, "Start", lambda: [save_file(save_file_handler, target_path.get()), preview_thread(lambda: start(update_preview))])
#start_button = create_button(window, "Start", lambda: [save_file(save_file_handler, target_path.get()), preview_thread(lambda: start(update_preview))])
start_button = create_button(window, "Start", lambda: [save_file(save_file_handler, target_path.get()), start])
start_button.place(x=170,y=560,width=120,height=49)
# Preview button
@ -301,3 +296,63 @@ def init(
status_label.place(x=10,y=640,width=580,height=30)
return window
def get_video_frame(video_path, frame_number = 1):
cap = cv2.VideoCapture(video_path)
amount_of_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
cap.set(cv2.CAP_PROP_POS_FRAMES, min(amount_of_frames, frame_number-1))
if not cap.isOpened():
update_status('Error opening video file')
return
ret, frame = cap.read()
if ret:
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
cap.release()
def preview_video(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
update_status('Error opening video file')
return 0
amount_of_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
ret, frame = cap.read()
if ret:
frame = get_video_frame(video_path)
cap.release()
return (amount_of_frames, frame)
def select_face_handler(path: str):
roop.globals.source_path = path
def select_target_handler(target_path: str) -> None:
roop.globals.target_path = target_path
return preview_video(roop.globals.target_path)
def toggle_all_faces_handler(value: int):
roop.globals.all_faces = True if value == 1 else False
def toggle_fps_limit_handler(value: int):
roop.globals.keep_fps = int(value != 1)
def toggle_keep_frames_handler(value: int):
roop.globals.keep_frames = value
def save_file_handler(path: str):
roop.globals.output_path = path
def create_test_preview(frame_number):
return process_faces(
get_face_single(cv2.imread(roop.globals.source_path)),
get_video_frame(roop.globals.target_path, frame_number)
)