Add code extractor for functions/classes

This commit is contained in:
afik.cohen 2023-10-19 20:50:57 -07:00
parent c9ed4cd1b6
commit 2645a19be0

View file

@ -0,0 +1,50 @@
import ast
from typing import Optional
def get_code(file_path: str, target_name: str) -> Optional[str]:
"""Returns the code for a class or function in a file."""
class_skeleton = []
def find_target(node_list, name_parts):
target_node = None
for node in node_list:
if (
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef))
and node.name == name_parts[0]
):
target_node = node
break
if target_node is None or len(name_parts) == 1:
return target_node
if isinstance(target_node, ast.ClassDef):
class_skeleton.append([node.lineno, node.lineno])
cbody = target_node.body
if isinstance(cbody[0], ast.expr): # Is a docstring
class_skeleton.append([cbody[0].lineno, cbody[0].end_lineno])
cbody = cbody[1:]
for cnode in cbody:
if hasattr(cnode, "name") and cnode.name == "__init__":
class_skeleton.append([cnode.lineno, cnode.end_lineno])
return find_target(target_node.body, name_parts[1:])
return None
with open(file_path, "r") as file:
source_code = file.read()
module_node = ast.parse(source_code)
name_parts = target_name.split(".")
target_node = find_target(module_node.body, name_parts)
if target_node is None:
return None
# Get the source code lines for the target node
lines = source_code.splitlines(keepends=True)
class_code = "".join(["".join(lines[s_lineno - 1: e_lineno]) for (s_lineno, e_lineno) in class_skeleton])
target_code = "".join(lines[target_node.lineno - 1: target_node.end_lineno])
return class_code + target_code