tensorflow, jax, pytorch now working on mac metal

This commit is contained in:
aseembits93 2026-01-15 19:04:31 -08:00
parent 4d28c1779f
commit ee6872c317
3 changed files with 24 additions and 8 deletions

View file

@ -1,13 +1,13 @@
""" """
Unit tests for JAX implementations of JIT-suitable functions. Unit tests for JAX implementations of JIT-suitable functions.
Tests run on CPU and CUDA devices. Tests run on CPU, CUDA, and Metal (Mac) devices.
""" """
import numpy as np import numpy as np
import pytest import pytest
jax = pytest.importorskip("jax") import jax
import jax.numpy as jnp import jax.numpy as jnp
from code_to_optimize.sample_code import ( from code_to_optimize.sample_code import (
@ -32,6 +32,14 @@ def get_available_devices():
except RuntimeError: except RuntimeError:
pass pass
# Check for Metal (Mac)
try:
metal_devices = jax.devices("METAL")
if metal_devices:
devices.append("metal")
except RuntimeError:
pass
return devices return devices
@ -44,6 +52,8 @@ def to_device(arr, device):
return jax.device_put(arr, jax.devices("cpu")[0]) return jax.device_put(arr, jax.devices("cpu")[0])
elif device == "cuda": elif device == "cuda":
return jax.device_put(arr, jax.devices("gpu")[0]) return jax.device_put(arr, jax.devices("gpu")[0])
elif device == "metal":
return jax.device_put(arr, jax.devices("METAL")[0])
return arr return arr

View file

@ -1,9 +1,11 @@
""" """
Unit tests for TensorFlow implementations of JIT-suitable functions. Unit tests for TensorFlow implementations of JIT-suitable functions.
Tests run on CPU and CUDA devices. Tests run on CPU, CUDA, and Metal (Mac) devices.
""" """
import platform
import numpy as np import numpy as np
import pytest import pytest
@ -20,10 +22,14 @@ def get_available_devices():
"""Return list of available TensorFlow devices for testing.""" """Return list of available TensorFlow devices for testing."""
devices = ["cpu"] devices = ["cpu"]
# Check for CUDA/GPU # Check for GPU devices
gpus = tf.config.list_physical_devices("GPU") gpus = tf.config.list_physical_devices("GPU")
if gpus: if gpus:
devices.append("cuda") # On macOS, GPUs are Metal devices; on other platforms, they're CUDA
if platform.system() == "Darwin":
devices.append("metal")
else:
devices.append("cuda")
return devices return devices
@ -35,7 +41,7 @@ def run_on_device(func, device, *args, **kwargs):
"""Run a function on the specified device.""" """Run a function on the specified device."""
if device == "cpu": if device == "cpu":
device_name = "/CPU:0" device_name = "/CPU:0"
elif device == "cuda": elif device in ("cuda", "metal"):
device_name = "/GPU:0" device_name = "/GPU:0"
else: else:
device_name = "/CPU:0" device_name = "/CPU:0"
@ -48,7 +54,7 @@ def to_tensor(arr, device, dtype=tf.float64):
"""Create a tensor on the specified device.""" """Create a tensor on the specified device."""
if device == "cpu": if device == "cpu":
device_name = "/CPU:0" device_name = "/CPU:0"
elif device == "cuda": elif device in ("cuda", "metal"):
device_name = "/GPU:0" device_name = "/GPU:0"
else: else:
device_name = "/CPU:0" device_name = "/CPU:0"

View file

@ -7,7 +7,7 @@ Tests run on CPU, CUDA, and MPS devices.
import numpy as np import numpy as np
import pytest import pytest
torch = pytest.importorskip("torch") import torch
from code_to_optimize.sample_code import ( from code_to_optimize.sample_code import (
leapfrog_integration_torch, leapfrog_integration_torch,