mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
format_generated_code was called without the language parameter in process_review(), defaulting to "python". This created temp files with .py extension when formatting JS/TS code, causing prettier to fail. Trace IDs: 11e9745d, 1578f081, 7e8abab2 (73 affected logs total) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1470 lines
45 KiB
Python
1470 lines
45 KiB
Python
import argparse
|
|
import shutil
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from codeflash.code_utils.config_parser import parse_config_file
|
|
from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports
|
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
|
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
|
from codeflash.languages.function_optimizer import FunctionOptimizer
|
|
from codeflash.verification.verification_utils import TestConfig
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_dir():
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
yield Path(tmpdirname)
|
|
|
|
|
|
def test_remove_duplicate_imports():
|
|
"""Test that duplicate imports are removed when should_sort_imports is True."""
|
|
original_code = "import os\nimport os\n"
|
|
new_code = sort_imports(original_code)
|
|
assert new_code == "import os\n"
|
|
|
|
|
|
def test_remove_multiple_duplicate_imports():
|
|
"""Test that multiple duplicate imports are removed when should_sort_imports is True."""
|
|
original_code = "import sys\nimport os\nimport sys\n"
|
|
|
|
new_code = sort_imports(original_code)
|
|
assert new_code == "import os\nimport sys\n"
|
|
|
|
|
|
def test_sorting_imports():
|
|
"""Test that imports are sorted when should_sort_imports is True."""
|
|
original_code = "import sys\nimport unittest\nimport os\n"
|
|
|
|
new_code = sort_imports(original_code)
|
|
assert new_code == "import os\nimport sys\nimport unittest\n"
|
|
|
|
|
|
def test_sort_imports_without_formatting(temp_dir):
|
|
"""Test that imports are sorted when formatting is disabled and should_sort_imports is True."""
|
|
temp_file = temp_dir / "test_file.py"
|
|
temp_file.write_text("import sys\nimport unittest\nimport os\n")
|
|
|
|
new_code = format_code(formatter_cmds=["disabled"], path=temp_file)
|
|
assert new_code is not None
|
|
new_code = sort_imports(new_code)
|
|
assert new_code == "import os\nimport sys\nimport unittest\n"
|
|
|
|
|
|
def test_dedup_and_sort_imports_deduplicates():
|
|
original_code = """
|
|
import os
|
|
import sys
|
|
|
|
|
|
def foo():
|
|
return os.path.join(sys.path[0], 'bar')
|
|
"""
|
|
|
|
expected = """
|
|
import os
|
|
import sys
|
|
|
|
|
|
def foo():
|
|
return os.path.join(sys.path[0], 'bar')
|
|
"""
|
|
|
|
actual = sort_imports(original_code)
|
|
|
|
assert actual == expected
|
|
|
|
|
|
def test_dedup_and_sort_imports_sorts_and_deduplicates():
|
|
original_code = """
|
|
import os
|
|
import sys
|
|
import json
|
|
import os
|
|
|
|
|
|
def foo():
|
|
return os.path.join(sys.path[0], 'bar')
|
|
"""
|
|
|
|
expected = """
|
|
import json
|
|
import os
|
|
import sys
|
|
|
|
|
|
def foo():
|
|
return os.path.join(sys.path[0], 'bar')
|
|
"""
|
|
|
|
actual = sort_imports(original_code)
|
|
|
|
assert actual == expected
|
|
|
|
|
|
def test_formatter_cmds_non_existent(temp_dir):
|
|
"""Test that default formatter-cmds is empty list when it doesn't exist in the toml."""
|
|
config_data = """
|
|
[tool.codeflash]
|
|
module-root = "src"
|
|
tests-root = "tests"
|
|
test-framework = "pytest"
|
|
ignore-paths = []
|
|
"""
|
|
config_file = temp_dir / "config.toml"
|
|
config_file.write_text(config_data)
|
|
|
|
config, _ = parse_config_file(config_file)
|
|
# Default is now empty list - formatters are detected by project detector
|
|
assert config["formatter_cmds"] == []
|
|
|
|
try:
|
|
import black
|
|
except ImportError:
|
|
pytest.skip("black is not installed")
|
|
|
|
original_code = """
|
|
import os
|
|
import sys
|
|
def foo():
|
|
return os.path.join(sys.path[0], 'bar')"""
|
|
expected = """import os
|
|
import sys
|
|
|
|
|
|
def foo():
|
|
return os.path.join(sys.path[0], \"bar\")
|
|
"""
|
|
temp_file = temp_dir / "test_file.py"
|
|
temp_file.write_text(original_code)
|
|
|
|
actual = format_code(formatter_cmds=["black $file"], path=temp_file)
|
|
assert actual == expected
|
|
|
|
|
|
def test_formatter_black(temp_dir):
|
|
try:
|
|
import black
|
|
except ImportError:
|
|
pytest.skip("black is not installed")
|
|
original_code = """
|
|
import os
|
|
import sys
|
|
def foo():
|
|
return os.path.join(sys.path[0], 'bar')"""
|
|
expected = """import os
|
|
import sys
|
|
|
|
|
|
def foo():
|
|
return os.path.join(sys.path[0], \"bar\")
|
|
"""
|
|
temp_file = temp_dir / "test_file.py"
|
|
temp_file.write_text(original_code)
|
|
|
|
actual = format_code(formatter_cmds=["black $file"], path=temp_file)
|
|
assert actual == expected
|
|
|
|
|
|
def test_formatter_ruff(temp_dir):
|
|
try:
|
|
import ruff # type: ignore
|
|
except ImportError:
|
|
pytest.skip("ruff is not installed")
|
|
original_code = """
|
|
import os
|
|
import sys
|
|
def foo():
|
|
return os.path.join(sys.path[0], 'bar')"""
|
|
expected = """import os
|
|
import sys
|
|
|
|
|
|
def foo():
|
|
return os.path.join(sys.path[0], \"bar\")
|
|
"""
|
|
temp_file = temp_dir / "test_file.py"
|
|
temp_file.write_text(original_code)
|
|
|
|
actual = format_code(formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"], path=temp_file)
|
|
assert actual == expected
|
|
|
|
|
|
def test_formatter_error(tmp_path: Path):
|
|
original_code = """
|
|
import os
|
|
import sys
|
|
def foo():
|
|
return os.path.join(sys.path[0], 'bar')"""
|
|
temp_file = tmp_path / "test_formatter_error.py"
|
|
temp_file.write_text(original_code, encoding="utf-8")
|
|
try:
|
|
new_code = format_code(formatter_cmds=["exit 1"], path=temp_file, exit_on_failure=False)
|
|
assert new_code == original_code
|
|
except Exception as e:
|
|
assert False, f"Shouldn't throw an exception even if the formatter is not found: {e}"
|
|
|
|
|
|
def _run_formatting_test(source_code: str, should_content_change: bool, expected=None, optimized_function: str = ""):
|
|
try:
|
|
import ruff # type: ignore
|
|
except ImportError:
|
|
pytest.skip("ruff is not installed")
|
|
|
|
with tempfile.TemporaryDirectory() as test_dir_str:
|
|
test_dir = Path(test_dir_str)
|
|
source_file = test_dir / "source.py"
|
|
|
|
source_file.write_text(source_code)
|
|
original = source_code
|
|
target_path = test_dir / "target.py"
|
|
|
|
shutil.copy2(source_file, target_path)
|
|
|
|
function_to_optimize = FunctionToOptimize(function_name="process_data", parents=[], file_path=target_path)
|
|
|
|
test_cfg = TestConfig(
|
|
tests_root=test_dir, project_root_path=test_dir, test_framework="pytest", tests_project_rootdir=test_dir
|
|
)
|
|
|
|
args = argparse.Namespace(
|
|
disable_imports_sorting=False, formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"]
|
|
)
|
|
|
|
optimizer = FunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_cfg, args=args)
|
|
|
|
optimizer.reformat_code_and_helpers(
|
|
helper_functions=[],
|
|
path=target_path,
|
|
original_code=optimizer.function_to_optimize_source_code,
|
|
optimized_context=CodeStringsMarkdown(
|
|
code_strings=[CodeString(code=optimized_function, file_path=target_path.relative_to(test_dir))]
|
|
),
|
|
)
|
|
|
|
content = target_path.read_text(encoding="utf8")
|
|
|
|
if expected is not None:
|
|
assert content == expected, (
|
|
f"Expected content to be \n===========\n{expected}\n===========\nbut got\n===========\n{content}\n===========\n"
|
|
)
|
|
|
|
if should_content_change:
|
|
assert content != original, "Expected content to change for source.py"
|
|
else:
|
|
assert content == original, "Expected content to remain unchanged for source.py"
|
|
|
|
|
|
def test_formatting_file_with_many_diffs():
|
|
"""Test that files with many formatting errors are skipped (content unchanged)."""
|
|
source_code = """import os,sys,json,datetime,re
|
|
from collections import defaultdict,OrderedDict
|
|
import numpy as np,pandas as pd
|
|
|
|
class DataProcessor:
|
|
def __init__(self,config_path,data_path,output_path):
|
|
self.config_path=config_path
|
|
self.data_path=data_path
|
|
self.output_path=output_path
|
|
self.config={}
|
|
self.data=[]
|
|
self.results={}
|
|
|
|
def load_config(self):
|
|
with open(self.config_path,'r') as f:
|
|
self.config=json.load(f)
|
|
if 'required_fields' not in self.config:self.config['required_fields']=[]
|
|
if 'optional_fields' not in self.config:self.config['optional_fields']=[]
|
|
return self.config
|
|
|
|
def validate_data(self,data):
|
|
errors=[]
|
|
for idx,record in enumerate(data):
|
|
if not isinstance(record,dict):
|
|
errors.append(f"Record {idx} is not a dictionary")
|
|
continue
|
|
for field in self.config.get('required_fields',[]):
|
|
if field not in record:
|
|
errors.append(f"Record {idx} missing required field: {field}")
|
|
elif record[field] is None or record[field]=='':
|
|
errors.append(f"Record {idx} has empty required field: {field}")
|
|
return errors
|
|
|
|
def process_data(self,data,filter_func=None,transform_func=None,sort_key=None):
|
|
if filter_func:data=[item for item in data if filter_func(item)]
|
|
if transform_func:data=[transform_func(item) for item in data]
|
|
if sort_key:data=sorted(data,key=sort_key)
|
|
aggregated_data=defaultdict(list)
|
|
for item in data:
|
|
category=item.get('category','unknown')
|
|
aggregated_data[category].append(item)
|
|
final_results={}
|
|
for category,items in aggregated_data.items():
|
|
total_value=sum(item.get('value',0) for item in items)
|
|
avg_value=total_value/len(items) if items else 0
|
|
final_results[category]={'count':len(items),'total':total_value,'average':avg_value,'items':items}
|
|
return final_results
|
|
|
|
def save_results(self,results):
|
|
with open(self.output_path,'w') as f:
|
|
json.dump(results,f,indent=2,default=str)
|
|
print(f"Results saved to {self.output_path}")
|
|
|
|
def run_pipeline(self):
|
|
try:
|
|
config=self.load_config()
|
|
with open(self.data_path,'r') as f:
|
|
raw_data=json.load(f)
|
|
validation_errors=self.validate_data(raw_data)
|
|
if validation_errors:
|
|
print("Validation errors found:")
|
|
for error in validation_errors:print(f" - {error}")
|
|
return False
|
|
processed_results=self.process_data(raw_data,filter_func=lambda x:x.get('active',True),transform_func=lambda x:{**x,'processed_at':datetime.datetime.now().isoformat()},sort_key=lambda x:x.get('name',''))
|
|
self.save_results(processed_results)
|
|
return True
|
|
except Exception as e:
|
|
print(f"Pipeline failed: {str(e)}")
|
|
return False
|
|
|
|
def main():
|
|
processor=DataProcessor('/path/to/config.json','/path/to/data.json','/path/to/output.json')
|
|
success=processor.run_pipeline()
|
|
if success:print("Pipeline completed successfully")
|
|
else:print("Pipeline failed")
|
|
|
|
if __name__=='__main__':main()
|
|
"""
|
|
_run_formatting_test(source_code, False)
|
|
|
|
|
|
def test_formatting_file_with_few_diffs():
|
|
"""Test that files with few formatting errors are formatted (content changed)."""
|
|
source_code = '''import json
|
|
from datetime import datetime
|
|
|
|
def process_data(data, config=None):
|
|
"""Process data with optional configuration."""
|
|
if not data:
|
|
return {"success": False, "error": "No data provided"}
|
|
|
|
if config is None:
|
|
config = {"filter_active": True}
|
|
|
|
# Minor formatting issues that should be fixed
|
|
result=[]
|
|
for item in data:
|
|
if config.get("filter_active") and not item.get("active",True):
|
|
continue
|
|
processed_item={
|
|
"id": item.get("id"),
|
|
"name": item.get("name",""),
|
|
"value": item.get("value",0),
|
|
"processed_at": datetime.now().isoformat()
|
|
}
|
|
result.append(processed_item)
|
|
|
|
return {"success": True, "data": result, "count": len(result)}
|
|
'''
|
|
_run_formatting_test(source_code, True)
|
|
|
|
|
|
def test_formatting_file_with_no_diffs():
|
|
"""Test that files with no formatting errors are unchanged."""
|
|
# this test assumes you use ruff defaults for formatting
|
|
source_code = '''from datetime import datetime
|
|
|
|
|
|
def process_data(data, config=None):
|
|
"""Process data with optional configuration."""
|
|
if not data:
|
|
return {"success": False, "error": "No data provided"}
|
|
|
|
if config is None:
|
|
config = {"filter_active": True}
|
|
|
|
result = []
|
|
for item in data:
|
|
if config.get("filter_active") and not item.get("active", True):
|
|
continue
|
|
|
|
processed_item = {
|
|
"id": item.get("id"),
|
|
"name": item.get("name", ""),
|
|
"value": item.get("value", 0),
|
|
"processed_at": datetime.now().isoformat(),
|
|
}
|
|
result.append(processed_item)
|
|
|
|
return {"success": True, "data": result, "count": len(result)}
|
|
'''
|
|
_run_formatting_test(source_code, False)
|
|
|
|
|
|
def test_formatting_extremely_messy_file():
|
|
"""Test that extremely messy files with 100+ potential changes are skipped."""
|
|
source_code = """import os,sys,json,datetime,re,collections,itertools,functools,operator
|
|
from pathlib import Path
|
|
from typing import Dict,List,Optional,Union,Any,Tuple
|
|
import numpy as np,pandas as pd,matplotlib.pyplot as plt
|
|
from dataclasses import dataclass,field
|
|
|
|
@dataclass
|
|
class Config:
|
|
input_path:str
|
|
output_path:str
|
|
batch_size:int=100
|
|
max_retries:int=3
|
|
timeout:float=30.0
|
|
debug:bool=False
|
|
filters:List[str]=field(default_factory=list)
|
|
transformations:Dict[str,Any]=field(default_factory=dict)
|
|
|
|
class DataProcessorAdvanced:
|
|
def __init__(self,config:Config):
|
|
self.config=config
|
|
self.data=[]
|
|
self.results={}
|
|
self.errors=[]
|
|
self.stats={'processed':0,'failed':0,'skipped':0}
|
|
|
|
def load_data(self,file_path:str)->List[Dict]:
|
|
try:
|
|
with open(file_path,'r',encoding='utf-8') as f:
|
|
if file_path.endswith('.json'):data=json.load(f)
|
|
elif file_path.endswith('.csv'):
|
|
import csv
|
|
reader=csv.DictReader(f)
|
|
data=[row for row in reader]
|
|
else:raise ValueError(f"Unsupported file format: {file_path}")
|
|
return data
|
|
except Exception as e:self.errors.append(f"Failed to load {file_path}: {str(e)}");return[]
|
|
|
|
def validate_record(self,record:Dict,schema:Dict)->Tuple[bool,List[str]]:
|
|
errors=[]
|
|
for field,rules in schema.items():
|
|
if rules.get('required',False) and field not in record:
|
|
errors.append(f"Missing required field: {field}")
|
|
elif field in record:
|
|
value=record[field]
|
|
if 'type' in rules and not isinstance(value,rules['type']):
|
|
errors.append(f"Field {field} has wrong type")
|
|
if 'min_length' in rules and isinstance(value,str) and len(value)<rules['min_length']:
|
|
errors.append(f"Field {field} too short")
|
|
if 'max_length' in rules and isinstance(value,str) and len(value)>rules['max_length']:
|
|
errors.append(f"Field {field} too long")
|
|
if 'min_value' in rules and isinstance(value,(int,float)) and value<rules['min_value']:
|
|
errors.append(f"Field {field} below minimum")
|
|
if 'max_value' in rules and isinstance(value,(int,float)) and value>rules['max_value']:
|
|
errors.append(f"Field {field} above maximum")
|
|
return len(errors)==0,errors
|
|
|
|
def apply_filters(self,data:List[Dict])->List[Dict]:
|
|
filtered_data=data
|
|
for filter_name in self.config.filters:
|
|
if filter_name=='active_only':filtered_data=[r for r in filtered_data if r.get('active',True)]
|
|
elif filter_name=='has_value':filtered_data=[r for r in filtered_data if r.get('value') is not None]
|
|
elif filter_name=='recent_only':
|
|
cutoff=datetime.datetime.now()-datetime.timedelta(days=30)
|
|
filtered_data=[r for r in filtered_data if datetime.datetime.fromisoformat(r.get('created_at','1970-01-01'))>cutoff]
|
|
return filtered_data
|
|
|
|
def apply_transformations(self,data:List[Dict])->List[Dict]:
|
|
for transform_name,params in self.config.transformations.items():
|
|
if transform_name=='add_timestamp':
|
|
for record in data:record['processed_at']=datetime.datetime.now().isoformat()
|
|
elif transform_name=='normalize_names':
|
|
for record in data:
|
|
if 'name' in record:record['name']=record['name'].strip().title()
|
|
elif transform_name=='calculate_derived':
|
|
for record in data:
|
|
if 'value' in record and 'multiplier' in params:
|
|
record['derived_value']=record['value']*params['multiplier']
|
|
return data
|
|
|
|
def process_batch(self,batch:List[Dict])->Dict[str,Any]:
|
|
try:
|
|
processed_batch=[]
|
|
for record in batch:
|
|
try:
|
|
processed_record=dict(record)
|
|
processed_record['batch_id']=len(self.results)
|
|
processed_record['processed_at']=datetime.datetime.now().isoformat()
|
|
processed_batch.append(processed_record)
|
|
self.stats['processed']+=1
|
|
except Exception as e:
|
|
self.errors.append(f"Failed to process record: {str(e)}")
|
|
self.stats['failed']+=1
|
|
return {'success':True,'data':processed_batch,'count':len(processed_batch)}
|
|
except Exception as e:
|
|
self.errors.append(f"Batch processing failed: {str(e)}")
|
|
return {'success':False,'error':str(e)}
|
|
|
|
def run_processing_pipeline(self)->bool:
|
|
try:
|
|
raw_data=self.load_data(self.config.input_path)
|
|
if not raw_data:return False
|
|
filtered_data=self.apply_filters(raw_data)
|
|
transformed_data=self.apply_transformations(filtered_data)
|
|
batches=[transformed_data[i:i+self.config.batch_size] for i in range(0,len(transformed_data),self.config.batch_size)]
|
|
all_results=[]
|
|
for i,batch in enumerate(batches):
|
|
if self.config.debug:print(f"Processing batch {i+1}/{len(batches)}")
|
|
result=self.process_batch(batch)
|
|
if result['success']:all_results.extend(result['data'])
|
|
else:self.stats['failed']+=len(batch)
|
|
with open(self.config.output_path,'w',encoding='utf-8') as f:
|
|
json.dump({'results':all_results,'stats':self.stats,'errors':self.errors},f,indent=2,default=str)
|
|
return True
|
|
except Exception as e:
|
|
self.errors.append(f"Pipeline failed: {str(e)}")
|
|
return False
|
|
|
|
def create_sample_config()->Config:
|
|
return Config(input_path='input.json',output_path='output.json',batch_size=50,max_retries=3,timeout=60.0,debug=True,filters=['active_only','has_value'],transformations={'add_timestamp':{},'normalize_names':{},'calculate_derived':{'multiplier':1.5}})
|
|
|
|
def main():
|
|
config=create_sample_config()
|
|
processor=DataProcessorAdvanced(config)
|
|
success=processor.run_processing_pipeline()
|
|
print(f"Processing {'completed' if success else 'failed'}")
|
|
print(f"Stats: {processor.stats}")
|
|
if processor.errors:
|
|
print("Errors encountered:")
|
|
for error in processor.errors:print(f" - {error}")
|
|
|
|
if __name__=='__main__':main()
|
|
"""
|
|
_run_formatting_test(source_code, False)
|
|
|
|
|
|
def test_formatting_edge_case_exactly_100_diffs():
|
|
"""Test behavior when exactly at the threshold of 100 changes."""
|
|
# Create a file with exactly 100 minor formatting issues
|
|
snippet = (
|
|
"""import json\n"""
|
|
"""
|
|
def func_{i}():
|
|
x=1;y=2;z=3
|
|
return x+y+z
|
|
"""
|
|
)
|
|
source_code = "".join([snippet.format(i=i) for i in range(100)])
|
|
_run_formatting_test(source_code, False)
|
|
|
|
|
|
def test_formatting_with_syntax_errors():
|
|
"""Test that files with syntax errors are handled gracefully."""
|
|
source_code = """import json
|
|
|
|
def process_data(data):
|
|
if not data:
|
|
return {"error": "No data"
|
|
# Missing closing brace above
|
|
|
|
result = []
|
|
for item in data
|
|
# Missing colon above
|
|
result.append(item)
|
|
|
|
return result
|
|
"""
|
|
_run_formatting_test(source_code, False)
|
|
|
|
|
|
def test_formatting_mixed_quotes_and_spacing():
|
|
"""Test files with mixed quote styles and inconsistent spacing."""
|
|
source_code = '''import json
|
|
from datetime import datetime
|
|
|
|
def process_mixed_style(data):
|
|
"""Process data with mixed formatting styles."""
|
|
config={'default_value':0,'required_fields':["id","name"],'optional_fields':["description","tags"]}
|
|
|
|
results=[]
|
|
for item in data:
|
|
if not isinstance(item,dict):continue
|
|
|
|
# Mixed quote styles
|
|
item_id=item.get("id")
|
|
item_name=item.get('name')
|
|
item_desc=item.get("description",'')
|
|
|
|
# Inconsistent spacing
|
|
processed={
|
|
'id':item_id,
|
|
"name": item_name,
|
|
'description':item_desc,
|
|
"processed_at":datetime.now().isoformat( ),
|
|
'status':'processed'
|
|
}
|
|
results.append(processed)
|
|
|
|
return {'data':results,"count":len(results)}
|
|
'''
|
|
_run_formatting_test(source_code, True)
|
|
|
|
|
|
def test_formatting_long_lines_and_imports():
|
|
"""Test files with long lines and import formatting issues."""
|
|
source_code = '''import os, sys, json, datetime, re, collections, itertools
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional
|
|
|
|
def process_with_long_lines(data, filter_func=lambda x: x.get('active', True) and x.get('value', 0) > 0, transform_func=lambda x: {**x, 'processed_at': datetime.datetime.now().isoformat(), 'status': 'processed'}):
|
|
"""Function with very long parameter line."""
|
|
return [transform_func(item) for item in data if filter_func(item) and isinstance(item, dict) and 'id' in item]
|
|
|
|
def another_function_with_long_line():
|
|
very_long_dictionary = {'key1': 'value1', 'key2': 'value2', 'key3': 'value3', 'key4': 'value4', 'key5': 'value5'}
|
|
return very_long_dictionary
|
|
'''
|
|
_run_formatting_test(source_code, True)
|
|
|
|
|
|
def test_formatting_class_with_methods():
|
|
"""Test formatting of classes with multiple methods and minor issues."""
|
|
source_code = """class DataProcessor:
|
|
def __init__(self, config):
|
|
self.config=config
|
|
self.data=[]
|
|
|
|
def load_data(self,file_path):
|
|
with open(file_path,'r') as f:
|
|
self.data=json.load(f)
|
|
return len(self.data)
|
|
|
|
def process(self):
|
|
result=[]
|
|
for item in self.data:
|
|
if item.get('active',True):
|
|
result.append({
|
|
'id':item['id'],
|
|
'processed':True
|
|
})
|
|
return result
|
|
"""
|
|
_run_formatting_test(source_code, True)
|
|
|
|
|
|
def test_formatting_with_complex_comprehensions():
|
|
"""Test files with complex list/dict comprehensions and formatting."""
|
|
source_code = """def complex_comprehensions(data):
|
|
# Various comprehension styles with formatting issues
|
|
result1=[item['value'] for item in data if item.get('active',True) and 'value' in item]
|
|
|
|
result2={item['id']:item['name'] for item in data if item.get('type')=='user'}
|
|
|
|
result3=[[x,y] for x in range(10) for y in range(5) if x*y>10]
|
|
|
|
# Nested comprehensions
|
|
nested=[[item for item in sublist if item%2==0] for sublist in data if isinstance(sublist,list)]
|
|
|
|
return {
|
|
'simple':result1,
|
|
'mapping':result2,
|
|
'complex':result3,
|
|
'nested':nested
|
|
}
|
|
"""
|
|
_run_formatting_test(source_code, True)
|
|
|
|
|
|
def test_formatting_with_decorators_and_async():
|
|
"""Test files with decorators and async functions."""
|
|
source_code = """import asyncio
|
|
from functools import wraps
|
|
|
|
def timer_decorator(func):
|
|
@wraps(func)
|
|
def wrapper(*args,**kwargs):
|
|
start=time.time()
|
|
result=func(*args,**kwargs)
|
|
end=time.time()
|
|
print(f"{func.__name__} took {end-start:.2f} seconds")
|
|
return result
|
|
return wrapper
|
|
|
|
@timer_decorator
|
|
async def async_process_data(data):
|
|
result=[]
|
|
for item in data:
|
|
await asyncio.sleep(0.01) # Simulate async work
|
|
processed_item={'id':item.get('id'),'processed':True}
|
|
result.append(processed_item)
|
|
return result
|
|
|
|
class AsyncProcessor:
|
|
@staticmethod
|
|
async def process_batch(batch):
|
|
return [{'id':item['id'],'status':'done'} for item in batch if 'id' in item]
|
|
"""
|
|
_run_formatting_test(source_code, True)
|
|
|
|
|
|
def test_formatting_threshold_configuration():
|
|
"""Test that the diff threshold can be configured (if supported)."""
|
|
# This test assumes the threshold might be configurable
|
|
source_code = """import json,os,sys
|
|
def func1():x=1;y=2;return x+y
|
|
def func2():a=1;b=2;return a+b
|
|
def func3():c=1;d=2;return c+d
|
|
"""
|
|
# Test with a file that has moderate formatting issues
|
|
_run_formatting_test(source_code, True, optimized_function="def func2():a=1;b=2;return a+b")
|
|
|
|
|
|
def test_formatting_empty_file():
|
|
"""Test formatting of empty or minimal files."""
|
|
source_code = """# Just a comment pass
|
|
"""
|
|
_run_formatting_test(source_code, False)
|
|
|
|
|
|
def test_formatting_with_docstrings():
|
|
"""Test files with various docstring formats."""
|
|
source_code = """def function_with_docstring( data):
|
|
'''
|
|
This is a function with a docstring.
|
|
|
|
Args:
|
|
data: Input data to process
|
|
|
|
Returns:
|
|
Processed data
|
|
'''
|
|
return [item for item in data if item.get('active',True)]
|
|
|
|
class ProcessorWithDocs:
|
|
'''A processor class with documentation.'''
|
|
|
|
def __init__(self,config):
|
|
'''Initialize with configuration.'''
|
|
self.config=config
|
|
|
|
def process(self,data):
|
|
'''Single quote docstring with formatting issues.'''
|
|
return{'result':[item for item in data if self._is_valid(item)]}
|
|
|
|
def _is_valid(self,item):
|
|
return isinstance(item,dict) and 'id' in item"""
|
|
expected = '''def function_with_docstring(data):
|
|
"""This is a function with a docstring.
|
|
|
|
Args:
|
|
data: Input data to process
|
|
|
|
Returns:
|
|
Processed data
|
|
|
|
"""
|
|
return [item for item in data if item.get("active", True)]
|
|
|
|
|
|
class ProcessorWithDocs:
|
|
"""A processor class with documentation."""
|
|
|
|
def __init__(self, config):
|
|
"""Initialize with configuration."""
|
|
self.config = config
|
|
|
|
def process(self, data):
|
|
"""Single quote docstring with formatting issues."""
|
|
return {"result": [item for item in data if self._is_valid(item)]}
|
|
|
|
def _is_valid(self, item):
|
|
return isinstance(item, dict) and "id" in item
|
|
'''
|
|
|
|
optimization_function = """def process(self,data):
|
|
'''Single quote docstring with formatting issues.'''
|
|
return{'result':[item for item in data if self._is_valid(item)]}"""
|
|
_run_formatting_test(source_code, True, optimized_function=optimization_function, expected=expected)
|
|
|
|
|
|
def test_sort_imports_skip_file():
|
|
"""Test that isort skips files with # isort:skip_file."""
|
|
code = """# isort:skip_file
|
|
|
|
import sys, os, json # isort will ignore this file completely"""
|
|
new_code = sort_imports(code)
|
|
assert new_code == code
|
|
|
|
|
|
# ==================== Tests for format_generated_code ====================
|
|
|
|
|
|
def test_format_generated_code_disabled():
|
|
"""Test that format_generated_code returns code with normalized newlines when formatter is disabled."""
|
|
test_code = """import os
|
|
|
|
|
|
def test_function():
|
|
pass
|
|
|
|
|
|
|
|
def another_function():
|
|
return 42"""
|
|
|
|
# Test with None formatter
|
|
result = format_generated_code(test_code, ["disabled"])
|
|
# Multiple newlines (3+) are reduced to 2
|
|
expected = """import os
|
|
|
|
def test_function():
|
|
pass
|
|
|
|
def another_function():
|
|
return 42"""
|
|
assert result == expected
|
|
|
|
# Test with ["disabled"] formatter
|
|
result = format_generated_code(test_code, ["disabled"])
|
|
assert result == expected
|
|
|
|
|
|
def test_format_generated_code_disabled_case_insensitive():
|
|
"""Test that format_generated_code handles 'Disabled', 'DISABLED' etc."""
|
|
test_code = """def test():
|
|
|
|
|
|
pass"""
|
|
|
|
# Multiple newlines are reduced to at most 2
|
|
expected = """def test():
|
|
|
|
pass"""
|
|
|
|
# Test various cases
|
|
assert format_generated_code(test_code, ["Disabled"]) == expected
|
|
assert format_generated_code(test_code, ["DISABLED"]) == expected
|
|
assert format_generated_code(test_code, ["DiSaBlEd"]) == expected
|
|
|
|
|
|
def test_format_generated_code_empty_string():
|
|
"""Test format_generated_code with empty string."""
|
|
result = format_generated_code("", ["disabled"])
|
|
assert result == ""
|
|
|
|
result = format_generated_code("", ["disabled"])
|
|
assert result == ""
|
|
|
|
|
|
def test_format_generated_code_with_black():
|
|
"""Test format_generated_code with black formatter."""
|
|
try:
|
|
import black
|
|
except ImportError:
|
|
pytest.skip("black is not installed")
|
|
|
|
test_code = """import os,sys
|
|
def test_function(x,y,z):
|
|
result=x+y+z
|
|
return result"""
|
|
|
|
expected = """import os, sys
|
|
|
|
|
|
def test_function(x, y, z):
|
|
result = x + y + z
|
|
return result
|
|
"""
|
|
|
|
result = format_generated_code(test_code, ["black $file"])
|
|
assert result == expected
|
|
|
|
|
|
def test_format_generated_code_with_inference():
|
|
"""Test format_generated_code with ruff formatter."""
|
|
try:
|
|
import ruff # type: ignore
|
|
except ImportError:
|
|
pytest.skip("ruff is not installed")
|
|
|
|
test_code = '''from time import sleep
|
|
from typing import List, Union
|
|
|
|
# imports
|
|
import pytest
|
|
from inference.core.models.base import Model
|
|
|
|
# --- Dummy classes to mimic the actual entities used in the function ---
|
|
|
|
class InferenceRequest:
|
|
def __init__(self, image, visualize_predictions=False, id=None):
|
|
self.image = image
|
|
self.visualize_predictions = visualize_predictions
|
|
self.id = id
|
|
|
|
def dict(self):
|
|
# Simulate the dict() method to unpack arguments for infer()
|
|
return {
|
|
"image": self.image,
|
|
"visualize_predictions": self.visualize_predictions,
|
|
"id": self.id
|
|
}
|
|
|
|
class InferenceResponse:
|
|
def __init__(self, instances=None):
|
|
self.instances = instances if instances is not None else []
|
|
self.time = None
|
|
self.visualization = None
|
|
self.inference_id = None
|
|
from inference.core.models.base import Model
|
|
|
|
# --- Unit tests for infer_from_request ---
|
|
|
|
@pytest.fixture
|
|
def model():
|
|
# Returns a fresh instance of Model for each test
|
|
return Model()
|
|
|
|
# --------------------------
|
|
# 1. Basic Test Cases
|
|
# --------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_visualization_true_but_no_draw_method(monkeypatch, model):
|
|
"""Test with visualize_predictions=True but draw_predictions raises exception."""
|
|
def broken_draw_predictions(request, response):
|
|
raise RuntimeError("Visualization failed")
|
|
monkeypatch.setattr(model, "draw_predictions", broken_draw_predictions)
|
|
req = InferenceRequest(image="img1", visualize_predictions=True)
|
|
with pytest.raises(RuntimeError):
|
|
model.infer_from_request(req)
|
|
|
|
|
|
|
|
|
|
|
|
def test_large_image_list_empty_instances(model):
|
|
"""Test with large image list and infer returns empty instances."""
|
|
# Patch the model.infer to return responses with empty instances
|
|
def empty_infer(image, **kwargs):
|
|
if isinstance(image, list):
|
|
return [InferenceResponse(instances=[]) for _ in image]
|
|
return [InferenceResponse(instances=[])]
|
|
model.infer = empty_infer
|
|
images = [f"img_{i}" for i in range(900)]
|
|
req = InferenceRequest(image=images)
|
|
codeflash_output = model.infer_from_request(req); resp = codeflash_output # 1.42ms -> 471μs (201% faster)
|
|
for r in resp:
|
|
pass
|
|
|
|
|
|
#------------------------------------------------
|
|
import time
|
|
from typing import Any, List, Tuple, Union
|
|
|
|
# imports
|
|
import pytest
|
|
from inference.core.models.base import Model
|
|
|
|
# --- Minimal stubs/mocks for dependencies ---
|
|
|
|
class DummyLogger:
|
|
def debug(self, msg):
|
|
pass
|
|
|
|
logger = DummyLogger()
|
|
|
|
def perf_counter():
|
|
# Use time.monotonic() for monotonic clock
|
|
return time.monotonic()
|
|
|
|
# --- Entities and types ---
|
|
|
|
class InferenceRequest:
|
|
def __init__(self, image, id=None, visualize_predictions=False, **kwargs):
|
|
self.image = image
|
|
self.id = id
|
|
self.visualize_predictions = visualize_predictions
|
|
self.kwargs = kwargs
|
|
def dict(self):
|
|
d = {"image": self.image}
|
|
d.update(self.kwargs)
|
|
return d
|
|
|
|
class InferenceResponse:
|
|
def __init__(self, result=None):
|
|
self.result = result
|
|
self.time = None
|
|
self.inference_id = None
|
|
self.visualization = None
|
|
from inference.core.models.base import Model
|
|
|
|
# --- Unit tests ---
|
|
|
|
# 1. BASIC TEST CASES
|
|
'''
|
|
expected = '''from time import sleep
|
|
from typing import List, Union
|
|
|
|
# imports
|
|
import pytest
|
|
from inference.core.models.base import Model
|
|
|
|
# --- Dummy classes to mimic the actual entities used in the function ---
|
|
|
|
|
|
class InferenceRequest:
|
|
def __init__(self, image, visualize_predictions=False, id=None):
|
|
self.image = image
|
|
self.visualize_predictions = visualize_predictions
|
|
self.id = id
|
|
|
|
def dict(self):
|
|
# Simulate the dict() method to unpack arguments for infer()
|
|
return {"image": self.image, "visualize_predictions": self.visualize_predictions, "id": self.id}
|
|
|
|
|
|
class InferenceResponse:
|
|
def __init__(self, instances=None):
|
|
self.instances = instances if instances is not None else []
|
|
self.time = None
|
|
self.visualization = None
|
|
self.inference_id = None
|
|
|
|
|
|
from inference.core.models.base import Model
|
|
|
|
# --- Unit tests for infer_from_request ---
|
|
|
|
|
|
@pytest.fixture
|
|
def model():
|
|
# Returns a fresh instance of Model for each test
|
|
return Model()
|
|
|
|
|
|
# --------------------------
|
|
# 1. Basic Test Cases
|
|
# --------------------------
|
|
|
|
|
|
def test_visualization_true_but_no_draw_method(monkeypatch, model):
|
|
"""Test with visualize_predictions=True but draw_predictions raises exception."""
|
|
|
|
def broken_draw_predictions(request, response):
|
|
raise RuntimeError("Visualization failed")
|
|
|
|
monkeypatch.setattr(model, "draw_predictions", broken_draw_predictions)
|
|
req = InferenceRequest(image="img1", visualize_predictions=True)
|
|
with pytest.raises(RuntimeError):
|
|
model.infer_from_request(req)
|
|
|
|
|
|
def test_large_image_list_empty_instances(model):
|
|
"""Test with large image list and infer returns empty instances."""
|
|
|
|
# Patch the model.infer to return responses with empty instances
|
|
def empty_infer(image, **kwargs):
|
|
if isinstance(image, list):
|
|
return [InferenceResponse(instances=[]) for _ in image]
|
|
return [InferenceResponse(instances=[])]
|
|
|
|
model.infer = empty_infer
|
|
images = [f"img_{i}" for i in range(900)]
|
|
req = InferenceRequest(image=images)
|
|
codeflash_output = model.infer_from_request(req)
|
|
resp = codeflash_output # 1.42ms -> 471μs (201% faster)
|
|
for r in resp:
|
|
pass
|
|
|
|
|
|
# ------------------------------------------------
|
|
import time
|
|
from typing import Any, List, Tuple, Union
|
|
|
|
# imports
|
|
import pytest
|
|
from inference.core.models.base import Model
|
|
|
|
# --- Minimal stubs/mocks for dependencies ---
|
|
|
|
|
|
class DummyLogger:
|
|
def debug(self, msg):
|
|
pass
|
|
|
|
|
|
logger = DummyLogger()
|
|
|
|
|
|
def perf_counter():
|
|
# Use time.monotonic() for monotonic clock
|
|
return time.monotonic()
|
|
|
|
|
|
# --- Entities and types ---
|
|
|
|
|
|
class InferenceRequest:
|
|
def __init__(self, image, id=None, visualize_predictions=False, **kwargs):
|
|
self.image = image
|
|
self.id = id
|
|
self.visualize_predictions = visualize_predictions
|
|
self.kwargs = kwargs
|
|
|
|
def dict(self):
|
|
d = {"image": self.image}
|
|
d.update(self.kwargs)
|
|
return d
|
|
|
|
|
|
class InferenceResponse:
|
|
def __init__(self, result=None):
|
|
self.result = result
|
|
self.time = None
|
|
self.inference_id = None
|
|
self.visualization = None
|
|
|
|
|
|
from inference.core.models.base import Model
|
|
|
|
# --- Unit tests ---
|
|
|
|
# 1. BASIC TEST CASES
|
|
'''
|
|
|
|
result = format_generated_code(test_code, ["ruff format $file"])
|
|
assert result == expected
|
|
|
|
|
|
def test_format_generated_code_with_ruff():
|
|
"""Test format_generated_code with ruff formatter."""
|
|
try:
|
|
import ruff # type: ignore
|
|
except ImportError:
|
|
pytest.skip("ruff is not installed")
|
|
|
|
test_code = """import os,sys
|
|
def test_function(x,y,z):
|
|
result=x+y+z
|
|
return result"""
|
|
|
|
expected = """import os, sys
|
|
|
|
|
|
def test_function(x, y, z):
|
|
result = x + y + z
|
|
return result
|
|
"""
|
|
|
|
result = format_generated_code(test_code, ["ruff format $file"])
|
|
assert result == expected
|
|
|
|
|
|
def test_format_generated_code_multiple_formatters():
|
|
"""Test format_generated_code with multiple formatter commands."""
|
|
try:
|
|
import ruff # type: ignore
|
|
except ImportError:
|
|
pytest.skip("ruff is not installed")
|
|
|
|
test_code = """import sys,os # wrong order
|
|
def test_function(x,y,z):
|
|
result=x+y+z
|
|
return result"""
|
|
|
|
# Ruff format will fix spacing
|
|
result = format_generated_code(test_code, ["ruff format $file"])
|
|
|
|
# Check that formatting happened
|
|
assert "result = x + y + z" in result # spacing should be fixed
|
|
assert "def test_function(x, y, z):" in result # parameters should have spaces
|
|
|
|
|
|
def test_format_generated_code_invalid_formatter():
|
|
"""Test format_generated_code with non-existent formatter command."""
|
|
test_code = """def test():
|
|
pass"""
|
|
|
|
# Should handle gracefully and return code with normalized newlines
|
|
result = format_generated_code(test_code, ["nonexistent_formatter $file"])
|
|
assert (
|
|
result
|
|
== """def test():
|
|
pass"""
|
|
)
|
|
|
|
|
|
def test_format_generated_code_syntax_error():
|
|
"""Test format_generated_code with Python code containing syntax errors."""
|
|
test_code = """def test(: # syntax error
|
|
pass"""
|
|
|
|
# Formatter should fail but function should handle it gracefully
|
|
result = format_generated_code(test_code, ["black $file"])
|
|
# Should return code with normalized newlines when formatting fails
|
|
assert (
|
|
result
|
|
== """def test(: # syntax error
|
|
pass"""
|
|
)
|
|
|
|
|
|
def test_format_generated_code_already_formatted():
|
|
"""Test format_generated_code with already well-formatted code."""
|
|
try:
|
|
import black
|
|
except ImportError:
|
|
pytest.skip("black is not installed")
|
|
|
|
test_code = """import os
|
|
import sys
|
|
|
|
|
|
def test_function(x, y, z):
|
|
result = x + y + z
|
|
return result
|
|
"""
|
|
|
|
# Code is already formatted, should return the same
|
|
result = format_generated_code(test_code, ["black $file"])
|
|
assert result == test_code
|
|
|
|
|
|
def test_format_generated_code_with_tabs():
|
|
"""Test format_generated_code with code containing tabs."""
|
|
try:
|
|
import black
|
|
except ImportError:
|
|
pytest.skip("black is not installed")
|
|
|
|
test_code = """def test():
|
|
\tif True:
|
|
\t\treturn 42
|
|
\treturn 0"""
|
|
|
|
# Black should convert tabs to spaces
|
|
result = format_generated_code(test_code, ["black $file"])
|
|
assert "\t" not in result # No tabs should remain
|
|
assert " " in result # Should have spaces
|
|
|
|
|
|
def test_format_generated_code_trailing_whitespace():
|
|
"""Test format_generated_code removes trailing whitespace."""
|
|
try:
|
|
import black
|
|
except ImportError:
|
|
pytest.skip("black is not installed")
|
|
|
|
test_code = """def test():
|
|
pass
|
|
"""
|
|
|
|
result = format_generated_code(test_code, ["black $file"])
|
|
lines = result.split("\n")
|
|
for line in lines:
|
|
assert line == line.rstrip(), f"Line has trailing whitespace: {line!r}"
|
|
|
|
|
|
def test_format_generated_code_preserves_comments():
|
|
"""Test format_generated_code preserves comments."""
|
|
try:
|
|
import black
|
|
except ImportError:
|
|
pytest.skip("black is not installed")
|
|
|
|
test_code = """# This is a module comment
|
|
import os # import os module
|
|
|
|
def test():
|
|
# This function does something
|
|
pass # TODO: implement this
|
|
"""
|
|
|
|
result = format_generated_code(test_code, ["black $file"])
|
|
assert "# This is a module comment" in result
|
|
assert "# import os module" in result
|
|
assert "# This function does something" in result
|
|
assert "# TODO: implement this" in result
|
|
|
|
|
|
def test_format_generated_code_with_docstrings():
|
|
"""Test format_generated_code handles docstrings correctly."""
|
|
try:
|
|
import black
|
|
except ImportError:
|
|
pytest.skip("black is not installed")
|
|
|
|
test_code = '''def test():
|
|
"""This is a docstring."""
|
|
pass
|
|
|
|
class TestClass:
|
|
"""
|
|
Multi-line
|
|
docstring
|
|
"""
|
|
def method(self):
|
|
\'\'\'Single quote docstring\'\'\'
|
|
pass'''
|
|
|
|
result = format_generated_code(test_code, ["black $file"])
|
|
assert '"""This is a docstring."""' in result
|
|
assert "Multi-line" in result
|
|
assert "docstring" in result
|
|
|
|
|
|
def test_format_generated_code_normalizes_multiple_newlines():
|
|
"""Test that multiple consecutive newlines are normalized to two."""
|
|
test_code = """import os
|
|
|
|
|
|
|
|
|
|
def func1():
|
|
pass
|
|
|
|
|
|
|
|
def func2():
|
|
pass"""
|
|
|
|
result = format_generated_code(test_code, ["disabled"])
|
|
# Should have at most two consecutive newlines
|
|
assert "\n\n\n" not in result
|
|
assert "import os\n\n" in result
|
|
assert "pass\n\n" in result
|
|
|
|
|
|
def test_format_generated_code_complex_code():
|
|
"""Test format_generated_code with complex real-world code."""
|
|
try:
|
|
import black
|
|
except ImportError:
|
|
pytest.skip("black is not installed")
|
|
|
|
test_code = """import unittest
|
|
from unittest.mock import patch,Mock,MagicMock
|
|
import os,sys
|
|
from typing import Dict,List,Optional
|
|
|
|
class TestComplexClass(unittest.TestCase):
|
|
def setUp(self):
|
|
self.config={'key1':'value1','key2':'value2'}
|
|
self.data=[{'id':1,'name':'test1'},{'id':2,'name':'test2'}]
|
|
|
|
def test_something(self):
|
|
result=process_data(self.data,lambda x:x['id']>0)
|
|
self.assertEqual(len(result),2)
|
|
|
|
@patch('module.function')
|
|
def test_with_mock(self,mock_func):
|
|
mock_func.return_value={'status':'ok'}
|
|
response=make_request()
|
|
self.assertEqual(response['status'],'ok')
|
|
|
|
def process_data(data:List[Dict],filter_func)->List[Dict]:
|
|
return [item for item in data if filter_func(item)]"""
|
|
|
|
result = format_generated_code(test_code, ["black $file"])
|
|
|
|
# Check that formatting was applied
|
|
assert "self.config = {" in result
|
|
assert "self.data = [" in result
|
|
assert "result = process_data" in result
|
|
assert "mock_func.return_value = {" in result
|
|
# Check imports are formatted
|
|
assert "from unittest.mock import " in result
|
|
assert "from typing import Dict, List, Optional" in result
|
|
|
|
|
|
def test_format_generated_code_unicode():
|
|
"""Test format_generated_code with Unicode characters."""
|
|
test_code = """def test():
|
|
message = "Hello, 世界! 🌍"
|
|
return message"""
|
|
|
|
result = format_generated_code(test_code, ["disabled"])
|
|
assert "Hello, 世界! 🌍" in result
|
|
|
|
|
|
def test_format_generated_code_uses_correct_extension_for_javascript():
|
|
"""Test that format_generated_code creates temp files with .js extension for JavaScript code."""
|
|
from unittest.mock import patch
|
|
|
|
js_code = """function test() {
|
|
return 42;
|
|
}"""
|
|
|
|
captured_paths = []
|
|
original_apply = format_code.__module__
|
|
|
|
with patch("codeflash.code_utils.formatter.apply_formatter_cmds") as mock_apply:
|
|
mock_apply.return_value = (Path("/tmp/temp.js"), js_code, False)
|
|
format_generated_code(js_code, ["npx prettier --write $file"], language="javascript")
|
|
# Verify the temp file path has .js extension
|
|
call_args = mock_apply.call_args
|
|
original_temp_path = call_args[0][1] # second positional arg is the path
|
|
assert original_temp_path.suffix == ".js", (
|
|
f"Expected .js extension for JavaScript, got {original_temp_path.suffix}"
|
|
)
|
|
|
|
|
|
def test_format_generated_code_uses_correct_extension_for_typescript():
|
|
"""Test that format_generated_code creates temp files with .ts extension for TypeScript code."""
|
|
from unittest.mock import patch
|
|
|
|
ts_code = """function test(): number {
|
|
return 42;
|
|
}"""
|
|
|
|
with patch("codeflash.code_utils.formatter.apply_formatter_cmds") as mock_apply:
|
|
mock_apply.return_value = (Path("/tmp/temp.ts"), ts_code, False)
|
|
format_generated_code(ts_code, ["npx prettier --write $file"], language="typescript")
|
|
call_args = mock_apply.call_args
|
|
original_temp_path = call_args[0][1]
|
|
assert original_temp_path.suffix == ".ts", (
|
|
f"Expected .ts extension for TypeScript, got {original_temp_path.suffix}"
|
|
)
|
|
|
|
|
|
def test_format_generated_code_defaults_to_py_extension():
|
|
"""Test that format_generated_code defaults to .py extension when no language specified."""
|
|
from unittest.mock import patch
|
|
|
|
py_code = """def test():
|
|
return 42"""
|
|
|
|
with patch("codeflash.code_utils.formatter.apply_formatter_cmds") as mock_apply:
|
|
mock_apply.return_value = (Path("/tmp/temp.py"), py_code, False)
|
|
format_generated_code(py_code, ["black $file"])
|
|
call_args = mock_apply.call_args
|
|
original_temp_path = call_args[0][1]
|
|
assert original_temp_path.suffix == ".py", f"Expected .py extension for Python, got {original_temp_path.suffix}"
|
|
|
|
|
|
def test_format_generated_code_f_strings():
|
|
"""Test format_generated_code with f-strings."""
|
|
try:
|
|
import black
|
|
except ImportError:
|
|
pytest.skip("black is not installed")
|
|
|
|
test_code = """def test(name,age):
|
|
return f"Hello {name}, you are {age} years old"
|
|
|
|
def test2():
|
|
x=10
|
|
y=20
|
|
return f"{x}+{y}={x+y}" """
|
|
|
|
result = format_generated_code(test_code, ["black $file"])
|
|
assert 'f"Hello {name}, you are {age} years old"' in result
|
|
assert "x = 10" in result
|
|
assert "y = 20" in result
|