diff --git a/growthbook/__init__.py b/growthbook/__init__.py index 997b2d0..9b5d524 100644 --- a/growthbook/__init__.py +++ b/growthbook/__init__.py @@ -7,6 +7,11 @@ BackoffStrategy ) +from .cache_interfaces import ( + AbstractFeatureCache, + AbstractAsyncFeatureCache +) + # Plugin support from .plugins import ( GrowthBookTrackingPlugin, diff --git a/growthbook/cache_interfaces.py b/growthbook/cache_interfaces.py new file mode 100644 index 0000000..25a5124 --- /dev/null +++ b/growthbook/cache_interfaces.py @@ -0,0 +1,46 @@ +from abc import abstractmethod, ABC +from typing import Optional, Dict + +class AbstractFeatureCache(ABC): + @abstractmethod + def get(self, key: str) -> Optional[Dict]: + pass + + @abstractmethod + def set(self, key: str, value: Dict, ttl: int) -> None: + pass + + def clear(self) -> None: + pass + +class AbstractAsyncFeatureCache(ABC): + """Abstract base class for async feature caching implementations""" + + @abstractmethod + async def get(self, key: str) -> Optional[Dict]: + """ + Retrieve cached features by key. + + Args: + key: Cache key + + Returns: + Cached dictionary or None if not found/expired + """ + pass + + @abstractmethod + async def set(self, key: str, value: Dict, ttl: int) -> None: + """ + Store features in cache with TTL. + + Args: + key: Cache key + value: Features dictionary to cache + ttl: Time to live in seconds + """ + pass + + async def clear(self) -> None: + """Clear all cached entries (optional to override)""" + pass diff --git a/growthbook/common_types.py b/growthbook/common_types.py index 916f4d3..97db9e1 100644 --- a/growthbook/common_types.py +++ b/growthbook/common_types.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, List, Optional, Union, Set, Tuple from enum import Enum from abc import ABC, abstractmethod +from .cache_interfaces import AbstractFeatureCache, AbstractAsyncFeatureCache class VariationMeta(TypedDict): key: str @@ -396,7 +397,7 @@ def get_all_assignments(self, attributes: Dict[str, str]) -> Dict[str, Dict]: return docs @dataclass -class StackContext: +class StackContext: id: Optional[str] = None evaluated_features: Set[str] = field(default_factory=set) @@ -431,6 +432,8 @@ class Options: on_experiment_viewed: Optional[Callable[[Experiment, Result, Optional[UserContext]], None]] = None on_feature_usage: Optional[Callable[[str, 'FeatureResult', UserContext], None]] = None tracking_plugins: Optional[List[Any]] = None + cache: Optional[AbstractFeatureCache] = None + async_cache: Optional[AbstractAsyncFeatureCache] = None @dataclass diff --git a/growthbook/growthbook.py b/growthbook/growthbook.py index 3706dfa..7895733 100644 --- a/growthbook/growthbook.py +++ b/growthbook/growthbook.py @@ -13,18 +13,19 @@ from abc import ABC, abstractmethod from typing import Optional, Any, Set, Tuple, List, Dict, Callable - -from .common_types import ( EvaluationContext, - Experiment, - FeatureResult, - Feature, - GlobalContext, - Options, - Result, StackContext, - UserContext, - AbstractStickyBucketService, - FeatureRule -) +from collections import OrderedDict +from .cache_interfaces import AbstractFeatureCache, AbstractAsyncFeatureCache +from .common_types import (EvaluationContext, + Experiment, + FeatureResult, + Feature, + GlobalContext, + Options, + Result, StackContext, + UserContext, + AbstractStickyBucketService, + FeatureRule + ) # Only require typing_extensions if using Python 3.7 or earlier if sys.version_info >= (3, 8): @@ -63,19 +64,6 @@ def decrypt(encrypted_str: str, key_str: str) -> str: return bytestring.decode("utf-8") -class AbstractFeatureCache(ABC): - @abstractmethod - def get(self, key: str) -> Optional[Dict]: - pass - - @abstractmethod - def set(self, key: str, value: Dict, ttl: int) -> None: - pass - - def clear(self) -> None: - pass - - class CacheEntry(object): def __init__(self, value: Dict, ttl: int) -> None: self.value = value @@ -107,6 +95,37 @@ def set(self, key: str, value: Dict, ttl: int) -> None: def clear(self) -> None: self.cache.clear() + +class InMemoryAsyncFeatureCache(AbstractAsyncFeatureCache): + """ + Async in-memory cache implementation. + Uses the same CacheEntry structure but with async interface. + """ + + def __init__(self) -> None: + self._cache: Dict[str, CacheEntry] = {} + self._lock = asyncio.Lock() + + async def get(self, key: str) -> Optional[Dict]: + async with self._lock: + if key in self._cache: + entry = self._cache[key] + if entry.expires >= time(): + return entry.value + return None + + async def set(self, key: str, value: Dict, ttl: int) -> None: + async with self._lock: + if key in self._cache: + self._cache[key].update(value) + else: + self._cache[key] = CacheEntry(value, ttl) + + async def clear(self) -> None: + async with self._lock: + self._cache.clear() + + class InMemoryStickyBucketService(AbstractStickyBucketService): def __init__(self) -> None: self.docs: Dict[str, Dict] = {} @@ -158,7 +177,7 @@ def disconnect(self, timeout=10): """Gracefully disconnect with timeout""" logger.debug("Initiating SSE client disconnect") self.is_running = False - + if self._loop and self._loop.is_running(): future = asyncio.run_coroutine_threadsafe(self._stop_session(timeout), self._loop) try: @@ -189,12 +208,12 @@ def _get_sse_url(self, api_host: str, client_key: str) -> str: async def _init_session(self): url = self._get_sse_url(self.api_host, self.client_key) - + try: while self.is_running: try: - async with aiohttp.ClientSession(headers=self.headers, - timeout=aiohttp.ClientTimeout(connect=self.timeout)) as session: + async with aiohttp.ClientSession(headers=self.headers, + timeout=aiohttp.ClientTimeout(connect=self.timeout)) as session: self._sse_session = session async with session.get(url) as response: @@ -234,7 +253,7 @@ async def _process_response(self, response): if not self.is_running: logger.debug("SSE processing stopped - is_running is False") break - + decoded_line = line.decode('utf-8').strip() if decoded_line.startswith("event:"): event_data['type'] = decoded_line[len("event:"):].strip() @@ -247,7 +266,7 @@ async def _process_response(self, response): except Exception as e: logger.warning(f"Error in event handler: {e}") event_data = {} - + # Process any remaining event data if 'type' in event_data and 'data' in event_data: try: @@ -276,7 +295,7 @@ async def _close_session(self): def _run_sse_channel(self): self._loop = asyncio.new_event_loop() - + try: self._loop.run_until_complete(self._init_session()) except asyncio.CancelledError: @@ -288,7 +307,7 @@ def _run_sse_channel(self): async def _stop_session(self, timeout=10): """Stop the SSE session and cancel all tasks with timeout""" logger.debug("Stopping SSE session") - + # Close the session first if self._sse_session and not self._sse_session.closed: try: @@ -301,15 +320,15 @@ async def _stop_session(self, timeout=10): if self._loop and self._loop.is_running(): try: # Get all tasks for this specific loop - tasks = [task for task in asyncio.all_tasks(self._loop) - if not task.done() and task is not asyncio.current_task(self._loop)] - + tasks = [task for task in asyncio.all_tasks(self._loop) + if not task.done() and task is not asyncio.current_task(self._loop)] + if tasks: logger.debug(f"Cancelling {len(tasks)} SSE tasks") # Cancel all tasks for task in tasks: task.cancel() - + # Wait for tasks to complete with timeout try: await asyncio.wait_for( @@ -324,22 +343,20 @@ async def _stop_session(self, timeout=10): except Exception as e: logger.warning(f"Error during SSE task cleanup: {e}") -from collections import OrderedDict - -# ... (imports) class FeatureRepository(object): def __init__(self) -> None: self.cache: AbstractFeatureCache = InMemoryFeatureCache() + self.async_cache: Optional[AbstractAsyncFeatureCache] = None self.http: Optional[PoolManager] = None self.sse_client: Optional[SSEClient] = None self._feature_update_callbacks: List[Callable[[Dict], None]] = [] - + # Background refresh support self._refresh_thread: Optional[threading.Thread] = None self._refresh_stop_event = threading.Event() self._refresh_lock = threading.Lock() - + # ETag cache for bandwidth optimization # Using OrderedDict for LRU cache (max 100 entries) self._etag_cache: OrderedDict[str, Tuple[str, Dict[str, Any]]] = OrderedDict() @@ -349,6 +366,13 @@ def __init__(self) -> None: def set_cache(self, cache: AbstractFeatureCache) -> None: self.cache = cache + def set_async_cache(self, cache: AbstractAsyncFeatureCache) -> None: + """ + Set asynchronous cache implementation. + When set, load_features_async() will use this instead of sync cache. + """ + self.async_cache = cache + def clear_cache(self): self.cache.clear() @@ -379,7 +403,7 @@ def load_features( ) -> Optional[Dict]: if not client_key: raise ValueError("Must specify `client_key` to refresh features") - + key = api_host + "::" + client_key cached = self.cache.get(key) @@ -392,28 +416,40 @@ def load_features( self._notify_feature_update_callbacks(res) return res return cached - - + async def load_features_async( self, api_host: str, client_key: str, decryption_key: str = "", ttl: int = 600 ) -> Optional[Dict]: + if not client_key: + raise ValueError("Must specify `client_key` to refresh features") + key = api_host + "::" + client_key - cached = self.cache.get(key) + # Use async cache if existed, unless fallback to sync + if self.async_cache: + cached = await self.async_cache.get(key) # Async + else: + cached = self.cache.get(key) # Fallback to sync cache + if not cached: res = await self._fetch_features_async(api_host, client_key, decryption_key) if res is not None: - self.cache.set(key, res, ttl) + # save in cache + if self.async_cache: + await self.async_cache.set(key, res, ttl) # Async! + else: + self.cache.set(key, res, ttl) + logger.debug("Fetched features from API, stored in cache") # Notify callbacks about fresh features self._notify_feature_update_callbacks(res) return res return cached - + @property def user_agent_suffix(self) -> Optional[str]: return getattr(self, "_user_agent_suffix", None) - + @user_agent_suffix.setter def user_agent_suffix(self, value: Optional[str]) -> None: self._user_agent_suffix = value @@ -422,23 +458,23 @@ def user_agent_suffix(self, value: Optional[str]) -> None: def _get(self, url: str, headers: Optional[Dict[str, str]] = None): self.http = self.http or PoolManager() return self.http.request("GET", url, headers=headers or {}) - + def _get_headers(self, client_key: str, existing_headers: Dict[str, str] = None) -> Dict[str, str]: headers = existing_headers or {} headers['Accept-Encoding'] = "gzip, deflate" - + # Add User-Agent with optional suffix ua = "Gb-Python" ua += f"-{self.user_agent_suffix}" if self.user_agent_suffix else f"-{client_key[-4:]}" headers['User-Agent'] = ua - + return headers def _fetch_and_decode(self, api_host: str, client_key: str) -> Optional[Dict]: url = self._get_features_url(api_host, client_key) headers = self._get_headers(client_key) logger.debug(f"Fetching features from {url} with headers {headers}") - + # Check if we have a cached ETag for this URL cached_etag = None cached_data = None @@ -451,10 +487,10 @@ def _fetch_and_decode(self, api_host: str, client_key: str) -> Optional[Dict]: logger.debug(f"Using cached ETag for request: {cached_etag[:20]}...") else: logger.debug(f"No ETag cache found for URL: {url}") - + try: r = self._get(url, headers) - + # Handle 304 Not Modified - content hasn't changed if r.status == 304: logger.debug(f"ETag match! Server returned 304 Not Modified - using cached data (saved bandwidth)") @@ -464,15 +500,15 @@ def _fetch_and_decode(self, api_host: str, client_key: str) -> Optional[Dict]: else: logger.warning("Received 304 but no cached data available") return None - + if r.status >= 400: logger.warning( "Failed to fetch features, received status code %d", r.status ) return None - + decoded = json.loads(r.data.decode("utf-8")) - + # Store the new ETag if present response_etag = r.headers.get('ETag') if response_etag: @@ -481,7 +517,7 @@ def _fetch_and_decode(self, api_host: str, client_key: str) -> Optional[Dict]: # Enforce max size if len(self._etag_cache) > self._max_etag_entries: self._etag_cache.popitem(last=False) - + if cached_etag: logger.debug(f"ETag updated: {cached_etag[:20]}... -> {response_etag[:20]}...") else: @@ -489,17 +525,17 @@ def _fetch_and_decode(self, api_host: str, client_key: str) -> Optional[Dict]: logger.debug(f"ETag cache now contains {len(self._etag_cache)} entries") else: logger.debug("No ETag header in response") - + return decoded # type: ignore[no-any-return] except Exception as e: logger.error(f"Failed to decode feature JSON from GrowthBook API: {e}") return None - + async def _fetch_and_decode_async(self, api_host: str, client_key: str) -> Optional[Dict]: url = self._get_features_url(api_host, client_key) headers = self._get_headers(client_key=client_key) logger.debug(f"[Async] Fetching features from {url} with headers {headers}") - + # Check if we have a cached ETag for this URL cached_etag = None cached_data = None @@ -512,26 +548,27 @@ async def _fetch_and_decode_async(self, api_host: str, client_key: str) -> Optio logger.debug(f"[Async] Using cached ETag for request: {cached_etag[:20]}...") else: logger.debug(f"[Async] No ETag cache found for URL: {url}") - + try: async with aiohttp.ClientSession() as session: async with session.get(url, headers=headers) as response: # Handle 304 Not Modified - content hasn't changed if response.status == 304: - logger.debug(f"[Async] ETag match! Server returned 304 Not Modified - using cached data (saved bandwidth)") + logger.debug( + f"[Async] ETag match! Server returned 304 Not Modified - using cached data (saved bandwidth)") if cached_data is not None: logger.debug(f"[Async] Returning cached response ({len(str(cached_data))} bytes)") return cached_data else: logger.warning("[Async] Received 304 but no cached data available") return None - + if response.status >= 400: logger.warning("Failed to fetch features, received status code %d", response.status) return None - + decoded = await response.json() - + # Store the new ETag if present response_etag = response.headers.get('ETag') if response_etag: @@ -540,15 +577,16 @@ async def _fetch_and_decode_async(self, api_host: str, client_key: str) -> Optio # Enforce max size if len(self._etag_cache) > self._max_etag_entries: self._etag_cache.popitem(last=False) - + if cached_etag: logger.debug(f"[Async] ETag updated: {cached_etag[:20]}... -> {response_etag[:20]}...") else: - logger.debug(f"[Async] New ETag cached: {response_etag[:20]}... ({len(str(decoded))} bytes)") + logger.debug( + f"[Async] New ETag cached: {response_etag[:20]}... ({len(str(decoded))} bytes)") logger.debug(f"[Async] ETag cache now contains {len(self._etag_cache)} entries") else: logger.debug("[Async] No ETag header in response") - + return decoded # type: ignore[no-any-return] except aiohttp.ClientError as e: logger.warning(f"HTTP request failed: {e}") @@ -556,7 +594,7 @@ async def _fetch_and_decode_async(self, api_host: str, client_key: str) -> Optio except Exception as e: logger.error(f"Failed to decode feature JSON from GrowthBook API: {e}") return None - + def decrypt_response(self, data, decryption_key: str): if "encryptedFeatures" in data: if not decryption_key: @@ -572,7 +610,7 @@ def decrypt_response(self, data, decryption_key: str): return None elif "features" not in data: logger.warning("GrowthBook API response missing features") - + if "encryptedSavedGroups" in data: if not decryption_key: raise ValueError("Must specify decryption_key") @@ -585,7 +623,7 @@ def decrypt_response(self, data, decryption_key: str): logger.warning( "Failed to decrypt saved groups from GrowthBook API response" ) - + return data # Fetch features from the GrowthBook API @@ -599,7 +637,7 @@ def _fetch_features( data = self.decrypt_response(decoded, decryption_key) return data # type: ignore[no-any-return] - + async def _fetch_features_async( self, api_host: str, client_key: str, decryption_key: str = "" ) -> Optional[Dict]: @@ -611,11 +649,11 @@ async def _fetch_features_async( return data # type: ignore[no-any-return] - def startAutoRefresh(self, api_host, client_key, cb, streaming_timeout=30): if not client_key: raise ValueError("Must specify `client_key` to start features streaming") - self.sse_client = self.sse_client or SSEClient(api_host=api_host, client_key=client_key, on_event=cb, timeout=streaming_timeout) + self.sse_client = self.sse_client or SSEClient(api_host=api_host, client_key=client_key, on_event=cb, + timeout=streaming_timeout) self.sse_client.connect() def stopAutoRefresh(self, timeout=10): @@ -623,8 +661,9 @@ def stopAutoRefresh(self, timeout=10): if self.sse_client: self.sse_client.disconnect(timeout=timeout) self.sse_client = None - - def start_background_refresh(self, api_host: str, client_key: str, decryption_key: str, ttl: int = 600, refresh_interval: int = 300) -> None: + + def start_background_refresh(self, api_host: str, client_key: str, decryption_key: str, ttl: int = 600, + refresh_interval: int = 300) -> None: """Start periodic background refresh task""" if not client_key: @@ -633,7 +672,7 @@ def start_background_refresh(self, api_host: str, client_key: str, decryption_ke with self._refresh_lock: if self._refresh_thread is not None: return # Already running - + self._refresh_stop_event.clear() self._refresh_thread = threading.Thread( target=self._background_refresh_worker, @@ -642,15 +681,16 @@ def start_background_refresh(self, api_host: str, client_key: str, decryption_ke ) self._refresh_thread.start() logger.debug("Started background refresh task") - - def _background_refresh_worker(self, api_host: str, client_key: str, decryption_key: str, ttl: int, refresh_interval: int) -> None: + + def _background_refresh_worker(self, api_host: str, client_key: str, decryption_key: str, ttl: int, + refresh_interval: int) -> None: """Worker method for periodic background refresh""" while not self._refresh_stop_event.is_set(): try: # Wait for the refresh interval or stop event if self._refresh_stop_event.wait(refresh_interval): break # Stop event was set - + logger.debug("Background refresh for Features - started") res = self._fetch_features(api_host, client_key, decryption_key) if res is not None: @@ -663,11 +703,11 @@ def _background_refresh_worker(self, api_host: str, client_key: str, decryption_ logger.warning("Background refresh failed") except Exception as e: logger.warning(f"Background refresh error: {e}") - + def stop_background_refresh(self) -> None: """Stop background refresh task""" self._refresh_stop_event.set() - + with self._refresh_lock: if self._refresh_thread is not None: self._refresh_thread.join(timeout=1.0) # Wait up to 1 second @@ -744,7 +784,8 @@ def __init__( self._user = user self._groups = groups self._overrides = overrides - self._forcedVariations = (forced_variations if forced_variations is not None else forcedVariations) if forced_variations is not None or forcedVariations else {} + self._forcedVariations = ( + forced_variations if forced_variations is not None else forcedVariations) if forced_variations is not None or forcedVariations else {} self._tracked: Dict[str, Any] = {} self._assigned: Dict[str, Any] = {} @@ -769,7 +810,7 @@ def __init__( ), features={}, saved_groups=self._saved_groups - ) + ) # Create a user context for the current user self._user_ctx: UserContext = UserContext( url=self._url, @@ -797,7 +838,7 @@ def __init__( # Start background refresh task for stale-while-revalidate self.load_features() # Initial load feature_repo.start_background_refresh( - self._api_host, self._client_key, self._decryption_key, + self._api_host, self._client_key, self._decryption_key, self._cache_ttl, self._stale_ttl ) @@ -837,7 +878,7 @@ def _features_event_handler(self, features): decoded = json.loads(features) if not decoded: return None - + data = feature_repo.decrypt_response(decoded, self._decryption_key) if data is not None: @@ -855,13 +896,12 @@ def _dispatch_sse_event(self, event_data): elif event_type == 'features': self._features_event_handler(data) - def startAutoRefresh(self): if not self._client_key: raise ValueError("Must specify `client_key` to start features streaming") - + feature_repo.startAutoRefresh( - api_host=self._api_host, + api_host=self._api_host, client_key=self._client_key, cb=self._dispatch_sse_event, streaming_timeout=self._streaming_timeout @@ -926,34 +966,34 @@ def get_attributes(self) -> dict: def destroy(self, timeout=10) -> None: """Gracefully destroy the GrowthBook instance""" logger.debug("Starting GrowthBook destroy process") - + try: # Clean up plugins logger.debug("Cleaning up plugins") self._cleanup_plugins() except Exception as e: logger.warning(f"Error cleaning up plugins: {e}") - + try: logger.debug("Stopping auto refresh during destroy") self.stopAutoRefresh(timeout=timeout) except Exception as e: logger.warning(f"Error stopping auto refresh during destroy: {e}") - + try: # Stop background refresh operations if self._stale_while_revalidate and self._client_key: feature_repo.stop_background_refresh() except Exception as e: logger.warning(f"Error stopping background refresh during destroy: {e}") - + try: # Clean up feature update callback if self._client_key: feature_repo.remove_feature_update_callback(self._on_feature_update) except Exception as e: logger.warning(f"Error removing feature update callback: {e}") - + # Clear all internal state try: self._subscriptions.clear() @@ -995,14 +1035,14 @@ def get_feature_value(self, key: str, fallback): def evalFeature(self, key: str) -> FeatureResult: warnings.warn("evalFeature is deprecated, use eval_feature instead", DeprecationWarning) return self.eval_feature(key) - + def _ensure_fresh_features(self) -> None: """Lazy refresh: Check cache expiry and refresh if needed, but only if client_key is provided""" - + # Prevent infinite recursion when updating features (e.g., during sticky bucket refresh) if self._is_updating_features: return - + if self._streaming or self._stale_while_revalidate or not self._client_key: return # Skip cache checks - SSE or background refresh handles freshness @@ -1014,7 +1054,7 @@ def _ensure_fresh_features(self) -> None: def _get_eval_context(self) -> EvaluationContext: # Lazy refresh: ensure features are fresh before evaluation self._ensure_fresh_features() - + # use the latest attributes for every evaluation. self._user_ctx.attributes = self._attributes self._user_ctx.url = self._url @@ -1028,8 +1068,8 @@ def _get_eval_context(self) -> EvaluationContext: ) def eval_feature(self, key: str) -> FeatureResult: - result = core_eval_feature(key=key, - evalContext=self._get_eval_context(), + result = core_eval_feature(key=key, + evalContext=self._get_eval_context(), callback_subscription=self._fireSubscriptions, tracking_cb=self._track ) @@ -1068,7 +1108,7 @@ def _fireSubscriptions(self, experiment: Experiment, result: Result): def run(self, experiment: Experiment) -> Result: # result = self._run(experiment) - result = run_experiment(experiment=experiment, + result = run_experiment(experiment=experiment, evalContext=self._get_eval_context(), tracking_cb=self._track ) @@ -1157,7 +1197,7 @@ def _initialize_plugins(self) -> None: def user_agent_suffix(self) -> Optional[str]: """Get the suffix appended to the User-Agent header""" return feature_repo.user_agent_suffix - + @user_agent_suffix.setter def user_agent_suffix(self, value: Optional[str]) -> None: """Set a suffix to be appended to the User-Agent header""" diff --git a/growthbook/growthbook_client.py b/growthbook/growthbook_client.py index 70cf6c3..316bf6b 100644 --- a/growthbook/growthbook_client.py +++ b/growthbook/growthbook_client.py @@ -43,9 +43,9 @@ def __call__(cls, *args, **kwargs): class BackoffStrategy: """Exponential backoff with jitter for failed requests""" def __init__( - self, - initial_delay: float = 1.0, - max_delay: float = 60.0, + self, + initial_delay: float = 1.0, + max_delay: float = 60.0, multiplier: float = 2.0, jitter: float = 0.1 ): @@ -59,7 +59,7 @@ def __init__( def next_delay(self) -> float: """Calculate next delay with jitter""" delay = min( - self.current_delay * (self.multiplier ** self.attempt), + self.current_delay * (self.multiplier ** self.attempt), self.max_delay ) # Add random jitter @@ -252,7 +252,7 @@ async def refresh_loop() -> None: async def start_feature_refresh(self, strategy: FeatureRefreshStrategy, callback=None): """Initialize feature refresh based on strategy""" self._refresh_callback = callback - + if strategy == FeatureRefreshStrategy.SERVER_SENT_EVENTS: await self._start_sse_refresh() else: @@ -281,7 +281,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): await self.stop_refresh() - + async def load_features_async( self, api_host: str, client_key: str, decryption_key: str = "", ttl: int = 60 ) -> Optional[Dict]: @@ -295,17 +295,17 @@ class GrowthBookClient: def __init__( self, options: Optional[Union[Dict[str, Any], Options]] = None - ): + ): self.options = ( options if isinstance(options, Options) else Options(**options) if options else Options() ) - + # Thread-safe tracking state self._tracked: Dict[str, bool] = {} # Access only within async context self._tracked_lock = threading.Lock() - + # Thread-safe subscription management self._subscriptions: Set[Callable[[Experiment, Result], None]] = set() self._subscriptions_lock = threading.Lock() @@ -316,25 +316,37 @@ def __init__( 'assignments': {} } self._sticky_bucket_cache_lock = False - + # Plugin support self._tracking_plugins: List[Any] = self.options.tracking_plugins or [] self._initialized_plugins: List[Any] = [] - + self._features_repository = ( EnhancedFeatureRepository( - self.options.api_host or "https://cdn.growthbook.io", - self.options.client_key or "", - self.options.decryption_key or "", + self.options.api_host or "https://cdn.growthbook.io", + self.options.client_key or "", + self.options.decryption_key or "", self.options.cache_ttl ) if self.options.client_key else None ) - + + # Check if repo was initialized + if self._features_repository is not None: + # 1. set sync cache + if self.options.cache is not None: + self._features_repository.set_cache(self.options.cache) + logger.debug("Custom sync cache set for FeatureRepository.") + + # 2. set async cache + if self.options.async_cache is not None: + self._features_repository.set_async_cache(self.options.async_cache) + logger.debug("Custom async cache set for FeatureRepository.") + self._global_context: Optional[GlobalContext] = None self._context_lock = asyncio.Lock() - + # Initialize plugins self._initialize_plugins() @@ -383,8 +395,8 @@ def _fire_subscriptions(self, experiment: Experiment, result: Result) -> None: async def set_features(self, features: dict) -> None: await self._feature_update_callback({"features": features}) - - + + async def _refresh_sticky_buckets(self, attributes: Dict[str, Any]) -> Dict[str, Any]: """Refresh sticky bucket assignments only if attributes have changed""" if not self.options.sticky_bucket_service: @@ -394,7 +406,7 @@ async def _refresh_sticky_buckets(self, attributes: Dict[str, Any]) -> Dict[str, while not self._sticky_bucket_cache_lock: if attributes == self._sticky_bucket_cache['attributes']: return self._sticky_bucket_cache['assignments'] - + self._sticky_bucket_cache_lock = True try: assignments = self.options.sticky_bucket_service.get_all_assignments(attributes) @@ -403,7 +415,7 @@ async def _refresh_sticky_buckets(self, attributes: Dict[str, Any]) -> Dict[str, return assignments finally: self._sticky_bucket_cache_lock = False - + # Fallback return for edge case where loop condition is never satisfied return {} @@ -416,9 +428,9 @@ async def initialize(self) -> bool: try: # Initial feature load initial_features = await self._features_repository.load_features_async( - self.options.api_host or "https://cdn.growthbook.io", - self.options.client_key or "", - self.options.decryption_key or "", + self.options.api_host or "https://cdn.growthbook.io", + self.options.client_key or "", + self.options.decryption_key or "", self.options.cache_ttl ) if not initial_features: @@ -427,15 +439,15 @@ async def initialize(self) -> bool: # Create global context with initial features await self._feature_update_callback(initial_features) - + # Set up callback for future updates self._features_repository.add_callback(self._feature_update_callback) - + # Start feature refresh refresh_strategy = self.options.refresh_strategy or FeatureRefreshStrategy.STALE_WHILE_REVALIDATE await self._features_repository.start_feature_refresh(refresh_strategy) return True - + except Exception as e: logger.error(f"Initialization failed: {str(e)}", exc_info=True) traceback.print_exc() @@ -482,10 +494,10 @@ async def create_evaluation_context(self, user_context: UserContext) -> Evaluati """Create evaluation context for feature evaluation""" if self._global_context is None: raise RuntimeError("GrowthBook client not properly initialized") - + # Get sticky bucket assignments if needed sticky_assignments = await self._refresh_sticky_buckets(user_context.attributes) - + # update user context with sticky bucket assignments user_context.sticky_bucket_assignment_docs = sticky_assignments @@ -520,7 +532,7 @@ async def is_on(self, key: str, user_context: UserContext) -> bool: except Exception: logger.exception("Error in feature usage callback") return result.on - + async def is_off(self, key: str, user_context: UserContext) -> bool: """Check if a feature is set to off with proper async context management""" async with self._context_lock: @@ -533,7 +545,7 @@ async def is_off(self, key: str, user_context: UserContext) -> bool: except Exception: logger.exception("Error in feature usage callback") return result.off - + async def get_feature_value(self, key: str, fallback: Any, user_context: UserContext) -> Any: async with self._context_lock: context = await self.create_evaluation_context(user_context) @@ -551,14 +563,14 @@ async def run(self, experiment: Experiment, user_context: UserContext) -> Result async with self._context_lock: context = await self.create_evaluation_context(user_context) result = run_experiment( - experiment=experiment, + experiment=experiment, evalContext=context, tracking_cb=self._track ) # Fire subscriptions synchronously self._fire_subscriptions(experiment, result) return result - + async def close(self) -> None: """Clean shutdown with proper cleanup""" if self._features_repository: @@ -572,7 +584,7 @@ async def close(self) -> None: # Clear context async with self._context_lock: self._global_context = None - + # Cleanup plugins self._cleanup_plugins() @@ -614,4 +626,4 @@ def _cleanup_plugins(self) -> None: logger.debug(f"Cleaned up plugin: {plugin.__class__.__name__}") except Exception as e: logger.error(f"Error cleaning up plugin {plugin}: {e}") - self._initialized_plugins.clear() \ No newline at end of file + self._initialized_plugins.clear()