docs: update rules and architecture for new language structure

- Add JS optimizer, normalizer, and support files to architecture tree
- Update key entry points table for per-function optimization and test execution
- Add protocol dispatch preference to language-patterns rules
- Fix stale context path in optimization-patterns
- Add explicit utf-8 encoding rule scoped to new/changed code
This commit is contained in:
Kevin Turcios 2026-03-02 06:09:43 -05:00
parent 04a94f2b03
commit 341c622d40
6 changed files with 23 additions and 524 deletions

View file

@ -1,5 +1,7 @@
# Architecture
When adding, moving, or deleting source files, update this doc to match.
```
codeflash/
├── main.py # CLI entry point
@ -15,9 +17,20 @@ codeflash/
├── code_utils/ # Code parsing, git utilities
├── models/ # Pydantic models and types
├── languages/ # Multi-language support (Python, JavaScript/TypeScript)
│ └── python/
│ ├── function_optimizer.py # PythonFunctionOptimizer (Python-specific hooks)
│ └── optimizer.py # Python module preparation & AST resolution
│ ├── base.py # LanguageSupport protocol and shared data types
│ ├── registry.py # Language registration and lookup by extension/enum
│ ├── current.py # Current language singleton (set_current_language / current_language_support)
│ ├── code_replacer.py # Language-agnostic code replacement
│ ├── python/
│ │ ├── support.py # PythonSupport (LanguageSupport implementation)
│ │ ├── function_optimizer.py # PythonFunctionOptimizer subclass
│ │ ├── optimizer.py # Python module preparation & AST resolution
│ │ └── normalizer.py # Python code normalization for deduplication
│ └── javascript/
│ ├── support.py # JavaScriptSupport (LanguageSupport implementation)
│ ├── function_optimizer.py # JavaScriptFunctionOptimizer subclass
│ ├── optimizer.py # JS project root finding & module preparation
│ └── normalizer.py # JS/TS code normalization for deduplication
├── setup/ # Config schema, auto-detection, first-run experience
├── picklepatch/ # Serialization/deserialization utilities
├── tracing/ # Function call tracing
@ -35,10 +48,10 @@ codeflash/
|------|------------|
| CLI arguments & commands | `cli_cmds/cli.py` |
| Optimization orchestration | `optimization/optimizer.py``run()` |
| Per-function optimization | `optimization/function_optimizer.py` (base), `languages/python/function_optimizer.py` (Python subclass) |
| Per-function optimization | `optimization/function_optimizer.py` (base), `languages/python/function_optimizer.py`, `languages/javascript/function_optimizer.py` |
| Function discovery | `discovery/functions_to_optimize.py` |
| Context extraction | `languages/<lang>/context/code_context_extractor.py` |
| Test execution | `verification/test_runner.py`, `verification/pytest_plugin.py` |
| Test execution | `languages/<lang>/support.py` (`run_behavioral_tests`, etc.), `verification/pytest_plugin.py` |
| Performance ranking | `benchmarking/function_ranker.py` |
| Domain types | `models/models.py`, `models/function_types.py` |
| Result handling | `either.py` (`Result`, `Success`, `Failure`, `is_successful`) |

View file

@ -7,4 +7,5 @@
- **Comments**: Minimal - only explain "why", not "what"
- **Docstrings**: Do not add unless explicitly requested
- **Naming**: NEVER use leading underscores (`_function_name`) - Python has no true private functions, use public names
- **Paths**: Always use absolute paths, handle encoding explicitly (UTF-8)
- **Paths**: Always use absolute paths
- **Encoding**: Always pass `encoding="utf-8"` to `open()`, `read_text()`, `write_text()`, etc. in new or changed code — Windows defaults to `cp1252` which breaks on non-ASCII content. Don't flag pre-existing code that lacks it unless you're already modifying that line.

View file

@ -9,4 +9,5 @@ paths:
- Use `get_language_support(identifier)` from `languages/registry.py` to get a `LanguageSupport` instance — never import language classes directly
- New language support classes must use the `@register_language` decorator to register with the extension and language registries
- `languages/__init__.py` uses `__getattr__` for lazy imports to avoid circular dependencies — follow this pattern when adding new exports
- `is_javascript()` returns `True` for both JavaScript and TypeScript
- Prefer `LanguageSupport` protocol dispatch over `is_python()`/`is_javascript()` guards — remaining guards are being migrated to protocol methods
- `is_javascript()` returns `True` for both JavaScript and TypeScript (still used in ~15 call sites pending migration)

View file

@ -3,7 +3,7 @@ paths:
- "codeflash/optimization/**/*.py"
- "codeflash/verification/**/*.py"
- "codeflash/benchmarking/**/*.py"
- "codeflash/context/**/*.py"
- "codeflash/languages/*/context/**/*.py"
---
# Optimization Pipeline Patterns

View file

@ -1,290 +0,0 @@
"""JavaScript/TypeScript code normalizer using tree-sitter."""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from codeflash.code_utils.normalizers.base import CodeNormalizer
if TYPE_CHECKING:
from tree_sitter import Node
# TODO:{claude} move to language support directory to keep the directory structure clean
class JavaScriptVariableNormalizer:
"""Normalizes JavaScript/TypeScript code for duplicate detection using tree-sitter.
Normalizes local variable names while preserving function names, class names,
parameters, and imported names.
"""
def __init__(self) -> None:
self.var_counter = 0
self.var_mapping: dict[str, str] = {}
self.preserved_names: set[str] = set()
# Common JavaScript builtins
self.builtins = {
"console",
"window",
"document",
"Math",
"JSON",
"Object",
"Array",
"String",
"Number",
"Boolean",
"Date",
"RegExp",
"Error",
"Promise",
"Map",
"Set",
"WeakMap",
"WeakSet",
"Symbol",
"Proxy",
"Reflect",
"undefined",
"null",
"NaN",
"Infinity",
"globalThis",
"parseInt",
"parseFloat",
"isNaN",
"isFinite",
"eval",
"setTimeout",
"setInterval",
"clearTimeout",
"clearInterval",
"fetch",
"require",
"module",
"exports",
"process",
"__dirname",
"__filename",
"Buffer",
}
def get_normalized_name(self, name: str) -> str:
"""Get or create normalized name for a variable."""
if name in self.builtins or name in self.preserved_names:
return name
if name not in self.var_mapping:
self.var_mapping[name] = f"var_{self.var_counter}"
self.var_counter += 1
return self.var_mapping[name]
def collect_preserved_names(self, node: Node, source_code: bytes) -> None:
"""Collect names that should be preserved (function names, class names, imports, params)."""
# Function declarations and expressions - preserve the function name
if node.type in ("function_declaration", "function_expression", "method_definition", "arrow_function"):
name_node = node.child_by_field_name("name")
if name_node:
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
# Preserve parameters
params_node = node.child_by_field_name("parameters") or node.child_by_field_name("parameter")
if params_node:
self._collect_parameter_names(params_node, source_code)
# Class declarations
elif node.type == "class_declaration":
name_node = node.child_by_field_name("name")
if name_node:
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
# Import declarations
elif node.type in ("import_statement", "import_declaration"):
for child in node.children:
if child.type == "import_clause":
self._collect_import_names(child, source_code)
elif child.type == "identifier":
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
# Recurse
for child in node.children:
self.collect_preserved_names(child, source_code)
def _collect_parameter_names(self, node: Node, source_code: bytes) -> None:
"""Collect parameter names from a parameters node."""
for child in node.children:
if child.type == "identifier":
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
elif child.type in ("required_parameter", "optional_parameter", "rest_parameter"):
pattern_node = child.child_by_field_name("pattern")
if pattern_node and pattern_node.type == "identifier":
self.preserved_names.add(
source_code[pattern_node.start_byte : pattern_node.end_byte].decode("utf-8")
)
# Recurse for nested patterns
self._collect_parameter_names(child, source_code)
def _collect_import_names(self, node: Node, source_code: bytes) -> None:
"""Collect imported names from import clause."""
for child in node.children:
if child.type == "identifier":
self.preserved_names.add(source_code[child.start_byte : child.end_byte].decode("utf-8"))
elif child.type == "import_specifier":
# Get the local name (alias or original)
alias_node = child.child_by_field_name("alias")
name_node = child.child_by_field_name("name")
if alias_node:
self.preserved_names.add(source_code[alias_node.start_byte : alias_node.end_byte].decode("utf-8"))
elif name_node:
self.preserved_names.add(source_code[name_node.start_byte : name_node.end_byte].decode("utf-8"))
self._collect_import_names(child, source_code)
def normalize_tree(self, node: Node, source_code: bytes) -> str:
"""Normalize the AST tree to a string representation for comparison."""
parts: list[str] = []
self._normalize_node(node, source_code, parts)
return " ".join(parts)
def _normalize_node(self, node: Node, source_code: bytes, parts: list[str]) -> None:
"""Recursively normalize a node."""
# Skip comments
if node.type in ("comment", "line_comment", "block_comment"):
return
# Handle identifiers - normalize variable names
if node.type == "identifier":
name = source_code[node.start_byte : node.end_byte].decode("utf-8")
normalized = self.get_normalized_name(name)
parts.append(normalized)
return
# Handle type identifiers (TypeScript) - preserve as-is
if node.type == "type_identifier":
parts.append(source_code[node.start_byte : node.end_byte].decode("utf-8"))
return
# Handle string literals - normalize to placeholder
if node.type in ("string", "template_string", "string_fragment"):
parts.append('"STR"')
return
# Handle number literals - normalize to placeholder
if node.type == "number":
parts.append("NUM")
return
# For leaf nodes, output the node type
if len(node.children) == 0:
text = source_code[node.start_byte : node.end_byte].decode("utf-8")
parts.append(text)
return
# Output node type for structure
parts.append(f"({node.type}")
# Recurse into children
for child in node.children:
self._normalize_node(child, source_code, parts)
parts.append(")")
def _basic_normalize(code: str) -> str:
"""Basic normalization: remove comments and normalize whitespace."""
# Remove single-line comments
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
# Remove multi-line comments
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
# Normalize whitespace
return " ".join(code.split())
class JavaScriptNormalizer(CodeNormalizer):
"""JavaScript code normalizer using tree-sitter.
Normalizes JavaScript code by:
- Replacing local variable names with canonical forms (var_0, var_1, etc.)
- Preserving function names, class names, parameters, and imports
- Removing comments
- Normalizing string and number literals
"""
@property
def language(self) -> str:
"""Return the language this normalizer handles."""
return "javascript"
@property
def supported_extensions(self) -> tuple[str, ...]:
"""Return file extensions this normalizer can handle."""
return (".js", ".jsx", ".mjs", ".cjs")
def _get_tree_sitter_language(self) -> str:
"""Get the tree-sitter language identifier."""
return "javascript"
def normalize(self, code: str) -> str:
"""Normalize JavaScript code to a canonical form.
Args:
code: JavaScript source code to normalize
Returns:
Normalized representation of the code
"""
try:
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
lang_map = {"javascript": TreeSitterLanguage.JAVASCRIPT, "typescript": TreeSitterLanguage.TYPESCRIPT}
lang = lang_map.get(self._get_tree_sitter_language(), TreeSitterLanguage.JAVASCRIPT)
analyzer = TreeSitterAnalyzer(lang)
tree = analyzer.parse(code)
if tree.root_node.has_error:
return _basic_normalize(code)
normalizer = JavaScriptVariableNormalizer()
source_bytes = code.encode("utf-8")
# First pass: collect preserved names
normalizer.collect_preserved_names(tree.root_node, source_bytes)
# Second pass: normalize and build representation
return normalizer.normalize_tree(tree.root_node, source_bytes)
except Exception:
return _basic_normalize(code)
def normalize_for_hash(self, code: str) -> str:
"""Normalize JavaScript code optimized for hashing.
For JavaScript, this is the same as normalize().
Args:
code: JavaScript source code to normalize
Returns:
Normalized representation suitable for hashing
"""
return self.normalize(code)
class TypeScriptNormalizer(JavaScriptNormalizer):
"""TypeScript code normalizer using tree-sitter.
Inherits from JavaScriptNormalizer and overrides language-specific settings.
"""
@property
def language(self) -> str:
"""Return the language this normalizer handles."""
return "typescript"
@property
def supported_extensions(self) -> tuple[str, ...]:
"""Return file extensions this normalizer can handle."""
return (".ts", ".tsx", ".mts", ".cts")
def _get_tree_sitter_language(self) -> str:
"""Get the tree-sitter language identifier."""
return "typescript"

View file

@ -1,226 +0,0 @@
"""Python code normalizer using AST transformation."""
from __future__ import annotations
import ast
from codeflash.code_utils.normalizers.base import CodeNormalizer
class VariableNormalizer(ast.NodeTransformer):
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
Preserves function names, class names, parameters, built-ins, and imported names.
"""
def __init__(self) -> None:
self.var_counter = 0
self.var_mapping: dict[str, str] = {}
self.scope_stack: list[dict] = []
self.builtins = set(dir(__builtins__))
self.imports: set[str] = set()
self.global_vars: set[str] = set()
self.nonlocal_vars: set[str] = set()
self.parameters: set[str] = set()
def enter_scope(self) -> None:
"""Enter a new scope (function/class)."""
self.scope_stack.append(
{"var_mapping": dict(self.var_mapping), "var_counter": self.var_counter, "parameters": set(self.parameters)}
)
def exit_scope(self) -> None:
"""Exit current scope and restore parent scope."""
if self.scope_stack:
scope = self.scope_stack.pop()
self.var_mapping = scope["var_mapping"]
self.var_counter = scope["var_counter"]
self.parameters = scope["parameters"]
def get_normalized_name(self, name: str) -> str:
"""Get or create normalized name for a variable."""
if (
name in self.builtins
or name in self.imports
or name in self.global_vars
or name in self.nonlocal_vars
or name in self.parameters
):
return name
if name not in self.var_mapping:
self.var_mapping[name] = f"var_{self.var_counter}"
self.var_counter += 1
return self.var_mapping[name]
def visit_Import(self, node: ast.Import) -> ast.Import:
"""Track imported names."""
for alias in node.names:
name = alias.asname if alias.asname else alias.name
self.imports.add(name.split(".")[0])
return node
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
"""Track imported names from modules."""
for alias in node.names:
name = alias.asname if alias.asname else alias.name
self.imports.add(name)
return node
def visit_Global(self, node: ast.Global) -> ast.Global:
"""Track global variable declarations."""
self.global_vars.update(node.names)
return node
def visit_Nonlocal(self, node: ast.Nonlocal) -> ast.Nonlocal:
"""Track nonlocal variable declarations."""
self.nonlocal_vars.update(node.names)
return node
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
"""Process function but keep function name and parameters unchanged."""
self.enter_scope()
for arg in node.args.args:
self.parameters.add(arg.arg)
if node.args.vararg:
self.parameters.add(node.args.vararg.arg)
if node.args.kwarg:
self.parameters.add(node.args.kwarg.arg)
for arg in node.args.kwonlyargs:
self.parameters.add(arg.arg)
node = self.generic_visit(node)
self.exit_scope()
return node
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
"""Handle async functions same as regular functions."""
return self.visit_FunctionDef(node) # type: ignore[return-value]
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
"""Process class but keep class name unchanged."""
self.enter_scope()
node = self.generic_visit(node)
self.exit_scope()
return node
def visit_Name(self, node: ast.Name) -> ast.Name:
"""Normalize variable names in Name nodes."""
if isinstance(node.ctx, (ast.Store, ast.Del)):
if (
node.id not in self.builtins
and node.id not in self.imports
and node.id not in self.parameters
and node.id not in self.global_vars
and node.id not in self.nonlocal_vars
):
node.id = self.get_normalized_name(node.id)
elif isinstance(node.ctx, ast.Load) and node.id in self.var_mapping:
node.id = self.var_mapping[node.id]
return node
def visit_ExceptHandler(self, node: ast.ExceptHandler) -> ast.ExceptHandler:
"""Normalize exception variable names."""
if node.name:
node.name = self.get_normalized_name(node.name)
return self.generic_visit(node)
def visit_comprehension(self, node: ast.comprehension) -> ast.comprehension:
"""Normalize comprehension target variables."""
old_mapping = dict(self.var_mapping)
old_counter = self.var_counter
node = self.generic_visit(node)
self.var_mapping = old_mapping
self.var_counter = old_counter
return node
def visit_For(self, node: ast.For) -> ast.For:
"""Handle for loop target variables."""
return self.generic_visit(node)
def visit_With(self, node: ast.With) -> ast.With:
"""Handle with statement as variables."""
return self.generic_visit(node)
def _remove_docstrings_from_ast(node: ast.AST) -> None:
"""Remove docstrings from AST nodes."""
node_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)
stack = [node]
while stack:
current_node = stack.pop()
if isinstance(current_node, node_types):
body = current_node.body
if (
body
and isinstance(body[0], ast.Expr)
and isinstance(body[0].value, ast.Constant)
and isinstance(body[0].value.value, str)
):
current_node.body = body[1:]
stack.extend([child for child in body if isinstance(child, node_types)])
class PythonNormalizer(CodeNormalizer):
"""Python code normalizer using AST transformation.
Normalizes Python code by:
- Replacing local variable names with canonical forms (var_0, var_1, etc.)
- Preserving function names, class names, parameters, and imports
- Optionally removing docstrings
"""
@property
def language(self) -> str:
"""Return the language this normalizer handles."""
return "python"
@property
def supported_extensions(self) -> tuple[str, ...]:
"""Return file extensions this normalizer can handle."""
return (".py", ".pyw", ".pyi")
def normalize(self, code: str, remove_docstrings: bool = True) -> str:
"""Normalize Python code to a canonical form.
Args:
code: Python source code to normalize
remove_docstrings: Whether to remove docstrings
Returns:
Normalized Python code as a string
"""
tree = ast.parse(code)
if remove_docstrings:
_remove_docstrings_from_ast(tree)
normalizer = VariableNormalizer()
normalized_tree = normalizer.visit(tree)
ast.fix_missing_locations(normalized_tree)
return ast.unparse(normalized_tree)
def normalize_for_hash(self, code: str) -> str:
"""Normalize Python code optimized for hashing.
Returns AST dump which is faster than unparsing.
Args:
code: Python source code to normalize
Returns:
AST dump string suitable for hashing
"""
tree = ast.parse(code)
_remove_docstrings_from_ast(tree)
normalizer = VariableNormalizer()
normalized_tree = normalizer.visit(tree)
return ast.dump(normalized_tree, annotate_fields=False, include_attributes=False)