Use isort to deduplicate and sort imports in generated candidates

This commit is contained in:
afik.cohen 2024-03-05 11:40:27 -08:00
parent 5a6f96f861
commit 07e24d3d57
4 changed files with 162 additions and 2 deletions

2
.idea/.gitignore vendored
View file

@ -6,3 +6,5 @@
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
# GitHub Copilot persisted chat sessions
/copilot/chatSessions

View file

@ -1,5 +1,8 @@
import libcst as cst
import re
from typing import List, Tuple, Dict, Optional
import isort
import libcst as cst
from libcst import (
CSTVisitor,
CSTTransformer,
@ -9,7 +12,6 @@ from libcst import (
Expr,
SimpleString,
)
from typing import List, Tuple, Dict, Optional
from optimizer.optimizer_utils import unparse_parse_source, compare_unparsed_ast_to_source
@ -146,12 +148,28 @@ def fix_missing_docstring(
return new_optimized_code_and_explanations
def dedup_and_sort_imports(
original_source_code: str, optimized_code_and_explanations: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
new_optimized_code_and_explanations = []
for code, explanation in optimized_code_and_explanations:
try:
# Use isort to sort and deduplicate the imports
sorted_code = isort.code(code)
new_optimized_code_and_explanations.append((sorted_code, explanation))
except Exception as e:
new_optimized_code_and_explanations.append((code, explanation))
return new_optimized_code_and_explanations
def optimizations_postprocessing_pipeline(
original_source_code: str, optimized_code_and_explanations: List[Tuple[str, str]]
) -> List[Tuple[str, str]]:
pipeline = [
deduplicate_optimizations,
equality_check,
dedup_and_sort_imports,
cleanup_explanations,
fix_missing_docstring,
]

View file

@ -3,6 +3,7 @@ from optimizer.postprocess import (
cleanup_explanations,
deduplicate_optimizations,
fix_missing_docstring,
dedup_and_sort_imports,
)
@ -285,3 +286,141 @@ class TestClass:
actual = fix_missing_docstring(original_code, [(original_code, "")])
assert actual[0][0] == original_code
def test_cleanup_imports_deduplicates():
original_code = """
import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')
"""
optimizations = [
(
"""
import os
import sys
import os
def foo():
return os.path.join(sys.path[0], 'bar')
""",
"Removed duplicate imports",
),
]
expected = [
(
"""
import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')
""",
"Removed duplicate imports",
)
]
actual = dedup_and_sort_imports(original_code, optimizations)
assert actual == expected
def test_cleanup_imports_sorts_and_deduplicates():
original_code = """
import os
import sys
import json
import os
def foo():
return os.path.join(sys.path[0], 'bar')
"""
optimizations = [
(
"""
import sys
import os
import json
import os
def foo():
return os.path.join(sys.path[0], 'bar')
""",
"Sorted and removed duplicate imports",
),
]
expected = [
(
"""
import json
import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')
""",
"Sorted and removed duplicate imports",
)
]
actual = dedup_and_sort_imports(original_code, optimizations)
assert actual == expected
# Doesn't work with isort, but maybe we don't care for now
#
# def test_cleanup_imports_sorts_multiple_blocks():
# original_code = """
# import os
# import sys
#
# def foo():
# return os.path.join(sys.path[0], 'bar')
#
# import json
# import os
# """
# optimizations = [
# (
# """
# import sys
# import os
#
# def foo():
# return os.path.join(sys.path[0], 'bar')
#
# import os
# import json
# """,
# "Sorted imports and ensured they are not duplicated across multiple blocks",
# ),
# ]
#
# expected = [
# (
# """
# import json
# import os
# import sys
#
# def foo():
# return os.path.join(sys.path[0], 'bar')
# """,
# "Sorted imports and ensured they are not duplicated across multiple blocks",
# )
# ]
#
# actual = cleanup_imports(original_code, optimizations)
#
# assert actual == expected

View file

@ -22,6 +22,7 @@ libcst = "^1.1.0"
posthog = "^3.3.1"
pytest = "^8.0.1"
sentry-sdk = { extras = ["django"], version = "^1.40.6" }
isort = "^5.13.2"
[tool.poetry.group.dev]