Add type hints
This commit is contained in:
parent
a0e3b9f879
commit
4895896c6a
2 changed files with 49 additions and 27 deletions
6
.idea/mypy.xml
Normal file
6
.idea/mypy.xml
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="MypyConfigService">
|
||||
<option name="scanBeforeCheckin" value="true" />
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -1,14 +1,17 @@
|
|||
import ast
|
||||
import os
|
||||
import site
|
||||
from typing import List
|
||||
|
||||
import jedi
|
||||
import tiktoken
|
||||
from jedi.api.classes import Name
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.code_utils.code_extractor import get_code_no_skeleton, get_code
|
||||
|
||||
|
||||
def belongs_to_function(name, function_name):
|
||||
def belongs_to_function(name: Name, function_name: str) -> bool:
|
||||
if name.full_name and name.full_name.startswith(name.module_name):
|
||||
subname = name.full_name[len(name.module_name) :]
|
||||
if f".{function_name}." in subname:
|
||||
|
|
@ -16,10 +19,19 @@ def belongs_to_function(name, function_name):
|
|||
return False
|
||||
|
||||
|
||||
def get_type_annotation_context(function_name, file_path, jedi_script, project_root_path):
|
||||
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
|
||||
class Source:
|
||||
name: str
|
||||
definition: Name
|
||||
source_code: str
|
||||
|
||||
|
||||
def get_type_annotation_context(
|
||||
function_name: str, file_path: str, jedi_script: jedi.Script, project_root_path: str
|
||||
) -> List[Source]:
|
||||
with open(file_path, "r") as file:
|
||||
source_code = file.read()
|
||||
module = ast.parse(source_code)
|
||||
file_contents = file.read()
|
||||
module = ast.parse(file_contents)
|
||||
sources = []
|
||||
for node in ast.walk(module):
|
||||
if isinstance(node, ast.FunctionDef) and node.name == function_name:
|
||||
|
|
@ -28,7 +40,7 @@ def get_type_annotation_context(function_name, file_path, jedi_script, project_r
|
|||
name = arg.annotation.id
|
||||
line_no = arg.annotation.lineno
|
||||
col_no = arg.annotation.col_offset
|
||||
definition = jedi_script.goto(
|
||||
definition: List[Name] = jedi_script.goto(
|
||||
line=line_no,
|
||||
column=col_no,
|
||||
follow_imports=True,
|
||||
|
|
@ -44,11 +56,13 @@ def get_type_annotation_context(function_name, file_path, jedi_script, project_r
|
|||
):
|
||||
source_code = get_code(definition_path, definition[0].name)
|
||||
if source_code:
|
||||
sources.append((name, definition[0], source_code))
|
||||
sources.append(Source(name, definition[0], source_code))
|
||||
return sources
|
||||
|
||||
|
||||
def get_function_variables_definition(function_name, file_path, project_root_path: str):
|
||||
def get_function_variables_definition(
|
||||
function_name: str, file_path: str, project_root_path: str
|
||||
) -> List[Source]:
|
||||
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
|
||||
sources = []
|
||||
# TODO: The function name condition can be stricter so that it does not clash with other class names etc.
|
||||
|
|
@ -59,15 +73,15 @@ def get_function_variables_definition(function_name, file_path, project_root_pat
|
|||
if ref.full_name and belongs_to_function(ref, function_name)
|
||||
]
|
||||
for name in names:
|
||||
definition = script.goto(
|
||||
definitions: List[Name] = script.goto(
|
||||
line=name.line,
|
||||
column=name.column,
|
||||
follow_imports=True,
|
||||
follow_builtin_imports=False,
|
||||
)
|
||||
if definition:
|
||||
if definitions:
|
||||
# TODO: there can be multiple definitions, see how to handle such cases
|
||||
definition_path = str(definition[0].module_path)
|
||||
definition_path = str(definitions[0].module_path)
|
||||
# The definition is part of this project and not defined within the original function
|
||||
if (
|
||||
definition_path.startswith(project_root_path + os.sep)
|
||||
|
|
@ -77,12 +91,12 @@ def get_function_variables_definition(function_name, file_path, project_root_pat
|
|||
for site_package_path in site.getsitepackages()
|
||||
]
|
||||
)
|
||||
and definition[0].full_name
|
||||
and not belongs_to_function(definition[0], function_name)
|
||||
and definitions[0].full_name
|
||||
and not belongs_to_function(definitions[0], function_name)
|
||||
):
|
||||
source_code = get_code_no_skeleton(definition_path, definition[0].name)
|
||||
source_code = get_code_no_skeleton(definition_path, definitions[0].name)
|
||||
if source_code:
|
||||
sources.append((name.full_name, definition[0], source_code))
|
||||
sources.append(Source(name.full_name, definitions[0], source_code))
|
||||
annotation_sources = get_type_annotation_context(
|
||||
function_name, file_path, script, project_root_path
|
||||
)
|
||||
|
|
@ -90,37 +104,39 @@ def get_function_variables_definition(function_name, file_path, project_root_pat
|
|||
deduped_sources = []
|
||||
existing_full_names = set()
|
||||
for source in sources:
|
||||
if source[0] not in existing_full_names:
|
||||
if source.name not in existing_full_names:
|
||||
deduped_sources.append(source)
|
||||
existing_full_names.add(source[0])
|
||||
existing_full_names.add(source.name)
|
||||
return deduped_sources
|
||||
|
||||
|
||||
def get_function_context_len_constrained(
|
||||
function_name, path, project_root_path, code_to_optimize, max_tokens
|
||||
):
|
||||
function_name: str, path: str, project_root_path: str, code_to_optimize: str, max_tokens: int
|
||||
) -> tuple[str, list[Source]]:
|
||||
# TODO: Not just do static analysis, but also find the datatypes of function arguments by running the existing
|
||||
# unittests and inspecting the arguments to resolve the real definitions and dependencies.
|
||||
function_dependencies = get_function_variables_definition(
|
||||
dependent_functions: list[Source] = get_function_variables_definition(
|
||||
function_name, path, project_root_path
|
||||
)
|
||||
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
||||
code_to_optimize_tokens = tokenizer.encode(code_to_optimize)
|
||||
dependencies = [definition[2] for definition in function_dependencies]
|
||||
dependencies_tokens = [len(tokenizer.encode(dependency)) for dependency in dependencies]
|
||||
dependent_functions_sources = [function.source_code for function in dependent_functions]
|
||||
dependent_functions_tokens = [
|
||||
len(tokenizer.encode(function)) for function in dependent_functions_sources
|
||||
]
|
||||
context_list = []
|
||||
context_len = len(code_to_optimize_tokens)
|
||||
print(
|
||||
"ORIGINAL CODE TOKENS LENGTH:",
|
||||
context_len,
|
||||
"ALL DEPENDENCIES TOKENS LENGTH:",
|
||||
sum(dependencies_tokens),
|
||||
sum(dependent_functions_tokens),
|
||||
)
|
||||
for dependency, dependency_len in zip(dependencies, dependencies_tokens):
|
||||
if context_len + dependency_len <= max_tokens:
|
||||
context_list.append(dependency)
|
||||
context_len += dependency_len
|
||||
for function_source, source_len in zip(dependent_functions_sources, dependent_functions_tokens):
|
||||
if context_len + source_len <= max_tokens:
|
||||
context_list.append(function_source)
|
||||
context_len += source_len
|
||||
else:
|
||||
break
|
||||
print("FINAL OPTIMIZATION CONTEXT TOKENS LENGTH:", context_len)
|
||||
return "\n".join(context_list) + "\n" + code_to_optimize, function_dependencies
|
||||
return "\n".join(context_list) + "\n" + code_to_optimize, dependent_functions
|
||||
|
|
|
|||
Loading…
Reference in a new issue