codeflash-internal/django/aiservice/tests/optimizer/test_optimizer_java.py
Saurabh Misra 198c0c1a4e
codeflash-omni-java (#2335)
# Pull Request Checklist

## Description
- [ ] **Breaking Changes**: Document any breaking changes (if
applicable)
- [ ] **Description of PR**: Clear and concise description of what this
PR accomplishes
- [ ] **Related Issues**: Link to any related issues or tickets

## Testing
- [ ] **Test cases Attached**: All relevant test cases have been
added/updated
- [ ] **Manual Testing**: Manual testing completed for the changes

## Monitoring & Debugging
- [ ] **Logging in place**: Appropriate logging has been added for
debugging user issues
- [ ] **Sentry will be able to catch errors**: Error handling ensures
Sentry can capture and report errors
- [ ] **Avoid Dev based/Prisma logging**: No development-only or
Prisma-specific logging in production code

## Configuration
- [ ] **Env variables newly added**: Any new environment variables are
documented in .env.example file or mentioned in description
---

## Additional Notes
<!-- Add any additional context, screenshots, or notes for reviewers
here -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Co-authored-by: HeshamHM28 <HeshamMohamedFathy@outlook.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-39-200.ec2.internal>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Kevin Turcios <turcioskevinr@gmail.com>
Co-authored-by: Kevin Turcios <106575910+KRRT7@users.noreply.github.com>
2026-02-13 23:26:55 +05:30

884 lines
24 KiB
Python

"""Tests for Java optimizer module.
Tests the code extraction, normalization, and validation functions.
"""
import re
from aiservice.validators.java_validator import validate_java_syntax
# Pattern to extract code blocks from LLM response (handles both ```java and ```java:filename.java)
JAVA_CODE_PATTERN = re.compile(r"```(?:java)(?::[^\n]*)?\s*\n(.*?)```", re.MULTILINE | re.DOTALL)
# Pattern to extract code blocks with file paths (multi-file context)
JAVA_CODE_WITH_PATH_PATTERN = re.compile(r"```(?:java):([^\n]+)\n(.*?)```", re.MULTILINE | re.DOTALL)
def extract_code_and_explanation(content: str, is_multi_file: bool = False) -> tuple[str | dict[str, str], str]:
"""Extract code and explanation from LLM response.
Args:
content: The raw LLM response content
is_multi_file: Whether to expect multi-file format
Returns:
Tuple of (code, explanation) where code is a string for single file
or dict[str, str] for multi-file
"""
if is_multi_file:
# Extract all code blocks with file paths
matches = JAVA_CODE_WITH_PATH_PATTERN.findall(content)
if matches:
file_to_code: dict[str, str] = {}
first_match_pos = content.find("```")
explanation = content[:first_match_pos].strip() if first_match_pos > 0 else ""
for file_path, code in matches:
file_to_code[file_path.strip()] = code.strip()
return file_to_code, explanation
# Fall back to single file extraction
return extract_code_and_explanation(content, is_multi_file=False)
# Single file extraction
match = JAVA_CODE_PATTERN.search(content)
if match:
code = match.group(1).strip()
# Explanation is everything before the code block
explanation_end = match.start()
explanation = content[:explanation_end].strip()
return code, explanation
# No code block found, return empty code
return "", content
def is_multi_context_java(source_code: str) -> bool:
"""Check if source code contains multiple Java file blocks."""
return source_code.count("```java:") >= 1
class TestExtractCodeAndExplanation:
"""Tests for extracting code and explanation from LLM responses."""
def test_extract_java_code_block(self) -> None:
"""Test extracting code from a Java code block."""
response = """**Optimization Explanation:**
I replaced the O(n²) nested loop with a more efficient HashMap-based lookup.
```java
public List<Integer> findDuplicates(int[] arr) {
Map<Integer, Boolean> seen = new HashMap<>();
List<Integer> duplicates = new ArrayList<>();
for (int item : arr) {
if (seen.containsKey(item)) {
duplicates.add(item);
}
seen.put(item, true);
}
return duplicates;
}
```
"""
code, explanation = extract_code_and_explanation(response)
assert "findDuplicates" in code
assert "HashMap" in code
assert "O(n²)" in explanation or "HashMap" in explanation
def test_extract_with_filename(self) -> None:
"""Test extracting code from a code block with filename."""
response = """Here's the optimized code:
```java:Calculator.java
public class Calculator {
public long fibonacci(int n) {
if (n <= 1) return n;
long a = 0, b = 1;
for (int i = 2; i <= n; i++) {
long temp = a + b;
a = b;
b = temp;
}
return b;
}
}
```
"""
code, explanation = extract_code_and_explanation(response, is_multi_file=True)
assert isinstance(code, dict)
assert "Calculator.java" in code
assert "fibonacci" in code["Calculator.java"]
def test_no_code_block_returns_empty(self) -> None:
"""Test that missing code block returns empty code."""
response = "This response has no code block, just explanation."
code, explanation = extract_code_and_explanation(response)
assert code == ""
assert len(explanation) > 0
def test_multiple_code_blocks_takes_first(self) -> None:
"""Test that only the first code block is extracted."""
response = """First version:
```java
public int first() { return 1; }
```
Alternative version:
```java
public int second() { return 2; }
```
"""
code, explanation = extract_code_and_explanation(response)
assert "first" in code
assert "second" not in code
def test_multi_file_extraction(self) -> None:
"""Test extracting multiple files from response."""
response = """Here are the optimized classes:
```java:MathUtils.java
public class MathUtils {
public static int add(int a, int b) {
return a + b;
}
}
```
```java:Calculator.java
public class Calculator {
private MathUtils utils;
public int compute(int x, int y) {
return MathUtils.add(x, y);
}
}
```
"""
code, explanation = extract_code_and_explanation(response, is_multi_file=True)
assert isinstance(code, dict)
assert len(code) == 2
assert "MathUtils.java" in code
assert "Calculator.java" in code
assert "add" in code["MathUtils.java"]
assert "compute" in code["Calculator.java"]
class TestIsMultiContextJava:
"""Tests for detecting multi-file Java context."""
def test_single_file_not_multi_context(self) -> None:
"""Test that single file code is not detected as multi-context."""
code = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
assert not is_multi_context_java(code)
def test_multi_file_is_multi_context(self) -> None:
"""Test that multi-file code is detected as multi-context."""
code = """```java:Calculator.java
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
```
"""
assert is_multi_context_java(code)
def test_multiple_files_is_multi_context(self) -> None:
"""Test that multiple Java files are detected as multi-context."""
code = """```java:A.java
public class A {}
```
```java:B.java
public class B {}
```
"""
assert is_multi_context_java(code)
class TestValidateJavaSyntax:
"""Tests for Java syntax validation using tree-sitter."""
def test_valid_java_code(self) -> None:
"""Test that valid Java code passes validation."""
code = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_empty_code_fails(self) -> None:
"""Test that empty code fails validation."""
is_valid, error = validate_java_syntax("")
assert not is_valid
assert error is not None
is_valid, error = validate_java_syntax(" ")
assert not is_valid
assert error is not None
def test_unbalanced_braces_fails(self) -> None:
"""Test that unbalanced braces fail validation."""
code = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
""" # Missing closing brace
is_valid, error = validate_java_syntax(code)
assert not is_valid
def test_unbalanced_parentheses_fails(self) -> None:
"""Test that unbalanced parentheses fail validation."""
code = """
public class Calculator {
public int add(int a, int b {
return a + b;
}
}
""" # Missing closing parenthesis
is_valid, error = validate_java_syntax(code)
assert not is_valid
def test_complex_valid_code(self) -> None:
"""Test that complex valid Java code passes validation."""
code = """
public class Fibonacci {
private Map<Integer, Long> memo = new HashMap<>();
public long fibonacci(int n) {
if (n <= 1) {
return n;
}
if (memo.containsKey(n)) {
return memo.get(n);
}
long result = fibonacci(n - 1) + fibonacci(n - 2);
memo.put(n, result);
return result;
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_braces_in_string_are_handled(self) -> None:
"""Test that braces inside strings are handled correctly by tree-sitter."""
code = """
public class Test {
public String getBraces() {
return "{ } ( ) [ ]";
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_braces_in_single_line_comment_are_handled(self) -> None:
"""Test that braces in single-line comments are handled correctly."""
code = """
public class Test {
public void method() {
// This comment has unbalanced braces: { { {
int x = 1;
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_braces_in_multi_line_comment_are_handled(self) -> None:
"""Test that braces in multi-line comments are handled correctly."""
code = """
public class Test {
/*
* This comment has unbalanced braces: { { {
* And parentheses: ( ( (
*/
public void method() {
int x = 1;
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_unbalanced_brackets_fails(self) -> None:
"""Test that unbalanced brackets fail validation."""
code = """
public class Test {
public int[] getArray() {
return new int[5;
}
}
"""
is_valid, error = validate_java_syntax(code)
assert not is_valid
def test_improper_nesting_fails(self) -> None:
"""Test that improperly nested delimiters fail validation."""
# Opening brace closed with parenthesis
code = """
public class Test {
public void method({)
}
"""
is_valid, error = validate_java_syntax(code)
assert not is_valid
def test_array_code_passes(self) -> None:
"""Test that code with arrays passes validation."""
code = """
public class ArrayTest {
public int[] processArray(int[] input) {
int[] result = new int[input.length];
for (int i = 0; i < input.length; i++) {
result[i] = input[i] * 2;
}
return result;
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_escaped_quotes_in_string(self) -> None:
"""Test that escaped quotes in strings are handled correctly."""
code = """
public class Test {
public String getQuote() {
return "He said \\"Hello\\" to me";
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_char_literal_with_special_chars(self) -> None:
"""Test that character literals are handled correctly."""
code = """
public class Test {
public char getBrace() {
return '{';
}
public char getParen() {
return '(';
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_lambda_expression_passes(self) -> None:
"""Test that lambda expressions pass validation."""
code = """
public class Test {
public void process() {
List<String> items = Arrays.asList("a", "b", "c");
items.forEach(item -> {
System.out.println(item);
});
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_generic_types_pass(self) -> None:
"""Test that generic types with angle brackets pass validation."""
code = """
public class Test {
private Map<String, List<Integer>> data = new HashMap<>();
public List<Map<String, Object>> process(Set<String> keys) {
return new ArrayList<>();
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
class TestValidateJavaSyntaxEdgeCases:
"""Additional edge case tests for Java syntax validation."""
def test_nested_braces_pass(self) -> None:
"""Test deeply nested braces pass validation."""
code = """
public class Test {
public void method() {
if (true) {
while (true) {
for (int i = 0; i < 10; i++) {
try {
synchronized (this) {
doSomething();
}
} catch (Exception e) {
// handle
}
}
}
}
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_string_with_escaped_backslash(self) -> None:
"""Test string with escaped backslash before quote."""
code = r"""
public class Test {
public String getPath() {
return "C:\\Users\\test\\";
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_empty_string_literal(self) -> None:
"""Test empty string literal doesn't break parsing."""
code = """
public class Test {
public String empty() {
return "";
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_string_with_newline_escape(self) -> None:
"""Test string with newline escape sequence."""
code = """
public class Test {
public String multiline() {
return "line1\\nline2\\nline3";
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_division_not_confused_with_comment(self) -> None:
"""Test that division operator is not confused with comment start."""
code = """
public class Test {
public int divide(int a, int b) {
return a / b;
}
public double ratio(double x, double y) {
return x / y / 2.0;
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_regex_in_string(self) -> None:
"""Test regex pattern in string with special characters."""
code = r"""
public class Test {
public Pattern getPattern() {
return Pattern.compile("\\{.*\\}|\\[.*\\]|\\(.*\\)");
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_javadoc_comment(self) -> None:
"""Test Javadoc comments with braces in examples."""
code = """
public class Test {
/**
* Example usage:
* <pre>
* Map<String, Object> map = new HashMap<>() {{
* put("key", "value");
* }};
* </pre>
* Note: The above uses double-brace initialization {{}}.
*/
public void method() {
int x = 1;
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_annotation_with_array(self) -> None:
"""Test annotations with array values."""
code = """
@SuppressWarnings({"unchecked", "rawtypes"})
public class Test {
@RequestMapping(value = {"/path1", "/path2"})
public void method() {
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_anonymous_inner_class(self) -> None:
"""Test anonymous inner class syntax."""
code = """
public class Test {
public Runnable getRunnable() {
return new Runnable() {
@Override
public void run() {
System.out.println("Running");
}
};
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_double_brace_initialization(self) -> None:
"""Test double brace initialization pattern."""
code = """
public class Test {
public Map<String, String> getMap() {
return new HashMap<String, String>() {{
put("key1", "value1");
put("key2", "value2");
}};
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_switch_expression(self) -> None:
"""Test switch expression (Java 14+)."""
code = """
public class Test {
public String getDay(int day) {
return switch (day) {
case 1, 2, 3, 4, 5 -> "Weekday";
case 6, 7 -> "Weekend";
default -> "Invalid";
};
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_record_class(self) -> None:
"""Test record class syntax (Java 16+)."""
code = """
public record Point(int x, int y) {
public Point {
if (x < 0 || y < 0) {
throw new IllegalArgumentException("Coordinates must be positive");
}
}
public double distance() {
return Math.sqrt(x * x + y * y);
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_method_reference(self) -> None:
"""Test method reference syntax."""
code = """
public class Test {
public void process(List<String> items) {
items.stream()
.map(String::toUpperCase)
.filter(s -> s.length() > 3)
.forEach(System.out::println);
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_try_with_resources(self) -> None:
"""Test try-with-resources syntax."""
code = """
public class Test {
public String readFile(String path) throws IOException {
try (BufferedReader reader = new BufferedReader(new FileReader(path));
PrintWriter writer = new PrintWriter(System.out)) {
return reader.readLine();
}
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_multiline_string_concat(self) -> None:
"""Test multiline string concatenation."""
code = """
public class Test {
public String getJson() {
return "{"
+ "\\"name\\": \\"test\\","
+ "\\"value\\": 123"
+ "}";
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_char_literal_escape_sequences(self) -> None:
"""Test various character literal escape sequences."""
code = """
public class Test {
char tab = '\\t';
char newline = '\\n';
char backslash = '\\\\';
char quote = '\\'';
char bracket = '[';
char brace = '}';
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_interface_with_default_method(self) -> None:
"""Test interface with default method."""
code = """
public interface Calculator {
int add(int a, int b);
default int subtract(int a, int b) {
return a - b;
}
static int multiply(int a, int b) {
return a * b;
}
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
def test_enum_with_methods(self) -> None:
"""Test enum with constructor and methods."""
code = """
public enum Status {
ACTIVE("A") {
@Override
public String getDescription() {
return "Active status";
}
},
INACTIVE("I") {
@Override
public String getDescription() {
return "Inactive status";
}
};
private final String code;
Status(String code) {
this.code = code;
}
public abstract String getDescription();
}
"""
is_valid, error = validate_java_syntax(code)
assert is_valid
assert error is None
class TestValidateJavaSyntaxFailureCases:
"""Test cases that should fail validation."""
def test_missing_closing_brace_at_end(self) -> None:
"""Test missing closing brace at end of class."""
code = """
public class Test {
public void method() {
int x = 1;
}
"""
is_valid, error = validate_java_syntax(code)
assert not is_valid
def test_extra_closing_brace(self) -> None:
"""Test extra closing brace."""
code = """
public class Test {
public void method() {
int x = 1;
}
}}
"""
is_valid, error = validate_java_syntax(code)
assert not is_valid
def test_mismatched_bracket_brace(self) -> None:
"""Test bracket closed with brace."""
code = """
public class Test {
int[] arr = new int[5};
}
"""
is_valid, error = validate_java_syntax(code)
assert not is_valid
def test_mismatched_paren_bracket(self) -> None:
"""Test parenthesis closed with bracket."""
code = """
public class Test {
public void method(int x] {
}
}
"""
is_valid, error = validate_java_syntax(code)
assert not is_valid
def test_unclosed_string_with_brace(self) -> None:
"""Test unclosed string should not hide syntax error."""
# This has unbalanced braces AND an unclosed string
code = """
public class Test {
String s = "unclosed
{
}
"""
# The unclosed string means the { on line 4 is visible
# Total: 2 open braces, 1 close brace = unbalanced
is_valid, error = validate_java_syntax(code)
assert not is_valid
def test_only_opening_delimiters(self) -> None:
"""Test code with only opening delimiters."""
code = "public class Test { public void method( int[] arr = new int["
is_valid, error = validate_java_syntax(code)
assert not is_valid
def test_only_closing_delimiters(self) -> None:
"""Test code with only closing delimiters."""
code = "} } ) ] }"
is_valid, error = validate_java_syntax(code)
assert not is_valid
def test_interleaved_wrong_nesting(self) -> None:
"""Test interleaved but wrongly nested delimiters."""
code = """
public class Test {
public void method() {
int[] arr = ([)];
}
}
"""
is_valid, error = validate_java_syntax(code)
assert not is_valid
def test_whitespace_only(self) -> None:
"""Test whitespace-only code fails."""
is_valid, error = validate_java_syntax(" \n\t\n ")
assert not is_valid
def test_newlines_only(self) -> None:
"""Test newlines-only code fails."""
is_valid, error = validate_java_syntax("\n\n\n")
assert not is_valid
def test_unterminated_string_fails(self) -> None:
"""Test that unterminated string literal fails validation."""
code = """
public class Test {
String s = "this string never ends
}
"""
is_valid, error = validate_java_syntax(code)
assert not is_valid
def test_unterminated_char_literal_fails(self) -> None:
"""Test that unterminated character literal fails validation."""
code = """
public class Test {
char c = 'x
}
"""
is_valid, error = validate_java_syntax(code)
assert not is_valid
def test_unterminated_multiline_comment_fails(self) -> None:
"""Test that unterminated multi-line comment fails validation."""
code = """
public class Test {
/* This comment never ends
public void method() {
}
}
"""
is_valid, error = validate_java_syntax(code)
assert not is_valid
def test_unterminated_string_with_balanced_braces_fails(self) -> None:
"""Test that unterminated string fails even if braces would be balanced."""
# Without the fix, this would incorrectly return True because
# the unterminated string would consume everything to EOF
code = """
public class Test {
String s = "unterminated
"""
is_valid, error = validate_java_syntax(code)
assert not is_valid