diff --git a/growthbook/core.py b/growthbook/core.py index 158b244..bd09728 100644 --- a/growthbook/core.py +++ b/growthbook/core.py @@ -50,47 +50,48 @@ def isOperatorObject(obj: Any) -> bool: return False return True -def getType(attributeValue) -> str: - t = type(attributeValue) +def _is_numeric(v: Any) -> bool: + return isinstance(v, (int, float)) and not isinstance(v, bool) +def getType(attributeValue) -> str: if attributeValue is None: return "null" - if t is int or t is float: + if isinstance(attributeValue, bool): + return "boolean" + if _is_numeric(attributeValue): return "number" - if t is str: + if isinstance(attributeValue, str): return "string" - if t is list or t is set: + if isinstance(attributeValue, (list, set)): return "array" - if t is dict: + if isinstance(attributeValue, dict): return "object" - if t is bool: - return "boolean" return "unknown" def getPath(attributes, path): current = attributes for segment in path.split("."): - if type(current) is dict and segment in current: + if isinstance(current, dict) and segment in current: current = current[segment] else: return None return current def evalConditionValue(conditionValue, attributeValue, savedGroups, insensitive: bool = False) -> bool: - if type(conditionValue) is dict and isOperatorObject(conditionValue): + if isinstance(conditionValue, dict) and isOperatorObject(conditionValue): for key, value in conditionValue.items(): if not evalOperatorCondition(key, attributeValue, value, savedGroups): return False return True # Simple equality comparison with optional case-insensitivity - if insensitive and type(conditionValue) is str and type(attributeValue) is str: + if insensitive and isinstance(conditionValue, str) and isinstance(attributeValue, str): return conditionValue.lower() == attributeValue.lower() return bool(conditionValue == attributeValue) def elemMatch(condition, attributeValue, savedGroups) -> bool: - if not type(attributeValue) is list: + if not isinstance(attributeValue, list): return False for item in attributeValue: @@ -104,13 +105,13 @@ def elemMatch(condition, attributeValue, savedGroups) -> bool: return False def compare(val1, val2) -> int: - if (type(val1) is int or type(val1) is float) and not (type(val2) is int or type(val2) is float): + if _is_numeric(val1) and not _is_numeric(val2): if (val2 is None): val2 = 0 else: val2 = float(val2) - if (type(val2) is int or type(val2) is float) and not (type(val1) is int or type(val1) is float): + if _is_numeric(val2) and not _is_numeric(val1): if (val1 is None): val1 = 0 else: @@ -166,13 +167,13 @@ def evalOperatorCondition(operator, attributeValue, conditionValue, savedGroups) elif operator == "$vgte": return paddedVersionString(attributeValue) >= paddedVersionString(conditionValue) elif operator == "$inGroup": - if not type(conditionValue) is str: + if not isinstance(conditionValue, str): return False if not conditionValue in savedGroups: return False return isIn(savedGroups[conditionValue] or [], attributeValue) elif operator == "$notInGroup": - if not type(conditionValue) is str: + if not isinstance(conditionValue, str): return False if not conditionValue in savedGroups: return True @@ -202,33 +203,33 @@ def evalOperatorCondition(operator, attributeValue, conditionValue, savedGroups) except Exception: return False elif operator == "$in": - if not type(conditionValue) is list: + if not isinstance(conditionValue, list): return False return isIn(conditionValue, attributeValue) elif operator == "$nin": - if not type(conditionValue) is list: + if not isinstance(conditionValue, list): return False return not isIn(conditionValue, attributeValue) elif operator == "$ini": - if not type(conditionValue) is list: + if not isinstance(conditionValue, list): return False return isIn(conditionValue, attributeValue, insensitive=True) elif operator == "$nini": - if not type(conditionValue) is list: + if not isinstance(conditionValue, list): return False return not isIn(conditionValue, attributeValue, insensitive=True) elif operator == "$elemMatch": return elemMatch(conditionValue, attributeValue, savedGroups) elif operator == "$size": - if not (type(attributeValue) is list): + if not isinstance(attributeValue, list): return False return evalConditionValue(conditionValue, len(attributeValue), savedGroups) elif operator == "$all": - if not type(conditionValue) is list: + if not isinstance(conditionValue, list): return False return isInAll(conditionValue, attributeValue, savedGroups, insensitive=False) elif operator == "$alli": - if not type(conditionValue) is list: + if not isinstance(conditionValue, list): return False return isInAll(conditionValue, attributeValue, savedGroups, insensitive=True) elif operator == "$exists": @@ -243,10 +244,10 @@ def evalOperatorCondition(operator, attributeValue, conditionValue, savedGroups) def paddedVersionString(input) -> str: # If input is a number, convert to a string - if type(input) is int or type(input) is float: + if _is_numeric(input): input = str(input) - if not input or type(input) is not str: + if not input or not isinstance(input, str): input = "0" # Remove build info and leading `v` if any @@ -268,10 +269,10 @@ def isIn(conditionValue, attributeValue, insensitive: bool = False) -> bool: if insensitive: # Helper function to case-fold values (lowercase for strings) def case_fold(val): - return val.lower() if type(val) is str else val + return val.lower() if isinstance(val, str) else val # Do an intersection if attribute is an array (insensitive) - if type(attributeValue) is list: + if isinstance(attributeValue, list): return any( case_fold(el) == case_fold(exp) for el in attributeValue @@ -280,13 +281,13 @@ def case_fold(val): return any(case_fold(attributeValue) == case_fold(exp) for exp in conditionValue) # Case-sensitive behavior (original) - if type(attributeValue) is list: + if isinstance(attributeValue, list): return bool(set(conditionValue) & set(attributeValue)) return attributeValue in conditionValue def isInAll(conditionValue, attributeValue, savedGroups, insensitive: bool = False) -> bool: """Check if attributeValue (array) contains all elements in conditionValue""" - if not type(attributeValue) is list: + if not isinstance(attributeValue, list): return False for cond in conditionValue: diff --git a/tests/test_dict_subclass.py b/tests/test_dict_subclass.py new file mode 100644 index 0000000..445e7dc --- /dev/null +++ b/tests/test_dict_subclass.py @@ -0,0 +1,33 @@ +import unittest +from growthbook.core import getPath, evalCondition + +class MyDict(dict): + pass + +class TestDictSubclass(unittest.TestCase): + def test_get_path_with_subclass(self): + # Test getPath with a dict subclass + attributes = MyDict({"user": MyDict({"id": "123", "name": "John"})}) + + self.assertEqual(getPath(attributes, "user.id"), "123") + self.assertEqual(getPath(attributes, "user.name"), "John") + self.assertEqual(getPath(attributes, "user.nonexistent"), None) + + def test_eval_condition_with_subclass(self): + # Test evalCondition with a dict subclass + attributes = MyDict({"company": "GrowthBook", "meta": MyDict({"plan": "pro"})}) + + # Simple condition + condition = {"company": "GrowthBook"} + self.assertTrue(evalCondition(attributes, condition)) + + # Nested condition using getPath (indirectly) + condition = {"meta.plan": "pro"} + self.assertTrue(evalCondition(attributes, condition)) + + # Condition failing + condition = {"meta.plan": "free"} + self.assertFalse(evalCondition(attributes, condition)) + +if __name__ == '__main__': + unittest.main()