mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
Add code extractor for functions/classes
This commit is contained in:
parent
c9ed4cd1b6
commit
2645a19be0
1 changed files with 50 additions and 0 deletions
50
codeflash/discovery/code_extractor.py
Normal file
50
codeflash/discovery/code_extractor.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
import ast
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_code(file_path: str, target_name: str) -> Optional[str]:
|
||||
"""Returns the code for a class or function in a file."""
|
||||
class_skeleton = []
|
||||
|
||||
def find_target(node_list, name_parts):
|
||||
target_node = None
|
||||
for node in node_list:
|
||||
if (
|
||||
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef))
|
||||
and node.name == name_parts[0]
|
||||
):
|
||||
target_node = node
|
||||
break
|
||||
|
||||
if target_node is None or len(name_parts) == 1:
|
||||
return target_node
|
||||
|
||||
if isinstance(target_node, ast.ClassDef):
|
||||
class_skeleton.append([node.lineno, node.lineno])
|
||||
cbody = target_node.body
|
||||
if isinstance(cbody[0], ast.expr): # Is a docstring
|
||||
class_skeleton.append([cbody[0].lineno, cbody[0].end_lineno])
|
||||
cbody = cbody[1:]
|
||||
for cnode in cbody:
|
||||
if hasattr(cnode, "name") and cnode.name == "__init__":
|
||||
class_skeleton.append([cnode.lineno, cnode.end_lineno])
|
||||
|
||||
return find_target(target_node.body, name_parts[1:])
|
||||
|
||||
return None
|
||||
|
||||
with open(file_path, "r") as file:
|
||||
source_code = file.read()
|
||||
|
||||
module_node = ast.parse(source_code)
|
||||
name_parts = target_name.split(".")
|
||||
target_node = find_target(module_node.body, name_parts)
|
||||
if target_node is None:
|
||||
return None
|
||||
|
||||
# Get the source code lines for the target node
|
||||
lines = source_code.splitlines(keepends=True)
|
||||
class_code = "".join(["".join(lines[s_lineno - 1: e_lineno]) for (s_lineno, e_lineno) in class_skeleton])
|
||||
target_code = "".join(lines[target_node.lineno - 1: target_node.end_lineno])
|
||||
|
||||
return class_code + target_code
|
||||
Loading…
Reference in a new issue