revert context to single for single file

This commit is contained in:
ali 2025-09-25 03:48:35 +03:00
parent d1cfea39a2
commit b2ad6a52ff
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
2 changed files with 19 additions and 10 deletions

View file

@ -83,7 +83,6 @@ def print_message_delta(delta, color_prefix_by_role=color_prefix_by_role) -> Non
async def generate_regression_tests_from_function(
ctx: BaseTestGenContext,
user_id: str,
function_code: str, # Python function to test, as a string
function_name: str, # the function to test
python_version: tuple[int, int, int], # Python version to use for ast parsing
unit_test_package: str = "pytest", # unit testing package; use the name as it appears in the import statement
@ -158,7 +157,7 @@ To help unit test the function above, list diverse scenarios that the function s
"content": execute_user_prompt.format(
unit_test_package=unit_test_package,
function_name=function_name,
function_code=function_code,
function_code=ctx.data.source_code_being_tested,
package_comment=package_comment,
),
}
@ -351,7 +350,7 @@ async def testgen(
try:
print("/testgen: Generating tests...")
debug_log_sensitive_data(f"Generating tests for function {data.function_to_optimize.function_name}")
debug_log_sensitive_data(f"Source code being tested: {data.source_code_being_tested}")
debug_log_sensitive_data(f"Source code being tested: {ctx.data.source_code_being_tested}")
max_tries = 2
count = 0
generated_test_source = ""
@ -364,7 +363,6 @@ async def testgen(
og_generated_test_source = await generate_regression_tests_from_function(
ctx=ctx,
user_id=request.user,
function_code=data.source_code_being_tested,
function_name=data.function_to_optimize.function_name,
unit_test_package=data.test_framework,
approx_min_cases_to_cover=10,
@ -391,7 +389,7 @@ async def testgen(
f"The generated test code is not valid Python code. \n"
f" Original generated test code: {og_generated_test_source} \n"
f" Modified generated test code: {generated_test_source}"
f" code being tested: {data.source_code_being_tested}"
f" code being tested: {ctx.data.source_code_being_tested}"
)
logging.exception(msg)
await asyncio.sleep(0.5)

View file

@ -40,9 +40,21 @@ class BaseTestGenContext:
@staticmethod
def get_dynamic_context(ctx_data: TestGenContextData | None) -> "BaseTestGenContext":
if is_multi_context(ctx_data.source_code_being_tested):
source = ctx_data.source_code_being_tested
if is_multi_context(source):
file_to_code = split_markdown_code(source)
files = list(file_to_code.keys())
if len(files) == 1:
file_name = files[0]
return SingleTestGenContext(
# create a context data without markdown
TestGenContextData(
source_code_being_tested=file_to_code.get(file_name),
function_name=ctx_data.function_name
)
)
return MultiTestGenContext(ctx_data)
return MultiSingleTestGenContext(ctx_data)
return SingleTestGenContext(ctx_data)
def validate_python_module(self, feature_version: tuple) -> None:
raise NotImplementedError
@ -53,11 +65,10 @@ class BaseTestGenContext:
def did_generate_ellipsis(self, generated_code: str, python_version: tuple) -> bool:
raise NotImplementedError
##########################################################################################
# MultiSingleTestGenContext #
# SingleTestGenContext #
##########################################################################################
class MultiSingleTestGenContext(BaseTestGenContext):
class SingleTestGenContext(BaseTestGenContext):
def __init__(self, ctx_data: TestGenContextData) -> None:
super().__init__(ctx_data)