From 54f800dd0e1dcec145601b49e7f3931c910f1166 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Tue, 30 May 2023 00:40:02 +0200 Subject: [PATCH] Add GPU support, Quit on missing model, Remove globals (sorry) --- core/config.py | 3 +-- core/globals.py | 4 ---- core/processor.py | 9 +++++++-- core/utils.py | 2 ++ requirements.txt | 3 ++- run.py | 10 +++++----- 6 files changed, 17 insertions(+), 14 deletions(-) delete mode 100644 core/globals.py mode change 100644 => 100755 run.py diff --git a/core/config.py b/core/config.py index bbaad5e..df1c3ca 100644 --- a/core/config.py +++ b/core/config.py @@ -1,8 +1,7 @@ import insightface import onnxruntime -import core.globals -face_analyser = insightface.app.FaceAnalysis(name='buffalo_l', providers=core.globals.providers) +face_analyser = insightface.app.FaceAnalysis(name='buffalo_l', providers=onnxruntime.get_available_providers()) face_analyser.prepare(ctx_id=0, det_size=(640, 640)) diff --git a/core/globals.py b/core/globals.py deleted file mode 100644 index 83d7277..0000000 --- a/core/globals.py +++ /dev/null @@ -1,4 +0,0 @@ -import onnxruntime - -use_gpu = False -providers = onnxruntime.get_available_providers() diff --git a/core/processor.py b/core/processor.py index eb75975..29dc082 100644 --- a/core/processor.py +++ b/core/processor.py @@ -1,11 +1,15 @@ +import os import cv2 import insightface import onnxruntime -import core.globals from core.config import get_face from core.utils import rreplace -face_swapper = insightface.model_zoo.get_model('inswapper_128.onnx', providers=core.globals.providers) +if os.path.isfile('inswapper_128.onnx'): + face_swapper = insightface.model_zoo.get_model('inswapper_128.onnx', providers=onnxruntime.get_available_providers()) +else: + quit('File "inswapper_128.onnx" does not exist!') + def process_video(source_img, frame_paths): @@ -25,6 +29,7 @@ def process_video(source_img, frame_paths): pass print(flush=True) + def process_img(source_img, target_path): frame = cv2.imread(target_path) face = get_face(frame) diff --git a/core/utils.py b/core/utils.py index 9c5cd75..228d83b 100644 --- a/core/utils.py +++ b/core/utils.py @@ -17,6 +17,7 @@ def run_command(command, mode="silent"): return os.system(command) return os.popen(command).read() + def detect_fps(input_path): input_path = path(input_path) output = os.popen(f'ffprobe -v error -select_streams v -of default=noprint_wrappers=1:nokey=1 -show_entries stream=r_frame_rate "{input_path}"').read() @@ -58,6 +59,7 @@ def add_audio(output_dir, target_path, keep_frames, output_file): def is_img(path): return path.lower().endswith(("png", "jpg", "jpeg", "bmp")) + def rreplace(s, old, new, occurrence): li = s.rsplit(old, occurrence) return new.join(li) diff --git a/requirements.txt b/requirements.txt index b4d8baa..96153a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ numpy==1.24.3 opencv-python==4.7.0.72 onnx==1.14.0 -onnxruntime==1.15.0 +onnxruntime-gpu==1.15.0 insightface==0.7.3 psutil==5.9.5 tk==0.1.0 +torch==2.0.1 \ No newline at end of file diff --git a/run.py b/run.py old mode 100644 new mode 100755 index 45badee..a47916d --- a/run.py +++ b/run.py @@ -1,12 +1,11 @@ -import sys +#!/usr/bin/env python3 + +import torch import shutil -import core.globals if not shutil.which('ffmpeg'): print('ffmpeg is not installed. Read the docs: https://github.com/s0md3v/roop#installation.\n' * 10) quit() -if '--gpu' not in sys.argv: - core.globals.providers = ['CPUExecutionProvider'] import glob import argparse @@ -31,7 +30,8 @@ parser.add_argument('-f', '--face', help='use this face', dest='source_img') parser.add_argument('-t', '--target', help='replace this face', dest='target_path') parser.add_argument('-o', '--output', help='save output to this file', dest='output_file') parser.add_argument('--keep-fps', help='maintain original fps', dest='keep_fps', action='store_true', default=False) -parser.add_argument('--gpu', help='use gpu', dest='gpu', action='store_true', default=False) +if torch.cuda.is_available(): + parser.add_argument('--gpu', help='use gpu', dest='gpu', action='store_true', default=False) parser.add_argument('--keep-frames', help='keep frames directory', dest='keep_frames', action='store_true', default=False) for name, value in vars(parser.parse_args()).items():