"""VAST Database table."""
import concurrent.futures
import logging
import os
import queue
from dataclasses import dataclass, field
from math import ceil
from threading import Event
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import ibis
import pyarrow as pa
import urllib3
from . import _internal, errors, schema, util
from .config import ImportConfig, QueryConfig
log = logging.getLogger(__name__)
INTERNAL_ROW_ID = "$row_id"
INTERNAL_ROW_ID_FIELD = pa.field(INTERNAL_ROW_ID, pa.uint64())
INTERNAL_ROW_ID_SORTED_FIELD = pa.field(INTERNAL_ROW_ID, pa.decimal128(38, 0)) # Sorted tables have longer row ids
MAX_ROWS_PER_BATCH = 512 * 1024
# for insert we need a smaller limit due to response amplification
# for example insert of 512k uint8 result in 512k*8bytes response since row_ids are uint64
MAX_INSERT_ROWS_PER_PATCH = 512 * 1024
# in case insert has TooWideRow - need to insert in smaller batches - each cell could contain up to 128K, and our wire is limited to 5MB
MAX_COLUMN_IN_BATCH = int(5 * 1024 / 128)
SORTING_SCORE_BITS = 63
[docs]
@dataclass
class TableStats:
"""Table-related information."""
num_rows: int
size_in_bytes: int
sorting_score: int
write_amplification: int
acummulative_row_inserition_count: int
is_external_rowid_alloc: bool = False
sorting_key_enabled: bool = False
sorting_done: bool = False
endpoints: Tuple[str, ...] = ()
[docs]
class SelectSplitState:
"""State of a specific query split execution."""
def __init__(self, query_data_request, table: "Table", split_id: int, config: QueryConfig) -> None:
"""Initialize query split state."""
self.split_id = split_id
self.subsplits_state = {i: 0 for i in range(config.num_sub_splits)}
self.config = config
self.query_data_request = query_data_request
self.table = table
[docs]
def process_split(self, api: _internal.VastdbApi, record_batches_queue: queue.Queue[pa.RecordBatch], check_stop: Callable):
"""Execute a sequence of QueryData requests, and queue the parsed RecordBatch objects.
Can be called repeatedly, to support resuming the query after a disconnection / retriable error.
"""
try:
# contains RecordBatch parts received from the server, must be re-created in case of a retry
while not self.done:
# raises if request parsing fails or throttled due to server load, and will be externally retried
response = api.query_data(
bucket=self.table.bucket.name,
schema=self.table.schema.name,
table=self.table.name,
params=self.query_data_request.serialized,
split=(self.split_id, self.config.num_splits, self.config.num_row_groups_per_sub_split),
num_sub_splits=self.config.num_sub_splits,
response_row_id=False,
txid=self.table.tx.txid,
limit_rows=self.config.limit_rows_per_sub_split,
sub_split_start_row_ids=self.subsplits_state.items(),
schedule_id=self.config.queue_priority,
enable_sorted_projections=self.config.use_semi_sorted_projections,
query_imports_table=self.table._imports_table,
projection=self.config.semi_sorted_projection_name)
# can raise during response parsing (e.g. due to disconnections), and will be externally retried
# the pagination state is stored in `self.subsplits_state` and must be correct in case of a reconnection
# the partial RecordBatch chunks are managed internally in `parse_query_data_response`
response_iter = _internal.parse_query_data_response(
conn=response.raw,
schema=self.query_data_request.response_schema,
parser=self.query_data_request.response_parser)
for stream_id, next_row_id, table_chunk in response_iter:
# in case of I/O error, `response_iter` will be closed and an appropriate exception will be thrown.
self.subsplits_state[stream_id] = next_row_id
# we have parsed a pyarrow.Table successfully, self.subsplits_state is now correctly updated
# if the below loop fails, the query is not retried
for batch in table_chunk.to_batches():
check_stop() # may raise StoppedException to early-exit the query (without retries)
if batch:
record_batches_queue.put(batch)
except urllib3.exceptions.ProtocolError as err:
log.warning("Failed parsing QueryData response table=%r split=%s/%s offsets=%s cause=%s",
self.table, self.split_id, self.config.num_splits, self.subsplits_state, err)
# since this is a read-only idempotent operation, it is safe to retry
raise errors.ConnectionError(cause=err, may_retry=True)
@property
def done(self):
"""Returns true iff the pagination over."""
return all(row_id == _internal.TABULAR_INVALID_ROW_ID for row_id in self.subsplits_state.values())
[docs]
@dataclass
class Table:
"""VAST Table."""
name: str
schema: "schema.Schema"
handle: int
arrow_schema: pa.Schema = field(init=False, compare=False, repr=False)
_ibis_table: ibis.Schema = field(init=False, compare=False, repr=False)
_imports_table: bool
sorted_table: bool
def __post_init__(self):
"""Also, load columns' metadata."""
self.arrow_schema = self.columns()
self._table_path = f'{self.schema.bucket.name}/{self.schema.name}/{self.name}'
self._ibis_table = ibis.table(ibis.Schema.from_pyarrow(self.arrow_schema), self._table_path)
@property
def path(self):
"""Return table's path."""
return self._table_path
@property
def tx(self):
"""Return transaction."""
return self.schema.tx
@property
def bucket(self):
"""Return bucket."""
return self.schema.bucket
@property
def stats(self):
"""Fetch table's statistics from server."""
return self.get_stats()
[docs]
def columns(self) -> pa.Schema:
"""Return columns' metadata."""
fields = []
next_key = 0
while True:
cur_columns, next_key, is_truncated, _count = self.tx._rpc.api.list_columns(
bucket=self.bucket.name, schema=self.schema.name, table=self.name, next_key=next_key, txid=self.tx.txid, list_imports_table=self._imports_table)
fields.extend(cur_columns)
if not is_truncated:
break
self.arrow_schema = pa.schema(fields)
return self.arrow_schema
[docs]
def sorted_columns(self) -> list:
"""Return sorted columns' metadata."""
fields = []
try:
self.tx._rpc.features.check_elysium()
next_key = 0
while True:
cur_columns, next_key, is_truncated, _count = self.tx._rpc.api.list_sorted_columns(
bucket=self.bucket.name, schema=self.schema.name, table=self.name, next_key=next_key, txid=self.tx.txid, list_imports_table=self._imports_table)
fields.extend(cur_columns)
if not is_truncated:
break
except errors.BadRequest:
pass
except errors.InternalServerError as ise:
log.warning("Failed to get the sorted columns Elysium might not be supported: %s", ise)
pass
except errors.NotSupportedVersion:
log.warning("Failed to get the sorted columns, Elysium not supported")
pass
return fields
[docs]
def projection(self, name: str) -> "Projection":
"""Get a specific semi-sorted projection of this table."""
if self._imports_table:
raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
projs = tuple(self.projections(projection_name=name))
if not projs:
raise errors.MissingProjection(self.bucket.name, self.schema.name, self.name, name)
assert len(projs) == 1, f"Expected to receive only a single projection, but got: {len(projs)}. projections: {projs}"
log.debug("Found projection: %s", projs[0])
return projs[0]
[docs]
def projections(self, projection_name: str = "") -> Iterable["Projection"]:
"""List all semi-sorted projections of this table if `projection_name` is empty.
Otherwise, list only the specific projection (if exists).
"""
if self._imports_table:
raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
projections = []
next_key = 0
name_prefix = projection_name if projection_name else ""
exact_match = bool(projection_name)
while True:
_bucket_name, _schema_name, _table_name, curr_projections, next_key, is_truncated, _ = \
self.tx._rpc.api.list_projections(
bucket=self.bucket.name, schema=self.schema.name, table=self.name, next_key=next_key, txid=self.tx.txid,
exact_match=exact_match, name_prefix=name_prefix)
if not curr_projections:
break
projections.extend(curr_projections)
if not is_truncated:
break
return [_parse_projection_info(projection, self) for projection in projections]
[docs]
def import_files(self, files_to_import: Iterable[str], config: Optional[ImportConfig] = None) -> None:
"""Import a list of Parquet files into this table.
The files must be on VAST S3 server and be accessible using current credentials.
"""
if self._imports_table:
raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
source_files = {}
for f in files_to_import:
bucket_name, object_path = _parse_bucket_and_object_names(f)
source_files[(bucket_name, object_path)] = b''
self._execute_import(source_files, config=config)
[docs]
def import_partitioned_files(self, files_and_partitions: Dict[str, pa.RecordBatch], config: Optional[ImportConfig] = None) -> None:
"""Import a list of Parquet files into this table.
The files must be on VAST S3 server and be accessible using current credentials.
Each file must have its own partition values defined as an Arrow RecordBatch.
"""
if self._imports_table:
raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
source_files = {}
for f, record_batch in files_and_partitions.items():
bucket_name, object_path = _parse_bucket_and_object_names(f)
serialized_batch = _serialize_record_batch(record_batch)
source_files = {(bucket_name, object_path): serialized_batch.to_pybytes()}
self._execute_import(source_files, config=config)
def _execute_import(self, source_files, config):
config = config or ImportConfig()
assert config.import_concurrency > 0 # TODO: Do we want to validate concurrency isn't too high?
max_batch_size = 10 # Enforced in server side.
endpoints = [self.tx._rpc.api.url for _ in range(config.import_concurrency)] # TODO: use valid endpoints...
files_queue = queue.Queue()
key_names = config.key_names or []
if key_names:
self.tx._rpc.features.check_zip_import()
for source_file in source_files.items():
files_queue.put(source_file)
stop_event = Event()
num_files_in_batch = min(ceil(len(source_files) / len(endpoints)), max_batch_size)
def import_worker(q, endpoint):
try:
with self.tx._rpc.api.with_endpoint(endpoint) as session:
while not q.empty():
if stop_event.is_set():
log.debug("stop_event is set, exiting")
break
files_batch = {}
try:
for _ in range(num_files_in_batch):
files_batch.update({q.get(block=False)})
except queue.Empty:
pass
if files_batch:
log.info("Starting import batch of %s files", len(files_batch))
log.debug(f"starting import of {files_batch}")
session.import_data(
self.bucket.name, self.schema.name, self.name, files_batch, txid=self.tx.txid,
key_names=key_names)
except (Exception, KeyboardInterrupt) as e:
stop_event.set()
log.error("Got exception inside import_worker. exception: %s", e)
raise
futures = []
with concurrent.futures.ThreadPoolExecutor(
max_workers=config.import_concurrency, thread_name_prefix='import_thread') as pool:
try:
for endpoint in endpoints:
futures.append(pool.submit(import_worker, files_queue, endpoint))
log.debug("Waiting for import workers to finish")
for future in concurrent.futures.as_completed(futures):
future.result()
finally:
stop_event.set()
# ThreadPoolExecutor will be joined at the end of the context
[docs]
def get_stats(self) -> TableStats:
"""Get the statistics of this table."""
stats_tuple = self.tx._rpc.api.get_table_stats(
bucket=self.bucket.name, schema=self.schema.name, name=self.name, txid=self.tx.txid,
imports_table_stats=self._imports_table)
return TableStats(**stats_tuple._asdict())
def _get_row_estimate(self, columns: List[str], predicate: ibis.expr.types.BooleanColumn, arrow_schema: pa.Schema):
query_data_request = _internal.build_query_data_request(
schema=arrow_schema,
predicate=predicate,
field_names=columns)
response = self.tx._rpc.api.query_data(
bucket=self.bucket.name,
schema=self.schema.name,
table=self.name,
params=query_data_request.serialized,
split=(0xffffffff - 3, 1, 1),
txid=self.tx.txid)
batch = _internal.read_first_batch(response.raw)
return batch.num_rows * 2**16 if batch is not None else 0
[docs]
def select(self, columns: Optional[List[str]] = None,
predicate: Union[ibis.expr.types.BooleanColumn, ibis.common.deferred.Deferred] = None,
config: Optional[QueryConfig] = None,
*,
internal_row_id: bool = False) -> pa.RecordBatchReader:
"""Execute a query over this table.
To read a subset of the columns, specify their names via `columns` argument. Otherwise, all columns will be read.
In order to apply a filter, a predicate can be specified. See https://github.com/vast-data/vastdb_sdk/blob/main/README.md#filters-and-projections for more details.
Query-execution configuration options can be specified via the optional `config` argument.
"""
if config is None:
config = QueryConfig()
stats = None
# Retrieve snapshots only if needed
if config.data_endpoints is None:
stats = self.get_stats()
log.debug("stats: %s", stats)
endpoints = stats.endpoints
else:
endpoints = tuple(config.data_endpoints)
log.debug("endpoints: %s", endpoints)
if columns is None:
columns = [f.name for f in self.arrow_schema]
query_schema = self.arrow_schema
if internal_row_id:
queried_fields = [INTERNAL_ROW_ID_SORTED_FIELD if self.sorted_table else INTERNAL_ROW_ID_FIELD]
queried_fields.extend(column for column in self.arrow_schema)
query_schema = pa.schema(queried_fields)
columns.append(INTERNAL_ROW_ID)
if predicate is True:
predicate = None
if predicate is False:
response_schema = _internal.get_response_schema(schema=query_schema, field_names=columns)
return pa.RecordBatchReader.from_batches(response_schema, [])
if isinstance(predicate, ibis.common.deferred.Deferred):
predicate = predicate.resolve(self._ibis_table) # may raise if the predicate is invalid (e.g. wrong types / missing column)
if config.num_splits is None:
num_rows = 0
if self.sorted_table:
num_rows = self._get_row_estimate(columns, predicate, query_schema)
log.debug(f'sorted estimate: {num_rows}')
if num_rows == 0:
if stats is None:
stats = self.get_stats()
num_rows = stats.num_rows
config.num_splits = max(1, num_rows // config.rows_per_split)
log.debug("config: %s", config)
if config.semi_sorted_projection_name:
self.tx._rpc.features.check_enforce_semisorted_projection()
query_data_request = _internal.build_query_data_request(
schema=query_schema,
predicate=predicate,
field_names=columns)
if len(query_data_request.serialized) > util.MAX_QUERY_DATA_REQUEST_SIZE:
raise errors.TooLargeRequest(f"{len(query_data_request.serialized)} bytes")
splits_queue: queue.Queue[int] = queue.Queue()
for split in range(config.num_splits):
splits_queue.put(split)
# this queue shouldn't be large it is marely a pipe through which the results
# are sent to the main thread. Most of the pages actually held in the
# threads that fetch the pages.
record_batches_queue: queue.Queue[pa.RecordBatch] = queue.Queue(maxsize=2)
stop_event = Event()
class StoppedException(Exception):
pass
def check_stop():
if stop_event.is_set():
raise StoppedException
def single_endpoint_worker(endpoint: str):
try:
with self.tx._rpc.api.with_endpoint(endpoint) as host_api:
backoff_decorator = self.tx._rpc.api._backoff_decorator
while True:
check_stop()
try:
split = splits_queue.get_nowait()
except queue.Empty:
log.debug("splits queue is empty")
break
split_state = SelectSplitState(query_data_request=query_data_request,
table=self,
split_id=split,
config=config)
process_with_retries = backoff_decorator(split_state.process_split)
process_with_retries(host_api, record_batches_queue, check_stop)
except StoppedException:
log.debug("stop signal.", exc_info=True)
return
finally:
# signal that this thread has ended
log.debug("exiting")
record_batches_queue.put(None)
def batches_iterator():
def propagate_first_exception(futures: List[concurrent.futures.Future], block=False):
done, not_done = concurrent.futures.wait(futures, None if block else 0, concurrent.futures.FIRST_EXCEPTION)
if self.tx.txid is None:
raise errors.MissingTransaction()
for future in done:
future.result()
return not_done
threads_prefix = "query-data"
# This is mainly for testing, it helps to identify running threads in runtime.
if config.query_id:
threads_prefix = threads_prefix + "-" + config.query_id
with concurrent.futures.ThreadPoolExecutor(max_workers=len(endpoints), thread_name_prefix=threads_prefix) as tp: # TODO: concurrency == enpoints is just a heuristic
futures = [tp.submit(single_endpoint_worker, endpoint) for endpoint in endpoints]
tasks_running = len(futures)
try:
while tasks_running > 0:
futures = propagate_first_exception(futures, block=False)
batch = record_batches_queue.get()
if batch is not None:
yield batch
else:
tasks_running -= 1
log.debug("one worker thread finished, remaining: %d", tasks_running)
# all host threads ended - wait for all futures to complete
propagate_first_exception(futures, block=True)
finally:
stop_event.set()
while tasks_running > 0:
if record_batches_queue.get() is None:
tasks_running -= 1
return pa.RecordBatchReader.from_batches(query_data_request.response_schema, batches_iterator())
[docs]
def insert_in_column_batches(self, rows: pa.RecordBatch):
"""Split the RecordBatch into max_columns that can be inserted in single RPC.
Insert first MAX_COLUMN_IN_BATCH columns and get the row_ids. Then loop on the rest of the columns and
update in groups of MAX_COLUMN_IN_BATCH.
"""
column_record_batch = pa.RecordBatch.from_arrays([_combine_chunks(rows.column(i)) for i in range(0, MAX_COLUMN_IN_BATCH)],
schema=pa.schema([rows.schema.field(i) for i in range(0, MAX_COLUMN_IN_BATCH)]))
row_ids = self.insert(rows=column_record_batch) # type: ignore
columns_names = [field.name for field in rows.schema]
columns = list(rows.schema)
arrays = [_combine_chunks(rows.column(i)) for i in range(len(rows.schema))]
for start in range(MAX_COLUMN_IN_BATCH, len(rows.schema), MAX_COLUMN_IN_BATCH):
end = start + MAX_COLUMN_IN_BATCH if start + MAX_COLUMN_IN_BATCH < len(rows.schema) else len(rows.schema)
columns_name_chunk = columns_names[start:end]
columns_chunks = columns[start:end]
arrays_chunks = arrays[start:end]
columns_chunks.append(INTERNAL_ROW_ID_FIELD)
arrays_chunks.append(row_ids.to_pylist())
column_record_batch = pa.RecordBatch.from_arrays(arrays_chunks, schema=pa.schema(columns_chunks))
self.update(rows=column_record_batch, columns=columns_name_chunk)
return row_ids
[docs]
def insert(self, rows: Union[pa.RecordBatch, pa.Table]):
"""Insert a RecordBatch into this table."""
if self._imports_table:
raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
try:
row_ids = []
serialized_slices = util.iter_serialized_slices(rows, MAX_INSERT_ROWS_PER_PATCH)
for slice in serialized_slices:
res = self.tx._rpc.api.insert_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
txid=self.tx.txid)
(batch,) = pa.RecordBatchStreamReader(res.content)
row_ids.append(batch[INTERNAL_ROW_ID])
try:
self.tx._rpc.features.check_return_row_ids()
except errors.NotSupportedVersion:
return # type: ignore
return pa.chunked_array(row_ids)
except errors.TooWideRow:
self.tx._rpc.features.check_return_row_ids()
return self.insert_in_column_batches(rows)
[docs]
def update(self, rows: Union[pa.RecordBatch, pa.Table], columns: Optional[List[str]] = None) -> None:
"""Update a subset of cells in this table.
Row IDs are specified using a special field (named "$row_id" of uint64 type) - this function assume that this
special field is part of arguments.
A subset of columns to be updated can be specified via the `columns` argument.
"""
if self._imports_table:
raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
try:
rows_chunk = rows[INTERNAL_ROW_ID]
except KeyError:
raise errors.MissingRowIdColumn
if columns is None:
columns = [name for name in rows.schema.names if name != INTERNAL_ROW_ID]
update_fields = [INTERNAL_ROW_ID_SORTED_FIELD if self.sorted_table else INTERNAL_ROW_ID_FIELD]
update_values = [_combine_chunks(rows_chunk)]
for col in columns:
update_fields.append(rows.field(col))
update_values.append(_combine_chunks(rows[col]))
update_rows_rb = pa.record_batch(schema=pa.schema(update_fields), data=update_values)
update_rows_rb = util.sort_record_batch_if_needed(update_rows_rb, INTERNAL_ROW_ID)
serialized_slices = util.iter_serialized_slices(update_rows_rb, MAX_ROWS_PER_BATCH)
for slice in serialized_slices:
self.tx._rpc.api.update_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
txid=self.tx.txid)
[docs]
def delete(self, rows: Union[pa.RecordBatch, pa.Table]) -> None:
"""Delete a subset of rows in this table.
Row IDs are specified using a special field (named "$row_id" of uint64 type).
"""
if self._imports_table:
raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
try:
rows_chunk = rows[INTERNAL_ROW_ID]
except KeyError:
raise errors.MissingRowIdColumn
delete_rows_rb = pa.record_batch(schema=pa.schema([INTERNAL_ROW_ID_SORTED_FIELD if self.sorted_table else INTERNAL_ROW_ID_FIELD]),
data=[_combine_chunks(rows_chunk)])
delete_rows_rb = util.sort_record_batch_if_needed(delete_rows_rb, INTERNAL_ROW_ID)
serialized_slices = util.iter_serialized_slices(delete_rows_rb, MAX_ROWS_PER_BATCH)
for slice in serialized_slices:
self.tx._rpc.api.delete_rows(self.bucket.name, self.schema.name, self.name, record_batch=slice,
txid=self.tx.txid, delete_from_imports_table=self._imports_table)
[docs]
def drop(self) -> None:
"""Drop this table."""
self.tx._rpc.api.drop_table(self.bucket.name, self.schema.name, self.name, txid=self.tx.txid, remove_imports_table=self._imports_table)
log.info("Dropped table: %s", self.name)
[docs]
def rename(self, new_name: str) -> None:
"""Rename this table."""
if self._imports_table:
raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
self.tx._rpc.api.alter_table(
self.bucket.name, self.schema.name, self.name, txid=self.tx.txid, new_name=new_name)
log.info("Renamed table from %s to %s ", self.name, new_name)
self.name = new_name
[docs]
def add_sorting_key(self, sorting_key: list) -> None:
"""Ads a sorting key to a table that doesn't have any."""
self.tx._rpc.features.check_elysium()
self.tx._rpc.api.alter_table(
self.bucket.name, self.schema.name, self.name, txid=self.tx.txid, sorting_key=sorting_key)
log.info("Enabled Elysium for table %s with sorting key %s ", self.name, str(sorting_key))
[docs]
def add_column(self, new_column: pa.Schema) -> None:
"""Add a new column."""
if self._imports_table:
raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
self.tx._rpc.api.add_columns(self.bucket.name, self.schema.name, self.name, new_column, txid=self.tx.txid)
log.info("Added column(s): %s", new_column)
self.arrow_schema = self.columns()
[docs]
def drop_column(self, column_to_drop: pa.Schema) -> None:
"""Drop an existing column."""
if self._imports_table:
raise errors.NotSupported(self.bucket.name, self.schema.name, self.name)
if self._imports_table:
raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
self.tx._rpc.api.drop_columns(self.bucket.name, self.schema.name, self.name, column_to_drop, txid=self.tx.txid)
log.info("Dropped column(s): %s", column_to_drop)
self.arrow_schema = self.columns()
[docs]
def rename_column(self, current_column_name: str, new_column_name: str) -> None:
"""Rename an existing column."""
if self._imports_table:
raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
self.tx._rpc.api.alter_column(self.bucket.name, self.schema.name, self.name, name=current_column_name,
new_name=new_column_name, txid=self.tx.txid)
log.info("Renamed column: %s to %s", current_column_name, new_column_name)
self.arrow_schema = self.columns()
[docs]
def create_projection(self, projection_name: str, sorted_columns: List[str], unsorted_columns: List[str]) -> "Projection":
"""Create a new semi-sorted projection."""
if self._imports_table:
raise errors.NotSupportedCommand(self.bucket.name, self.schema.name, self.name)
columns = [(sorted_column, "Sorted") for sorted_column in sorted_columns] + [(unsorted_column, "Unorted") for unsorted_column in unsorted_columns]
self.tx._rpc.api.create_projection(self.bucket.name, self.schema.name, self.name, projection_name, columns=columns, txid=self.tx.txid)
log.info("Created projection: %s", projection_name)
return self.projection(projection_name)
[docs]
def create_imports_table(self, fail_if_exists=True) -> "Table":
"""Create imports table."""
self.tx._rpc.features.check_imports_table()
empty_schema = pa.schema([])
self.tx._rpc.api.create_table(self.bucket.name, self.schema.name, self.name, empty_schema, txid=self.tx.txid,
create_imports_table=True)
log.info("Created imports table for table: %s", self.name)
return self.imports_table() # type: ignore[return-value]
[docs]
def imports_table(self) -> Optional["Table"]:
"""Get the imports table of this table."""
self.tx._rpc.features.check_imports_table()
return Table(name=self.name, schema=self.schema, handle=int(self.handle), _imports_table=True, sorted_table=self.sorted_table)
def __getitem__(self, col_name: str):
"""Allow constructing ibis-like column expressions from this table.
It is useful for constructing expressions for predicate pushdown in `Table.select()` method.
"""
return self._ibis_table[col_name]
[docs]
def sorting_done(self) -> int:
"""Sorting done indicator for the table. Always False for unsorted tables."""
if not self.sorted_table:
return False
raw_sorting_score = self.tx._rpc.api.raw_sorting_score(self.schema.bucket.name, self.schema.name, self.schema.tx.txid, self.name)
return bool(raw_sorting_score >> SORTING_SCORE_BITS)
[docs]
def sorting_score(self) -> int:
"""Sorting score for the table. Always 0 for unsorted tables."""
if not self.sorted_table:
return 0
raw_sorting_score = self.tx._rpc.api.raw_sorting_score(self.schema.bucket.name, self.schema.name, self.schema.tx.txid, self.name)
return raw_sorting_score & ((1 << SORTING_SCORE_BITS) - 1)
[docs]
@dataclass
class Projection:
"""VAST semi-sorted projection."""
name: str
table: Table
handle: int
stats: TableStats
@property
def bucket(self):
"""Return bucket."""
return self.table.schema.bucket
@property
def schema(self):
"""Return schema."""
return self.table.schema
@property
def tx(self):
"""Return transaction."""
return self.table.schema.tx
[docs]
def columns(self) -> pa.Schema:
"""Return this projections' columns as an Arrow schema."""
columns = []
next_key = 0
while True:
curr_columns, next_key, is_truncated, _count, _ = \
self.tx._rpc.api.list_projection_columns(
self.bucket.name, self.schema.name, self.table.name, self.name, txid=self.table.tx.txid, next_key=next_key)
if not curr_columns:
break
columns.extend(curr_columns)
if not is_truncated:
break
self.arrow_schema = pa.schema([(col[0], col[1]) for col in columns])
return self.arrow_schema
[docs]
def rename(self, new_name: str) -> None:
"""Rename this projection."""
self.tx._rpc.api.alter_projection(self.bucket.name, self.schema.name,
self.table.name, self.name, txid=self.tx.txid, new_name=new_name)
log.info("Renamed projection from %s to %s ", self.name, new_name)
self.name = new_name
[docs]
def drop(self) -> None:
"""Drop this projection."""
self.tx._rpc.api.drop_projection(self.bucket.name, self.schema.name, self.table.name,
self.name, txid=self.tx.txid)
log.info("Dropped projection: %s", self.name)
def _parse_projection_info(projection_info, table: "Table"):
log.info("Projection info %s", str(projection_info))
stats = TableStats(num_rows=projection_info.num_rows, size_in_bytes=projection_info.size_in_bytes,
sorting_score=0, write_amplification=0, acummulative_row_inserition_count=0)
return Projection(name=projection_info.name, table=table, stats=stats, handle=int(projection_info.handle))
def _parse_bucket_and_object_names(path: str) -> Tuple[str, str]:
if not path.startswith('/'):
raise errors.InvalidArgument(f"Path {path} must start with a '/'")
components = path.split(os.path.sep)
bucket_name = components[1]
object_path = os.path.sep.join(components[2:])
return bucket_name, object_path
def _serialize_record_batch(record_batch: pa.RecordBatch) -> pa.lib.Buffer:
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, record_batch.schema) as writer:
writer.write(record_batch)
return sink.getvalue()
def _combine_chunks(col):
if hasattr(col, "combine_chunks"):
return col.combine_chunks()
else:
return col