tests and fixes for context extraction (handle generator functions and complext types)

This commit is contained in:
ali 2026-01-29 14:46:10 +02:00
parent 20cd0bd239
commit 1f184cbc52
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
4 changed files with 1881 additions and 16 deletions

View file

@ -723,6 +723,7 @@ class JavaScriptSupport:
1. Function parameters
2. Function return type
3. Class fields (if the function is a class method)
4. Types referenced within other type definitions (recursive)
Then looks up these type definitions in:
1. The same file
@ -762,21 +763,46 @@ class JavaScriptSupport:
# Track which types we've found (avoid duplicates)
found_type_names: set[str] = set()
# First, look for types defined in the same file
for type_name in type_names:
if type_name in same_file_type_map and type_name not in found_type_names:
found_definitions.append(same_file_type_map[type_name])
found_type_names.add(type_name)
# Recursively find types - including types referenced within type definitions
types_to_find = set(type_names)
processed_types: set[str] = set()
max_iterations = 10 # Prevent infinite loops
# For types not found in same file, look in imported files
remaining_types = type_names - found_type_names
if remaining_types:
imported_definitions = self._find_imported_type_definitions(
remaining_types, imports, module_root, function.file_path
)
for defn in imported_definitions:
found_definitions.append(defn)
found_type_names.add(defn.name)
for _ in range(max_iterations):
if not types_to_find:
break
new_types_to_find: set[str] = set()
types_not_in_same_file: set[str] = set()
for type_name in types_to_find:
if type_name in processed_types:
continue
processed_types.add(type_name)
# Look in same file first
if type_name in same_file_type_map and type_name not in found_type_names:
defn = same_file_type_map[type_name]
found_definitions.append(defn)
found_type_names.add(type_name)
# Extract types referenced in this type definition
referenced_types = self._extract_types_from_definition(defn.source_code, analyzer)
new_types_to_find.update(referenced_types - found_type_names - processed_types)
elif type_name not in same_file_type_map and type_name not in found_type_names:
# Type not found in same file, needs to be looked up in imports
types_not_in_same_file.add(type_name)
# For types not found in same file, look in imported files
if types_not_in_same_file:
imported_definitions = self._find_imported_type_definitions(
types_not_in_same_file, imports, module_root, function.file_path
)
for defn in imported_definitions:
if defn.name not in found_type_names:
found_definitions.append(defn)
found_type_names.add(defn.name)
types_to_find = new_types_to_find
if not found_definitions:
return "", found_type_names
@ -799,6 +825,48 @@ class JavaScriptSupport:
return "\n\n".join(type_def_parts), found_type_names
def _extract_types_from_definition(self, type_source: str, analyzer: TreeSitterAnalyzer) -> set[str]:
"""Extract type names referenced in a type definition's source code.
Args:
type_source: Source code of the type definition.
analyzer: TreeSitterAnalyzer for parsing.
Returns:
Set of type names found in the definition.
"""
# Parse the type definition and find type identifiers
source_bytes = type_source.encode("utf8")
tree = analyzer.parse(source_bytes)
type_names: set[str] = set()
def walk_for_types(node):
# Look for type_identifier nodes (user-defined types)
if node.type == "type_identifier":
type_name = source_bytes[node.start_byte : node.end_byte].decode("utf8")
# Skip primitive types
if type_name not in (
"number",
"string",
"boolean",
"void",
"null",
"undefined",
"any",
"never",
"unknown",
"object",
"symbol",
"bigint",
):
type_names.add(type_name)
for child in node.children:
walk_for_types(child)
walk_for_types(tree.root_node)
return type_names
def _find_imported_type_definitions(
self, type_names: set[str], imports: list[Any], module_root: Path, source_file_path: Path
) -> list[TypeDefinition]:

View file

@ -303,6 +303,12 @@ class TreeSitterAnalyzer:
name_node = node.child_by_field_name("name")
if name_node:
name = self.get_node_text(name_node, source_bytes)
else:
# Fallback: search for identifier child (some tree-sitter versions)
for child in node.children:
if child.type == "identifier":
name = self.get_node_text(child, source_bytes)
break
elif node.type == "method_definition":
is_method = True
name_node = node.child_by_field_name("name")
@ -1108,9 +1114,14 @@ class TreeSitterAnalyzer:
if parent.type == "export_specifier":
return
# Skip parameter names in function definitions
if parent.type in {"formal_parameters", "required_parameter"}:
# Skip parameter names in function definitions (but NOT default values)
if parent.type == "formal_parameters":
return
if parent.type == "required_parameter":
# Only skip if this is the parameter name (pattern field), not the default value
if parent.child_by_field_name("pattern") == node:
return
# If it's the value field (default value), it's a reference - don't skip
# This is a reference
references.add(self.get_node_text(node, source_bytes))
@ -1133,6 +1144,10 @@ class TreeSitterAnalyzer:
"""
source_bytes = source.encode("utf8")
# Generator functions always implicitly return a Generator/Iterator
if function_node.is_generator:
return True
# For arrow functions with expression body, there's an implicit return
if function_node.is_arrow:
body_node = function_node.node.child_by_field_name("body")

View file

@ -0,0 +1,4 @@
{
"name": "test",
"type": "module"
}

File diff suppressed because it is too large Load diff