End-to-end done.

This commit is contained in:
RD 2024-10-31 19:08:20 -07:00
parent cafc1ecbaa
commit 5c3a2e18d0
6 changed files with 228 additions and 210 deletions

View file

@ -11,6 +11,7 @@
</inspection_tool>
<inspection_tool class="Mypy" enabled="true" level="SERVER PROBLEM" enabled_by_default="true" editorAttributes="GENERIC_SERVER_ERROR_OR_WARNING" />
<inspection_tool class="PyMissingTypeHintsInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true" />
<inspection_tool class="PyNestedDecoratorsInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
<inspection_tool class="PyPep8Inspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<option name="ignoredErrors">
<list>

View file

@ -1,154 +1,104 @@
from __future__ import annotations
import ast
import importlib.machinery
import importlib.util
import site
import sys
import sysconfig
from pathlib import Path
from typing import Iterator
from pydantic import BaseModel, field_validator
class ImportedModuleAnalysis(BaseModel, frozen=True):
name: str # TODO Crosshair: Validate that this is the basename of full_name.
origin: str # TODO Crosshair: Make this an enum. Validate what the prefix of file_path must be depending on origin.
full_name: str # TODO Crosshair: Validate that if file_path exists, it can transform into its suffix.
file_path: Path | None # TODO Crosshair: Validate that it transforms into full_name, can be None only for std lib.
# TODO CROSSHAIR Add clone of libcst path to qualified name and package function, add package info to model.
class ImportedInternalModuleAnalysis(BaseModel, frozen=True):
name: str
full_name: str
file_path: Path
@field_validator("name")
@classmethod
def name_is_identifier(cls, v: str) -> str:
if not v.isidentifier():
raise ValueError("must be an identifier")
msg = "must be an identifier"
raise ValueError(msg)
return v
# TODO CROSSHAIR Make this into a standalone function.
@field_validator("full_name")
@classmethod
def full_name_is_dotted_identifier(cls, v: str) -> str:
if any(not s or not s.isidentifier() for s in v.split(".")):
raise ValueError("must be a dotted identifier")
msg = "must be a dotted identifier"
raise ValueError(msg)
return v
@field_validator("file_path")
@classmethod
def file_path_exists(cls, v: Path | None) -> Path | None:
if v and not v.exists():
raise ValueError("must be an existing path")
msg = "must be an existing path"
raise ValueError(msg)
return v
def resolve_module_name(module: str | None, level: int, name: str, current_module_name: str) -> str | None:
def parse_imports(code: str) -> Iterator[ast.AST]:
return (node for node in ast.walk(ast.parse(code)) if isinstance(node, (ast.Import, ast.ImportFrom)))
def resolve_relative_name(module: str | None, level: int, current_module: str) -> str | None:
if level == 0:
if module:
return module # Absolute import, return module name
return name # Edge case
current_module_parts = current_module_name.split(".")
if level > len(current_module_parts):
return None # Invalid relative import
base_parts = current_module_parts[:-level]
return module
current_parts = current_module.split(".")
if level > len(current_parts):
return None
base_parts = current_parts[:-level]
if module:
base_parts.extend(module.split("."))
if name:
base_parts.append(name)
return ".".join(base_parts)
def collect_imports(node: ast.AST, module_name: str) -> set[str]:
imports: set[str] = set()
def get_module_full_name(node: ast.Import | ast.ImportFrom, current_module: str) -> Iterator[str]:
if isinstance(node, ast.Import):
imports.update(alias.name for alias in node.names)
elif isinstance(node, ast.ImportFrom):
module = node.module
level = node.level
if module is None and level > 0:
# Relative import with names
for alias in node.names:
name = alias.name
resolved_module = resolve_module_name(module, level, name, module_name)
if resolved_module:
imports.add(resolved_module)
else:
# Absolute import, collect the module without imported names
resolved_module = resolve_module_name(module, level, "", module_name)
if resolved_module:
imports.add(resolved_module)
else:
for child in ast.iter_child_nodes(node):
imports.update(collect_imports(child, module_name))
return imports
return (alias.name for alias in node.names)
if isinstance(node, ast.ImportFrom):
base_module = resolve_relative_name(node.module, node.level, current_module)
if base_module is None:
return iter(())
if node.module is None and node.level > 0:
# Relative import with no module specified, e.g., from . import mymodule
return (f"{base_module}.{alias.name}" for alias in node.names)
# For absolute imports or relative imports with module specified
return iter([base_module])
return iter(())
def categorize_module(module_name: str, project_root: Path, module_file_path: Path) -> ImportedModuleAnalysis | None:
if not module_name:
return None
# Default module search paths
try:
spec = importlib.util.find_spec(module_name)
except (ModuleNotFoundError, ImportError):
spec = None
if spec is None:
# Internal modules
custom_path_list = [str(module_file_path.parent), str(project_root), *sys.path.copy()]
try:
spec = importlib.machinery.PathFinder.find_spec(module_name, path=custom_path_list)
except (ModuleNotFoundError, ImportError):
spec = None
if spec is None:
# Module not found
origin = "unknown"
file_path = None
else:
spec_origin = spec.origin
if spec_origin is None or spec_origin in ("built-in", "frozen"):
origin = "standard library"
file_path = None
else:
file_path_path: Path = Path(spec_origin).resolve()
stdlib_paths: list[Path] = [
Path(p).resolve() for key, p in sysconfig.get_paths().items() if "stdlib" in key or key == "purelib"
]
site_packages_paths: list[Path] = [Path(p).resolve() for p in site.getsitepackages()]
if user_site_packages := site.getusersitepackages():
site_packages_paths.append(Path(user_site_packages).resolve())
# Determine the origin based on the file path
if any(file_path_path.is_relative_to(p) for p in stdlib_paths):
origin = "standard library"
elif any(file_path_path.is_relative_to(p) for p in site_packages_paths):
origin = "third party"
elif file_path_path.is_relative_to(project_root.resolve()):
origin = "internal"
else:
origin = "unknown"
file_path = file_path_path
return ImportedModuleAnalysis(
name=module_name.split(".")[-1], origin=origin, full_name=module_name, file_path=file_path
)
def is_internal_module(module_name: str, project_root: Path) -> bool:
module_path = module_name.replace(".", "/")
possible_paths = [project_root / f"{module_path}.py", project_root / module_path / "__init__.py"]
return any(path.exists() for path in possible_paths)
def analyze_imported_modules(code_str: str, module_file_path: Path, project_root: Path) -> list[ImportedModuleAnalysis]:
rel_parts: list[str] = list(module_file_path.resolve().relative_to(project_root.resolve()).with_suffix("").parts)
if rel_parts and rel_parts[-1] == "__init__":
rel_parts = rel_parts[:-1]
def get_module_file_path(module_name: str, project_root: Path) -> Path | None:
module_path = module_name.replace(".", "/")
possible_paths = [project_root / f"{module_path}.py", project_root / module_path / "__init__.py"]
for path in possible_paths:
if path.exists():
return path.resolve()
return None
imported_modules = []
for module in collect_imports(ast.parse(code_str), ".".join(rel_parts)):
info = categorize_module(module, project_root, module_file_path)
if info and info not in imported_modules:
imported_modules.append(info)
return imported_modules
def analyze_imported_internal_modules(
code_str: str, module_file_path: Path, project_root: Path
) -> list[ImportedInternalModuleAnalysis]:
"""Analyzes a Python module's code to find all imported internal modules."""
module_rel_path = module_file_path.relative_to(project_root).with_suffix("")
current_module = ".".join(module_rel_path.parts)
imports = parse_imports(code_str)
module_names = set()
for node in imports:
module_names.update(get_module_full_name(node, current_module))
internal_modules = filter(is_internal_module, (module_names, project_root))
return [
ImportedInternalModuleAnalysis(name=mod_name.split(".")[-1], full_name=mod_name, file_path=file_path)
for mod_name in internal_modules
if (file_path := get_module_file_path(mod_name, project_root)) is not None
]

View file

@ -236,7 +236,6 @@ def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]:
def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[FunctionToOptimize]]:
functions: dict[str, list[FunctionToOptimize]] = {}
module_root_path = Path(module_root_path)
for file_path in module_root_path.rglob("*.py"):
# Find all the functions in the file
functions.update(find_all_functions_in_file(file_path).items())

View file

@ -18,12 +18,17 @@ from rich.console import Group
from rich.panel import Panel
from rich.syntax import Syntax
from rich.tree import Tree
from sqlalchemy import false
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code, find_preexisting_objects
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module, replace_functions_and_add_imports
from codeflash.code_utils.code_replacer import (
normalize_code,
replace_function_definitions_in_module,
replace_functions_and_add_imports,
)
from codeflash.code_utils.code_utils import (
file_name_from_test_module_name,
get_run_tmp_file,
@ -155,17 +160,19 @@ class Optimizer:
# TODO CROSSHAIR Parse/validate ast, norm. ast, libcst. New pydantic model for code.
for function_to_optimize in file_to_funcs_to_optimize[path]:
# TODO CROSSHAIR Factor out in git tools.
# TODO CROSSHAIR Factor out in git tools. Handle no git case.
# TODO CROSSHAIR Pydantic model for worktree info: work_tree_root, work_trees. work_tree_branches
worktree_root = Path(tempfile.mkdtemp()) # TODO CROSSHAIR Check for randomness side effect issues
worktree_root: Path = Path(
tempfile.mkdtemp()
) # TODO CROSSHAIR Check for randomness side effect issues
# TODO CROSSHAIR Make more race condition collision proof. Handle 11 magic number.
worktrees = [Path(tempfile.mkdtemp(dir=worktree_root)) for _ in range(11)]
worktrees: list[Path] = [Path(tempfile.mkdtemp(dir=worktree_root)) for _ in range(N_CANDIDATES + 1)]
for worktree in worktrees:
# TODO CROSSHAIR Check for IO errors, collisions (race conditions).
worktree.mkdir()
# TODO CROSSHAIR Check for IO errors, collisions, add more chaos.
subprocess.run(
["git", "worktree", "add", "-d", worktree], cwd=self.args.project_root, check=True
["git", "worktree", "add", "-d", worktree], cwd=self.args.module_root, check=True
)
function_iterator_count += 1
@ -175,7 +182,12 @@ class Optimizer:
)
best_optimization = self.optimize_function(
function_to_optimize, function_to_tests, original_code, imported_internal_module_information
function_to_optimize,
function_to_tests,
original_code,
imported_internal_module_information,
worktree_root,
worktrees,
)
self.test_files = TestFiles(test_files=[])
@ -213,6 +225,8 @@ class Optimizer:
function_to_tests: dict[str, list[FunctionCalledInTest]],
original_code: str,
imported_internal_module_information: dict[Path, dict[str, str]],
worktree_root: Path,
worktrees: list[Path],
) -> Result[BestOptimization, str]:
should_run_experiment = self.experiment_id is not None
function_trace_id: str = str(uuid.uuid4())
@ -306,8 +320,8 @@ class Optimizer:
if candidates is None:
continue
optimized_module_code_strings = [
replace_functions_and_add_imports(
optimized_module_code_strings = {
candidate.optimization_id: replace_functions_and_add_imports(
original_code,
[function_to_optimize_qualified_name],
candidate.source_code,
@ -318,18 +332,17 @@ class Optimizer:
self.args.project_root,
)
for candidate in candidates
]
}
callee_module_paths = {callee.file_path for callee in code_context.helper_functions}
optimized_callee_modules_code_strings = [
{
optimized_callee_modules_code_strings = {
candidate.optimization_id: {
callee_module_path: replace_functions_and_add_imports(
imported_internal_module_information[callee_module_path]["code"],
(
[
callee.qualified_name
for callee in code_context.helper_functions
if callee.file_path == callee_module_path
if callee.file_path == callee_module_path and callee.jedi_definition.type != "class"
]
),
candidate.source_code,
@ -342,18 +355,108 @@ class Optimizer:
for callee_module_path in callee_module_paths
}
for candidate in candidates
}
normalized_original_code = normalize_code(original_code)
are_optimized_module_code_strings_zero_diff = {
candidate.optimization_id: normalize_code(optimized_module_code_strings[candidate.optimization_id])
!= normalized_original_code
for candidate in candidates
}
normalized_callee_module_code_strings = {
callee_module_path: normalize_code(imported_internal_module_information[callee_module_path]["code"])
for callee_module_path in callee_module_paths
}
are_optimized_callee_module_code_strings_zero_diff = {
candidate.optimization_id: {
callee_module_path: normalize_code(
optimized_callee_modules_code_strings[candidate.optimization_id][callee_module_path]
)
!= normalized_callee_module_code_strings[callee_module_path]
for callee_module_path in callee_module_paths
}
for candidate in candidates
}
candidates_with_diffs = [
candidate
for candidate in candidates
if not (
are_optimized_module_code_strings_zero_diff[candidate.optimization_id]
and all(are_optimized_callee_module_code_strings_zero_diff[candidate.optimization_id].values())
)
]
# TODO Crosshair: Write optimized modules into 10 git work trees, filter out identical normalized ast.
# TODO Crosshair: It's concolic testing time, using Python API of Crosshair.
# TODO Crosshair: Collect passes and failures, to compare later with regression verification results.
# TODO CROSSHAIR: refactor filtering + write as single function. Put under try/except block.
# TODO Crosshair: Precalculate or factor out repeated function_to_optimize.file_path.relative_to(self.args.module_root)
for candidate, worktree in zip(candidates_with_diffs, worktrees[1:]):
if are_optimized_module_code_strings_zero_diff[candidate.optimization_id]:
(worktree / function_to_optimize.file_path.relative_to(self.args.module_root)).write_text(
optimized_module_code_strings[i], encoding="utf8"
)
for callee_module_path in optimized_callee_modules_code_strings[candidate.optimization_id]:
if are_optimized_callee_module_code_strings_zero_diff[candidate.optimization_id][
callee_module_path
]:
(worktree / callee_module_path.relative_to(self.args.module_root)).write_text(
optimized_callee_modules_code_strings[i][callee_module_path], encoding="utf8"
)
# TODO Crosshair: Factor out relative path munging code, repeated.
function_to_optimize_original_worktree_fqn = str(
worktrees[0].name / function_to_optimize.file_path.relative_to(self.args.module_root).with_suffix("")
).replace("/", ".")
# TODO CROSSHAIR: Turn into enum.
diffbehavior_results: dict[str, int] = {}
for candidate_index, candidate in enumerate(candidates_with_diffs, start=1):
logger.info(f"Optimization candidate {candidate_index}/{len(candidates_with_diffs)}:")
code_print(candidate.source_code)
result = subprocess.run(
[
"crosshair",
"diffbehavior",
"--max_uninteresting_iterations",
"64",
str(
worktrees[candidate_index].name
/ function_to_optimize.file_path.relative_to(self.args.module_root).with_suffix("")
).replace("/", "."),
function_to_optimize_original_worktree_fqn,
],
capture_output=True,
text=True,
cwd=worktree_root,
check=False,
)
diffbehavior_results[candidate.optimization_id] = result.returncode
if result.returncode == 2:
logger.info("Inconclusive results from concolic behavior correctness check.")
logger.warning(
f"Error running crosshair diffbehavior{': '+ result.stderr if result.stderr else '.'}"
)
elif result.returncode == 1:
logger.info(f"Optimization candidate failed concolic behavior correctness check:\n{result.stdout}")
elif result.returncode == 0:
logger.info(
f"Optimization candidate passed concolic behavior correctness check"
f"{': \n' + result.stdout.split('\n', 1)[0] if '\n' in result.stdout else '.'}"
)
if result.stdout.endswith("All paths exhausted, functions are likely the same!\n"):
logger.info("All paths exhausted, functions are likely the same!")
else:
logger.warning("Consider increasing the --max_uninteresting_iterations option.")
else:
logger.info("Inconclusive results from concolic behavior correctness check.")
logger.error("Unknown return code running crosshair diffbehavior.")
tests_in_file: list[FunctionCalledInTest] = function_to_tests.get(
function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root), []
)
best_optimization = self.determine_best_candidate(
candidates=candidates,
candidates=candidates_with_diffs,
code_context=code_context,
function_to_optimize=function_to_optimize,
original_code=original_code,
@ -361,6 +464,7 @@ class Optimizer:
original_helper_code=original_helper_code,
function_trace_id=function_trace_id[:-4] + f"EXP{u}" if should_run_experiment else function_trace_id,
only_run_this_test_function=tests_in_file,
diffbehavior_results=diffbehavior_results,
)
ph("cli-optimize-function-finished", {"function_trace_id": function_trace_id})
@ -420,10 +524,6 @@ class Optimizer:
function_trace_id=function_trace_id,
)
if self.args.all or env_utils.get_pr_number():
# Reverting to original code, because optimizing functions in a sequence can lead to
# a) Error propagation, where error in one function can cause the next optimization to fail
# b) Performance estimates become unstable, as the runtime of an optimization might be
# dependent on the runtime of the previous optimization
self.write_code_and_helpers(original_code, original_helper_code, function_to_optimize.file_path)
for generated_test_path in generated_tests_paths:
generated_test_path.unlink(missing_ok=True)
@ -445,6 +545,7 @@ class Optimizer:
original_helper_code: dict[Path, str],
function_trace_id: str,
only_run_this_test_function: list[FunctionCalledInTest] | None = None,
diffbehavior_results: dict[str, int],
) -> BestOptimization | None:
best_optimization: BestOptimization | None = None
best_runtime_until_now = original_code_baseline.runtime # The fastest code runtime until now
@ -454,14 +555,14 @@ class Optimizer:
is_correct = {}
logger.info(
f"Determining best optimized candidate (out of {len(candidates)}) for {function_to_optimize.qualified_name} ..."
f"Determining best optimization candidate (out of {len(candidates)}) for {function_to_optimize.qualified_name} ..."
)
console.rule()
try:
for candidate_index, candidate in enumerate(candidates, start=1):
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
logger.info(f"Optimized candidate {candidate_index}/{len(candidates)}:")
logger.info(f"Optimization candidate {candidate_index}/{len(candidates)}:")
code_print(candidate.source_code)
try:
did_update = self.replace_function_and_helpers_with_optimized_code(
@ -480,18 +581,11 @@ class Optimizer:
self.write_code_and_helpers(original_code, original_helper_code, function_to_optimize.file_path)
continue
# Run generated tests if at least one of them passed
run_generated_tests = False
if original_code_baseline.generated_test_results:
for test_result in original_code_baseline.generated_test_results.test_results:
if test_result.did_pass:
run_generated_tests = True
break
run_results = self.run_optimized_candidate(
optimization_candidate_index=candidate_index,
original_test_results=original_code_baseline.overall_test_results,
tests_in_file=only_run_this_test_function,
diffbehavior_result=diffbehavior_results[candidate_index],
)
if not is_successful(run_results):
optimized_runtimes[candidate.optimization_id] = None
@ -968,6 +1062,7 @@ class Optimizer:
optimization_candidate_index: int,
original_test_results: TestResults | None,
tests_in_file: list[FunctionCalledInTest] | None,
diffbehavior_result: int,
) -> Result[OptimizedCandidateResult, str]:
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
@ -1054,17 +1149,40 @@ class Optimizer:
test_results=[result for result in original_test_results.test_results if result.loop_index == 1]
)
# TODO Crosshair: Compare to Concolic Testing results (false negative, match). Collect all comparisons.
# TODO Crosshair: Report failure if regression OR concolic failure.
if compare_test_results(initial_loop_original_test_results, initial_loop_candidate_results):
logger.info("Test results matched!")
console.rule()
equal_results = True
else:
logger.info("Test results did not match the test results of the original code.")
console.rule()
success = False
equal_results = False
console.rule()
if diffbehavior_result == 0:
logger.info("Concolic behavior correctness check successful!")
console.rule()
if equal_results:
logger.info("True negative: Concolic behavior correctness check successful and test results matched.")
else:
logger.warning(
"False negative for concolic testing: Concolic behavior correctness check successful but test results did not match."
)
console.rule()
elif diffbehavior_result == 1:
logger.warning("Concolic behavior correctness check failed.")
console.rule()
if equal_results:
logger.warning(
"False negative for regression testing: Concolic behavior correctness check failed but test results matched."
)
success = false()
equal_results = False
else:
logger.info("True positive: Concolic behavior correctness check failed and test results did not match.")
console.rule()
else:
logger.warning("Concolic behavior correctness check inconclusive.")
console.rule()
if (total_candidate_timing := candidate_results.total_passed_runtime()) == 0:
logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.")
@ -1076,7 +1194,7 @@ class Optimizer:
success = False
if not success:
return Failure("Failed to run the optimized candidate.")
return Failure("Failed to run the optimization candidate.")
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
return Success(
OptimizedCandidateResult(

View file

@ -39,6 +39,7 @@ isort = ">=5.11.0"
dill = "^0.3.8"
rich = "^13.8.1"
lxml = "^5.3.0"
crosshair-tool = ">=0.0.76"
[tool.poetry.group.dev]

View file

@ -1,6 +1,6 @@
from pathlib import Path
from codeflash.code_utils.static_analysis import ImportedModuleAnalysis, analyze_imported_modules
from codeflash.code_utils.static_analysis import ImportedInternalModuleAnalysis, analyze_imported_internal_modules
def test_analyze_imported_modules() -> None:
@ -12,14 +12,13 @@ from . import mymodule
from datetime import datetime
from pandas import DataFrame
from pathlib import *
from codeflash.code_utils.static_analysis import analyze_imported_modules
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from argparse import Namespace
from codeflash.code_utils.static_analysis import ImportedInternalModuleAnalysis
def afunction():
import datetime
def a_function():
from codeflash.code_utils.static_analysis import analyze_imported_modules
from returns.result import Failure, Success
pass
"""
@ -28,64 +27,14 @@ def afunction():
project_root = (Path(__file__).parent.resolve() / "../").resolve()
expected_imported_module_analysis = [
ImportedModuleAnalysis(
name="result",
origin="standard library",
full_name="returns.result",
file_path=Path("/Users/renaud/miniforge3/envs/codeflash312/lib/python3.12/site-packages/returns/result.py"),
),
ImportedModuleAnalysis(
name="pandas",
origin="standard library",
full_name="pandas",
file_path=Path(
"/Users/renaud/miniforge3/envs/codeflash312/lib/python3.12/site-packages/pandas/__init__.py"
),
),
ImportedModuleAnalysis(name="sys", origin="standard library", full_name="sys", file_path=None),
ImportedModuleAnalysis(
name="typing",
origin="standard library",
full_name="typing",
file_path=Path("/Users/renaud/miniforge3/envs/codeflash312/lib/python3.12/typing.py"),
),
ImportedModuleAnalysis(
name="numpy",
origin="standard library",
full_name="numpy",
file_path=Path("/Users/renaud/miniforge3/envs/codeflash312/lib/python3.12/site-packages/numpy/__init__.py"),
),
ImportedModuleAnalysis(
name="datetime",
origin="standard library",
full_name="datetime",
file_path=Path("/Users/renaud/miniforge3/envs/codeflash312/lib/python3.12/datetime.py"),
),
ImportedModuleAnalysis(
name="argparse",
origin="standard library",
full_name="argparse",
file_path=Path("/Users/renaud/miniforge3/envs/codeflash312/lib/python3.12/argparse.py"),
),
ImportedModuleAnalysis(name="os", origin="standard library", full_name="os", file_path=None),
ImportedModuleAnalysis(
ImportedInternalModuleAnalysis(
name="static_analysis",
origin="internal",
full_name="codeflash.code_utils.static_analysis",
file_path=Path("/Users/renaud/repos/codeflash/cli/codeflash/code_utils/static_analysis.py"),
file_path=project_root / Path("codeflash/code_utils/static_analysis.py"),
),
ImportedModuleAnalysis(
name="mymodule",
origin="internal",
full_name="tests.mymodule",
file_path=Path("/Users/renaud/repos/codeflash/cli/tests/mymodule.py"),
),
ImportedModuleAnalysis(
name="pathlib",
origin="standard library",
full_name="pathlib",
file_path=Path("/Users/renaud/miniforge3/envs/codeflash312/lib/python3.12/pathlib.py"),
ImportedInternalModuleAnalysis(
name="mymodule", full_name="tests.mymodule", file_path=project_root / Path("tests/mymodule.py")
),
]
actual_imported_module_analysis = analyze_imported_modules(code_str, module_file_path, project_root)
actual_imported_module_analysis = analyze_imported_internal_modules(code_str, module_file_path, project_root)
assert set(actual_imported_module_analysis) == set(expected_imported_module_analysis)