don't import torch in AMD
This commit is contained in:
8
run.py
8
run.py
@@ -1,7 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import torch
|
|
||||||
import shutil
|
import shutil
|
||||||
import core.globals
|
import core.globals
|
||||||
|
|
||||||
@@ -11,6 +10,11 @@ if not shutil.which('ffmpeg'):
|
|||||||
if '--gpu' not in sys.argv:
|
if '--gpu' not in sys.argv:
|
||||||
core.globals.providers = ['CPUExecutionProvider']
|
core.globals.providers = ['CPUExecutionProvider']
|
||||||
|
|
||||||
|
if 'ROCMExecutionProvider' not in core.globals.providers:
|
||||||
|
import torch
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
quit("You are using --gpu flag but CUDA isn't available on your system.")
|
||||||
|
|
||||||
import glob
|
import glob
|
||||||
import argparse
|
import argparse
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
@@ -42,8 +46,6 @@ parser.add_argument('--keep-frames', help='keep frames directory', dest='keep_fr
|
|||||||
for name, value in vars(parser.parse_args()).items():
|
for name, value in vars(parser.parse_args()).items():
|
||||||
args[name] = value
|
args[name] = value
|
||||||
|
|
||||||
if not torch.cuda.is_available() and args['gpu']:
|
|
||||||
quit("You are using --gpu flag but CUDA isn't available on your system.")
|
|
||||||
|
|
||||||
sep = "/"
|
sep = "/"
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
|
|||||||
Reference in New Issue
Block a user