56 lines
1.4 KiB
Python
56 lines
1.4 KiB
Python
|
import os
|
||
|
import random
|
||
|
import unittest
|
||
|
from distutils.util import strtobool
|
||
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
global_rng = random.Random()
|
||
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
|
||
|
|
||
|
def parse_flag_from_env(key, default=False):
|
||
|
try:
|
||
|
value = os.environ[key]
|
||
|
except KeyError:
|
||
|
# KEY isn't set, default to `default`.
|
||
|
_value = default
|
||
|
else:
|
||
|
# KEY is set, convert it to True or False.
|
||
|
try:
|
||
|
_value = strtobool(value)
|
||
|
except ValueError:
|
||
|
# More values are supported, but let's keep the message simple.
|
||
|
raise ValueError(f"If set, {key} must be yes or no.")
|
||
|
return _value
|
||
|
|
||
|
|
||
|
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
||
|
|
||
|
|
||
|
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||
|
"""Creates a random float32 tensor"""
|
||
|
if rng is None:
|
||
|
rng = global_rng
|
||
|
|
||
|
total_dims = 1
|
||
|
for dim in shape:
|
||
|
total_dims *= dim
|
||
|
|
||
|
values = []
|
||
|
for _ in range(total_dims):
|
||
|
values.append(rng.random() * scale)
|
||
|
|
||
|
return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
|
||
|
|
||
|
|
||
|
def slow(test_case):
|
||
|
"""
|
||
|
Decorator marking a test as slow.
|
||
|
|
||
|
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
|
||
|
|
||
|
"""
|
||
|
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|