mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
temp
This commit is contained in:
parent
26989b2602
commit
d480061064
2 changed files with 172 additions and 8 deletions
|
|
@ -752,17 +752,30 @@ def extract_init_stub_from_class(class_name: str, module_source: str, module_tre
|
|||
if class_node is None:
|
||||
return None
|
||||
|
||||
init_node = None
|
||||
lines = module_source.splitlines()
|
||||
relevant_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = []
|
||||
for item in class_node.body:
|
||||
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == "__init__":
|
||||
init_node = item
|
||||
break
|
||||
if init_node is None:
|
||||
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
if item.name in ("__init__", "__post_init__"):
|
||||
relevant_nodes.append(item)
|
||||
elif any(
|
||||
isinstance(d, ast.Name) and d.id == "property"
|
||||
or isinstance(d, ast.Attribute) and d.attr == "property"
|
||||
for d in item.decorator_list
|
||||
):
|
||||
relevant_nodes.append(item)
|
||||
|
||||
if not relevant_nodes:
|
||||
return None
|
||||
|
||||
lines = module_source.splitlines()
|
||||
init_source = "\n".join(lines[init_node.lineno - 1 : init_node.end_lineno])
|
||||
return f"class {class_name}:\n{init_source}"
|
||||
snippets: list[str] = []
|
||||
for node in relevant_nodes:
|
||||
start = node.lineno
|
||||
if node.decorator_list:
|
||||
start = min(d.lineno for d in node.decorator_list)
|
||||
snippets.append("\n".join(lines[start - 1 : node.end_lineno]))
|
||||
|
||||
return f"class {class_name}:\n" + "\n".join(snippets)
|
||||
|
||||
|
||||
def extract_parameter_type_constructors(
|
||||
|
|
@ -844,6 +857,27 @@ def extract_parameter_type_constructors(
|
|||
return CodeStringsMarkdown(code_strings=code_strings)
|
||||
|
||||
|
||||
def resolve_instance_class_name(name: str, module_tree: ast.Module) -> str | None:
|
||||
for node in module_tree.body:
|
||||
if isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and target.id == name:
|
||||
value = node.value
|
||||
if isinstance(value, ast.Call):
|
||||
func = value.func
|
||||
if isinstance(func, ast.Name):
|
||||
return func.id
|
||||
if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
|
||||
return func.value.id
|
||||
elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and node.target.id == name:
|
||||
ann = node.annotation
|
||||
if isinstance(ann, ast.Name):
|
||||
return ann.id
|
||||
if isinstance(ann, ast.Subscript) and isinstance(ann.value, ast.Name):
|
||||
return ann.value.id
|
||||
return None
|
||||
|
||||
|
||||
def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown:
|
||||
import jedi
|
||||
|
||||
|
|
@ -938,6 +972,11 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
|
|||
|
||||
extract_class_and_bases(name, module_path, module_source, module_tree)
|
||||
|
||||
if (module_path, name) not in extracted_classes:
|
||||
resolved_class = resolve_instance_class_name(name, module_tree)
|
||||
if resolved_class and resolved_class not in existing_classes:
|
||||
extract_class_and_bases(resolved_class, module_path, module_source, module_tree)
|
||||
|
||||
except Exception:
|
||||
logger.debug(f"Error extracting class definition for {name} from {module_name}")
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from codeflash.languages.python.context.code_context_extractor import (
|
|||
extract_init_stub_from_class,
|
||||
extract_parameter_type_constructors,
|
||||
get_code_optimization_context,
|
||||
resolve_instance_class_name,
|
||||
)
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
|
@ -4339,3 +4340,127 @@ def process(c: Config) -> str:
|
|||
)
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 0
|
||||
|
||||
|
||||
# --- Tests for resolve_instance_class_name ---
|
||||
|
||||
|
||||
def test_resolve_instance_class_name_direct_call() -> None:
|
||||
source = "config = MyConfig(debug=True)"
|
||||
tree = ast.parse(source)
|
||||
assert resolve_instance_class_name("config", tree) == "MyConfig"
|
||||
|
||||
|
||||
def test_resolve_instance_class_name_annotated() -> None:
|
||||
source = "config: MyConfig = load()"
|
||||
tree = ast.parse(source)
|
||||
assert resolve_instance_class_name("config", tree) == "MyConfig"
|
||||
|
||||
|
||||
def test_resolve_instance_class_name_factory_method() -> None:
|
||||
source = "config = MyConfig.from_env()"
|
||||
tree = ast.parse(source)
|
||||
assert resolve_instance_class_name("config", tree) == "MyConfig"
|
||||
|
||||
|
||||
def test_resolve_instance_class_name_no_match() -> None:
|
||||
source = "x = 42"
|
||||
tree = ast.parse(source)
|
||||
assert resolve_instance_class_name("x", tree) is None
|
||||
|
||||
|
||||
def test_resolve_instance_class_name_missing_variable() -> None:
|
||||
source = "config = MyConfig()"
|
||||
tree = ast.parse(source)
|
||||
assert resolve_instance_class_name("other", tree) is None
|
||||
|
||||
|
||||
# --- Tests for enhanced extract_init_stub_from_class ---
|
||||
|
||||
|
||||
def test_extract_init_stub_includes_post_init() -> None:
|
||||
source = """\
|
||||
class MyDataclass:
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
def __post_init__(self):
|
||||
self.y = self.x * 2
|
||||
"""
|
||||
tree = ast.parse(source)
|
||||
stub = extract_init_stub_from_class("MyDataclass", source, tree)
|
||||
assert stub is not None
|
||||
assert "class MyDataclass:" in stub
|
||||
assert "def __init__" in stub
|
||||
assert "def __post_init__" in stub
|
||||
assert "self.y = self.x * 2" in stub
|
||||
|
||||
|
||||
def test_extract_init_stub_includes_properties() -> None:
|
||||
source = """\
|
||||
class MyClass:
|
||||
def __init__(self, name: str):
|
||||
self._name = name
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
"""
|
||||
tree = ast.parse(source)
|
||||
stub = extract_init_stub_from_class("MyClass", source, tree)
|
||||
assert stub is not None
|
||||
assert "def __init__" in stub
|
||||
assert "@property" in stub
|
||||
assert "def name" in stub
|
||||
|
||||
|
||||
def test_extract_init_stub_property_only_class() -> None:
|
||||
source = """\
|
||||
class ReadOnly:
|
||||
@property
|
||||
def value(self) -> int:
|
||||
return 42
|
||||
"""
|
||||
tree = ast.parse(source)
|
||||
stub = extract_init_stub_from_class("ReadOnly", source, tree)
|
||||
assert stub is not None
|
||||
assert "class ReadOnly:" in stub
|
||||
assert "@property" in stub
|
||||
assert "def value" in stub
|
||||
|
||||
|
||||
# --- Tests for enrich_testgen_context resolving instances ---
|
||||
|
||||
|
||||
def test_enrich_testgen_context_resolves_instance_to_class(tmp_path: Path) -> None:
|
||||
package_dir = tmp_path / "mypkg"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
config_module = """\
|
||||
class AppConfig:
|
||||
def __init__(self, debug: bool = False):
|
||||
self.debug = debug
|
||||
|
||||
@property
|
||||
def log_level(self) -> str:
|
||||
return "DEBUG" if self.debug else "INFO"
|
||||
|
||||
app_config = AppConfig(debug=True)
|
||||
"""
|
||||
(package_dir / "config.py").write_text(config_module, encoding="utf-8")
|
||||
|
||||
consumer_code = """\
|
||||
from mypkg.config import app_config
|
||||
|
||||
def get_log_level() -> str:
|
||||
return app_config.log_level
|
||||
"""
|
||||
consumer_path = package_dir / "consumer.py"
|
||||
consumer_path.write_text(consumer_code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=consumer_code, file_path=consumer_path)])
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) >= 1
|
||||
combined = "\n".join(cs.code for cs in result.code_strings)
|
||||
assert "class AppConfig:" in combined
|
||||
assert "@property" in combined
|
||||
|
|
|
|||
Loading…
Reference in a new issue