diff --git a/run.py b/run.py index d94ff6a..9ad2a13 100755 --- a/run.py +++ b/run.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import sys import time -import torch import shutil import core.globals @@ -11,6 +10,11 @@ if not shutil.which('ffmpeg'): if '--gpu' not in sys.argv: core.globals.providers = ['CPUExecutionProvider'] +if 'ROCMExecutionProvider' not in core.globals.providers: + import torch + if not torch.cuda.is_available(): + quit("You are using --gpu flag but CUDA isn't available on your system.") + import glob import argparse import multiprocessing as mp @@ -42,8 +46,6 @@ parser.add_argument('--keep-frames', help='keep frames directory', dest='keep_fr for name, value in vars(parser.parse_args()).items(): args[name] = value -if not torch.cuda.is_available() and args['gpu']: - quit("You are using --gpu flag but CUDA isn't available on your system.") sep = "/" if os.name == "nt":