mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
docs: update rules and architecture for new language structure
- Add JS optimizer, normalizer, and support files to architecture tree - Update key entry points table for per-function optimization and test execution - Add protocol dispatch preference to language-patterns rules - Fix stale context path in optimization-patterns - Add explicit utf-8 encoding rule scoped to new/changed code
This commit is contained in:
parent
04a94f2b03
commit
341c622d40
6 changed files with 23 additions and 524 deletions
|
|
@ -1,5 +1,7 @@
|
|||
# Architecture
|
||||
|
||||
When adding, moving, or deleting source files, update this doc to match.
|
||||
|
||||
```
|
||||
codeflash/
|
||||
├── main.py # CLI entry point
|
||||
|
|
@ -15,9 +17,20 @@ codeflash/
|
|||
├── code_utils/ # Code parsing, git utilities
|
||||
├── models/ # Pydantic models and types
|
||||
├── languages/ # Multi-language support (Python, JavaScript/TypeScript)
|
||||
│ └── python/
|
||||
│ ├── function_optimizer.py # PythonFunctionOptimizer (Python-specific hooks)
|
||||
│ └── optimizer.py # Python module preparation & AST resolution
|
||||
│ ├── base.py # LanguageSupport protocol and shared data types
|
||||
│ ├── registry.py # Language registration and lookup by extension/enum
|
||||
│ ├── current.py # Current language singleton (set_current_language / current_language_support)
|
||||
│ ├── code_replacer.py # Language-agnostic code replacement
|
||||
│ ├── python/
|
||||
│ │ ├── support.py # PythonSupport (LanguageSupport implementation)
|
||||
│ │ ├── function_optimizer.py # PythonFunctionOptimizer subclass
|
||||
│ │ ├── optimizer.py # Python module preparation & AST resolution
|
||||
│ │ └── normalizer.py # Python code normalization for deduplication
|
||||
│ └── javascript/
|
||||
│ ├── support.py # JavaScriptSupport (LanguageSupport implementation)
|
||||
│ ├── function_optimizer.py # JavaScriptFunctionOptimizer subclass
|
||||
│ ├── optimizer.py # JS project root finding & module preparation
|
||||
│ └── normalizer.py # JS/TS code normalization for deduplication
|
||||
├── setup/ # Config schema, auto-detection, first-run experience
|
||||
├── picklepatch/ # Serialization/deserialization utilities
|
||||
├── tracing/ # Function call tracing
|
||||
|
|
@ -35,10 +48,10 @@ codeflash/
|
|||
|------|------------|
|
||||
| CLI arguments & commands | `cli_cmds/cli.py` |
|
||||
| Optimization orchestration | `optimization/optimizer.py` → `run()` |
|
||||
| Per-function optimization | `optimization/function_optimizer.py` (base), `languages/python/function_optimizer.py` (Python subclass) |
|
||||
| Per-function optimization | `optimization/function_optimizer.py` (base), `languages/python/function_optimizer.py`, `languages/javascript/function_optimizer.py` |
|
||||
| Function discovery | `discovery/functions_to_optimize.py` |
|
||||
| Context extraction | `languages/<lang>/context/code_context_extractor.py` |
|
||||
| Test execution | `verification/test_runner.py`, `verification/pytest_plugin.py` |
|
||||
| Test execution | `languages/<lang>/support.py` (`run_behavioral_tests`, etc.), `verification/pytest_plugin.py` |
|
||||
| Performance ranking | `benchmarking/function_ranker.py` |
|
||||
| Domain types | `models/models.py`, `models/function_types.py` |
|
||||
| Result handling | `either.py` (`Result`, `Success`, `Failure`, `is_successful`) |
|
||||
|
|
|
|||
|
|
@ -7,4 +7,5 @@
|
|||
- **Comments**: Minimal - only explain "why", not "what"
|
||||
- **Docstrings**: Do not add unless explicitly requested
|
||||
- **Naming**: NEVER use leading underscores (`_function_name`) - Python has no true private functions, use public names
|
||||
- **Paths**: Always use absolute paths, handle encoding explicitly (UTF-8)
|
||||
- **Paths**: Always use absolute paths
|
||||
- **Encoding**: Always pass `encoding="utf-8"` to `open()`, `read_text()`, `write_text()`, etc. in new or changed code — Windows defaults to `cp1252` which breaks on non-ASCII content. Don't flag pre-existing code that lacks it unless you're already modifying that line.
|
||||
|
|
|
|||
|
|
@ -9,4 +9,5 @@ paths:
|
|||
- Use `get_language_support(identifier)` from `languages/registry.py` to get a `LanguageSupport` instance — never import language classes directly
|
||||
- New language support classes must use the `@register_language` decorator to register with the extension and language registries
|
||||
- `languages/__init__.py` uses `__getattr__` for lazy imports to avoid circular dependencies — follow this pattern when adding new exports
|
||||
- `is_javascript()` returns `True` for both JavaScript and TypeScript
|
||||
- Prefer `LanguageSupport` protocol dispatch over `is_python()`/`is_javascript()` guards — remaining guards are being migrated to protocol methods
|
||||
- `is_javascript()` returns `True` for both JavaScript and TypeScript (still used in ~15 call sites pending migration)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ paths:
|
|||
- "codeflash/optimization/**/*.py"
|
||||
- "codeflash/verification/**/*.py"
|
||||
- "codeflash/benchmarking/**/*.py"
|
||||
- "codeflash/context/**/*.py"
|
||||
- "codeflash/languages/*/context/**/*.py"
|
||||
---
|
||||
|
||||
# Optimization Pipeline Patterns
|
||||
|
|
|
|||
|
|
@ -1,290 +0,0 @@
|
|||
"""JavaScript/TypeScript code normalizer using tree-sitter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.code_utils.normalizers.base import CodeNormalizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Node
|
||||
|
||||
|
||||
# TODO:{claude} move to language support directory to keep the directory structure clean
|
||||
class JavaScriptVariableNormalizer:
|
||||
"""Normalizes JavaScript/TypeScript code for duplicate detection using tree-sitter.
|
||||
|
||||
Normalizes local variable names while preserving function names, class names,
|
||||
parameters, and imported names.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.var_counter = 0
|
||||
self.var_mapping: dict[str, str] = {}
|
||||
self.preserved_names: set[str] = set()
|
||||
# Common JavaScript builtins
|
||||
self.builtins = {
|
||||
"console",
|
||||
"window",
|
||||
"document",
|
||||
"Math",
|
||||
"JSON",
|
||||
"Object",
|
||||
"Array",
|
||||
"String",
|
||||
"Number",
|
||||
"Boolean",
|
||||
"Date",
|
||||
"RegExp",
|
||||
"Error",
|
||||
"Promise",
|
||||
"Map",
|
||||
"Set",
|
||||
"WeakMap",
|
||||
"WeakSet",
|
||||
"Symbol",
|
||||
"Proxy",
|
||||
"Reflect",
|
||||
"undefined",
|
||||
"null",
|
||||
"NaN",
|
||||
"Infinity",
|
||||
"globalThis",
|
||||
"parseInt",
|
||||
"parseFloat",
|
||||
"isNaN",
|
||||
"isFinite",
|
||||
"eval",
|
||||
"setTimeout",
|
||||
"setInterval",
|
||||
"clearTimeout",
|
||||
"clearInterval",
|
||||
"fetch",
|
||||
"require",
|
||||
"module",
|
||||
"exports",
|
||||
"process",
|
||||
"__dirname",
|
||||
"__filename",
|
||||
"Buffer",
|
||||
}
|
||||
|
||||
def get_normalized_name(self, name: str) -> str:
|
||||
"""Get or create normalized name for a variable."""
|
||||
if name in self.builtins or name in self.preserved_names:
|
||||
return name
|
||||
if name not in self.var_mapping:
|
||||
self.var_mapping[name] = f"var_{self.var_counter}"
|
||||
self.var_counter += 1
|
||||
return self.var_mapping[name]
|
||||
|
||||
def collect_preserved_names(self, node: Node, source_code: bytes) -> None:
|
||||
"""Collect names that should be preserved (function names, class names, imports, params)."""
|
||||
# Function declarations and expressions - preserve the function name
|
||||
if node.type in ("function_declaration", "function_expression", "method_definition", "arrow_function"):
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
|
||||
# Preserve parameters
|
||||
params_node = node.child_by_field_name("parameters") or node.child_by_field_name("parameter")
|
||||
if params_node:
|
||||
self._collect_parameter_names(params_node, source_code)
|
||||
|
||||
# Class declarations
|
||||
elif node.type == "class_declaration":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
|
||||
|
||||
# Import declarations
|
||||
elif node.type in ("import_statement", "import_declaration"):
|
||||
for child in node.children:
|
||||
if child.type == "import_clause":
|
||||
self._collect_import_names(child, source_code)
|
||||
elif child.type == "identifier":
|
||||
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
|
||||
|
||||
# Recurse
|
||||
for child in node.children:
|
||||
self.collect_preserved_names(child, source_code)
|
||||
|
||||
def _collect_parameter_names(self, node: Node, source_code: bytes) -> None:
|
||||
"""Collect parameter names from a parameters node."""
|
||||
for child in node.children:
|
||||
if child.type == "identifier":
|
||||
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
|
||||
elif child.type in ("required_parameter", "optional_parameter", "rest_parameter"):
|
||||
pattern_node = child.child_by_field_name("pattern")
|
||||
if pattern_node and pattern_node.type == "identifier":
|
||||
self.preserved_names.add(
|
||||
source_code[pattern_node.start_byte : pattern_node.end_byte].decode("utf-8")
|
||||
)
|
||||
# Recurse for nested patterns
|
||||
self._collect_parameter_names(child, source_code)
|
||||
|
||||
def _collect_import_names(self, node: Node, source_code: bytes) -> None:
|
||||
"""Collect imported names from import clause."""
|
||||
for child in node.children:
|
||||
if child.type == "identifier":
|
||||
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
|
||||
elif child.type == "import_specifier":
|
||||
# Get the local name (alias or original)
|
||||
alias_node = child.child_by_field_name("alias")
|
||||
name_node = child.child_by_field_name("name")
|
||||
if alias_node:
|
||||
self.preserved_names.add(source_code[alias_node.start_byte : alias_node.end_byte].decode("utf-8"))
|
||||
elif name_node:
|
||||
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
|
||||
self._collect_import_names(child, source_code)
|
||||
|
||||
def normalize_tree(self, node: Node, source_code: bytes) -> str:
|
||||
"""Normalize the AST tree to a string representation for comparison."""
|
||||
parts: list[str] = []
|
||||
self._normalize_node(node, source_code, parts)
|
||||
return " ".join(parts)
|
||||
|
||||
def _normalize_node(self, node: Node, source_code: bytes, parts: list[str]) -> None:
|
||||
"""Recursively normalize a node."""
|
||||
# Skip comments
|
||||
if node.type in ("comment", "line_comment", "block_comment"):
|
||||
return
|
||||
|
||||
# Handle identifiers - normalize variable names
|
||||
if node.type == "identifier":
|
||||
name = source_code[node.start_byte : node.end_byte].decode("utf-8")
|
||||
normalized = self.get_normalized_name(name)
|
||||
parts.append(normalized)
|
||||
return
|
||||
|
||||
# Handle type identifiers (TypeScript) - preserve as-is
|
||||
if node.type == "type_identifier":
|
||||
parts.append(source_code[node.start_byte : node.end_byte].decode("utf-8"))
|
||||
return
|
||||
|
||||
# Handle string literals - normalize to placeholder
|
||||
if node.type in ("string", "template_string", "string_fragment"):
|
||||
parts.append('"STR"')
|
||||
return
|
||||
|
||||
# Handle number literals - normalize to placeholder
|
||||
if node.type == "number":
|
||||
parts.append("NUM")
|
||||
return
|
||||
|
||||
# For leaf nodes, output the node type
|
||||
if len(node.children) == 0:
|
||||
text = source_code[node.start_byte : node.end_byte].decode("utf-8")
|
||||
parts.append(text)
|
||||
return
|
||||
|
||||
# Output node type for structure
|
||||
parts.append(f"({node.type}")
|
||||
|
||||
# Recurse into children
|
||||
for child in node.children:
|
||||
self._normalize_node(child, source_code, parts)
|
||||
|
||||
parts.append(")")
|
||||
|
||||
|
||||
def _basic_normalize(code: str) -> str:
|
||||
"""Basic normalization: remove comments and normalize whitespace."""
|
||||
# Remove single-line comments
|
||||
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
|
||||
# Remove multi-line comments
|
||||
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
|
||||
# Normalize whitespace
|
||||
return " ".join(code.split())
|
||||
|
||||
|
||||
class JavaScriptNormalizer(CodeNormalizer):
|
||||
"""JavaScript code normalizer using tree-sitter.
|
||||
|
||||
Normalizes JavaScript code by:
|
||||
- Replacing local variable names with canonical forms (var_0, var_1, etc.)
|
||||
- Preserving function names, class names, parameters, and imports
|
||||
- Removing comments
|
||||
- Normalizing string and number literals
|
||||
"""
|
||||
|
||||
@property
|
||||
def language(self) -> str:
|
||||
"""Return the language this normalizer handles."""
|
||||
return "javascript"
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> tuple[str, ...]:
|
||||
"""Return file extensions this normalizer can handle."""
|
||||
return (".js", ".jsx", ".mjs", ".cjs")
|
||||
|
||||
def _get_tree_sitter_language(self) -> str:
|
||||
"""Get the tree-sitter language identifier."""
|
||||
return "javascript"
|
||||
|
||||
def normalize(self, code: str) -> str:
|
||||
"""Normalize JavaScript code to a canonical form.
|
||||
|
||||
Args:
|
||||
code: JavaScript source code to normalize
|
||||
|
||||
Returns:
|
||||
Normalized representation of the code
|
||||
|
||||
"""
|
||||
try:
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
lang_map = {"javascript": TreeSitterLanguage.JAVASCRIPT, "typescript": TreeSitterLanguage.TYPESCRIPT}
|
||||
lang = lang_map.get(self._get_tree_sitter_language(), TreeSitterLanguage.JAVASCRIPT)
|
||||
analyzer = TreeSitterAnalyzer(lang)
|
||||
tree = analyzer.parse(code)
|
||||
|
||||
if tree.root_node.has_error:
|
||||
return _basic_normalize(code)
|
||||
|
||||
normalizer = JavaScriptVariableNormalizer()
|
||||
source_bytes = code.encode("utf-8")
|
||||
|
||||
# First pass: collect preserved names
|
||||
normalizer.collect_preserved_names(tree.root_node, source_bytes)
|
||||
|
||||
# Second pass: normalize and build representation
|
||||
return normalizer.normalize_tree(tree.root_node, source_bytes)
|
||||
except Exception:
|
||||
return _basic_normalize(code)
|
||||
|
||||
def normalize_for_hash(self, code: str) -> str:
|
||||
"""Normalize JavaScript code optimized for hashing.
|
||||
|
||||
For JavaScript, this is the same as normalize().
|
||||
|
||||
Args:
|
||||
code: JavaScript source code to normalize
|
||||
|
||||
Returns:
|
||||
Normalized representation suitable for hashing
|
||||
|
||||
"""
|
||||
return self.normalize(code)
|
||||
|
||||
|
||||
class TypeScriptNormalizer(JavaScriptNormalizer):
|
||||
"""TypeScript code normalizer using tree-sitter.
|
||||
|
||||
Inherits from JavaScriptNormalizer and overrides language-specific settings.
|
||||
"""
|
||||
|
||||
@property
|
||||
def language(self) -> str:
|
||||
"""Return the language this normalizer handles."""
|
||||
return "typescript"
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> tuple[str, ...]:
|
||||
"""Return file extensions this normalizer can handle."""
|
||||
return (".ts", ".tsx", ".mts", ".cts")
|
||||
|
||||
def _get_tree_sitter_language(self) -> str:
|
||||
"""Get the tree-sitter language identifier."""
|
||||
return "typescript"
|
||||
|
|
@ -1,226 +0,0 @@
|
|||
"""Python code normalizer using AST transformation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
|
||||
from codeflash.code_utils.normalizers.base import CodeNormalizer
|
||||
|
||||
|
||||
class VariableNormalizer(ast.NodeTransformer):
|
||||
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
|
||||
|
||||
Preserves function names, class names, parameters, built-ins, and imported names.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.var_counter = 0
|
||||
self.var_mapping: dict[str, str] = {}
|
||||
self.scope_stack: list[dict] = []
|
||||
self.builtins = set(dir(__builtins__))
|
||||
self.imports: set[str] = set()
|
||||
self.global_vars: set[str] = set()
|
||||
self.nonlocal_vars: set[str] = set()
|
||||
self.parameters: set[str] = set()
|
||||
|
||||
def enter_scope(self) -> None:
|
||||
"""Enter a new scope (function/class)."""
|
||||
self.scope_stack.append(
|
||||
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
|
||||
)
|
||||
|
||||
def exit_scope(self) -> None:
|
||||
"""Exit current scope and restore parent scope."""
|
||||
if self.scope_stack:
|
||||
scope = self.scope_stack.pop()
|
||||
self.var_mapping = scope["var_mapping"]
|
||||
self.var_counter = scope["var_counter"]
|
||||
self.parameters = scope["parameters"]
|
||||
|
||||
def get_normalized_name(self, name: str) -> str:
|
||||
"""Get or create normalized name for a variable."""
|
||||
if (
|
||||
name in self.builtins
|
||||
or name in self.imports
|
||||
or name in self.global_vars
|
||||
or name in self.nonlocal_vars
|
||||
or name in self.parameters
|
||||
):
|
||||
return name
|
||||
|
||||
if name not in self.var_mapping:
|
||||
self.var_mapping[name] = f"var_{self.var_counter}"
|
||||
self.var_counter += 1
|
||||
return self.var_mapping[name]
|
||||
|
||||
def visit_Import(self, node: ast.Import) -> ast.Import:
|
||||
"""Track imported names."""
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
self.imports.add(name.split(".")[0])
|
||||
return node
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
|
||||
"""Track imported names from modules."""
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
self.imports.add(name)
|
||||
return node
|
||||
|
||||
def visit_Global(self, node: ast.Global) -> ast.Global:
|
||||
"""Track global variable declarations."""
|
||||
self.global_vars.update(node.names)
|
||||
return node
|
||||
|
||||
def visit_Nonlocal(self, node: ast.Nonlocal) -> ast.Nonlocal:
|
||||
"""Track nonlocal variable declarations."""
|
||||
self.nonlocal_vars.update(node.names)
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
|
||||
"""Process function but keep function name and parameters unchanged."""
|
||||
self.enter_scope()
|
||||
|
||||
for arg in node.args.args:
|
||||
self.parameters.add(arg.arg)
|
||||
if node.args.vararg:
|
||||
self.parameters.add(node.args.vararg.arg)
|
||||
if node.args.kwarg:
|
||||
self.parameters.add(node.args.kwarg.arg)
|
||||
for arg in node.args.kwonlyargs:
|
||||
self.parameters.add(arg.arg)
|
||||
|
||||
node = self.generic_visit(node)
|
||||
self.exit_scope()
|
||||
return node
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
|
||||
"""Handle async functions same as regular functions."""
|
||||
return self.visit_FunctionDef(node) # type: ignore[return-value]
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
||||
"""Process class but keep class name unchanged."""
|
||||
self.enter_scope()
|
||||
node = self.generic_visit(node)
|
||||
self.exit_scope()
|
||||
return node
|
||||
|
||||
def visit_Name(self, node: ast.Name) -> ast.Name:
|
||||
"""Normalize variable names in Name nodes."""
|
||||
if isinstance(node.ctx, (ast.Store, ast.Del)):
|
||||
if (
|
||||
node.id not in self.builtins
|
||||
and node.id not in self.imports
|
||||
and node.id not in self.parameters
|
||||
and node.id not in self.global_vars
|
||||
and node.id not in self.nonlocal_vars
|
||||
):
|
||||
node.id = self.get_normalized_name(node.id)
|
||||
elif isinstance(node.ctx, ast.Load) and node.id in self.var_mapping:
|
||||
node.id = self.var_mapping[node.id]
|
||||
return node
|
||||
|
||||
def visit_ExceptHandler(self, node: ast.ExceptHandler) -> ast.ExceptHandler:
|
||||
"""Normalize exception variable names."""
|
||||
if node.name:
|
||||
node.name = self.get_normalized_name(node.name)
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_comprehension(self, node: ast.comprehension) -> ast.comprehension:
|
||||
"""Normalize comprehension target variables."""
|
||||
old_mapping = dict(self.var_mapping)
|
||||
old_counter = self.var_counter
|
||||
|
||||
node = self.generic_visit(node)
|
||||
|
||||
self.var_mapping = old_mapping
|
||||
self.var_counter = old_counter
|
||||
return node
|
||||
|
||||
def visit_For(self, node: ast.For) -> ast.For:
|
||||
"""Handle for loop target variables."""
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_With(self, node: ast.With) -> ast.With:
|
||||
"""Handle with statement as variables."""
|
||||
return self.generic_visit(node)
|
||||
|
||||
|
||||
def _remove_docstrings_from_ast(node: ast.AST) -> None:
|
||||
"""Remove docstrings from AST nodes."""
|
||||
node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)
|
||||
stack = [node]
|
||||
while stack:
|
||||
current_node = stack.pop()
|
||||
if isinstance(current_node, node_types):
|
||||
body = current_node.body
|
||||
if (
|
||||
body
|
||||
and isinstance(body[0], ast.Expr)
|
||||
and isinstance(body[0].value, ast.Constant)
|
||||
and isinstance(body[0].value.value, str)
|
||||
):
|
||||
current_node.body = body[1:]
|
||||
stack.extend([child for child in body if isinstance(child, node_types)])
|
||||
|
||||
|
||||
class PythonNormalizer(CodeNormalizer):
|
||||
"""Python code normalizer using AST transformation.
|
||||
|
||||
Normalizes Python code by:
|
||||
- Replacing local variable names with canonical forms (var_0, var_1, etc.)
|
||||
- Preserving function names, class names, parameters, and imports
|
||||
- Optionally removing docstrings
|
||||
"""
|
||||
|
||||
@property
|
||||
def language(self) -> str:
|
||||
"""Return the language this normalizer handles."""
|
||||
return "python"
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> tuple[str, ...]:
|
||||
"""Return file extensions this normalizer can handle."""
|
||||
return (".py", ".pyw", ".pyi")
|
||||
|
||||
def normalize(self, code: str, remove_docstrings: bool = True) -> str:
|
||||
"""Normalize Python code to a canonical form.
|
||||
|
||||
Args:
|
||||
code: Python source code to normalize
|
||||
remove_docstrings: Whether to remove docstrings
|
||||
|
||||
Returns:
|
||||
Normalized Python code as a string
|
||||
|
||||
"""
|
||||
tree = ast.parse(code)
|
||||
|
||||
if remove_docstrings:
|
||||
_remove_docstrings_from_ast(tree)
|
||||
|
||||
normalizer = VariableNormalizer()
|
||||
normalized_tree = normalizer.visit(tree)
|
||||
ast.fix_missing_locations(normalized_tree)
|
||||
|
||||
return ast.unparse(normalized_tree)
|
||||
|
||||
def normalize_for_hash(self, code: str) -> str:
|
||||
"""Normalize Python code optimized for hashing.
|
||||
|
||||
Returns AST dump which is faster than unparsing.
|
||||
|
||||
Args:
|
||||
code: Python source code to normalize
|
||||
|
||||
Returns:
|
||||
AST dump string suitable for hashing
|
||||
|
||||
"""
|
||||
tree = ast.parse(code)
|
||||
_remove_docstrings_from_ast(tree)
|
||||
|
||||
normalizer = VariableNormalizer()
|
||||
normalized_tree = normalizer.visit(tree)
|
||||
|
||||
return ast.dump(normalized_tree, annotate_fields=False, include_attributes=False)
|
||||
Loading…
Reference in a new issue