diff --git a/.idea/.gitignore b/.idea/.gitignore index 13566b81b..a9d7db9c0 100644 --- a/.idea/.gitignore +++ b/.idea/.gitignore @@ -6,3 +6,5 @@ # Datasource local storage ignored files /dataSources/ /dataSources.local.xml +# GitHub Copilot persisted chat sessions +/copilot/chatSessions diff --git a/django/aiservice/optimizer/postprocess.py b/django/aiservice/optimizer/postprocess.py index a5e3ba61f..7c2783c46 100644 --- a/django/aiservice/optimizer/postprocess.py +++ b/django/aiservice/optimizer/postprocess.py @@ -1,5 +1,8 @@ -import libcst as cst import re +from typing import List, Tuple, Dict, Optional + +import isort +import libcst as cst from libcst import ( CSTVisitor, CSTTransformer, @@ -9,7 +12,6 @@ from libcst import ( Expr, SimpleString, ) -from typing import List, Tuple, Dict, Optional from optimizer.optimizer_utils import unparse_parse_source, compare_unparsed_ast_to_source @@ -146,12 +148,28 @@ def fix_missing_docstring( return new_optimized_code_and_explanations +def dedup_and_sort_imports( + original_source_code: str, optimized_code_and_explanations: List[Tuple[str, str]] +) -> List[Tuple[str, str]]: + new_optimized_code_and_explanations = [] + for code, explanation in optimized_code_and_explanations: + try: + # Use isort to sort and deduplicate the imports + sorted_code = isort.code(code) + new_optimized_code_and_explanations.append((sorted_code, explanation)) + except Exception as e: + new_optimized_code_and_explanations.append((code, explanation)) + + return new_optimized_code_and_explanations + + def optimizations_postprocessing_pipeline( original_source_code: str, optimized_code_and_explanations: List[Tuple[str, str]] ) -> List[Tuple[str, str]]: pipeline = [ deduplicate_optimizations, equality_check, + dedup_and_sort_imports, cleanup_explanations, fix_missing_docstring, ] diff --git a/django/aiservice/optimizer/test_optimizer.py b/django/aiservice/optimizer/test_optimizer.py index 4159b661e..b3be7403c 100644 --- a/django/aiservice/optimizer/test_optimizer.py +++ b/django/aiservice/optimizer/test_optimizer.py @@ -3,6 +3,7 @@ from optimizer.postprocess import ( cleanup_explanations, deduplicate_optimizations, fix_missing_docstring, + dedup_and_sort_imports, ) @@ -285,3 +286,141 @@ class TestClass: actual = fix_missing_docstring(original_code, [(original_code, "")]) assert actual[0][0] == original_code + + +def test_cleanup_imports_deduplicates(): + original_code = """ +import os +import sys + + +def foo(): + return os.path.join(sys.path[0], 'bar') +""" + optimizations = [ + ( + """ +import os +import sys +import os + + +def foo(): + return os.path.join(sys.path[0], 'bar') +""", + "Removed duplicate imports", + ), + ] + + expected = [ + ( + """ +import os +import sys + + +def foo(): + return os.path.join(sys.path[0], 'bar') +""", + "Removed duplicate imports", + ) + ] + + actual = dedup_and_sort_imports(original_code, optimizations) + + assert actual == expected + + +def test_cleanup_imports_sorts_and_deduplicates(): + original_code = """ +import os +import sys +import json +import os + + +def foo(): + return os.path.join(sys.path[0], 'bar') +""" + optimizations = [ + ( + """ +import sys +import os +import json +import os + + +def foo(): + return os.path.join(sys.path[0], 'bar') +""", + "Sorted and removed duplicate imports", + ), + ] + + expected = [ + ( + """ +import json +import os +import sys + + +def foo(): + return os.path.join(sys.path[0], 'bar') +""", + "Sorted and removed duplicate imports", + ) + ] + + actual = dedup_and_sort_imports(original_code, optimizations) + + assert actual == expected + + +# Doesn't work with isort, but maybe we don't care for now +# +# def test_cleanup_imports_sorts_multiple_blocks(): +# original_code = """ +# import os +# import sys +# +# def foo(): +# return os.path.join(sys.path[0], 'bar') +# +# import json +# import os +# """ +# optimizations = [ +# ( +# """ +# import sys +# import os +# +# def foo(): +# return os.path.join(sys.path[0], 'bar') +# +# import os +# import json +# """, +# "Sorted imports and ensured they are not duplicated across multiple blocks", +# ), +# ] +# +# expected = [ +# ( +# """ +# import json +# import os +# import sys +# +# def foo(): +# return os.path.join(sys.path[0], 'bar') +# """, +# "Sorted imports and ensured they are not duplicated across multiple blocks", +# ) +# ] +# +# actual = cleanup_imports(original_code, optimizations) +# +# assert actual == expected diff --git a/django/aiservice/pyproject.toml b/django/aiservice/pyproject.toml index 7fdfd1183..d15200dea 100644 --- a/django/aiservice/pyproject.toml +++ b/django/aiservice/pyproject.toml @@ -22,6 +22,7 @@ libcst = "^1.1.0" posthog = "^3.3.1" pytest = "^8.0.1" sentry-sdk = { extras = ["django"], version = "^1.40.6" } +isort = "^5.13.2" [tool.poetry.group.dev]