mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
Use isort to deduplicate and sort imports in generated candidates
This commit is contained in:
parent
5a6f96f861
commit
07e24d3d57
4 changed files with 162 additions and 2 deletions
2
.idea/.gitignore
vendored
2
.idea/.gitignore
vendored
|
|
@ -6,3 +6,5 @@
|
|||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
# GitHub Copilot persisted chat sessions
|
||||
/copilot/chatSessions
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
import libcst as cst
|
||||
import re
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
|
||||
import isort
|
||||
import libcst as cst
|
||||
from libcst import (
|
||||
CSTVisitor,
|
||||
CSTTransformer,
|
||||
|
|
@ -9,7 +12,6 @@ from libcst import (
|
|||
Expr,
|
||||
SimpleString,
|
||||
)
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
|
||||
from optimizer.optimizer_utils import unparse_parse_source, compare_unparsed_ast_to_source
|
||||
|
||||
|
|
@ -146,12 +148,28 @@ def fix_missing_docstring(
|
|||
return new_optimized_code_and_explanations
|
||||
|
||||
|
||||
def dedup_and_sort_imports(
|
||||
original_source_code: str, optimized_code_and_explanations: List[Tuple[str, str]]
|
||||
) -> List[Tuple[str, str]]:
|
||||
new_optimized_code_and_explanations = []
|
||||
for code, explanation in optimized_code_and_explanations:
|
||||
try:
|
||||
# Use isort to sort and deduplicate the imports
|
||||
sorted_code = isort.code(code)
|
||||
new_optimized_code_and_explanations.append((sorted_code, explanation))
|
||||
except Exception as e:
|
||||
new_optimized_code_and_explanations.append((code, explanation))
|
||||
|
||||
return new_optimized_code_and_explanations
|
||||
|
||||
|
||||
def optimizations_postprocessing_pipeline(
|
||||
original_source_code: str, optimized_code_and_explanations: List[Tuple[str, str]]
|
||||
) -> List[Tuple[str, str]]:
|
||||
pipeline = [
|
||||
deduplicate_optimizations,
|
||||
equality_check,
|
||||
dedup_and_sort_imports,
|
||||
cleanup_explanations,
|
||||
fix_missing_docstring,
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from optimizer.postprocess import (
|
|||
cleanup_explanations,
|
||||
deduplicate_optimizations,
|
||||
fix_missing_docstring,
|
||||
dedup_and_sort_imports,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -285,3 +286,141 @@ class TestClass:
|
|||
|
||||
actual = fix_missing_docstring(original_code, [(original_code, "")])
|
||||
assert actual[0][0] == original_code
|
||||
|
||||
|
||||
def test_cleanup_imports_deduplicates():
|
||||
original_code = """
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], 'bar')
|
||||
"""
|
||||
optimizations = [
|
||||
(
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], 'bar')
|
||||
""",
|
||||
"Removed duplicate imports",
|
||||
),
|
||||
]
|
||||
|
||||
expected = [
|
||||
(
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], 'bar')
|
||||
""",
|
||||
"Removed duplicate imports",
|
||||
)
|
||||
]
|
||||
|
||||
actual = dedup_and_sort_imports(original_code, optimizations)
|
||||
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_cleanup_imports_sorts_and_deduplicates():
|
||||
original_code = """
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], 'bar')
|
||||
"""
|
||||
optimizations = [
|
||||
(
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], 'bar')
|
||||
""",
|
||||
"Sorted and removed duplicate imports",
|
||||
),
|
||||
]
|
||||
|
||||
expected = [
|
||||
(
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def foo():
|
||||
return os.path.join(sys.path[0], 'bar')
|
||||
""",
|
||||
"Sorted and removed duplicate imports",
|
||||
)
|
||||
]
|
||||
|
||||
actual = dedup_and_sort_imports(original_code, optimizations)
|
||||
|
||||
assert actual == expected
|
||||
|
||||
|
||||
# Doesn't work with isort, but maybe we don't care for now
|
||||
#
|
||||
# def test_cleanup_imports_sorts_multiple_blocks():
|
||||
# original_code = """
|
||||
# import os
|
||||
# import sys
|
||||
#
|
||||
# def foo():
|
||||
# return os.path.join(sys.path[0], 'bar')
|
||||
#
|
||||
# import json
|
||||
# import os
|
||||
# """
|
||||
# optimizations = [
|
||||
# (
|
||||
# """
|
||||
# import sys
|
||||
# import os
|
||||
#
|
||||
# def foo():
|
||||
# return os.path.join(sys.path[0], 'bar')
|
||||
#
|
||||
# import os
|
||||
# import json
|
||||
# """,
|
||||
# "Sorted imports and ensured they are not duplicated across multiple blocks",
|
||||
# ),
|
||||
# ]
|
||||
#
|
||||
# expected = [
|
||||
# (
|
||||
# """
|
||||
# import json
|
||||
# import os
|
||||
# import sys
|
||||
#
|
||||
# def foo():
|
||||
# return os.path.join(sys.path[0], 'bar')
|
||||
# """,
|
||||
# "Sorted imports and ensured they are not duplicated across multiple blocks",
|
||||
# )
|
||||
# ]
|
||||
#
|
||||
# actual = cleanup_imports(original_code, optimizations)
|
||||
#
|
||||
# assert actual == expected
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ libcst = "^1.1.0"
|
|||
posthog = "^3.3.1"
|
||||
pytest = "^8.0.1"
|
||||
sentry-sdk = { extras = ["django"], version = "^1.40.6" }
|
||||
isort = "^5.13.2"
|
||||
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
|
|
|
|||
Loading…
Reference in a new issue