mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
feat: add Java stdlib import postprocessing for generated tests
LLM-generated Java tests often miss stdlib imports (e.g., java.util.Optional, java.math.BigDecimal), causing valid tests to fail tree-sitter validation and be silently removed. This adds ensure_java_stdlib_imports() — analogous to Python's add_missing_imports — that runs before validation in the testgen pipeline. Covers 90+ stdlib classes across java.util, java.math, java.io, java.nio, java.time, and java.util.concurrent. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
35abf30226
commit
10fa19392d
3 changed files with 330 additions and 0 deletions
|
|
@ -0,0 +1,165 @@
|
|||
import re
|
||||
|
||||
# Mapping of Java stdlib class names to their full import statements.
|
||||
# Only standard library classes that LLMs commonly use in generated tests.
|
||||
JAVA_STDLIB_CLASSES: dict[str, str] = {
|
||||
# java.util
|
||||
"ArrayList": "import java.util.ArrayList;",
|
||||
"Arrays": "import java.util.Arrays;",
|
||||
"Base64": "import java.util.Base64;",
|
||||
"BitSet": "import java.util.BitSet;",
|
||||
"Collections": "import java.util.Collections;",
|
||||
"Comparator": "import java.util.Comparator;",
|
||||
"Deque": "import java.util.Deque;",
|
||||
"EnumMap": "import java.util.EnumMap;",
|
||||
"EnumSet": "import java.util.EnumSet;",
|
||||
"HashMap": "import java.util.HashMap;",
|
||||
"HashSet": "import java.util.HashSet;",
|
||||
"Hashtable": "import java.util.Hashtable;",
|
||||
"Iterator": "import java.util.Iterator;",
|
||||
"LinkedHashMap": "import java.util.LinkedHashMap;",
|
||||
"LinkedHashSet": "import java.util.LinkedHashSet;",
|
||||
"LinkedList": "import java.util.LinkedList;",
|
||||
"List": "import java.util.List;",
|
||||
"Map": "import java.util.Map;",
|
||||
"Objects": "import java.util.Objects;",
|
||||
"Optional": "import java.util.Optional;",
|
||||
"PriorityQueue": "import java.util.PriorityQueue;",
|
||||
"Properties": "import java.util.Properties;",
|
||||
"Queue": "import java.util.Queue;",
|
||||
"Random": "import java.util.Random;",
|
||||
"Set": "import java.util.Set;",
|
||||
"Stack": "import java.util.Stack;",
|
||||
"StringJoiner": "import java.util.StringJoiner;",
|
||||
"TreeMap": "import java.util.TreeMap;",
|
||||
"TreeSet": "import java.util.TreeSet;",
|
||||
"UUID": "import java.util.UUID;",
|
||||
"Vector": "import java.util.Vector;",
|
||||
# java.util.stream
|
||||
"Collectors": "import java.util.stream.Collectors;",
|
||||
"DoubleStream": "import java.util.stream.DoubleStream;",
|
||||
"IntStream": "import java.util.stream.IntStream;",
|
||||
"LongStream": "import java.util.stream.LongStream;",
|
||||
"Stream": "import java.util.stream.Stream;",
|
||||
# java.util.function
|
||||
"BiConsumer": "import java.util.function.BiConsumer;",
|
||||
"BiFunction": "import java.util.function.BiFunction;",
|
||||
"BiPredicate": "import java.util.function.BiPredicate;",
|
||||
"Consumer": "import java.util.function.Consumer;",
|
||||
"Function": "import java.util.function.Function;",
|
||||
"Predicate": "import java.util.function.Predicate;",
|
||||
"Supplier": "import java.util.function.Supplier;",
|
||||
"UnaryOperator": "import java.util.function.UnaryOperator;",
|
||||
# java.util.regex
|
||||
"Matcher": "import java.util.regex.Matcher;",
|
||||
"Pattern": "import java.util.regex.Pattern;",
|
||||
# java.util.concurrent
|
||||
"ConcurrentHashMap": "import java.util.concurrent.ConcurrentHashMap;",
|
||||
"ConcurrentLinkedQueue": "import java.util.concurrent.ConcurrentLinkedQueue;",
|
||||
"CopyOnWriteArrayList": "import java.util.concurrent.CopyOnWriteArrayList;",
|
||||
"CountDownLatch": "import java.util.concurrent.CountDownLatch;",
|
||||
"ExecutorService": "import java.util.concurrent.ExecutorService;",
|
||||
"Executors": "import java.util.concurrent.Executors;",
|
||||
"Future": "import java.util.concurrent.Future;",
|
||||
"TimeUnit": "import java.util.concurrent.TimeUnit;",
|
||||
# java.util.concurrent.atomic
|
||||
"AtomicInteger": "import java.util.concurrent.atomic.AtomicInteger;",
|
||||
"AtomicLong": "import java.util.concurrent.atomic.AtomicLong;",
|
||||
"AtomicReference": "import java.util.concurrent.atomic.AtomicReference;",
|
||||
# java.math
|
||||
"BigDecimal": "import java.math.BigDecimal;",
|
||||
"BigInteger": "import java.math.BigInteger;",
|
||||
"MathContext": "import java.math.MathContext;",
|
||||
"RoundingMode": "import java.math.RoundingMode;",
|
||||
# java.io
|
||||
"BufferedReader": "import java.io.BufferedReader;",
|
||||
"BufferedWriter": "import java.io.BufferedWriter;",
|
||||
"ByteArrayInputStream": "import java.io.ByteArrayInputStream;",
|
||||
"ByteArrayOutputStream": "import java.io.ByteArrayOutputStream;",
|
||||
"File": "import java.io.File;",
|
||||
"FileReader": "import java.io.FileReader;",
|
||||
"FileWriter": "import java.io.FileWriter;",
|
||||
"IOException": "import java.io.IOException;",
|
||||
"InputStream": "import java.io.InputStream;",
|
||||
"OutputStream": "import java.io.OutputStream;",
|
||||
"StringReader": "import java.io.StringReader;",
|
||||
"StringWriter": "import java.io.StringWriter;",
|
||||
# java.nio
|
||||
"ByteBuffer": "import java.nio.ByteBuffer;",
|
||||
"CharBuffer": "import java.nio.CharBuffer;",
|
||||
"Files": "import java.nio.file.Files;",
|
||||
"Path": "import java.nio.file.Path;",
|
||||
"Paths": "import java.nio.file.Paths;",
|
||||
"StandardCharsets": "import java.nio.charset.StandardCharsets;",
|
||||
# java.time
|
||||
"Duration": "import java.time.Duration;",
|
||||
"Instant": "import java.time.Instant;",
|
||||
"LocalDate": "import java.time.LocalDate;",
|
||||
"LocalDateTime": "import java.time.LocalDateTime;",
|
||||
"LocalTime": "import java.time.LocalTime;",
|
||||
"ZoneId": "import java.time.ZoneId;",
|
||||
"ZonedDateTime": "import java.time.ZonedDateTime;",
|
||||
}
|
||||
|
||||
# Pre-compiled regex patterns for class name detection
|
||||
_CLASS_PATTERNS: dict[str, re.Pattern[str]] = {
|
||||
class_name: re.compile(rf"\b{class_name}\b") for class_name in JAVA_STDLIB_CLASSES
|
||||
}
|
||||
|
||||
# Pattern to match existing import lines
|
||||
_IMPORT_LINE_RE = re.compile(r"^import\s+(?:static\s+)?[\w.]+(?:\.\*)?;\s*$", re.MULTILINE)
|
||||
|
||||
|
||||
def _has_import(code: str, class_name: str, import_stmt: str) -> bool:
|
||||
"""Check if the code already has an import for the given class (exact or wildcard)."""
|
||||
if import_stmt in code:
|
||||
return True
|
||||
# Check for wildcard import of the same package
|
||||
package = import_stmt.split()[1].rsplit(".", 1)[0]
|
||||
return f"import {package}.*;" in code
|
||||
|
||||
|
||||
def ensure_java_stdlib_imports(code: str) -> str:
|
||||
"""Add missing Java stdlib imports to generated test code.
|
||||
|
||||
Scans the code for usage of known stdlib classes and adds their import
|
||||
statements if not already present. Inserts after the last existing import line.
|
||||
"""
|
||||
if not code:
|
||||
return code
|
||||
|
||||
imports_to_add: list[str] = []
|
||||
|
||||
for class_name, import_stmt in JAVA_STDLIB_CLASSES.items():
|
||||
if not _CLASS_PATTERNS[class_name].search(code):
|
||||
continue
|
||||
if _has_import(code, class_name, import_stmt):
|
||||
continue
|
||||
imports_to_add.append(import_stmt)
|
||||
|
||||
if not imports_to_add:
|
||||
return code
|
||||
|
||||
# Find the last import line to insert after it
|
||||
last_import_match = None
|
||||
for match in _IMPORT_LINE_RE.finditer(code):
|
||||
last_import_match = match
|
||||
|
||||
if last_import_match:
|
||||
insert_pos = last_import_match.end()
|
||||
import_block = "\n".join(imports_to_add)
|
||||
return code[:insert_pos] + "\n" + import_block + code[insert_pos:]
|
||||
|
||||
# No existing imports — insert before the first non-blank, non-package line
|
||||
lines = code.split("\n")
|
||||
insert_idx = 0
|
||||
for i, line in enumerate(lines):
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("package ") or stripped == "":
|
||||
insert_idx = i + 1
|
||||
else:
|
||||
break
|
||||
|
||||
import_block = "\n".join(imports_to_add)
|
||||
lines.insert(insert_idx, import_block)
|
||||
return "\n".join(lines)
|
||||
|
|
@ -365,6 +365,12 @@ def parse_and_validate_java_output(response_content: str) -> str:
|
|||
|
||||
code = pattern_res.group(1).strip()
|
||||
|
||||
# Add missing stdlib imports before validation — prevents valid tests
|
||||
# from being removed just because the LLM forgot an import statement
|
||||
from core.languages.java.postprocessing.add_missing_imports import ensure_java_stdlib_imports
|
||||
|
||||
code = ensure_java_stdlib_imports(code)
|
||||
|
||||
# Individual test validation: validate each @Test method separately,
|
||||
# removing broken ones while keeping valid ones (mirrors Python's approach)
|
||||
try:
|
||||
|
|
|
|||
159
django/aiservice/tests/testgen/test_java_add_missing_imports.py
Normal file
159
django/aiservice/tests/testgen/test_java_add_missing_imports.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
from core.languages.java.postprocessing.add_missing_imports import ensure_java_stdlib_imports
|
||||
|
||||
|
||||
class TestEnsureJavaStdlibImports:
|
||||
def test_adds_missing_optional_import(self):
|
||||
code = (
|
||||
"import org.junit.jupiter.api.Test;\n"
|
||||
"\n"
|
||||
"public class FooTest {\n"
|
||||
" @Test\n"
|
||||
" void testOptional() {\n"
|
||||
" Optional<String> opt = Optional.of(\"hello\");\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
result = ensure_java_stdlib_imports(code)
|
||||
assert "import java.util.Optional;" in result
|
||||
|
||||
def test_no_change_when_import_exists(self):
|
||||
code = (
|
||||
"import java.util.Optional;\n"
|
||||
"import org.junit.jupiter.api.Test;\n"
|
||||
"\n"
|
||||
"public class FooTest {\n"
|
||||
" @Test\n"
|
||||
" void testOptional() {\n"
|
||||
" Optional<String> opt = Optional.of(\"hello\");\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
result = ensure_java_stdlib_imports(code)
|
||||
assert result == code
|
||||
|
||||
def test_no_change_when_wildcard_import_exists(self):
|
||||
code = (
|
||||
"import java.util.*;\n"
|
||||
"import org.junit.jupiter.api.Test;\n"
|
||||
"\n"
|
||||
"public class FooTest {\n"
|
||||
" @Test\n"
|
||||
" void testOptional() {\n"
|
||||
" Optional<String> opt = Optional.of(\"hello\");\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
result = ensure_java_stdlib_imports(code)
|
||||
# Should not add java.util.Optional since java.util.* covers it
|
||||
assert result == code
|
||||
|
||||
def test_adds_multiple_missing_imports(self):
|
||||
code = (
|
||||
"import org.junit.jupiter.api.Test;\n"
|
||||
"\n"
|
||||
"public class FooTest {\n"
|
||||
" @Test\n"
|
||||
" void testCollections() {\n"
|
||||
" List<String> list = new ArrayList<>();\n"
|
||||
" Map<String, Integer> map = new HashMap<>();\n"
|
||||
" BigDecimal bd = new BigDecimal(\"1.0\");\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
result = ensure_java_stdlib_imports(code)
|
||||
assert "import java.util.List;" in result
|
||||
assert "import java.util.ArrayList;" in result
|
||||
assert "import java.util.Map;" in result
|
||||
assert "import java.util.HashMap;" in result
|
||||
assert "import java.math.BigDecimal;" in result
|
||||
|
||||
def test_empty_code(self):
|
||||
assert ensure_java_stdlib_imports("") == ""
|
||||
|
||||
def test_no_stdlib_classes_used(self):
|
||||
code = (
|
||||
"import org.junit.jupiter.api.Test;\n"
|
||||
"\n"
|
||||
"public class FooTest {\n"
|
||||
" @Test\n"
|
||||
" void testSimple() {\n"
|
||||
" int x = 42;\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
result = ensure_java_stdlib_imports(code)
|
||||
assert result == code
|
||||
|
||||
def test_inserts_after_last_import(self):
|
||||
code = (
|
||||
"import org.junit.jupiter.api.Test;\n"
|
||||
"import static org.junit.jupiter.api.Assertions.*;\n"
|
||||
"\n"
|
||||
"public class FooTest {\n"
|
||||
" @Test\n"
|
||||
" void test() {\n"
|
||||
" Arrays.sort(new int[]{3, 1, 2});\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
result = ensure_java_stdlib_imports(code)
|
||||
# The import should appear after the last import line
|
||||
lines = result.split("\n")
|
||||
import_idx = None
|
||||
for i, line in enumerate(lines):
|
||||
if "import java.util.Arrays;" in line:
|
||||
import_idx = i
|
||||
break
|
||||
assert import_idx is not None
|
||||
# Should be after the static import (line index 1), with a newline separator
|
||||
assert import_idx > 1
|
||||
|
||||
def test_stream_collectors_import(self):
|
||||
code = (
|
||||
"import org.junit.jupiter.api.Test;\n"
|
||||
"\n"
|
||||
"public class FooTest {\n"
|
||||
" @Test\n"
|
||||
" void test() {\n"
|
||||
" Stream.of(1, 2, 3).collect(Collectors.toList());\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
result = ensure_java_stdlib_imports(code)
|
||||
assert "import java.util.stream.Stream;" in result
|
||||
assert "import java.util.stream.Collectors;" in result
|
||||
|
||||
def test_concurrent_imports(self):
|
||||
code = (
|
||||
"import org.junit.jupiter.api.Test;\n"
|
||||
"\n"
|
||||
"public class FooTest {\n"
|
||||
" @Test\n"
|
||||
" void test() {\n"
|
||||
" ConcurrentHashMap<String, Integer> map = new ConcurrentHashMap<>();\n"
|
||||
" AtomicInteger counter = new AtomicInteger(0);\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
result = ensure_java_stdlib_imports(code)
|
||||
assert "import java.util.concurrent.ConcurrentHashMap;" in result
|
||||
assert "import java.util.concurrent.atomic.AtomicInteger;" in result
|
||||
|
||||
def test_code_with_package_and_no_imports(self):
|
||||
code = (
|
||||
"package com.example;\n"
|
||||
"\n"
|
||||
"public class FooTest {\n"
|
||||
" void test() {\n"
|
||||
" List<String> list = new ArrayList<>();\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
result = ensure_java_stdlib_imports(code)
|
||||
assert "import java.util.List;" in result
|
||||
assert "import java.util.ArrayList;" in result
|
||||
# Imports should be after the package line
|
||||
lines = result.split("\n")
|
||||
pkg_idx = next(i for i, l in enumerate(lines) if l.startswith("package"))
|
||||
import_idx = next(i for i, l in enumerate(lines) if "import java.util" in l)
|
||||
assert import_idx > pkg_idx
|
||||
Loading…
Reference in a new issue