Source code for vastdb.util

import logging
import re
from typing import TYPE_CHECKING, Callable, List, Optional, Union

import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.parquet as pq

from .errors import InvalidArgument, TooWideRow

log = logging.getLogger(__name__)

if TYPE_CHECKING:
    from .schema import Schema
    from .table import ImportConfig, Table


[docs] def create_table_from_files( schema: "Schema", table_name: str, parquet_files: List[str], schema_merge_func: Optional[Callable] = None, config: Optional["ImportConfig"] = None) -> "Table": if not schema_merge_func: schema_merge_func = default_schema_merge else: assert schema_merge_func in [default_schema_merge, strict_schema_merge, union_schema_merge] tx = schema.tx current_schema = pa.schema([]) s3fs = pa.fs.S3FileSystem( access_key=tx._rpc.api.access_key, secret_key=tx._rpc.api.secret_key, endpoint_override=tx._rpc.api.url) for prq_file in parquet_files: if not prq_file.startswith('/'): raise InvalidArgument(f"Path {prq_file} must start with a '/'") parquet_ds = pq.ParquetDataset(prq_file.lstrip('/'), filesystem=s3fs) current_schema = schema_merge_func(current_schema, parquet_ds.schema) log.info("Creating table %s from %d Parquet files, with columns: %s", table_name, len(parquet_files), list(current_schema)) table = schema.create_table(table_name, current_schema) log.info("Starting import of %d files to table: %s", len(parquet_files), table) table.import_files(parquet_files, config=config) log.info("Finished import of %d files to table: %s", len(parquet_files), table) return table
[docs] def default_schema_merge(current_schema: pa.Schema, new_schema: pa.Schema) -> pa.Schema: """ This function validates a schema is contained in another schema Raises an InvalidArgument if a certain field does not exist in the target schema """ if not current_schema.names: return new_schema s1 = set(current_schema) s2 = set(new_schema) if len(s1) > len(s2): s1, s2 = s2, s1 result = current_schema # We need this variable in order to preserve the original fields order else: result = new_schema if not s1.issubset(s2): log.error("Schema mismatch. schema: %s isn't contained in schema: %s.", s1, s2) raise InvalidArgument("Found mismatch in parquet files schemas.") return result
[docs] def strict_schema_merge(current_schema: pa.Schema, new_schema: pa.Schema) -> pa.Schema: """ This function validates two Schemas are identical. Raises an InvalidArgument if schemas aren't identical. """ if current_schema.names and current_schema != new_schema: raise InvalidArgument(f"Schemas are not identical. \n {current_schema} \n vs \n {new_schema}") return new_schema
[docs] def union_schema_merge(current_schema: pa.Schema, new_schema: pa.Schema) -> pa.Schema: """ This function returns a unified schema from potentially two different schemas. """ return pa.unify_schemas([current_schema, new_schema])
MAX_TABULAR_REQUEST_SIZE = 5 << 20 # in bytes MAX_RECORD_BATCH_SLICE_SIZE = int(0.9 * MAX_TABULAR_REQUEST_SIZE) MAX_QUERY_DATA_REQUEST_SIZE = int(0.9 * MAX_TABULAR_REQUEST_SIZE)
[docs] def iter_serialized_slices(batch: Union[pa.RecordBatch, pa.Table], max_rows_per_slice=None): """Iterate over a list of record batch slices.""" if batch.nbytes: rows_per_slice = int(0.9 * len(batch) * MAX_RECORD_BATCH_SLICE_SIZE / batch.nbytes) else: rows_per_slice = len(batch) # if the batch has no buffers (no rows/columns) if max_rows_per_slice is not None: rows_per_slice = min(rows_per_slice, max_rows_per_slice) offset = 0 while offset < len(batch): if rows_per_slice < 1: raise TooWideRow(batch) batch_slice = batch.slice(offset, rows_per_slice) serialized_slice_batch = serialize_record_batch(batch_slice) if len(serialized_slice_batch) <= MAX_RECORD_BATCH_SLICE_SIZE: yield serialized_slice_batch offset += rows_per_slice else: rows_per_slice = rows_per_slice // 2
[docs] def serialize_record_batch(batch: Union[pa.RecordBatch, pa.Table]): """Serialize a RecordBatch using Arrow IPC format.""" if isinstance(batch, pa.Table): if len(batch.to_batches()) > 1: # the server expects a single RecordBatch per request batch = batch.combine_chunks() sink = pa.BufferOutputStream() with pa.ipc.new_stream(sink, batch.schema) as writer: writer.write(batch) return sink.getvalue()
[docs] def expand_ip_ranges(endpoints): """Expands endpoint strings that include an IP range in the format 'http://172.19.101.1-16'.""" expanded_endpoints = [] pattern = re.compile(r"(http://\d+\.\d+\.\d+)\.(\d+)-(\d+)") for endpoint in endpoints: match = pattern.match(endpoint) if match: base_url = match.group(1) start_ip = int(match.group(2)) end_ip = int(match.group(3)) if start_ip > end_ip: raise ValueError("Start IP cannot be greater than end IP in the range.") expanded_endpoints.extend(f"{base_url}.{ip}" for ip in range(start_ip, end_ip + 1)) else: expanded_endpoints.append(endpoint) return expanded_endpoints
[docs] def is_sorted(arr): """Check if the array is sorted.""" return pc.all(pc.greater(arr[1:], arr[:-1])).as_py()
[docs] def sort_record_batch_if_needed(record_batch, sort_column): """Sort the RecordBatch by the specified column if it is not already sorted.""" column_data = record_batch[sort_column] if not is_sorted(column_data): return record_batch.sort_by(sort_column) else: return record_batch