fix multiprocessing

This commit is contained in:
Somdev Sangwan 2023-06-04 17:00:57 +05:30 committed by GitHub
parent c19f8125f2
commit c36d3a2009
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,27 +2,24 @@
import os import os
import sys import sys
# single thread doubles performance of gpu-mode - needs to be set before torch import
if any(arg.startswith('--gpu-vendor=') for arg in sys.argv):
os.environ['OMP_NUM_THREADS'] = '1'
import platform import platform
import signal import signal
import shutil import shutil
import psutil
import glob import glob
import argparse import argparse
import torch import torch
from pathlib import Path from pathlib import Path
import multiprocessing as mp
from opennsfw2 import predict_video_frames, predict_image from opennsfw2 import predict_video_frames, predict_image
import cv2 import cv2
import roop.globals import roop.globals
from roop.swapper import process_video, process_img, process_faces from roop.swapper import process_video, process_img, process_faces, process_frames
from roop.utils import is_img, detect_fps, set_fps, create_video, add_audio, extract_frames, rreplace from roop.utils import is_img, detect_fps, set_fps, create_video, add_audio, extract_frames, rreplace
from roop.analyser import get_face_single from roop.analyser import get_face_single
import roop.ui as ui import roop.ui as ui
if 'ROCMExecutionProvider' in roop.globals.providers:
del torch
signal.signal(signal.SIGINT, lambda signal_number, frame: quit()) signal.signal(signal.SIGINT, lambda signal_number, frame: quit())
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -33,7 +30,7 @@ parser.add_argument('--keep-fps', help='maintain original fps', dest='keep_fps',
parser.add_argument('--keep-frames', help='keep frames directory', dest='keep_frames', action='store_true', default=False) parser.add_argument('--keep-frames', help='keep frames directory', dest='keep_frames', action='store_true', default=False)
parser.add_argument('--all-faces', help='swap all faces in frame', dest='all_faces', action='store_true', default=False) parser.add_argument('--all-faces', help='swap all faces in frame', dest='all_faces', action='store_true', default=False)
parser.add_argument('--max-memory', help='maximum amount of RAM in GB to be used', dest='max_memory', type=int) parser.add_argument('--max-memory', help='maximum amount of RAM in GB to be used', dest='max_memory', type=int)
parser.add_argument('--cpu-threads', help='number of threads to be use for CPU mode', dest='cpu_threads', type=int) parser.add_argument('--max-cores', help='number of cores to use at max', dest='max_cores', type=int)
parser.add_argument('--gpu-threads', help='number of threads to be use for GPU mode', dest='gpu_threads', type=int) parser.add_argument('--gpu-threads', help='number of threads to be use for GPU mode', dest='gpu_threads', type=int)
parser.add_argument('--gpu-vendor', help='choice your gpu vendor', dest='gpu_vendor', choices=['apple', 'amd', 'intel', 'nvidia']) parser.add_argument('--gpu-vendor', help='choice your gpu vendor', dest='gpu_vendor', choices=['apple', 'amd', 'intel', 'nvidia'])
@ -44,14 +41,18 @@ for name, value in vars(parser.parse_args()).items():
if 'all_faces' in args: if 'all_faces' in args:
roop.globals.all_faces = True roop.globals.all_faces = True
if 'cpu_threads' in args and args['cpu_threads']: if args['max_cores']:
roop.globals.cpu_threads = args['cpu_threads'] roop.globals.max_cores = args['max_cores']
else:
roop.globals.max_cores = psutil.cpu_count() - 1
if 'gpu_threads' in args and args['gpu_threads']: if args['gpu_threads']:
roop.globals.gpu_threads = args['gpu_threads'] roop.globals.gpu_threads = args['gpu_threads']
if 'gpu_vendor' in args and args['gpu_vendor']: if args['gpu_vendor']:
roop.globals.gpu_vendor = args['gpu_vendor'] roop.globals.gpu_vendor = args['gpu_vendor']
else:
roop.globals.providers = ['CPUExecutionProvider']
sep = "/" sep = "/"
if os.name == "nt": if os.name == "nt":
@ -137,6 +138,19 @@ def status(string):
ui.update_status_label(value) ui.update_status_label(value)
def process_video_multi_cores(source_img, frame_paths):
n = len(frame_paths) // roop.globals.max_cores
if n > 2:
processes = []
for i in range(0, len(frame_paths), n):
p = pool.apply_async(process_frames, args=(source_img, frame_paths[i:i+n],))
processes.append(p)
for p in processes:
p.get()
pool.close()
pool.join()
def start(preview_callback=None): def start(preview_callback=None):
if not args['source_img'] or not os.path.isfile(args['source_img']): if not args['source_img'] or not os.path.isfile(args['source_img']):
print("\n[WARNING] Please select an image containing a face.") print("\n[WARNING] Please select an image containing a face.")
@ -163,7 +177,7 @@ def start(preview_callback = None):
quit() quit()
video_name_full = target_path.split("/")[-1] video_name_full = target_path.split("/")[-1]
video_name = os.path.splitext(video_name_full)[0] video_name = os.path.splitext(video_name_full)[0]
output_dir = os.path.dirname(target_path) + "/" + video_name output_dir = os.path.dirname(target_path) + "/" + video_name if os.path.dirname(target_path) else video_name
Path(output_dir).mkdir(exist_ok=True) Path(output_dir).mkdir(exist_ok=True)
status("detecting video's FPS...") status("detecting video's FPS...")
fps, exact_fps = detect_fps(target_path) fps, exact_fps = detect_fps(target_path)
@ -180,6 +194,11 @@ def start(preview_callback = None):
key=lambda x: int(x.split(sep)[-1].replace(".png", "")) key=lambda x: int(x.split(sep)[-1].replace(".png", ""))
)) ))
status("swapping in progress...") status("swapping in progress...")
if sys.platform != 'darwin' and not args['gpu_vendor']:
global pool
pool = mp.Pool(roop.globals.max_cores)
process_video_multi_cores(args['source_img'], args['frame_paths'])
else:
process_video(args['source_img'], args["frame_paths"], preview_callback) process_video(args['source_img'], args["frame_paths"], preview_callback)
status("creating video...") status("creating video...")
create_video(video_name, exact_fps, output_dir) create_video(video_name, exact_fps, output_dir)