mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
tests and fixes for context extraction (handle generator functions and complext types)
This commit is contained in:
parent
20cd0bd239
commit
1f184cbc52
4 changed files with 1881 additions and 16 deletions
|
|
@ -723,6 +723,7 @@ class JavaScriptSupport:
|
|||
1. Function parameters
|
||||
2. Function return type
|
||||
3. Class fields (if the function is a class method)
|
||||
4. Types referenced within other type definitions (recursive)
|
||||
|
||||
Then looks up these type definitions in:
|
||||
1. The same file
|
||||
|
|
@ -762,21 +763,46 @@ class JavaScriptSupport:
|
|||
# Track which types we've found (avoid duplicates)
|
||||
found_type_names: set[str] = set()
|
||||
|
||||
# First, look for types defined in the same file
|
||||
for type_name in type_names:
|
||||
if type_name in same_file_type_map and type_name not in found_type_names:
|
||||
found_definitions.append(same_file_type_map[type_name])
|
||||
found_type_names.add(type_name)
|
||||
# Recursively find types - including types referenced within type definitions
|
||||
types_to_find = set(type_names)
|
||||
processed_types: set[str] = set()
|
||||
max_iterations = 10 # Prevent infinite loops
|
||||
|
||||
# For types not found in same file, look in imported files
|
||||
remaining_types = type_names - found_type_names
|
||||
if remaining_types:
|
||||
imported_definitions = self._find_imported_type_definitions(
|
||||
remaining_types, imports, module_root, function.file_path
|
||||
)
|
||||
for defn in imported_definitions:
|
||||
found_definitions.append(defn)
|
||||
found_type_names.add(defn.name)
|
||||
for _ in range(max_iterations):
|
||||
if not types_to_find:
|
||||
break
|
||||
|
||||
new_types_to_find: set[str] = set()
|
||||
types_not_in_same_file: set[str] = set()
|
||||
|
||||
for type_name in types_to_find:
|
||||
if type_name in processed_types:
|
||||
continue
|
||||
processed_types.add(type_name)
|
||||
|
||||
# Look in same file first
|
||||
if type_name in same_file_type_map and type_name not in found_type_names:
|
||||
defn = same_file_type_map[type_name]
|
||||
found_definitions.append(defn)
|
||||
found_type_names.add(type_name)
|
||||
# Extract types referenced in this type definition
|
||||
referenced_types = self._extract_types_from_definition(defn.source_code, analyzer)
|
||||
new_types_to_find.update(referenced_types - found_type_names - processed_types)
|
||||
elif type_name not in same_file_type_map and type_name not in found_type_names:
|
||||
# Type not found in same file, needs to be looked up in imports
|
||||
types_not_in_same_file.add(type_name)
|
||||
|
||||
# For types not found in same file, look in imported files
|
||||
if types_not_in_same_file:
|
||||
imported_definitions = self._find_imported_type_definitions(
|
||||
types_not_in_same_file, imports, module_root, function.file_path
|
||||
)
|
||||
for defn in imported_definitions:
|
||||
if defn.name not in found_type_names:
|
||||
found_definitions.append(defn)
|
||||
found_type_names.add(defn.name)
|
||||
|
||||
types_to_find = new_types_to_find
|
||||
|
||||
if not found_definitions:
|
||||
return "", found_type_names
|
||||
|
|
@ -799,6 +825,48 @@ class JavaScriptSupport:
|
|||
|
||||
return "\n\n".join(type_def_parts), found_type_names
|
||||
|
||||
def _extract_types_from_definition(self, type_source: str, analyzer: TreeSitterAnalyzer) -> set[str]:
|
||||
"""Extract type names referenced in a type definition's source code.
|
||||
|
||||
Args:
|
||||
type_source: Source code of the type definition.
|
||||
analyzer: TreeSitterAnalyzer for parsing.
|
||||
|
||||
Returns:
|
||||
Set of type names found in the definition.
|
||||
|
||||
"""
|
||||
# Parse the type definition and find type identifiers
|
||||
source_bytes = type_source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
type_names: set[str] = set()
|
||||
|
||||
def walk_for_types(node):
|
||||
# Look for type_identifier nodes (user-defined types)
|
||||
if node.type == "type_identifier":
|
||||
type_name = source_bytes[node.start_byte : node.end_byte].decode("utf8")
|
||||
# Skip primitive types
|
||||
if type_name not in (
|
||||
"number",
|
||||
"string",
|
||||
"boolean",
|
||||
"void",
|
||||
"null",
|
||||
"undefined",
|
||||
"any",
|
||||
"never",
|
||||
"unknown",
|
||||
"object",
|
||||
"symbol",
|
||||
"bigint",
|
||||
):
|
||||
type_names.add(type_name)
|
||||
for child in node.children:
|
||||
walk_for_types(child)
|
||||
|
||||
walk_for_types(tree.root_node)
|
||||
return type_names
|
||||
|
||||
def _find_imported_type_definitions(
|
||||
self, type_names: set[str], imports: list[Any], module_root: Path, source_file_path: Path
|
||||
) -> list[TypeDefinition]:
|
||||
|
|
|
|||
|
|
@ -303,6 +303,12 @@ class TreeSitterAnalyzer:
|
|||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self.get_node_text(name_node, source_bytes)
|
||||
else:
|
||||
# Fallback: search for identifier child (some tree-sitter versions)
|
||||
for child in node.children:
|
||||
if child.type == "identifier":
|
||||
name = self.get_node_text(child, source_bytes)
|
||||
break
|
||||
elif node.type == "method_definition":
|
||||
is_method = True
|
||||
name_node = node.child_by_field_name("name")
|
||||
|
|
@ -1108,9 +1114,14 @@ class TreeSitterAnalyzer:
|
|||
if parent.type == "export_specifier":
|
||||
return
|
||||
|
||||
# Skip parameter names in function definitions
|
||||
if parent.type in {"formal_parameters", "required_parameter"}:
|
||||
# Skip parameter names in function definitions (but NOT default values)
|
||||
if parent.type == "formal_parameters":
|
||||
return
|
||||
if parent.type == "required_parameter":
|
||||
# Only skip if this is the parameter name (pattern field), not the default value
|
||||
if parent.child_by_field_name("pattern") == node:
|
||||
return
|
||||
# If it's the value field (default value), it's a reference - don't skip
|
||||
|
||||
# This is a reference
|
||||
references.add(self.get_node_text(node, source_bytes))
|
||||
|
|
@ -1133,6 +1144,10 @@ class TreeSitterAnalyzer:
|
|||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
|
||||
# Generator functions always implicitly return a Generator/Iterator
|
||||
if function_node.is_generator:
|
||||
return True
|
||||
|
||||
# For arrow functions with expression body, there's an implicit return
|
||||
if function_node.is_arrow:
|
||||
body_node = function_node.node.child_by_field_name("body")
|
||||
|
|
|
|||
4
tests/test_languages/fixtures/js_esm/package.json
Normal file
4
tests/test_languages/fixtures/js_esm/package.json
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
{
|
||||
"name": "test",
|
||||
"type": "module"
|
||||
}
|
||||
1778
tests/test_languages/test_code_context_extraction.py
Normal file
1778
tests/test_languages/test_code_context_extraction.py
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue