diff --git a/tests/test_git_utils.py b/tests/test_git_utils.py index 5542d66c0..f3f23c1d9 100644 --- a/tests/test_git_utils.py +++ b/tests/test_git_utils.py @@ -303,7 +303,7 @@ class TestGetGitDiffMultiLanguage(unittest.TestCase): reset_current_language() @patch("codeflash.code_utils.git_utils.git.Repo") - def test_java_diff_ignored_when_language_is_python(self, mock_repo_cls): + def test_java_diff_found_regardless_of_current_language(self, mock_repo_cls): from codeflash.languages.current import reset_current_language, set_current_language repo = mock_repo_cls.return_value @@ -311,15 +311,18 @@ class TestGetGitDiffMultiLanguage(unittest.TestCase): repo.working_dir = "/repo" repo.git.diff.return_value = JAVA_ADDITION_DIFF + # get_git_diff uses all registered extensions, not just the current language's set_current_language("python") try: result = get_git_diff(repo_directory=None, uncommitted_changes=True) - assert len(result) == 0 + assert len(result) == 1 + key = list(result.keys())[0] + assert str(key).endswith("Fibonacci.java") finally: reset_current_language() @patch("codeflash.code_utils.git_utils.git.Repo") - def test_mixed_lang_diff_filters_by_current_language(self, mock_repo_cls): + def test_mixed_lang_diff_returns_all_supported_extensions(self, mock_repo_cls): from codeflash.languages.current import reset_current_language, set_current_language repo = mock_repo_cls.return_value @@ -327,23 +330,14 @@ class TestGetGitDiffMultiLanguage(unittest.TestCase): repo.working_dir = "/repo" repo.git.diff.return_value = MIXED_LANG_DIFF - # When language is Python, only .py file should be found + # All supported extensions are returned regardless of current language set_current_language("python") try: result = get_git_diff(repo_directory=None, uncommitted_changes=True) - assert len(result) == 1 - key = list(result.keys())[0] - assert str(key).endswith("utils.py") - finally: - reset_current_language() - - # When language is Java, only .java file should be found - set_current_language("java") - try: - result = get_git_diff(repo_directory=None, uncommitted_changes=True) - assert len(result) == 1 - key = list(result.keys())[0] - assert str(key).endswith("App.java") + assert len(result) == 2 + paths = [str(k) for k in result.keys()] + assert any(p.endswith("utils.py") for p in paths) + assert any(p.endswith("App.java") for p in paths) finally: reset_current_language()