make ArrowWriter thread safe

This commit is contained in:
Kevin Hester
2024-07-11 11:48:53 -07:00
parent 628a4cb9be
commit 7e007e7e24

View File

@@ -1,6 +1,7 @@
"""Utilities for Apache Arrow serialization."""
import logging
import threading
import os
from typing import Optional
@@ -22,13 +23,15 @@ class ArrowWriter:
self.new_rows: list[dict] = []
self.schema: Optional[pa.Schema] = None # haven't yet learned the schema
self.writer: Optional[pa.RecordBatchStreamWriter] = None
self._lock = threading.Condition() # Ensure only one thread writes at a time
def close(self):
"""Close the stream and writes the file as needed."""
self._write()
if self.writer:
self.writer.close()
self.sink.close()
with self._lock:
self._write()
if self.writer:
self.writer.close()
self.sink.close()
def set_schema(self, schema: pa.Schema):
"""Set the schema for the file.
@@ -36,9 +39,10 @@ class ArrowWriter:
schema (pa.Schema): The schema to use.
"""
assert self.schema is None
self.schema = schema
self.writer = pa.ipc.new_stream(self.sink, schema)
with self._lock:
assert self.schema is None
self.schema = schema
self.writer = pa.ipc.new_stream(self.sink, schema)
def _write(self):
"""Write the new rows to the file."""
@@ -56,9 +60,10 @@ class ArrowWriter:
"""Add a row to the arrow file.
We will automatically learn the schema from the first row. But all rows must use that schema.
"""
self.new_rows.append(row_dict)
if len(self.new_rows) >= chunk_size:
self._write()
with self._lock:
self.new_rows.append(row_dict)
if len(self.new_rows) >= chunk_size:
self._write()
class FeatherWriter(ArrowWriter):