Skip to content

Commit aac3a46

Browse files
committed
Implement a LMDB stream state store
1 parent 710c4e3 commit aac3a46

File tree

3 files changed

+754
-0
lines changed

3 files changed

+754
-0
lines changed

src/amp/streaming/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Streaming module for continuous data loading
22
from .iterator import StreamingResultIterator
3+
from .lmdb_state import LMDBStreamStateStore
34
from .parallel import (
45
BlockRangePartitionStrategy,
56
ParallelConfig,
@@ -35,6 +36,7 @@
3536
'StreamStateStore',
3637
'InMemoryStreamStateStore',
3738
'NullStreamStateStore',
39+
'LMDBStreamStateStore',
3840
'BatchIdentifier',
3941
'ProcessedBatch',
4042
]

src/amp/streaming/lmdb_state.py

Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
"""
2+
LMDB-based stream state store for durable batch tracking.
3+
4+
This implementation uses LMDB (Lightning Memory-Mapped Database) for fast,
5+
embedded, durable storage of batch processing state. It can be used with any
6+
loader (Kafka, PostgreSQL, etc.) to provide crash recovery and idempotency.
7+
"""
8+
9+
import json
10+
import logging
11+
from pathlib import Path
12+
from typing import Dict, List, Optional
13+
14+
import lmdb
15+
16+
from .state import BatchIdentifier, StreamStateStore
17+
from .types import BlockRange, ResumeWatermark
18+
19+
20+
class LMDBStreamStateStore(StreamStateStore):
21+
env: lmdb.Environment
22+
"""
23+
Generic LMDB-based state store for tracking processed batches.
24+
25+
Uses LMDB for fast, durable key-value storage with ACID transactions.
26+
Tracks individual batches with unique hash-based IDs to support:
27+
- Crash recovery and resume
28+
- Idempotency (duplicate detection)
29+
- Reorg handling (invalidate by block hash)
30+
- Gap detection for parallel loading
31+
32+
Uses two LMDB sub-databases for efficient queries:
33+
1. "batches" - Individual batch records keyed by batch_id
34+
2. "metadata" - Max block metadata per network for fast resume position queries
35+
36+
Batch database layout:
37+
- Key: {connection_name}|{table_name}|{batch_id}
38+
- Value: JSON with {network, start_block, end_block, end_hash, start_parent_hash}
39+
40+
Metadata database layout:
41+
- Key: {connection_name}|{table_name}|{network}
42+
- Value: JSON with {end_block, end_hash, start_parent_hash} (max processed block)
43+
"""
44+
45+
def __init__(
46+
self,
47+
connection_name: str,
48+
data_dir: str = '.amp_state',
49+
map_size: int = 10 * 1024 * 1024 * 1024,
50+
sync: bool = True,
51+
):
52+
"""
53+
Initialize LMDB state store with two sub-databases.
54+
55+
Args:
56+
connection_name: Name of the connection (for multi-connection support)
57+
data_dir: Directory to store LMDB database files
58+
map_size: Maximum database size in bytes (default: 10GB)
59+
sync: Whether to sync writes to disk (True for durability, False for speed)
60+
"""
61+
self.connection_name = connection_name
62+
self.data_dir = Path(data_dir)
63+
self.data_dir.mkdir(parents=True, exist_ok=True)
64+
65+
self.logger = logging.getLogger(__name__)
66+
67+
self.env = lmdb.open(str(self.data_dir), map_size=map_size, sync=sync, max_dbs=2)
68+
69+
self.batches_db = self.env.open_db(b'batches')
70+
self.metadata_db = self.env.open_db(b'metadata')
71+
72+
self.logger.info(f'Initialized LMDB state store at {self.data_dir} with 2 sub-databases')
73+
74+
def _make_batch_key(self, connection_name: str, table_name: str, batch_id: str) -> bytes:
75+
"""Create composite key for batch database."""
76+
return f'{connection_name}|{table_name}|{batch_id}'.encode('utf-8')
77+
78+
def _make_metadata_key(self, connection_name: str, table_name: str, network: str) -> bytes:
79+
"""Create composite key for metadata database."""
80+
return f'{connection_name}|{table_name}|{network}'.encode('utf-8')
81+
82+
def _parse_key(self, key: bytes) -> tuple[str, str, str]:
83+
"""Parse composite key into (connection_name, table_name, batch_id/network)."""
84+
parts = key.decode('utf-8').split('|')
85+
return parts[0], parts[1], parts[2]
86+
87+
def _serialize_batch(self, batch: BatchIdentifier) -> bytes:
88+
"""Serialize BatchIdentifier to JSON bytes."""
89+
batch_value_dict = {
90+
'network': batch.network,
91+
'start_block': batch.start_block,
92+
'end_block': batch.end_block,
93+
'end_hash': batch.end_hash,
94+
'start_parent_hash': batch.start_parent_hash,
95+
}
96+
return json.dumps(batch_value_dict).encode('utf-8')
97+
98+
def _serialize_metadata(self, end_block: int, end_hash: str, start_parent_hash: str) -> bytes:
99+
"""Serialize metadata to JSON bytes."""
100+
meta_value_dict = {
101+
'end_block': end_block,
102+
'end_hash': end_hash,
103+
'start_parent_hash': start_parent_hash,
104+
}
105+
return json.dumps(meta_value_dict).encode('utf-8')
106+
107+
def _deserialize_batch(self, value: bytes) -> Dict:
108+
"""Deserialize batch data from JSON bytes."""
109+
return json.loads(value.decode('utf-8'))
110+
111+
def is_processed(self, connection_name: str, table_name: str, batch_ids: List[BatchIdentifier]) -> bool:
112+
"""
113+
Check if all given batches have already been processed.
114+
115+
Args:
116+
connection_name: Connection identifier
117+
table_name: Name of the table being loaded
118+
batch_ids: List of batch identifiers to check
119+
120+
Returns:
121+
True only if ALL batches are already processed
122+
"""
123+
if not batch_ids:
124+
return True
125+
126+
with self.env.begin(db=self.batches_db) as txn:
127+
for batch_id in batch_ids:
128+
key = self._make_batch_key(connection_name, table_name, batch_id.unique_id)
129+
value = txn.get(key)
130+
if value is None:
131+
return False
132+
133+
return True
134+
135+
def mark_processed(self, connection_name: str, table_name: str, batch_ids: List[BatchIdentifier]) -> None:
136+
"""
137+
Mark batches as processed in durable storage.
138+
139+
Atomically updates both batch records and metadata (max block per network).
140+
141+
Args:
142+
connection_name: Connection identifier
143+
table_name: Name of the table being loaded
144+
batch_ids: List of batch identifiers to mark as processed
145+
"""
146+
with self.env.begin(write=True) as txn:
147+
for batch in batch_ids:
148+
batch_key = self._make_batch_key(connection_name, table_name, batch.unique_id)
149+
batch_value = self._serialize_batch(batch)
150+
txn.put(batch_key, batch_value, db=self.batches_db)
151+
152+
meta_key = self._make_metadata_key(connection_name, table_name, batch.network)
153+
current_meta = txn.get(meta_key, db=self.metadata_db)
154+
155+
should_update = False
156+
if current_meta is None:
157+
should_update = True
158+
else:
159+
current_meta_dict = self._deserialize_batch(current_meta)
160+
if batch.end_block > current_meta_dict['end_block']:
161+
should_update = True
162+
163+
if should_update:
164+
meta_value = self._serialize_metadata(batch.end_block, batch.end_hash, batch.start_parent_hash)
165+
txn.put(meta_key, meta_value, db=self.metadata_db)
166+
167+
self.logger.debug(f'Marked {len(batch_ids)} batches as processed in {table_name}')
168+
169+
def get_resume_position(
170+
self, connection_name: str, table_name: str, detect_gaps: bool = False
171+
) -> Optional[ResumeWatermark]:
172+
"""
173+
Get the resume watermark (max processed block per network).
174+
175+
Reads only from metadata database. Does not scan batch records.
176+
177+
Args:
178+
connection_name: Connection identifier
179+
table_name: Destination table name
180+
detect_gaps: If True, detect gaps. Not implemented - raises error.
181+
182+
Returns:
183+
ResumeWatermark with max block ranges for all networks, or None if no state exists
184+
185+
Raises:
186+
NotImplementedError: If detect_gaps=True
187+
"""
188+
if detect_gaps:
189+
raise NotImplementedError('Gap detection not implemented in LMDB state store')
190+
191+
prefix = f'{connection_name}|{table_name}|'.encode('utf-8')
192+
ranges = []
193+
194+
with self.env.begin(db=self.metadata_db) as txn:
195+
cursor = txn.cursor()
196+
197+
if not cursor.set_range(prefix):
198+
return None
199+
200+
for key, value in cursor:
201+
if not key.startswith(prefix):
202+
break
203+
204+
try:
205+
_, _, network = self._parse_key(key)
206+
meta_data = self._deserialize_batch(value)
207+
208+
ranges.append(
209+
BlockRange(
210+
network=network,
211+
start=meta_data['end_block'],
212+
end=meta_data['end_block'],
213+
hash=meta_data.get('end_hash'),
214+
prev_hash=meta_data.get('start_parent_hash'),
215+
)
216+
)
217+
218+
except (json.JSONDecodeError, KeyError) as e:
219+
self.logger.warning(f'Failed to parse metadata: {e}')
220+
continue
221+
222+
if not ranges:
223+
return None
224+
225+
return ResumeWatermark(ranges=ranges)
226+
227+
def invalidate_from_block(
228+
self, connection_name: str, table_name: str, network: str, from_block: int
229+
) -> List[BatchIdentifier]:
230+
"""
231+
Invalidate (delete) all batches from a specific block onwards.
232+
233+
Used for reorg handling to remove invalidated data. Requires full scan
234+
of batches database to find matching batches.
235+
236+
Args:
237+
connection_name: Connection identifier
238+
table_name: Name of the table
239+
network: Network name
240+
from_block: Block number to invalidate from (inclusive)
241+
242+
Returns:
243+
List of BatchIdentifier objects that were invalidated
244+
"""
245+
prefix = f'{connection_name}|{table_name}|'.encode('utf-8')
246+
invalidated_batch_ids = []
247+
keys_to_delete = []
248+
249+
with self.env.begin(db=self.batches_db) as txn:
250+
cursor = txn.cursor()
251+
252+
if not cursor.set_range(prefix):
253+
return []
254+
255+
for key, value in cursor:
256+
if not key.startswith(prefix):
257+
break
258+
259+
try:
260+
batch_data = self._deserialize_batch(value)
261+
262+
if batch_data['network'] == network and batch_data['end_block'] >= from_block:
263+
batch_id = BatchIdentifier(
264+
network=batch_data['network'],
265+
start_block=batch_data['start_block'],
266+
end_block=batch_data['end_block'],
267+
end_hash=batch_data.get('end_hash'),
268+
start_parent_hash=batch_data.get('start_parent_hash'),
269+
)
270+
invalidated_batch_ids.append(batch_id)
271+
keys_to_delete.append(key)
272+
273+
except (json.JSONDecodeError, KeyError) as e:
274+
self.logger.warning(f'Failed to parse batch data during invalidation: {e}')
275+
continue
276+
277+
if keys_to_delete:
278+
with self.env.begin(write=True) as txn:
279+
for key in keys_to_delete:
280+
txn.delete(key, db=self.batches_db)
281+
282+
meta_key = self._make_metadata_key(connection_name, table_name, network)
283+
284+
remaining_batches = []
285+
cursor = txn.cursor(db=self.batches_db)
286+
if cursor.set_range(prefix):
287+
for key, value in cursor:
288+
if not key.startswith(prefix):
289+
break
290+
try:
291+
batch_data = self._deserialize_batch(value)
292+
if batch_data['network'] == network:
293+
remaining_batches.append(batch_data)
294+
except (json.JSONDecodeError, KeyError) as e:
295+
self.logger.warning(f'Failed to parse batch data during metadata recalculation: {e}')
296+
continue
297+
298+
if remaining_batches:
299+
remaining_batches.sort(key=lambda b: b['end_block'])
300+
max_batch = remaining_batches[-1]
301+
meta_value = self._serialize_metadata(
302+
max_batch['end_block'],
303+
max_batch.get('end_hash'),
304+
max_batch.get('start_parent_hash')
305+
)
306+
txn.put(meta_key, meta_value, db=self.metadata_db)
307+
else:
308+
txn.delete(meta_key, db=self.metadata_db)
309+
310+
self.logger.info(
311+
f'Invalidated {len(invalidated_batch_ids)} batches from block {from_block} '
312+
f'on {network} in {table_name}'
313+
)
314+
315+
return invalidated_batch_ids
316+
317+
def cleanup_before_block(self, connection_name: str, table_name: str, network: str, before_block: int) -> None:
318+
"""
319+
Clean up old batch records before a specific block.
320+
321+
Removes batches where end_block < before_block. Requires full scan
322+
to find matching batches for the given network.
323+
324+
Args:
325+
connection_name: Connection identifier
326+
table_name: Name of the table
327+
network: Network name
328+
before_block: Block number to clean up before (exclusive)
329+
"""
330+
prefix = f'{connection_name}|{table_name}|'.encode('utf-8')
331+
keys_to_delete = []
332+
333+
with self.env.begin(db=self.batches_db) as txn:
334+
cursor = txn.cursor()
335+
336+
if not cursor.set_range(prefix):
337+
return
338+
339+
for key, value in cursor:
340+
if not key.startswith(prefix):
341+
break
342+
343+
try:
344+
batch_data = self._deserialize_batch(value)
345+
346+
if batch_data['network'] == network and batch_data['end_block'] < before_block:
347+
keys_to_delete.append(key)
348+
349+
except (json.JSONDecodeError, KeyError) as e:
350+
self.logger.warning(f'Failed to parse batch data during cleanup: {e}')
351+
continue
352+
353+
if keys_to_delete:
354+
with self.env.begin(write=True, db=self.batches_db) as txn:
355+
for key in keys_to_delete:
356+
txn.delete(key)
357+
358+
self.logger.info(
359+
f'Cleaned up {len(keys_to_delete)} old batches before block {before_block} '
360+
f'on {network} in {table_name}'
361+
)
362+
363+
def close(self) -> None:
364+
"""Close the LMDB environment."""
365+
if self.env:
366+
self.env.close()
367+
self.logger.info('Closed LMDB state store')
368+
369+
def __enter__(self) -> 'LMDBStreamStateStore':
370+
"""Context manager entry."""
371+
return self
372+
373+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
374+
"""Context manager exit."""
375+
self.close()

0 commit comments

Comments
 (0)