mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
tensorflow, jax, pytorch now working on mac metal
This commit is contained in:
parent
4d28c1779f
commit
ee6872c317
3 changed files with 24 additions and 8 deletions
|
|
@ -1,13 +1,13 @@
|
|||
"""
|
||||
Unit tests for JAX implementations of JIT-suitable functions.
|
||||
|
||||
Tests run on CPU and CUDA devices.
|
||||
Tests run on CPU, CUDA, and Metal (Mac) devices.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
jax = pytest.importorskip("jax")
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from code_to_optimize.sample_code import (
|
||||
|
|
@ -32,6 +32,14 @@ def get_available_devices():
|
|||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# Check for Metal (Mac)
|
||||
try:
|
||||
metal_devices = jax.devices("METAL")
|
||||
if metal_devices:
|
||||
devices.append("metal")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
return devices
|
||||
|
||||
|
||||
|
|
@ -44,6 +52,8 @@ def to_device(arr, device):
|
|||
return jax.device_put(arr, jax.devices("cpu")[0])
|
||||
elif device == "cuda":
|
||||
return jax.device_put(arr, jax.devices("gpu")[0])
|
||||
elif device == "metal":
|
||||
return jax.device_put(arr, jax.devices("METAL")[0])
|
||||
return arr
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
"""
|
||||
Unit tests for TensorFlow implementations of JIT-suitable functions.
|
||||
|
||||
Tests run on CPU and CUDA devices.
|
||||
Tests run on CPU, CUDA, and Metal (Mac) devices.
|
||||
"""
|
||||
|
||||
import platform
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
|
@ -20,10 +22,14 @@ def get_available_devices():
|
|||
"""Return list of available TensorFlow devices for testing."""
|
||||
devices = ["cpu"]
|
||||
|
||||
# Check for CUDA/GPU
|
||||
# Check for GPU devices
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
if gpus:
|
||||
devices.append("cuda")
|
||||
# On macOS, GPUs are Metal devices; on other platforms, they're CUDA
|
||||
if platform.system() == "Darwin":
|
||||
devices.append("metal")
|
||||
else:
|
||||
devices.append("cuda")
|
||||
|
||||
return devices
|
||||
|
||||
|
|
@ -35,7 +41,7 @@ def run_on_device(func, device, *args, **kwargs):
|
|||
"""Run a function on the specified device."""
|
||||
if device == "cpu":
|
||||
device_name = "/CPU:0"
|
||||
elif device == "cuda":
|
||||
elif device in ("cuda", "metal"):
|
||||
device_name = "/GPU:0"
|
||||
else:
|
||||
device_name = "/CPU:0"
|
||||
|
|
@ -48,7 +54,7 @@ def to_tensor(arr, device, dtype=tf.float64):
|
|||
"""Create a tensor on the specified device."""
|
||||
if device == "cpu":
|
||||
device_name = "/CPU:0"
|
||||
elif device == "cuda":
|
||||
elif device in ("cuda", "metal"):
|
||||
device_name = "/GPU:0"
|
||||
else:
|
||||
device_name = "/CPU:0"
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ Tests run on CPU, CUDA, and MPS devices.
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
torch = pytest.importorskip("torch")
|
||||
import torch
|
||||
|
||||
from code_to_optimize.sample_code import (
|
||||
leapfrog_integration_torch,
|
||||
|
|
|
|||
Loading…
Reference in a new issue