reworked tests
This commit is contained in:
parent
a8f5f1bfe3
commit
0d77ea4a0b
2 changed files with 372 additions and 297 deletions
|
|
@ -183,14 +183,16 @@ def extract_code_string_context_from_files(
|
|||
helpers_of_helpers_qualified_names = {
|
||||
func.qualified_name for func in helpers_of_helpers.get(file_path, set())
|
||||
}
|
||||
code_without_unused_defs = remove_unused_definitions_by_function_names(
|
||||
original_code, qualified_function_names | helpers_of_helpers_qualified_names
|
||||
)
|
||||
code_context = parse_code_and_prune_cst(
|
||||
original_code,
|
||||
code_without_unused_defs,
|
||||
code_context_type,
|
||||
qualified_function_names,
|
||||
helpers_of_helpers_qualified_names,
|
||||
remove_docstrings,
|
||||
)
|
||||
code_context = remove_unused_definitions_by_function_names(code_context, qualified_function_names | helpers_of_helpers_qualified_names)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
|
|
@ -215,10 +217,10 @@ def extract_code_string_context_from_files(
|
|||
continue
|
||||
try:
|
||||
qualified_helper_function_names = {func.qualified_name for func in helper_function_sources}
|
||||
code_without_unused_defs = remove_unused_definitions_by_function_names(original_code, qualified_helper_function_names)
|
||||
code_context = parse_code_and_prune_cst(
|
||||
original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings
|
||||
code_without_unused_defs, code_context_type, set(), qualified_helper_function_names, remove_docstrings
|
||||
)
|
||||
code_context = remove_unused_definitions_by_function_names(code_context, qualified_helper_function_names)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
|
|
@ -285,16 +287,16 @@ def extract_code_markdown_context_from_files(
|
|||
helpers_of_helpers_qualified_names = {
|
||||
func.qualified_name for func in helpers_of_helpers.get(file_path, set())
|
||||
}
|
||||
code_without_unused_defs = remove_unused_definitions_by_function_names(
|
||||
original_code, qualified_function_names | helpers_of_helpers_qualified_names
|
||||
)
|
||||
code_context = parse_code_and_prune_cst(
|
||||
original_code,
|
||||
code_without_unused_defs,
|
||||
code_context_type,
|
||||
qualified_function_names,
|
||||
helpers_of_helpers_qualified_names,
|
||||
remove_docstrings,
|
||||
)
|
||||
code_context = remove_unused_definitions_by_function_names(
|
||||
code_context, qualified_function_names | helpers_of_helpers_qualified_names
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
|
|
@ -323,11 +325,9 @@ def extract_code_markdown_context_from_files(
|
|||
continue
|
||||
try:
|
||||
qualified_helper_function_names = {func.qualified_name for func in helper_function_sources}
|
||||
code_without_unused_defs = remove_unused_definitions_by_function_names(original_code, qualified_helper_function_names)
|
||||
code_context = parse_code_and_prune_cst(
|
||||
original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings
|
||||
)
|
||||
code_context = remove_unused_definitions_by_function_names(
|
||||
code_context, qualified_helper_function_names
|
||||
code_without_unused_defs, code_context_type, set(), qualified_helper_function_names, remove_docstrings
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
|
|
|
|||
|
|
@ -1363,254 +1363,133 @@ def function_to_optimize():
|
|||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
|
||||
def test_comfy_module_import() -> None:
|
||||
code = '''
|
||||
import model_management
|
||||
def test_module_import_optimization() -> None:
|
||||
main_code = '''
|
||||
import utility_module
|
||||
|
||||
class HunyuanVideoClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
dtype_llama = model_management.pick_weight_dtype(dtype_llama, dtype, device)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
|
||||
self.dtypes = set([dtype, dtype_llama])
|
||||
class Calculator:
|
||||
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
|
||||
# This is where we use the imported module
|
||||
self.precision = utility_module.select_precision(precision, fallback_precision)
|
||||
self.mode = mode
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.clip_l.set_clip_options(options)
|
||||
self.llama.set_clip_options(options)
|
||||
# Using variables from the utility module
|
||||
self.backend = utility_module.CALCULATION_BACKEND
|
||||
self.system = utility_module.SYSTEM_TYPE
|
||||
self.default_precision = utility_module.DEFAULT_PRECISION
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.clip_l.reset_clip_options()
|
||||
self.llama.reset_clip_options()
|
||||
def add(self, a, b):
|
||||
return a + b
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_l = token_weight_pairs["l"]
|
||||
token_weight_pairs_llama = token_weight_pairs["llama"]
|
||||
def subtract(self, a, b):
|
||||
return a - b
|
||||
|
||||
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||
|
||||
template_end = 0
|
||||
extra_template_end = 0
|
||||
extra_sizes = 0
|
||||
user_end = 9999999999999
|
||||
images = []
|
||||
|
||||
tok_pairs = token_weight_pairs_llama[0]
|
||||
for i, v in enumerate(tok_pairs):
|
||||
elem = v[0]
|
||||
if not torch.is_tensor(elem):
|
||||
if isinstance(elem, numbers.Integral):
|
||||
if elem == 128006:
|
||||
if tok_pairs[i + 1][0] == 882:
|
||||
if tok_pairs[i + 2][0] == 128007:
|
||||
template_end = i + 2
|
||||
user_end = -1
|
||||
if elem == 128009 and user_end == -1:
|
||||
user_end = i + 1
|
||||
else:
|
||||
if elem.get("original_type") == "image":
|
||||
elem_size = elem.get("data").shape[0]
|
||||
if template_end > 0:
|
||||
if user_end == -1:
|
||||
extra_template_end += elem_size - 1
|
||||
else:
|
||||
image_start = i + extra_sizes
|
||||
image_end = i + elem_size + extra_sizes
|
||||
images.append((image_start, image_end, elem.get("image_interleave", 1)))
|
||||
extra_sizes += elem_size - 1
|
||||
|
||||
if llama_out.shape[1] > (template_end + 2):
|
||||
if tok_pairs[template_end + 1][0] == 271:
|
||||
template_end += 2
|
||||
llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
|
||||
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
|
||||
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
|
||||
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
|
||||
|
||||
if len(images) > 0:
|
||||
out = []
|
||||
for i in images:
|
||||
out.append(llama_out[:, i[0]: i[1]: i[2]])
|
||||
llama_output = torch.cat(out + [llama_output], dim=1)
|
||||
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
return llama_output, l_pooled, llama_extra_out
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||
return self.clip_l.load_sd(sd)
|
||||
def calculate(self, operation, x, y):
|
||||
if operation == "add":
|
||||
return self.add(x, y)
|
||||
elif operation == "subtract":
|
||||
return self.subtract(x, y)
|
||||
else:
|
||||
return self.llama.load_sd(sd)
|
||||
return None
|
||||
'''
|
||||
model_management_code = '''
|
||||
import psutil
|
||||
import logging
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import torch
|
||||
|
||||
utility_module_code = '''
|
||||
import sys
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
import logging
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
NO_VRAM = 1 #Very low vram: enable all the options to save vram
|
||||
LOW_VRAM = 2
|
||||
NORMAL_VRAM = 3
|
||||
HIGH_VRAM = 4
|
||||
SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
|
||||
DEFAULT_PRECISION = "medium"
|
||||
DEFAULT_MODE = "standard"
|
||||
|
||||
class CPUState(Enum):
|
||||
GPU = 0
|
||||
CPU = 1
|
||||
MPS = 2
|
||||
|
||||
# Determine VRAM State
|
||||
vram_state = VRAMState.NORMAL_VRAM
|
||||
set_vram_to = VRAMState.NORMAL_VRAM
|
||||
cpu_state = CPUState.GPU
|
||||
|
||||
total_vram = 0
|
||||
|
||||
def get_supported_float8_types():
|
||||
float8_types = []
|
||||
try:
|
||||
float8_types.append(torch.float8_e4m3fn)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
float8_types.append(torch.float8_e4m3fnuz)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
float8_types.append(torch.float8_e5m2)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
float8_types.append(torch.float8_e5m2fnuz)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
float8_types.append(torch.float8_e8m0fnu)
|
||||
except:
|
||||
pass
|
||||
return float8_types
|
||||
|
||||
FLOAT8_TYPES = get_supported_float8_types()
|
||||
|
||||
xpu_available = False
|
||||
torch_version = ""
|
||||
# Try-except block with variable definitions
|
||||
try:
|
||||
torch_version = torch.version.__version__
|
||||
temp = torch_version.split(".")
|
||||
torch_version_numeric = (int(temp[0]), int(temp[1]))
|
||||
xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available()
|
||||
except:
|
||||
pass
|
||||
import numpy as np
|
||||
# Used variable in try block
|
||||
CALCULATION_BACKEND = "numpy"
|
||||
# Unused variable in try block
|
||||
VECTOR_DIMENSIONS = 3
|
||||
except ImportError:
|
||||
# Used variable in except block
|
||||
CALCULATION_BACKEND = "python"
|
||||
# Unused variable in except block
|
||||
FALLBACK_WARNING = "NumPy not available, using slower Python implementation"
|
||||
|
||||
lowvram_available = True
|
||||
if args.deterministic:
|
||||
logging.info("Using deterministic algorithms for pytorch")
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
|
||||
directml_enabled = False
|
||||
if args.directml is not None:
|
||||
import torch_directml
|
||||
directml_enabled = True
|
||||
device_index = args.directml
|
||||
if device_index < 0:
|
||||
directml_device = torch_directml.device()
|
||||
# Nested if-else with variable definitions
|
||||
if sys.platform.startswith('win'):
|
||||
# Used variable in outer if
|
||||
SYSTEM_TYPE = "windows"
|
||||
if platform.architecture()[0] == '64bit':
|
||||
# Unused variable in nested if
|
||||
MEMORY_MODEL = "x64"
|
||||
else:
|
||||
directml_device = torch_directml.device(device_index)
|
||||
logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index)))
|
||||
# torch_directml.disable_tiled_resources(True)
|
||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
# Unused variable in nested else
|
||||
MEMORY_MODEL = "x86"
|
||||
elif sys.platform.startswith('linux'):
|
||||
# Used variable in outer elif
|
||||
SYSTEM_TYPE = "linux"
|
||||
# Unused variable in outer elif
|
||||
KERNEL_VERSION = platform.release()
|
||||
else:
|
||||
# Used variable in outer else
|
||||
SYSTEM_TYPE = "other"
|
||||
# Unused variable in outer else
|
||||
UNKNOWN_SYSTEM_MSG = "Running on an unrecognized platform"
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
_ = torch.xpu.device_count()
|
||||
xpu_available = xpu_available or torch.xpu.is_available()
|
||||
except:
|
||||
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||
# Function that will be used in the main code
|
||||
def select_precision(precision, fallback_precision):
|
||||
if precision is None:
|
||||
return fallback_precision or DEFAULT_PRECISION
|
||||
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
cpu_state = CPUState.MPS
|
||||
import torch.mps
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
_ = torch.npu.device_count()
|
||||
npu_available = torch.npu.is_available()
|
||||
except:
|
||||
npu_available = False
|
||||
|
||||
try:
|
||||
import torch_mlu # noqa: F401
|
||||
_ = torch.mlu.device_count()
|
||||
mlu_available = torch.mlu.is_available()
|
||||
except:
|
||||
mlu_available = False
|
||||
|
||||
if args.cpu:
|
||||
cpu_state = CPUState.CPU
|
||||
|
||||
def supports_cast(device, dtype): #TODO
|
||||
if dtype == torch.float32:
|
||||
return True
|
||||
if dtype == torch.float16:
|
||||
return True
|
||||
if directml_enabled: #TODO: test this
|
||||
return False
|
||||
if dtype == torch.bfloat16:
|
||||
return True
|
||||
if is_device_mps(device):
|
||||
return False
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return True
|
||||
if dtype == torch.float8_e5m2:
|
||||
return True
|
||||
return False
|
||||
|
||||
def pick_weight_dtype(dtype, fallback_dtype, device=None):
|
||||
if dtype is None:
|
||||
dtype = fallback_dtype
|
||||
elif dtype_size(dtype) > dtype_size(fallback_dtype):
|
||||
dtype = fallback_dtype
|
||||
|
||||
if not supports_cast(device, dtype):
|
||||
dtype = fallback_dtype
|
||||
|
||||
return dtype
|
||||
# Using the variables defined above
|
||||
if CALCULATION_BACKEND == "numpy":
|
||||
# Higher precision available with NumPy
|
||||
precision_options = ["low", "medium", "high", "ultra"]
|
||||
else:
|
||||
# Limited precision without NumPy
|
||||
precision_options = ["low", "medium", "high"]
|
||||
|
||||
if isinstance(precision, str):
|
||||
if precision.lower() not in precision_options:
|
||||
if fallback_precision:
|
||||
return fallback_precision
|
||||
else:
|
||||
return DEFAULT_PRECISION
|
||||
return precision.lower()
|
||||
else:
|
||||
return DEFAULT_PRECISION
|
||||
|
||||
# Function that won't be used
|
||||
def get_system_details():
|
||||
return {
|
||||
"system": SYSTEM_TYPE,
|
||||
"backend": CALCULATION_BACKEND,
|
||||
"default_precision": DEFAULT_PRECISION,
|
||||
"python_version": sys.version
|
||||
}
|
||||
'''
|
||||
|
||||
# Create a temporary directory instead of a single file
|
||||
# Create a temporary directory for the test
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create a package structure
|
||||
# Set up the package structure
|
||||
package_dir = Path(temp_dir) / "package"
|
||||
package_dir.mkdir()
|
||||
|
||||
# Create the __init__.py file to make it a proper package
|
||||
# Create the __init__.py file
|
||||
with open(package_dir / "__init__.py", "w") as init_file:
|
||||
init_file.write("")
|
||||
|
||||
# Write the model_management.py file
|
||||
with open(package_dir / "model_management.py", "w") as model_file:
|
||||
model_file.write(model_management_code)
|
||||
model_file.flush()
|
||||
# Write the utility_module.py file
|
||||
with open(package_dir / "utility_module.py", "w") as utility_file:
|
||||
utility_file.write(utility_module_code)
|
||||
utility_file.flush()
|
||||
|
||||
# Write the main code file that imports from model_management
|
||||
# Write the main code file
|
||||
main_file_path = package_dir / "main_module.py"
|
||||
with open(main_file_path, "w") as main_file:
|
||||
main_file.write(code)
|
||||
main_file.write(main_code)
|
||||
main_file.flush()
|
||||
|
||||
# Now set up the optimizer with the path to the main file
|
||||
# Set up the optimizer
|
||||
file_path = main_file_path.resolve()
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
|
|
@ -1624,95 +1503,291 @@ def pick_weight_dtype(dtype, fallback_dtype, device=None):
|
|||
)
|
||||
)
|
||||
|
||||
# Define the function to optimize
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="encode_token_weights",
|
||||
function_name="calculate",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="HunyuanVideoClipModel", type="ClassDef")],
|
||||
parents=[FunctionParent(name="Calculator", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
# Get the code optimization context
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
# The expected contexts
|
||||
expected_read_write_context = """
|
||||
import model_management
|
||||
import utility_module
|
||||
|
||||
class HunyuanVideoClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
dtype_llama = model_management.pick_weight_dtype(dtype_llama, dtype, device)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
|
||||
self.dtypes = set([dtype, dtype_llama])
|
||||
class Calculator:
|
||||
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
|
||||
# This is where we use the imported module
|
||||
self.precision = utility_module.select_precision(precision, fallback_precision)
|
||||
self.mode = mode
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_l = token_weight_pairs["l"]
|
||||
token_weight_pairs_llama = token_weight_pairs["llama"]
|
||||
# Using variables from the utility module
|
||||
self.backend = utility_module.CALCULATION_BACKEND
|
||||
self.system = utility_module.SYSTEM_TYPE
|
||||
self.default_precision = utility_module.DEFAULT_PRECISION
|
||||
|
||||
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||
def add(self, a, b):
|
||||
return a + b
|
||||
|
||||
template_end = 0
|
||||
extra_template_end = 0
|
||||
extra_sizes = 0
|
||||
user_end = 9999999999999
|
||||
images = []
|
||||
def subtract(self, a, b):
|
||||
return a - b
|
||||
|
||||
tok_pairs = token_weight_pairs_llama[0]
|
||||
for i, v in enumerate(tok_pairs):
|
||||
elem = v[0]
|
||||
if not torch.is_tensor(elem):
|
||||
if isinstance(elem, numbers.Integral):
|
||||
if elem == 128006:
|
||||
if tok_pairs[i + 1][0] == 882:
|
||||
if tok_pairs[i + 2][0] == 128007:
|
||||
template_end = i + 2
|
||||
user_end = -1
|
||||
if elem == 128009 and user_end == -1:
|
||||
user_end = i + 1
|
||||
else:
|
||||
if elem.get("original_type") == "image":
|
||||
elem_size = elem.get("data").shape[0]
|
||||
if template_end > 0:
|
||||
if user_end == -1:
|
||||
extra_template_end += elem_size - 1
|
||||
else:
|
||||
image_start = i + extra_sizes
|
||||
image_end = i + elem_size + extra_sizes
|
||||
images.append((image_start, image_end, elem.get("image_interleave", 1)))
|
||||
extra_sizes += elem_size - 1
|
||||
|
||||
if llama_out.shape[1] > (template_end + 2):
|
||||
if tok_pairs[template_end + 1][0] == 271:
|
||||
template_end += 2
|
||||
llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
|
||||
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
|
||||
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
|
||||
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
|
||||
|
||||
if len(images) > 0:
|
||||
out = []
|
||||
for i in images:
|
||||
out.append(llama_out[:, i[0]: i[1]: i[2]])
|
||||
llama_output = torch.cat(out + [llama_output], dim=1)
|
||||
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
return llama_output, l_pooled, llama_extra_out
|
||||
def calculate(self, operation, x, y):
|
||||
if operation == "add":
|
||||
return self.add(x, y)
|
||||
elif operation == "subtract":
|
||||
return self.subtract(x, y)
|
||||
else:
|
||||
return None
|
||||
"""
|
||||
expected_read_only_context = """
|
||||
```python:model_management.py
|
||||
# Determine VRAM State
|
||||
```python:utility_module.py
|
||||
DEFAULT_PRECISION = "medium"
|
||||
|
||||
# Try-except block with variable definitions
|
||||
try:
|
||||
# Used variable in try block
|
||||
CALCULATION_BACKEND = "numpy"
|
||||
except ImportError:
|
||||
# Used variable in except block
|
||||
CALCULATION_BACKEND = "python"
|
||||
|
||||
# Function that will be used in the main code
|
||||
def select_precision(precision, fallback_precision):
|
||||
if precision is None:
|
||||
return fallback_precision or DEFAULT_PRECISION
|
||||
|
||||
# Using the variables defined above
|
||||
if CALCULATION_BACKEND == "numpy":
|
||||
# Higher precision available with NumPy
|
||||
precision_options = ["low", "medium", "high", "ultra"]
|
||||
else:
|
||||
# Limited precision without NumPy
|
||||
precision_options = ["low", "medium", "high"]
|
||||
|
||||
if isinstance(precision, str):
|
||||
if precision.lower() not in precision_options:
|
||||
if fallback_precision:
|
||||
return fallback_precision
|
||||
else:
|
||||
return DEFAULT_PRECISION
|
||||
return precision.lower()
|
||||
else:
|
||||
return DEFAULT_PRECISION
|
||||
```
|
||||
"""
|
||||
# Verify the contexts match the expected values
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
|
||||
def test_module_import_init_fto() -> None:
|
||||
main_code = '''
|
||||
import utility_module
|
||||
|
||||
class Calculator:
|
||||
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
|
||||
# This is where we use the imported module
|
||||
self.precision = utility_module.select_precision(precision, fallback_precision)
|
||||
self.mode = mode
|
||||
|
||||
# Using variables from the utility module
|
||||
self.backend = utility_module.CALCULATION_BACKEND
|
||||
self.system = utility_module.SYSTEM_TYPE
|
||||
self.default_precision = utility_module.DEFAULT_PRECISION
|
||||
|
||||
def add(self, a, b):
|
||||
return a + b
|
||||
|
||||
def subtract(self, a, b):
|
||||
return a - b
|
||||
|
||||
def calculate(self, operation, x, y):
|
||||
if operation == "add":
|
||||
return self.add(x, y)
|
||||
elif operation == "subtract":
|
||||
return self.subtract(x, y)
|
||||
else:
|
||||
return None
|
||||
'''
|
||||
|
||||
utility_module_code = '''
|
||||
import sys
|
||||
import platform
|
||||
import logging
|
||||
|
||||
DEFAULT_PRECISION = "medium"
|
||||
DEFAULT_MODE = "standard"
|
||||
|
||||
# Try-except block with variable definitions
|
||||
try:
|
||||
import numpy as np
|
||||
# Used variable in try block
|
||||
CALCULATION_BACKEND = "numpy"
|
||||
# Unused variable in try block
|
||||
VECTOR_DIMENSIONS = 3
|
||||
except ImportError:
|
||||
# Used variable in except block
|
||||
CALCULATION_BACKEND = "python"
|
||||
# Unused variable in except block
|
||||
FALLBACK_WARNING = "NumPy not available, using slower Python implementation"
|
||||
|
||||
# Nested if-else with variable definitions
|
||||
if sys.platform.startswith('win'):
|
||||
# Used variable in outer if
|
||||
SYSTEM_TYPE = "windows"
|
||||
if platform.architecture()[0] == '64bit':
|
||||
# Unused variable in nested if
|
||||
MEMORY_MODEL = "x64"
|
||||
else:
|
||||
# Unused variable in nested else
|
||||
MEMORY_MODEL = "x86"
|
||||
elif sys.platform.startswith('linux'):
|
||||
# Used variable in outer elif
|
||||
SYSTEM_TYPE = "linux"
|
||||
# Unused variable in outer elif
|
||||
KERNEL_VERSION = platform.release()
|
||||
else:
|
||||
# Used variable in outer else
|
||||
SYSTEM_TYPE = "other"
|
||||
# Unused variable in outer else
|
||||
UNKNOWN_SYSTEM_MSG = "Running on an unrecognized platform"
|
||||
|
||||
# Function that will be used in the main code
|
||||
def select_precision(precision, fallback_precision):
|
||||
if precision is None:
|
||||
return fallback_precision or DEFAULT_PRECISION
|
||||
|
||||
# Using the variables defined above
|
||||
if CALCULATION_BACKEND == "numpy":
|
||||
# Higher precision available with NumPy
|
||||
precision_options = ["low", "medium", "high", "ultra"]
|
||||
else:
|
||||
# Limited precision without NumPy
|
||||
precision_options = ["low", "medium", "high"]
|
||||
|
||||
if isinstance(precision, str):
|
||||
if precision.lower() not in precision_options:
|
||||
if fallback_precision:
|
||||
return fallback_precision
|
||||
else:
|
||||
return DEFAULT_PRECISION
|
||||
return precision.lower()
|
||||
else:
|
||||
return DEFAULT_PRECISION
|
||||
|
||||
# Function that won't be used
|
||||
def get_system_details():
|
||||
return {
|
||||
"system": SYSTEM_TYPE,
|
||||
"backend": CALCULATION_BACKEND,
|
||||
"default_precision": DEFAULT_PRECISION,
|
||||
"python_version": sys.version
|
||||
}
|
||||
'''
|
||||
|
||||
# Create a temporary directory for the test
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Set up the package structure
|
||||
package_dir = Path(temp_dir) / "package"
|
||||
package_dir.mkdir()
|
||||
|
||||
# Create the __init__.py file
|
||||
with open(package_dir / "__init__.py", "w") as init_file:
|
||||
init_file.write("")
|
||||
|
||||
# Write the utility_module.py file
|
||||
with open(package_dir / "utility_module.py", "w") as utility_file:
|
||||
utility_file.write(utility_module_code)
|
||||
utility_file.flush()
|
||||
|
||||
# Write the main code file
|
||||
main_file_path = package_dir / "main_module.py"
|
||||
with open(main_file_path, "w") as main_file:
|
||||
main_file.write(main_code)
|
||||
main_file.flush()
|
||||
|
||||
# Set up the optimizer
|
||||
file_path = main_file_path.resolve()
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=package_dir.resolve(),
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=Path().resolve(),
|
||||
)
|
||||
)
|
||||
|
||||
# Define the function to optimize
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="__init__",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="Calculator", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
# Get the code optimization context
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
# The expected contexts
|
||||
expected_read_write_context = """
|
||||
# Function that will be used in the main code
|
||||
|
||||
import utility_module
|
||||
|
||||
def select_precision(precision, fallback_precision):
|
||||
if precision is None:
|
||||
return fallback_precision or DEFAULT_PRECISION
|
||||
|
||||
# Using the variables defined above
|
||||
if CALCULATION_BACKEND == "numpy":
|
||||
# Higher precision available with NumPy
|
||||
precision_options = ["low", "medium", "high", "ultra"]
|
||||
else:
|
||||
# Limited precision without NumPy
|
||||
precision_options = ["low", "medium", "high"]
|
||||
|
||||
if isinstance(precision, str):
|
||||
if precision.lower() not in precision_options:
|
||||
if fallback_precision:
|
||||
return fallback_precision
|
||||
else:
|
||||
return DEFAULT_PRECISION
|
||||
return precision.lower()
|
||||
else:
|
||||
return DEFAULT_PRECISION
|
||||
|
||||
|
||||
def pick_weight_dtype(dtype, fallback_dtype, device=None):
|
||||
if dtype is None:
|
||||
dtype = fallback_dtype
|
||||
elif dtype_size(dtype) > dtype_size(fallback_dtype):
|
||||
dtype = fallback_dtype
|
||||
|
||||
if not supports_cast(device, dtype):
|
||||
dtype = fallback_dtype
|
||||
class Calculator:
|
||||
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
|
||||
# This is where we use the imported module
|
||||
self.precision = utility_module.select_precision(precision, fallback_precision)
|
||||
self.mode = mode
|
||||
|
||||
return dtype
|
||||
# Using variables from the utility module
|
||||
self.backend = utility_module.CALCULATION_BACKEND
|
||||
self.system = utility_module.SYSTEM_TYPE
|
||||
self.default_precision = utility_module.DEFAULT_PRECISION
|
||||
"""
|
||||
expected_read_only_context = """
|
||||
```python:utility_module.py
|
||||
DEFAULT_PRECISION = "medium"
|
||||
|
||||
# Try-except block with variable definitions
|
||||
try:
|
||||
# Used variable in try block
|
||||
CALCULATION_BACKEND = "numpy"
|
||||
except ImportError:
|
||||
# Used variable in except block
|
||||
CALCULATION_BACKEND = "python"
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
|
|
|
|||
Loading…
Reference in a new issue