From a2b7394d8eb9a52c4dc9c32489f18482df1540ed Mon Sep 17 00:00:00 2001 From: Somdev Sangwan Date: Tue, 30 May 2023 07:24:10 +0530 Subject: [PATCH] don't import torch in AMD --- run.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/run.py b/run.py index d94ff6a..9ad2a13 100755 --- a/run.py +++ b/run.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import sys import time -import torch import shutil import core.globals @@ -11,6 +10,11 @@ if not shutil.which('ffmpeg'): if '--gpu' not in sys.argv: core.globals.providers = ['CPUExecutionProvider'] +if 'ROCMExecutionProvider' not in core.globals.providers: + import torch + if not torch.cuda.is_available(): + quit("You are using --gpu flag but CUDA isn't available on your system.") + import glob import argparse import multiprocessing as mp @@ -42,8 +46,6 @@ parser.add_argument('--keep-frames', help='keep frames directory', dest='keep_fr for name, value in vars(parser.parse_args()).items(): args[name] = value -if not torch.cuda.is_available() and args['gpu']: - quit("You are using --gpu flag but CUDA isn't available on your system.") sep = "/" if os.name == "nt":