add warning for nvidia compute_capability

This commit is contained in:
bojiang
2024-07-02 14:49:55 +08:00
parent 7baac978fe
commit d3d08f20d6

View File

@@ -78,6 +78,7 @@ def get_local_machine_spec():
from pynvml import (
nvmlDeviceGetCount,
nvmlDeviceGetCudaComputeCapability,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo,
nvmlDeviceGetName,
@@ -98,6 +99,14 @@ def get_local_machine_spec():
model=name, memory_size=math.ceil(int(memory_info.total) / 1024**3)
)
)
compute_capability = nvmlDeviceGetCudaComputeCapability(handle)
if compute_capability < (7, 5):
output(
f"GPU {name.decode()} with compute capability {compute_capability} "
"may not be supported, 7.5 or higher is recommended. check "
"https://developer.nvidia.com/cuda-gpus for more information",
style="yellow",
)
nvmlShutdown()
return DeploymentTarget(
accelerators=accelerators,