feat: add Java stdlib import postprocessing for generated tests

LLM-generated Java tests often miss stdlib imports (e.g., java.util.Optional,
java.math.BigDecimal), causing valid tests to fail tree-sitter validation and
be silently removed. This adds ensure_java_stdlib_imports() — analogous to
Python's add_missing_imports — that runs before validation in the testgen
pipeline. Covers 90+ stdlib classes across java.util, java.math, java.io,
java.nio, java.time, and java.util.concurrent.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Mohamed Ashraf 2026-02-25 22:04:57 +00:00
parent 35abf30226
commit 10fa19392d
3 changed files with 330 additions and 0 deletions

View file

@ -0,0 +1,165 @@
import re
# Mapping of Java stdlib class names to their full import statements.
# Only standard library classes that LLMs commonly use in generated tests.
JAVA_STDLIB_CLASSES: dict[str, str] = {
# java.util
"ArrayList": "import java.util.ArrayList;",
"Arrays": "import java.util.Arrays;",
"Base64": "import java.util.Base64;",
"BitSet": "import java.util.BitSet;",
"Collections": "import java.util.Collections;",
"Comparator": "import java.util.Comparator;",
"Deque": "import java.util.Deque;",
"EnumMap": "import java.util.EnumMap;",
"EnumSet": "import java.util.EnumSet;",
"HashMap": "import java.util.HashMap;",
"HashSet": "import java.util.HashSet;",
"Hashtable": "import java.util.Hashtable;",
"Iterator": "import java.util.Iterator;",
"LinkedHashMap": "import java.util.LinkedHashMap;",
"LinkedHashSet": "import java.util.LinkedHashSet;",
"LinkedList": "import java.util.LinkedList;",
"List": "import java.util.List;",
"Map": "import java.util.Map;",
"Objects": "import java.util.Objects;",
"Optional": "import java.util.Optional;",
"PriorityQueue": "import java.util.PriorityQueue;",
"Properties": "import java.util.Properties;",
"Queue": "import java.util.Queue;",
"Random": "import java.util.Random;",
"Set": "import java.util.Set;",
"Stack": "import java.util.Stack;",
"StringJoiner": "import java.util.StringJoiner;",
"TreeMap": "import java.util.TreeMap;",
"TreeSet": "import java.util.TreeSet;",
"UUID": "import java.util.UUID;",
"Vector": "import java.util.Vector;",
# java.util.stream
"Collectors": "import java.util.stream.Collectors;",
"DoubleStream": "import java.util.stream.DoubleStream;",
"IntStream": "import java.util.stream.IntStream;",
"LongStream": "import java.util.stream.LongStream;",
"Stream": "import java.util.stream.Stream;",
# java.util.function
"BiConsumer": "import java.util.function.BiConsumer;",
"BiFunction": "import java.util.function.BiFunction;",
"BiPredicate": "import java.util.function.BiPredicate;",
"Consumer": "import java.util.function.Consumer;",
"Function": "import java.util.function.Function;",
"Predicate": "import java.util.function.Predicate;",
"Supplier": "import java.util.function.Supplier;",
"UnaryOperator": "import java.util.function.UnaryOperator;",
# java.util.regex
"Matcher": "import java.util.regex.Matcher;",
"Pattern": "import java.util.regex.Pattern;",
# java.util.concurrent
"ConcurrentHashMap": "import java.util.concurrent.ConcurrentHashMap;",
"ConcurrentLinkedQueue": "import java.util.concurrent.ConcurrentLinkedQueue;",
"CopyOnWriteArrayList": "import java.util.concurrent.CopyOnWriteArrayList;",
"CountDownLatch": "import java.util.concurrent.CountDownLatch;",
"ExecutorService": "import java.util.concurrent.ExecutorService;",
"Executors": "import java.util.concurrent.Executors;",
"Future": "import java.util.concurrent.Future;",
"TimeUnit": "import java.util.concurrent.TimeUnit;",
# java.util.concurrent.atomic
"AtomicInteger": "import java.util.concurrent.atomic.AtomicInteger;",
"AtomicLong": "import java.util.concurrent.atomic.AtomicLong;",
"AtomicReference": "import java.util.concurrent.atomic.AtomicReference;",
# java.math
"BigDecimal": "import java.math.BigDecimal;",
"BigInteger": "import java.math.BigInteger;",
"MathContext": "import java.math.MathContext;",
"RoundingMode": "import java.math.RoundingMode;",
# java.io
"BufferedReader": "import java.io.BufferedReader;",
"BufferedWriter": "import java.io.BufferedWriter;",
"ByteArrayInputStream": "import java.io.ByteArrayInputStream;",
"ByteArrayOutputStream": "import java.io.ByteArrayOutputStream;",
"File": "import java.io.File;",
"FileReader": "import java.io.FileReader;",
"FileWriter": "import java.io.FileWriter;",
"IOException": "import java.io.IOException;",
"InputStream": "import java.io.InputStream;",
"OutputStream": "import java.io.OutputStream;",
"StringReader": "import java.io.StringReader;",
"StringWriter": "import java.io.StringWriter;",
# java.nio
"ByteBuffer": "import java.nio.ByteBuffer;",
"CharBuffer": "import java.nio.CharBuffer;",
"Files": "import java.nio.file.Files;",
"Path": "import java.nio.file.Path;",
"Paths": "import java.nio.file.Paths;",
"StandardCharsets": "import java.nio.charset.StandardCharsets;",
# java.time
"Duration": "import java.time.Duration;",
"Instant": "import java.time.Instant;",
"LocalDate": "import java.time.LocalDate;",
"LocalDateTime": "import java.time.LocalDateTime;",
"LocalTime": "import java.time.LocalTime;",
"ZoneId": "import java.time.ZoneId;",
"ZonedDateTime": "import java.time.ZonedDateTime;",
}
# Pre-compiled regex patterns for class name detection
_CLASS_PATTERNS: dict[str, re.Pattern[str]] = {
class_name: re.compile(rf"\b{class_name}\b") for class_name in JAVA_STDLIB_CLASSES
}
# Pattern to match existing import lines
_IMPORT_LINE_RE = re.compile(r"^import\s+(?:static\s+)?[\w.]+(?:\.\*)?;\s*$", re.MULTILINE)
def _has_import(code: str, class_name: str, import_stmt: str) -> bool:
"""Check if the code already has an import for the given class (exact or wildcard)."""
if import_stmt in code:
return True
# Check for wildcard import of the same package
package = import_stmt.split()[1].rsplit(".", 1)[0]
return f"import {package}.*;" in code
def ensure_java_stdlib_imports(code: str) -> str:
"""Add missing Java stdlib imports to generated test code.
Scans the code for usage of known stdlib classes and adds their import
statements if not already present. Inserts after the last existing import line.
"""
if not code:
return code
imports_to_add: list[str] = []
for class_name, import_stmt in JAVA_STDLIB_CLASSES.items():
if not _CLASS_PATTERNS[class_name].search(code):
continue
if _has_import(code, class_name, import_stmt):
continue
imports_to_add.append(import_stmt)
if not imports_to_add:
return code
# Find the last import line to insert after it
last_import_match = None
for match in _IMPORT_LINE_RE.finditer(code):
last_import_match = match
if last_import_match:
insert_pos = last_import_match.end()
import_block = "\n".join(imports_to_add)
return code[:insert_pos] + "\n" + import_block + code[insert_pos:]
# No existing imports — insert before the first non-blank, non-package line
lines = code.split("\n")
insert_idx = 0
for i, line in enumerate(lines):
stripped = line.strip()
if stripped.startswith("package ") or stripped == "":
insert_idx = i + 1
else:
break
import_block = "\n".join(imports_to_add)
lines.insert(insert_idx, import_block)
return "\n".join(lines)

View file

@ -365,6 +365,12 @@ def parse_and_validate_java_output(response_content: str) -> str:
code = pattern_res.group(1).strip()
# Add missing stdlib imports before validation — prevents valid tests
# from being removed just because the LLM forgot an import statement
from core.languages.java.postprocessing.add_missing_imports import ensure_java_stdlib_imports
code = ensure_java_stdlib_imports(code)
# Individual test validation: validate each @Test method separately,
# removing broken ones while keeping valid ones (mirrors Python's approach)
try:

View file

@ -0,0 +1,159 @@
from core.languages.java.postprocessing.add_missing_imports import ensure_java_stdlib_imports
class TestEnsureJavaStdlibImports:
def test_adds_missing_optional_import(self):
code = (
"import org.junit.jupiter.api.Test;\n"
"\n"
"public class FooTest {\n"
" @Test\n"
" void testOptional() {\n"
" Optional<String> opt = Optional.of(\"hello\");\n"
" }\n"
"}\n"
)
result = ensure_java_stdlib_imports(code)
assert "import java.util.Optional;" in result
def test_no_change_when_import_exists(self):
code = (
"import java.util.Optional;\n"
"import org.junit.jupiter.api.Test;\n"
"\n"
"public class FooTest {\n"
" @Test\n"
" void testOptional() {\n"
" Optional<String> opt = Optional.of(\"hello\");\n"
" }\n"
"}\n"
)
result = ensure_java_stdlib_imports(code)
assert result == code
def test_no_change_when_wildcard_import_exists(self):
code = (
"import java.util.*;\n"
"import org.junit.jupiter.api.Test;\n"
"\n"
"public class FooTest {\n"
" @Test\n"
" void testOptional() {\n"
" Optional<String> opt = Optional.of(\"hello\");\n"
" }\n"
"}\n"
)
result = ensure_java_stdlib_imports(code)
# Should not add java.util.Optional since java.util.* covers it
assert result == code
def test_adds_multiple_missing_imports(self):
code = (
"import org.junit.jupiter.api.Test;\n"
"\n"
"public class FooTest {\n"
" @Test\n"
" void testCollections() {\n"
" List<String> list = new ArrayList<>();\n"
" Map<String, Integer> map = new HashMap<>();\n"
" BigDecimal bd = new BigDecimal(\"1.0\");\n"
" }\n"
"}\n"
)
result = ensure_java_stdlib_imports(code)
assert "import java.util.List;" in result
assert "import java.util.ArrayList;" in result
assert "import java.util.Map;" in result
assert "import java.util.HashMap;" in result
assert "import java.math.BigDecimal;" in result
def test_empty_code(self):
assert ensure_java_stdlib_imports("") == ""
def test_no_stdlib_classes_used(self):
code = (
"import org.junit.jupiter.api.Test;\n"
"\n"
"public class FooTest {\n"
" @Test\n"
" void testSimple() {\n"
" int x = 42;\n"
" }\n"
"}\n"
)
result = ensure_java_stdlib_imports(code)
assert result == code
def test_inserts_after_last_import(self):
code = (
"import org.junit.jupiter.api.Test;\n"
"import static org.junit.jupiter.api.Assertions.*;\n"
"\n"
"public class FooTest {\n"
" @Test\n"
" void test() {\n"
" Arrays.sort(new int[]{3, 1, 2});\n"
" }\n"
"}\n"
)
result = ensure_java_stdlib_imports(code)
# The import should appear after the last import line
lines = result.split("\n")
import_idx = None
for i, line in enumerate(lines):
if "import java.util.Arrays;" in line:
import_idx = i
break
assert import_idx is not None
# Should be after the static import (line index 1), with a newline separator
assert import_idx > 1
def test_stream_collectors_import(self):
code = (
"import org.junit.jupiter.api.Test;\n"
"\n"
"public class FooTest {\n"
" @Test\n"
" void test() {\n"
" Stream.of(1, 2, 3).collect(Collectors.toList());\n"
" }\n"
"}\n"
)
result = ensure_java_stdlib_imports(code)
assert "import java.util.stream.Stream;" in result
assert "import java.util.stream.Collectors;" in result
def test_concurrent_imports(self):
code = (
"import org.junit.jupiter.api.Test;\n"
"\n"
"public class FooTest {\n"
" @Test\n"
" void test() {\n"
" ConcurrentHashMap<String, Integer> map = new ConcurrentHashMap<>();\n"
" AtomicInteger counter = new AtomicInteger(0);\n"
" }\n"
"}\n"
)
result = ensure_java_stdlib_imports(code)
assert "import java.util.concurrent.ConcurrentHashMap;" in result
assert "import java.util.concurrent.atomic.AtomicInteger;" in result
def test_code_with_package_and_no_imports(self):
code = (
"package com.example;\n"
"\n"
"public class FooTest {\n"
" void test() {\n"
" List<String> list = new ArrayList<>();\n"
" }\n"
"}\n"
)
result = ensure_java_stdlib_imports(code)
assert "import java.util.List;" in result
assert "import java.util.ArrayList;" in result
# Imports should be after the package line
lines = result.split("\n")
pkg_idx = next(i for i, l in enumerate(lines) if l.startswith("package"))
import_idx = next(i for i, l in enumerate(lines) if "import java.util" in l)
assert import_idx > pkg_idx