Test reproducing CF-137 issue.
This commit is contained in:
parent
0e33f211a0
commit
084359c792
1 changed files with 200 additions and 2 deletions
|
|
@ -48,7 +48,10 @@ print("Hello world")
|
|||
function_name: str = "NewClass.new_function"
|
||||
preexisting_functions: list[str] = ["NewClass.new_function"]
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code, [function_name], optim_code, preexisting_functions,
|
||||
original_code,
|
||||
[function_name],
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -336,7 +339,10 @@ def blab(st):
|
|||
print("Not cool")
|
||||
"""
|
||||
new_main_code: str = replace_functions_in_file(
|
||||
original_code_main, ["other_function"], optim_code, ["other_function", "yet_another_function", "blob"]
|
||||
original_code_main,
|
||||
["other_function"],
|
||||
optim_code,
|
||||
["other_function", "yet_another_function", "blob"],
|
||||
)
|
||||
assert new_main_code == expected_main
|
||||
|
||||
|
|
@ -344,3 +350,195 @@ print("Not cool")
|
|||
original_code_dependent, ["blob"], optim_code, []
|
||||
)
|
||||
assert new_dependent_code == expected_dependent
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement7():
|
||||
optim_code = """@register_deserializable
|
||||
class CacheSimilarityEvalConfig(BaseConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: Optional[str] = "distance",
|
||||
max_distance: Optional[float] = 1.0,
|
||||
positive: Optional[bool] = False,
|
||||
):
|
||||
self.strategy = strategy
|
||||
self.max_distance = max_distance
|
||||
self.positive = positive
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[dict[str, Any]]):
|
||||
if config is None:
|
||||
return CacheSimilarityEvalConfig()
|
||||
|
||||
strategy = config.get("strategy", "distance")
|
||||
max_distance = config.get("max_distance", 1.0)
|
||||
positive = config.get("positive", False)
|
||||
|
||||
return CacheSimilarityEvalConfig(strategy, max_distance, positive)
|
||||
"""
|
||||
|
||||
original_code = """from typing import Any, Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class CacheSimilarityEvalConfig(BaseConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: Optional[str] = "distance",
|
||||
max_distance: Optional[float] = 1.0,
|
||||
positive: Optional[bool] = False,
|
||||
):
|
||||
self.strategy = strategy
|
||||
self.max_distance = max_distance
|
||||
self.positive = positive
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[dict[str, Any]]):
|
||||
if config is None:
|
||||
return CacheSimilarityEvalConfig()
|
||||
else:
|
||||
return CacheSimilarityEvalConfig(
|
||||
strategy=config.get("strategy", "distance"),
|
||||
max_distance=config.get("max_distance", 1.0),
|
||||
positive=config.get("positive", False),
|
||||
)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class CacheInitConfig(BaseConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
similarity_threshold: Optional[float] = 0.8,
|
||||
auto_flush: Optional[int] = 20,
|
||||
):
|
||||
if similarity_threshold < 0 or similarity_threshold > 1:
|
||||
raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
|
||||
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.auto_flush = auto_flush
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[dict[str, Any]]):
|
||||
if config is None:
|
||||
return CacheInitConfig()
|
||||
else:
|
||||
return CacheInitConfig(
|
||||
similarity_threshold=config.get("similarity_threshold", 0.8),
|
||||
auto_flush=config.get("auto_flush", 20),
|
||||
)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class CacheConfig(BaseConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
similarity_eval_config: Optional[CacheSimilarityEvalConfig] = CacheSimilarityEvalConfig(),
|
||||
init_config: Optional[CacheInitConfig] = CacheInitConfig(),
|
||||
):
|
||||
self.similarity_eval_config = similarity_eval_config
|
||||
self.init_config = init_config
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[dict[str, Any]]):
|
||||
if config is None:
|
||||
return CacheConfig()
|
||||
else:
|
||||
return CacheConfig(
|
||||
similarity_eval_config=CacheSimilarityEvalConfig.from_config(config.get("similarity_evaluation", {})),
|
||||
init_config=CacheInitConfig.from_config(config.get("init_config", {})),
|
||||
)
|
||||
"""
|
||||
expected = """from typing import Any, Optional
|
||||
|
||||
from embedchain.config.base_config import BaseConfig
|
||||
from embedchain.helpers.json_serializable import register_deserializable
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class CacheSimilarityEvalConfig(BaseConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: Optional[str] = "distance",
|
||||
max_distance: Optional[float] = 1.0,
|
||||
positive: Optional[bool] = False,
|
||||
):
|
||||
self.strategy = strategy
|
||||
self.max_distance = max_distance
|
||||
self.positive = positive
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[dict[str, Any]]):
|
||||
if config is None:
|
||||
return CacheSimilarityEvalConfig()
|
||||
|
||||
strategy = config.get("strategy", "distance")
|
||||
max_distance = config.get("max_distance", 1.0)
|
||||
positive = config.get("positive", False)
|
||||
|
||||
return CacheSimilarityEvalConfig(strategy, max_distance, positive)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class CacheInitConfig(BaseConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
similarity_threshold: Optional[float] = 0.8,
|
||||
auto_flush: Optional[int] = 20,
|
||||
):
|
||||
if similarity_threshold < 0 or similarity_threshold > 1:
|
||||
raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
|
||||
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.auto_flush = auto_flush
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[dict[str, Any]]):
|
||||
if config is None:
|
||||
return CacheInitConfig()
|
||||
else:
|
||||
return CacheInitConfig(
|
||||
similarity_threshold=config.get("similarity_threshold", 0.8),
|
||||
auto_flush=config.get("auto_flush", 20),
|
||||
)
|
||||
|
||||
|
||||
@register_deserializable
|
||||
class CacheConfig(BaseConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
similarity_eval_config: Optional[CacheSimilarityEvalConfig] = CacheSimilarityEvalConfig(),
|
||||
init_config: Optional[CacheInitConfig] = CacheInitConfig(),
|
||||
):
|
||||
self.similarity_eval_config = similarity_eval_config
|
||||
self.init_config = init_config
|
||||
|
||||
@staticmethod
|
||||
def from_config(config: Optional[dict[str, Any]]):
|
||||
if config is None:
|
||||
return CacheConfig()
|
||||
else:
|
||||
return CacheConfig(
|
||||
similarity_eval_config=CacheSimilarityEvalConfig.from_config(config.get("similarity_evaluation", {})),
|
||||
init_config=CacheInitConfig.from_config(config.get("init_config", {})),
|
||||
)
|
||||
"""
|
||||
|
||||
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
|
||||
preexisting_functions: list[str] = [
|
||||
"CacheSimilarityEvalConfig.__init__",
|
||||
"CacheSimilarityEvalConfig.from_config",
|
||||
]
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code, function_names, optim_code, preexisting_functions
|
||||
)
|
||||
assert new_code == expected
|
||||
|
|
|
|||
Loading…
Reference in a new issue