tensorflow, jax, pytorch now working on mac metal

This commit is contained in:
aseembits93 2026-01-15 19:04:31 -08:00
parent 4d28c1779f
commit ee6872c317
3 changed files with 24 additions and 8 deletions

View file

@ -1,13 +1,13 @@
"""
Unit tests for JAX implementations of JIT-suitable functions.
Tests run on CPU and CUDA devices.
Tests run on CPU, CUDA, and Metal (Mac) devices.
"""
import numpy as np
import pytest
jax = pytest.importorskip("jax")
import jax
import jax.numpy as jnp
from code_to_optimize.sample_code import (
@ -32,6 +32,14 @@ def get_available_devices():
except RuntimeError:
pass
# Check for Metal (Mac)
try:
metal_devices = jax.devices("METAL")
if metal_devices:
devices.append("metal")
except RuntimeError:
pass
return devices
@ -44,6 +52,8 @@ def to_device(arr, device):
return jax.device_put(arr, jax.devices("cpu")[0])
elif device == "cuda":
return jax.device_put(arr, jax.devices("gpu")[0])
elif device == "metal":
return jax.device_put(arr, jax.devices("METAL")[0])
return arr

View file

@ -1,9 +1,11 @@
"""
Unit tests for TensorFlow implementations of JIT-suitable functions.
Tests run on CPU and CUDA devices.
Tests run on CPU, CUDA, and Metal (Mac) devices.
"""
import platform
import numpy as np
import pytest
@ -20,10 +22,14 @@ def get_available_devices():
"""Return list of available TensorFlow devices for testing."""
devices = ["cpu"]
# Check for CUDA/GPU
# Check for GPU devices
gpus = tf.config.list_physical_devices("GPU")
if gpus:
devices.append("cuda")
# On macOS, GPUs are Metal devices; on other platforms, they're CUDA
if platform.system() == "Darwin":
devices.append("metal")
else:
devices.append("cuda")
return devices
@ -35,7 +41,7 @@ def run_on_device(func, device, *args, **kwargs):
"""Run a function on the specified device."""
if device == "cpu":
device_name = "/CPU:0"
elif device == "cuda":
elif device in ("cuda", "metal"):
device_name = "/GPU:0"
else:
device_name = "/CPU:0"
@ -48,7 +54,7 @@ def to_tensor(arr, device, dtype=tf.float64):
"""Create a tensor on the specified device."""
if device == "cpu":
device_name = "/CPU:0"
elif device == "cuda":
elif device in ("cuda", "metal"):
device_name = "/GPU:0"
else:
device_name = "/CPU:0"

View file

@ -7,7 +7,7 @@ Tests run on CPU, CUDA, and MPS devices.
import numpy as np
import pytest
torch = pytest.importorskip("torch")
import torch
from code_to_optimize.sample_code import (
leapfrog_integration_torch,