Add type hints

This commit is contained in:
afik.cohen 2023-10-31 14:35:35 -07:00
parent a0e3b9f879
commit 4895896c6a
2 changed files with 49 additions and 27 deletions

6
.idea/mypy.xml Normal file
View file

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="MypyConfigService">
<option name="scanBeforeCheckin" value="true" />
</component>
</project>

View file

@ -1,14 +1,17 @@
import ast
import os
import site
from typing import List
import jedi
import tiktoken
from jedi.api.classes import Name
from pydantic.dataclasses import dataclass
from codeflash.code_utils.code_extractor import get_code_no_skeleton, get_code
def belongs_to_function(name, function_name):
def belongs_to_function(name: Name, function_name: str) -> bool:
if name.full_name and name.full_name.startswith(name.module_name):
subname = name.full_name[len(name.module_name) :]
if f".{function_name}." in subname:
@ -16,10 +19,19 @@ def belongs_to_function(name, function_name):
return False
def get_type_annotation_context(function_name, file_path, jedi_script, project_root_path):
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
class Source:
name: str
definition: Name
source_code: str
def get_type_annotation_context(
function_name: str, file_path: str, jedi_script: jedi.Script, project_root_path: str
) -> List[Source]:
with open(file_path, "r") as file:
source_code = file.read()
module = ast.parse(source_code)
file_contents = file.read()
module = ast.parse(file_contents)
sources = []
for node in ast.walk(module):
if isinstance(node, ast.FunctionDef) and node.name == function_name:
@ -28,7 +40,7 @@ def get_type_annotation_context(function_name, file_path, jedi_script, project_r
name = arg.annotation.id
line_no = arg.annotation.lineno
col_no = arg.annotation.col_offset
definition = jedi_script.goto(
definition: List[Name] = jedi_script.goto(
line=line_no,
column=col_no,
follow_imports=True,
@ -44,11 +56,13 @@ def get_type_annotation_context(function_name, file_path, jedi_script, project_r
):
source_code = get_code(definition_path, definition[0].name)
if source_code:
sources.append((name, definition[0], source_code))
sources.append(Source(name, definition[0], source_code))
return sources
def get_function_variables_definition(function_name, file_path, project_root_path: str):
def get_function_variables_definition(
function_name: str, file_path: str, project_root_path: str
) -> List[Source]:
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
sources = []
# TODO: The function name condition can be stricter so that it does not clash with other class names etc.
@ -59,15 +73,15 @@ def get_function_variables_definition(function_name, file_path, project_root_pat
if ref.full_name and belongs_to_function(ref, function_name)
]
for name in names:
definition = script.goto(
definitions: List[Name] = script.goto(
line=name.line,
column=name.column,
follow_imports=True,
follow_builtin_imports=False,
)
if definition:
if definitions:
# TODO: there can be multiple definitions, see how to handle such cases
definition_path = str(definition[0].module_path)
definition_path = str(definitions[0].module_path)
# The definition is part of this project and not defined within the original function
if (
definition_path.startswith(project_root_path + os.sep)
@ -77,12 +91,12 @@ def get_function_variables_definition(function_name, file_path, project_root_pat
for site_package_path in site.getsitepackages()
]
)
and definition[0].full_name
and not belongs_to_function(definition[0], function_name)
and definitions[0].full_name
and not belongs_to_function(definitions[0], function_name)
):
source_code = get_code_no_skeleton(definition_path, definition[0].name)
source_code = get_code_no_skeleton(definition_path, definitions[0].name)
if source_code:
sources.append((name.full_name, definition[0], source_code))
sources.append(Source(name.full_name, definitions[0], source_code))
annotation_sources = get_type_annotation_context(
function_name, file_path, script, project_root_path
)
@ -90,37 +104,39 @@ def get_function_variables_definition(function_name, file_path, project_root_pat
deduped_sources = []
existing_full_names = set()
for source in sources:
if source[0] not in existing_full_names:
if source.name not in existing_full_names:
deduped_sources.append(source)
existing_full_names.add(source[0])
existing_full_names.add(source.name)
return deduped_sources
def get_function_context_len_constrained(
function_name, path, project_root_path, code_to_optimize, max_tokens
):
function_name: str, path: str, project_root_path: str, code_to_optimize: str, max_tokens: int
) -> tuple[str, list[Source]]:
# TODO: Not just do static analysis, but also find the datatypes of function arguments by running the existing
# unittests and inspecting the arguments to resolve the real definitions and dependencies.
function_dependencies = get_function_variables_definition(
dependent_functions: list[Source] = get_function_variables_definition(
function_name, path, project_root_path
)
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
code_to_optimize_tokens = tokenizer.encode(code_to_optimize)
dependencies = [definition[2] for definition in function_dependencies]
dependencies_tokens = [len(tokenizer.encode(dependency)) for dependency in dependencies]
dependent_functions_sources = [function.source_code for function in dependent_functions]
dependent_functions_tokens = [
len(tokenizer.encode(function)) for function in dependent_functions_sources
]
context_list = []
context_len = len(code_to_optimize_tokens)
print(
"ORIGINAL CODE TOKENS LENGTH:",
context_len,
"ALL DEPENDENCIES TOKENS LENGTH:",
sum(dependencies_tokens),
sum(dependent_functions_tokens),
)
for dependency, dependency_len in zip(dependencies, dependencies_tokens):
if context_len + dependency_len <= max_tokens:
context_list.append(dependency)
context_len += dependency_len
for function_source, source_len in zip(dependent_functions_sources, dependent_functions_tokens):
if context_len + source_len <= max_tokens:
context_list.append(function_source)
context_len += source_len
else:
break
print("FINAL OPTIMIZATION CONTEXT TOKENS LENGTH:", context_len)
return "\n".join(context_list) + "\n" + code_to_optimize, function_dependencies
return "\n".join(context_list) + "\n" + code_to_optimize, dependent_functions