diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 678d44b8f..8c53cb5b1 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -48,7 +48,10 @@ print("Hello world") function_name: str = "NewClass.new_function" preexisting_functions: list[str] = ["NewClass.new_function"] new_code: str = replace_functions_in_file( - original_code, [function_name], optim_code, preexisting_functions, + original_code, + [function_name], + optim_code, + preexisting_functions, ) assert new_code == expected @@ -336,7 +339,10 @@ def blab(st): print("Not cool") """ new_main_code: str = replace_functions_in_file( - original_code_main, ["other_function"], optim_code, ["other_function", "yet_another_function", "blob"] + original_code_main, + ["other_function"], + optim_code, + ["other_function", "yet_another_function", "blob"], ) assert new_main_code == expected_main @@ -344,3 +350,195 @@ print("Not cool") original_code_dependent, ["blob"], optim_code, [] ) assert new_dependent_code == expected_dependent + + +def test_test_libcst_code_replacement7(): + optim_code = """@register_deserializable +class CacheSimilarityEvalConfig(BaseConfig): + + def __init__( + self, + strategy: Optional[str] = "distance", + max_distance: Optional[float] = 1.0, + positive: Optional[bool] = False, + ): + self.strategy = strategy + self.max_distance = max_distance + self.positive = positive + + @staticmethod + def from_config(config: Optional[dict[str, Any]]): + if config is None: + return CacheSimilarityEvalConfig() + + strategy = config.get("strategy", "distance") + max_distance = config.get("max_distance", 1.0) + positive = config.get("positive", False) + + return CacheSimilarityEvalConfig(strategy, max_distance, positive) +""" + + original_code = """from typing import Any, Optional + +from embedchain.config.base_config import BaseConfig +from embedchain.helpers.json_serializable import register_deserializable + + +@register_deserializable +class CacheSimilarityEvalConfig(BaseConfig): + + def __init__( + self, + strategy: Optional[str] = "distance", + max_distance: Optional[float] = 1.0, + positive: Optional[bool] = False, + ): + self.strategy = strategy + self.max_distance = max_distance + self.positive = positive + + @staticmethod + def from_config(config: Optional[dict[str, Any]]): + if config is None: + return CacheSimilarityEvalConfig() + else: + return CacheSimilarityEvalConfig( + strategy=config.get("strategy", "distance"), + max_distance=config.get("max_distance", 1.0), + positive=config.get("positive", False), + ) + + +@register_deserializable +class CacheInitConfig(BaseConfig): + + def __init__( + self, + similarity_threshold: Optional[float] = 0.8, + auto_flush: Optional[int] = 20, + ): + if similarity_threshold < 0 or similarity_threshold > 1: + raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1") + + self.similarity_threshold = similarity_threshold + self.auto_flush = auto_flush + + @staticmethod + def from_config(config: Optional[dict[str, Any]]): + if config is None: + return CacheInitConfig() + else: + return CacheInitConfig( + similarity_threshold=config.get("similarity_threshold", 0.8), + auto_flush=config.get("auto_flush", 20), + ) + + +@register_deserializable +class CacheConfig(BaseConfig): + + def __init__( + self, + similarity_eval_config: Optional[CacheSimilarityEvalConfig] = CacheSimilarityEvalConfig(), + init_config: Optional[CacheInitConfig] = CacheInitConfig(), + ): + self.similarity_eval_config = similarity_eval_config + self.init_config = init_config + + @staticmethod + def from_config(config: Optional[dict[str, Any]]): + if config is None: + return CacheConfig() + else: + return CacheConfig( + similarity_eval_config=CacheSimilarityEvalConfig.from_config(config.get("similarity_evaluation", {})), + init_config=CacheInitConfig.from_config(config.get("init_config", {})), + ) +""" + expected = """from typing import Any, Optional + +from embedchain.config.base_config import BaseConfig +from embedchain.helpers.json_serializable import register_deserializable + + +@register_deserializable +class CacheSimilarityEvalConfig(BaseConfig): + + def __init__( + self, + strategy: Optional[str] = "distance", + max_distance: Optional[float] = 1.0, + positive: Optional[bool] = False, + ): + self.strategy = strategy + self.max_distance = max_distance + self.positive = positive + + @staticmethod + def from_config(config: Optional[dict[str, Any]]): + if config is None: + return CacheSimilarityEvalConfig() + + strategy = config.get("strategy", "distance") + max_distance = config.get("max_distance", 1.0) + positive = config.get("positive", False) + + return CacheSimilarityEvalConfig(strategy, max_distance, positive) + + +@register_deserializable +class CacheInitConfig(BaseConfig): + + def __init__( + self, + similarity_threshold: Optional[float] = 0.8, + auto_flush: Optional[int] = 20, + ): + if similarity_threshold < 0 or similarity_threshold > 1: + raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1") + + self.similarity_threshold = similarity_threshold + self.auto_flush = auto_flush + + @staticmethod + def from_config(config: Optional[dict[str, Any]]): + if config is None: + return CacheInitConfig() + else: + return CacheInitConfig( + similarity_threshold=config.get("similarity_threshold", 0.8), + auto_flush=config.get("auto_flush", 20), + ) + + +@register_deserializable +class CacheConfig(BaseConfig): + + def __init__( + self, + similarity_eval_config: Optional[CacheSimilarityEvalConfig] = CacheSimilarityEvalConfig(), + init_config: Optional[CacheInitConfig] = CacheInitConfig(), + ): + self.similarity_eval_config = similarity_eval_config + self.init_config = init_config + + @staticmethod + def from_config(config: Optional[dict[str, Any]]): + if config is None: + return CacheConfig() + else: + return CacheConfig( + similarity_eval_config=CacheSimilarityEvalConfig.from_config(config.get("similarity_evaluation", {})), + init_config=CacheInitConfig.from_config(config.get("init_config", {})), + ) +""" + + function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"] + preexisting_functions: list[str] = [ + "CacheSimilarityEvalConfig.__init__", + "CacheSimilarityEvalConfig.from_config", + ] + new_code: str = replace_functions_in_file( + original_code, function_names, optim_code, preexisting_functions + ) + assert new_code == expected