tensorflow, jax, pytorch now working on mac metal
This commit is contained in:
parent
4d28c1779f
commit
ee6872c317
3 changed files with 24 additions and 8 deletions
|
|
@ -1,13 +1,13 @@
|
||||||
"""
|
"""
|
||||||
Unit tests for JAX implementations of JIT-suitable functions.
|
Unit tests for JAX implementations of JIT-suitable functions.
|
||||||
|
|
||||||
Tests run on CPU and CUDA devices.
|
Tests run on CPU, CUDA, and Metal (Mac) devices.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
jax = pytest.importorskip("jax")
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
from code_to_optimize.sample_code import (
|
from code_to_optimize.sample_code import (
|
||||||
|
|
@ -32,6 +32,14 @@ def get_available_devices():
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Check for Metal (Mac)
|
||||||
|
try:
|
||||||
|
metal_devices = jax.devices("METAL")
|
||||||
|
if metal_devices:
|
||||||
|
devices.append("metal")
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
return devices
|
return devices
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -44,6 +52,8 @@ def to_device(arr, device):
|
||||||
return jax.device_put(arr, jax.devices("cpu")[0])
|
return jax.device_put(arr, jax.devices("cpu")[0])
|
||||||
elif device == "cuda":
|
elif device == "cuda":
|
||||||
return jax.device_put(arr, jax.devices("gpu")[0])
|
return jax.device_put(arr, jax.devices("gpu")[0])
|
||||||
|
elif device == "metal":
|
||||||
|
return jax.device_put(arr, jax.devices("METAL")[0])
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,11 @@
|
||||||
"""
|
"""
|
||||||
Unit tests for TensorFlow implementations of JIT-suitable functions.
|
Unit tests for TensorFlow implementations of JIT-suitable functions.
|
||||||
|
|
||||||
Tests run on CPU and CUDA devices.
|
Tests run on CPU, CUDA, and Metal (Mac) devices.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import platform
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -20,10 +22,14 @@ def get_available_devices():
|
||||||
"""Return list of available TensorFlow devices for testing."""
|
"""Return list of available TensorFlow devices for testing."""
|
||||||
devices = ["cpu"]
|
devices = ["cpu"]
|
||||||
|
|
||||||
# Check for CUDA/GPU
|
# Check for GPU devices
|
||||||
gpus = tf.config.list_physical_devices("GPU")
|
gpus = tf.config.list_physical_devices("GPU")
|
||||||
if gpus:
|
if gpus:
|
||||||
devices.append("cuda")
|
# On macOS, GPUs are Metal devices; on other platforms, they're CUDA
|
||||||
|
if platform.system() == "Darwin":
|
||||||
|
devices.append("metal")
|
||||||
|
else:
|
||||||
|
devices.append("cuda")
|
||||||
|
|
||||||
return devices
|
return devices
|
||||||
|
|
||||||
|
|
@ -35,7 +41,7 @@ def run_on_device(func, device, *args, **kwargs):
|
||||||
"""Run a function on the specified device."""
|
"""Run a function on the specified device."""
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
device_name = "/CPU:0"
|
device_name = "/CPU:0"
|
||||||
elif device == "cuda":
|
elif device in ("cuda", "metal"):
|
||||||
device_name = "/GPU:0"
|
device_name = "/GPU:0"
|
||||||
else:
|
else:
|
||||||
device_name = "/CPU:0"
|
device_name = "/CPU:0"
|
||||||
|
|
@ -48,7 +54,7 @@ def to_tensor(arr, device, dtype=tf.float64):
|
||||||
"""Create a tensor on the specified device."""
|
"""Create a tensor on the specified device."""
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
device_name = "/CPU:0"
|
device_name = "/CPU:0"
|
||||||
elif device == "cuda":
|
elif device in ("cuda", "metal"):
|
||||||
device_name = "/GPU:0"
|
device_name = "/GPU:0"
|
||||||
else:
|
else:
|
||||||
device_name = "/CPU:0"
|
device_name = "/CPU:0"
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ Tests run on CPU, CUDA, and MPS devices.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
torch = pytest.importorskip("torch")
|
import torch
|
||||||
|
|
||||||
from code_to_optimize.sample_code import (
|
from code_to_optimize.sample_code import (
|
||||||
leapfrog_integration_torch,
|
leapfrog_integration_torch,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue