diff --git a/test/python/test_serialize.py b/test/python/test_serialize.py index 38734d1..565a160 100644 --- a/test/python/test_serialize.py +++ b/test/python/test_serialize.py @@ -21,6 +21,7 @@ ERROR_UNCATCHED = os.environ.get("ERROR_UNCATCHED", "false").lower() == "true" SLEEP_TIME = float(os.environ.get("SLEEP", "0")) CUESOR_TEST_COUNT = int(os.environ.get("CUESOR_TEST_COUNT", "3")) STRICT_MODE = os.environ.get("STRICT_MODE", "false").lower() == "true" +MULTI_THREAD = os.environ.get("MULTI_THREAD", "true").lower() == "true" def get_key(snake_str): @@ -64,13 +65,12 @@ def get_kwargs(key, additional): return kwargs -def match_rate_zero(key): - if STRICT_MODE: - raise Exception(f"Strict mode: {key}") +def match_rate_zero(a, b, key, fn): + fn(a, b, key) return 0 -def match_rate(a, b, key=""): +def match_rate(a, b, key="", fn=lambda x: None): if isinstance(a, aenum.Enum): a = a.value if isinstance(b, aenum.Enum): @@ -79,26 +79,26 @@ def match_rate(a, b, key=""): return 1 if a is False and b is None: return 1 - if isinstance(a, list): - data = [match_rate(a[i], b[i], key=f"{key}[{i}]") for i in range(len(a))] + if a is None and isinstance(b, list) and len(b) == 0: + return 1 + if isinstance(a, list) and b is None and len(a) == 0: + return 1 + if isinstance(a, dict) and isinstance(b, dict): if len(a) == 0 and len(b) == 0: return 1 - if len(a) == 0 or len(b) == 0: - return match_rate_zero(key) - - data = [match_rate(a.get(k), b.get(k), key=f"{key}.{k}") for k in a.keys()] - - return sum(data) / len(a) + marge_key = set(a.keys()) | set(b.keys()) + data = [match_rate(a.get(k), b.get(k), [*key, k], fn) for k in marge_key] + return sum(data) / len(b) if isinstance(a, list) and isinstance(b, list): if len(a) == 0 and len(b) == 0: return 1 if len(a) != len(b): - return match_rate_zero(key) - data = [match_rate(a[i], b[i], key=f"{key}[{i}]") for i in range(len(a))] + return match_rate_zero(a, b, key, fn) + data = [match_rate(a[i], b[i], [*key, i], fn) for i in range(len(a))] return sum(data) / len(a) if a == b: return 1 - return match_rate_zero(key) + return match_rate_zero(a, b, key, fn) def save_cache(data): @@ -108,12 +108,44 @@ def save_cache(data): json.dump(data, f, indent=4) +def super_get(obj: dict, key: str): + keys = [ + key, + "".join(["_" + c.lower() if c.isupper() else c for c in key]).lstrip("_"), + ] + + for k in keys: + if obj.get(k) is not None: + return obj[k] + raise KeyError(key) + + def task_callback(file, thread=True): try: with open(file, "r") as f: cache = json.load(f) data = pt.__dict__[cache["type"]].from_json(cache["raw"]) - rate = match_rate(data.to_dict(), json.loads(cache["raw"])) + + def get(obj, key): + if isinstance(obj, list): + return get(obj[key[0]], key[1:]) + if obj.__dict__.get("actual_instance") is not None: + return get(obj.actual_instance, key) + if len(key) == 0: + return obj + return get(super_get(obj.__dict__, key[0]), key[1:]) + + def match_rate_hook(a, b, key): + if STRICT_MODE: + obj_name = type(get(data, key[:-1])) + obj_key = f"{obj_name.__name__}.{key[-1]}" + raise Exception(f"Not defined: {obj_key}\nContents: {b}") + + rate = match_rate( + data.to_dict(), + json.loads(cache["raw"]), + fn=match_rate_hook, + ) return rate, file except Exception: if thread: @@ -151,10 +183,18 @@ if __name__ == "__main__": placeholder = json.load(f) fail = [] - with concurrent.futures.ProcessPoolExecutor() as executor: - tasks = [executor.submit(task_callback, x) for x in glob.glob("cache/*.json")] - for task in concurrent.futures.as_completed(tasks): - rate, file = task.result() + files = glob.glob("cache/*.json") + if MULTI_THREAD: + with concurrent.futures.ProcessPoolExecutor() as executor: + tasks = [executor.submit(task_callback, x) for x in files] + for task in concurrent.futures.as_completed(tasks): + rate, file = task.result() + if rate < 1: + fail.append(file) + logger.info(f"Match rate: {rate}") + else: + for file in files: + rate, file = task_callback(file, thread=False) if rate < 1: fail.append(file) logger.info(f"Match rate: {rate}")