前言
在使用並研究Kafka的時候,發現它在Producer的階段,會對publish進行優化:並不是一次一個message publish到kafka broke,而是一個batch、一個batch來傳入。這樣的方式可以優化I/O次數、以及處理效率。
因此,本文想簡化並抽象出這樣的batch處理方式。
入參
查看 python kafka SDK,可以看到關於batch的主要是兩個入參:batch_size
和 linger_ms
,這兩個參數為指定batch要在什麼條件下執行:
- 當batch size大於某個數的時候
- 當publish到一段時間後
因此,先初始化一個 Python BatchProcessor 物件:
class BatchProcessor:
def __init__(self, batch_size=0, linger_ms=0):
self.batch_size = batch_size
self.linger_ms = linger_ms
初始化過程
再來,需要一個queue來當作 task/message 的buffer,同時必須保證 thread safty,不然在執行 push/pop 的時候可能會造成race condition
class BatchProcessor:
def __init__(self, batch_size=0, linger_ms=0):
self.batch_size = batch_size
self.linger_ms = linger_ms
self.queue = asyncio.Queue() # ensure thread safty
然後需要一個flag來控制 BatchProcessor 的狀態:是要發布還是不要發布
class BatchProcessor:
def __init__(self, batch_size=0, linger_ms=0):
# ...
self._batch_event_flag = asyncio.Event()
然後就是啟動監聽 執行時間有沒有超過 self.linger_ms、以及buffer queue有沒有超過 self.batch_size
監聽執行時間
這個部分就蠻簡單的,直接在 while 循環中暫停 self.linger_ms
的時間,下一步就執行 self._batch_event_flag.set()
,以用來啟動batch process
async def _batch_timer(self):
while True:
await asyncio.sleep(self.linger_ms)
self._batch_event_flag.set()
監聽batch size
由於在 execute的時候會對queue 產生變化,所以就直接在 execute 函數中檢查 queue.qsize():
async def execute(self, task):
await self.queue.put(task)
if self.queue.qsize() >= self.batch_size:
self._batch_event.set()
執行batch
主要就是在 while循環中監聽self._batch_event_flag
,然後把 queue buffer裡面的task/message 給灌到 batch 中,進而執行self._process_batch(batch)
async def _batch_processor(self):
while True:
await self._batch_event_flag.wait()
batch = []
while not self.queue.empty() and len(batch) < self.batch_size:
batch.append(await self.queue.get())
if batch:
self._process_batch(batch)
# 重置事件
self. _batch_event_flag.clear()
而這個 self._process_batch(batch)
,可以看業務需求,要如何處理task/message,來決定其中的邏輯、或者要不要是異步函數。
把監聽函數執行在 BatchProcessor 初始化時
把兩個監聽函數self._batch_processor
和 self._batch_timer
利用 asyncio.create_task
來執行在BatchProcessor 初始化的時候
全部程式碼
import asyncio
class BatchProcessor:
def __init__(self, batch_size=0, linger_ms=0):
self.queue = asyncio.Queue() # ensure thread safty
self.batch_size = batch_size
self.linger_ms = linger_ms
self._batch_event_flag = asyncio.Event()
asyncio.create_task(self._batch_processor())
asyncio.create_task(self._batch_timer())
async def execute(self, task):
await self.queue.put(task)
if self.queue.qsize() >= self.batch_size:
self._batch_event_flag.set()
async def _batch_timer(self):
while True:
await asyncio.sleep(self.linger_ms)
self._batch_event_flag.set()
async def _batch_processor(self):
while True:
await self._batch_event_flag.wait()
batch = []
while not self.queue.empty() and len(batch) < self.batch_size:
batch.append(await self.queue.get())
if batch:
self._process_batch(batch)
# reset
self._batch_event_flag.clear()
def _process_batch(self, batch):
print(f"data:{batch}")
async def main():
processor = BatchProcessor(batch_size=2,linger_ms=5)
for i in range(100):
await asyncio.sleep(1)
await processor.execute(f"task {i}")
if __name__ == "__main__":
asyncio.run(main())
ChangeLog
- 20240706–初稿