multi-face support

This commit is contained in:
Somdev Sangwan 2023-06-02 14:40:32 +05:30
parent 6a694138de
commit 94e26909aa

8
run.py
View File

@ -87,8 +87,6 @@ def pre_check():
quit(f"CUDNN version {CUDNN_VERSION} is not supported - please downgrade to 8.9.1") quit(f"CUDNN version {CUDNN_VERSION} is not supported - please downgrade to 8.9.1")
else: else:
core.globals.providers = ['CPUExecutionProvider'] core.globals.providers = ['CPUExecutionProvider']
if '--all-faces' in sys.argv or '-a' in sys.argv:
core.globals.all_faces = True
def start_processing(): def start_processing():
@ -193,17 +191,19 @@ def start():
print("\n[WARNING] No face detected in source image. Please try with another one.\n") print("\n[WARNING] No face detected in source image. Please try with another one.\n")
return return
if is_img(target_path): if is_img(target_path):
if predict_image(target_path) > 0.7: if predict_image(target_path) > 0.85:
quit() quit()
process_img(args['source_img'], target_path, args['output_file']) process_img(args['source_img'], target_path, args['output_file'])
status("swap successful!") status("swap successful!")
return return
seconds, probabilities = predict_video_frames(video_path=args['target_path'], frame_interval=100) seconds, probabilities = predict_video_frames(video_path=args['target_path'], frame_interval=100)
if any(probability > 0.7 for probability in probabilities): if any(probability > 0.85 for probability in probabilities):
quit() quit()
video_name_full = target_path.split("/")[-1] video_name_full = target_path.split("/")[-1]
video_name = os.path.splitext(video_name_full)[0] video_name = os.path.splitext(video_name_full)[0]
output_dir = os.path.dirname(target_path) + "/" + video_name output_dir = os.path.dirname(target_path) + "/" + video_name
if output_dir.startswith("/"):
output_dir = "." + output_dir
Path(output_dir).mkdir(exist_ok=True) Path(output_dir).mkdir(exist_ok=True)
status("detecting video's FPS...") status("detecting video's FPS...")
fps, exact_fps = detect_fps(target_path) fps, exact_fps = detect_fps(target_path)