Skip to content

Commit

Permalink
Postgres schema name fix (#5432)
Browse files Browse the repository at this point in the history
* Handle postgres schema name

Signed-off-by: Hai Yan <[email protected]>

* Remove an empty file

Signed-off-by: Hai Yan <[email protected]>

* Fix missing column types

Signed-off-by: Hai Yan <[email protected]>

* Address review comments

Signed-off-by: Hai Yan <[email protected]>

---------

Signed-off-by: Hai Yan <[email protected]>
  • Loading branch information
oeyh authored Feb 14, 2025
1 parent 3522ce1 commit 85127bc
Show file tree
Hide file tree
Showing 16 changed files with 99 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ public class MetadataKeyAttributes {

static final String EVENT_DATABASE_NAME_METADATA_ATTRIBUTE = "database_name";

static final String EVENT_SCHEMA_NAME_METADATA_ATTRIBUTE = "schema_name";

static final String EVENT_TABLE_NAME_METADATA_ATTRIBUTE = "table_name";

static final String INGESTION_EVENT_TYPE_ATTRIBUTE = "ingestion_type";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.BULK_ACTION_METADATA_ATTRIBUTE;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.CHANGE_EVENT_TYPE_METADATA_ATTRIBUTE;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_DATABASE_NAME_METADATA_ATTRIBUTE;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_SCHEMA_NAME_METADATA_ATTRIBUTE;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_TABLE_NAME_METADATA_ATTRIBUTE;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_TIMESTAMP_METADATA_ATTRIBUTE;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_VERSION_FROM_TIMESTAMP;
Expand All @@ -45,6 +46,7 @@ public RecordConverter(final String s3Prefix, final int partitionCount) {

public Event convert(final Event event,
final String databaseName,
final String schemaName,
final String tableName,
final OpenSearchBulkActions bulkAction,
final List<String> primaryKeys,
Expand All @@ -62,6 +64,7 @@ public Event convert(final Event event,
}

eventMetadata.setAttribute(EVENT_DATABASE_NAME_METADATA_ATTRIBUTE, databaseName);
eventMetadata.setAttribute(EVENT_SCHEMA_NAME_METADATA_ATTRIBUTE, schemaName);
eventMetadata.setAttribute(EVENT_TABLE_NAME_METADATA_ATTRIBUTE, tableName);
eventMetadata.setAttribute(BULK_ACTION_METADATA_ATTRIBUTE, bulkAction.toString());
setIngestionTypeMetadata(event);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

package org.opensearch.dataprepper.plugins.source.rds.coordination.state;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType;

import java.util.List;
import java.util.Map;
Expand All @@ -25,9 +27,12 @@ public class DataFileProgressState {
private String sourceDatabase;

/**
* For MySQL, sourceTable is in the format of tableName
* For Postgres, sourceTable is in the format of schemaName.tableName
* For PostgreSQL engine type, sourceSchema is the schema name.
* For MySQL engine type, this field will store database name, same as sourceDatabase field.
*/
@JsonProperty("sourceSchema")
private String sourceSchema;

@JsonProperty("sourceTable")
private String sourceTable;

Expand Down Expand Up @@ -72,6 +77,25 @@ public void setSourceDatabase(String sourceDatabase) {
this.sourceDatabase = sourceDatabase;
}

public String getSourceSchema() {
return sourceSchema;
}

public void setSourceSchema(String sourceSchema) {
this.sourceSchema = sourceSchema;
}

@JsonIgnore
public String getFullSourceTableName() {
if (EngineType.fromString(engineType) == EngineType.MYSQL) {
return sourceDatabase + "." + sourceTable;
} else if (EngineType.fromString(engineType) == EngineType.POSTGRES) {
return sourceDatabase + "." + sourceSchema + "." + sourceTable;
} else {
throw new RuntimeException("Unsupported engine type: " + engineType);
}
}

public String getSourceTable() {
return sourceTable;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import java.util.concurrent.atomic.AtomicLong;

import static org.opensearch.dataprepper.logging.DataPrepperMarkers.SENSITIVE;
import static org.opensearch.dataprepper.plugins.source.rds.model.TableMetadata.DOT_DELIMITER;

public class DataFileLoader implements Runnable {

Expand Down Expand Up @@ -126,7 +125,7 @@ public void run() {

DataFileProgressState progressState = dataFilePartition.getProgressState().get();

final String fullTableName = progressState.getSourceDatabase() + DOT_DELIMITER + progressState.getSourceTable();
final String fullTableName = progressState.getFullSourceTableName();
final List<String> primaryKeys = progressState.getPrimaryKeyMap().getOrDefault(fullTableName, List.of());
transformEvent(event, fullTableName, EngineType.fromString(progressState.getEngineType()));

Expand All @@ -135,6 +134,7 @@ public void run() {
final Event transformedEvent = recordConverter.convert(
event,
progressState.getSourceDatabase(),
progressState.getSourceSchema(),
progressState.getSourceTable(),
OpenSearchBulkActions.INDEX,
primaryKeys,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,13 +320,11 @@ private void createDataFilePartitions(String bucket,
for (final String objectKey : dataFileObjectKeys) {
final DataFileProgressState progressState = new DataFileProgressState();
final ExportObjectKey exportObjectKey = ExportObjectKey.fromString(objectKey);
final String database = exportObjectKey.getDatabaseName();
final String table = engineType == EngineType.MYSQL ?
exportObjectKey.getTableName() :
exportObjectKey.getSchemaName() + DOT_DELIMITER + exportObjectKey.getTableName();

progressState.setSourceDatabase(database);
progressState.setSourceTable(table);
progressState.setEngineType(engineType.toString());
progressState.setSourceDatabase(exportObjectKey.getDatabaseName());
progressState.setSourceSchema(exportObjectKey.getSchemaName());
progressState.setSourceTable(exportObjectKey.getTableName());
progressState.setSnapshotTime(snapshotTime);
progressState.setPrimaryKeyMap(primaryKeyMap);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public class TableMetadata {
public static final String DOT_DELIMITER = ".";

private final String databaseName;
private final String schemaName;
private final String tableName;
private final List<String> columnNames;
private final List<String> columnTypes;
Expand All @@ -22,6 +23,7 @@ public class TableMetadata {

private TableMetadata(final Builder builder) {
this.databaseName = builder.databaseName;
this.schemaName = builder.schemaName != null ? builder.schemaName : builder.databaseName;
this.tableName = builder.tableName;
this.columnNames = builder.columnNames;
this.columnTypes = builder.columnTypes;
Expand All @@ -34,6 +36,10 @@ public String getDatabaseName() {
return databaseName;
}

public String getSchemaName() {
return schemaName;
}

public String getTableName() {
return tableName;
}
Expand Down Expand Up @@ -68,6 +74,7 @@ public static Builder builder() {

public static class Builder {
private String databaseName;
private String schemaName;
private String tableName;
private List<String> columnNames = Collections.emptyList();
private List<String> columnTypes = Collections.emptyList();
Expand All @@ -83,6 +90,11 @@ public Builder withDatabaseName(String databaseName) {
return this;
}

public Builder withSchemaName(String schemaName) {
this.schemaName = schemaName;
return this;
}

public Builder withTableName(String tableName) {
this.tableName = tableName;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ private void processRows(List<Map<String, Object>> rows, String database, String
final Event pipelineEvent = recordConverter.convert(
dataPrepperEvent,
database,
database,
table,
OpenSearchBulkActions.INDEX,
primaryKeys,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,19 @@ public List<String> getPrimaryKeys(final String fullTableName) {
primaryKeys.add(rs.getString(COLUMN_NAME));
}
if (primaryKeys.isEmpty()) {
throw new NoSuchElementException("No primary keys found for table " + table);
throw new NoSuchElementException("No primary keys found for table " + fullTableName);
}
return primaryKeys;
}
} catch (NoSuchElementException e) {
throw e;
} catch (Exception e) {
LOG.error("Failed to get primary keys for table {}, retrying", table, e);
LOG.error("Failed to get primary keys for table {}, retrying", fullTableName, e);
}
applyBackoff();
retry++;
}
throw new RuntimeException("Failed to get primary keys for table " + table);
throw new RuntimeException("Failed to get primary keys for table " + fullTableName);
}

private void applyBackoff() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ void handleRowChangeEvent(com.github.shyiko.mysql.binlog.event.Event event,
final Event pipelineEvent = recordConverter.convert(
dataPrepperEvent,
tableMetadata.getDatabaseName(),
tableMetadata.getDatabaseName(),
tableMetadata.getTableName(),
bulkAction,
primaryKeys,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.util.function.Consumer;

public class LogicalReplicationEventProcessor {

enum TupleDataType {
NEW('N'),
KEY('K'),
Expand Down Expand Up @@ -76,6 +77,7 @@ public static TupleDataType fromValue(char value) {
static final int DEFAULT_BUFFER_BATCH_SIZE = 1_000;
static final int NUM_OF_RETRIES = 3;
static final int BACKOFF_IN_MILLIS = 500;
static final String DOT_DELIMITER_REGEX = "\\.";
static final String CHANGE_EVENTS_PROCESSED_COUNT = "changeEventsProcessed";
static final String CHANGE_EVENTS_PROCESSING_ERROR_COUNT = "changeEventsProcessingErrors";
static final String BYTES_RECEIVED = "bytesReceived";
Expand Down Expand Up @@ -222,14 +224,16 @@ void processRelationMessage(ByteBuffer msg) {
columnNames.add(columnName);
}

final List<String> primaryKeys = getPrimaryKeys(schemaName, tableName);
final TableMetadata tableMetadata = TableMetadata.builder().
withTableName(tableName).
withDatabaseName(schemaName).
withColumnNames(columnNames).
withColumnTypes(columnTypes).
withPrimaryKeys(primaryKeys).
build();
final String databaseName = getDatabaseName(sourceConfig.getTableNames());
final List<String> primaryKeys = getPrimaryKeys(databaseName, schemaName, tableName);
final TableMetadata tableMetadata = TableMetadata.builder()
.withDatabaseName(databaseName)
.withSchemaName(schemaName)
.withTableName(tableName)
.withColumnNames(columnNames)
.withColumnTypes(columnTypes)
.withPrimaryKeys(primaryKeys)
.build();

tableMetadataMap.put((long) tableId, tableMetadata);

Expand Down Expand Up @@ -375,6 +379,7 @@ private void createPipelineEvent(Map<String, Object> rowDataMap, TableMetadata t
final Event pipelineEvent = recordConverter.convert(
dataPrepperEvent,
tableMetadata.getDatabaseName(),
tableMetadata.getSchemaName(),
tableMetadata.getTableName(),
bulkAction,
primaryKeys,
Expand Down Expand Up @@ -432,13 +437,16 @@ private String getNullTerminatedString(ByteBuffer msg) {
return sb.toString();
}

private List<String> getPrimaryKeys(String schemaName, String tableName) {
final String databaseName = sourceConfig.getTableNames().get(0).split("\\.")[0];
private List<String> getPrimaryKeys(String databaseName, String schemaName, String tableName) {
StreamProgressState progressState = streamPartition.getProgressState().get();

return progressState.getPrimaryKeyMap().get(databaseName + "." + schemaName + "." + tableName);
}

private String getDatabaseName(List<String> tableNames) {
return tableNames.get(0).split(DOT_DELIMITER_REGEX)[0];
}

private void handleMessageWithRetries(ByteBuffer message, Consumer<ByteBuffer> function, MessageType messageType) {
int retry = 0;
while (retry <= NUM_OF_RETRIES) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.CHANGE_EVENT_TYPE_METADATA_ATTRIBUTE;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_DATABASE_NAME_METADATA_ATTRIBUTE;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_S3_PARTITION_KEY;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_SCHEMA_NAME_METADATA_ATTRIBUTE;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_TABLE_NAME_METADATA_ATTRIBUTE;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_TIMESTAMP_METADATA_ATTRIBUTE;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_VERSION_FROM_TIMESTAMP;
Expand All @@ -54,6 +55,7 @@ void setUp() {
@Test
void test_convert() {
final String databaseName = UUID.randomUUID().toString();
final String schemaName = UUID.randomUUID().toString();
final String tableName = UUID.randomUUID().toString();
final String primaryKeyName = UUID.randomUUID().toString();
final List<String> primaryKeys = List.of(primaryKeyName);
Expand All @@ -67,11 +69,12 @@ void test_convert() {
.build();

Event actualEvent = exportRecordConverter.convert(
testEvent, databaseName, tableName, OpenSearchBulkActions.INDEX, primaryKeys,
testEvent, databaseName, schemaName, tableName, OpenSearchBulkActions.INDEX, primaryKeys,
eventCreateTimeEpochMillis, eventVersionNumber, null);

// Assert
assertThat(actualEvent.getMetadata().getAttribute(EVENT_DATABASE_NAME_METADATA_ATTRIBUTE), equalTo(databaseName));
assertThat(actualEvent.getMetadata().getAttribute(EVENT_SCHEMA_NAME_METADATA_ATTRIBUTE), equalTo(schemaName));
assertThat(actualEvent.getMetadata().getAttribute(EVENT_TABLE_NAME_METADATA_ATTRIBUTE), equalTo(tableName));
assertThat(actualEvent.getMetadata().getAttribute(BULK_ACTION_METADATA_ATTRIBUTE), equalTo(OpenSearchBulkActions.INDEX.toString()));
assertThat(actualEvent.getMetadata().getAttribute(PRIMARY_KEY_DOCUMENT_ID_METADATA_ATTRIBUTE), equalTo(primaryKeyValue));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.CHANGE_EVENT_TYPE_METADATA_ATTRIBUTE;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_DATABASE_NAME_METADATA_ATTRIBUTE;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_S3_PARTITION_KEY;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_SCHEMA_NAME_METADATA_ATTRIBUTE;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_TABLE_NAME_METADATA_ATTRIBUTE;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_TIMESTAMP_METADATA_ATTRIBUTE;
import static org.opensearch.dataprepper.plugins.source.rds.converter.MetadataKeyAttributes.EVENT_VERSION_FROM_TIMESTAMP;
Expand Down Expand Up @@ -54,6 +55,7 @@ void setUp() {
void test_convert_returns_expected_event() {
final Map<String, Object> rowData = Map.of("key1", "value1", "key2", "value2");
final String databaseName = UUID.randomUUID().toString();
final String schemaName = UUID.randomUUID().toString();
final String tableName = UUID.randomUUID().toString();
final EventType eventType = EventType.EXT_WRITE_ROWS;
final OpenSearchBulkActions bulkAction = OpenSearchBulkActions.INDEX;
Expand All @@ -67,11 +69,12 @@ void test_convert_returns_expected_event() {
.build();

Event event = streamRecordConverter.convert(
testEvent, databaseName, tableName, bulkAction, primaryKeys,
testEvent, databaseName, schemaName, tableName, bulkAction, primaryKeys,
eventCreateTimeEpochMillis, eventVersionNumber, eventType);

assertThat(event.toMap(), is(rowData));
assertThat(event.getMetadata().getAttribute(EVENT_DATABASE_NAME_METADATA_ATTRIBUTE), is(databaseName));
assertThat(event.getMetadata().getAttribute(EVENT_SCHEMA_NAME_METADATA_ATTRIBUTE), equalTo(schemaName));
assertThat(event.getMetadata().getAttribute(EVENT_TABLE_NAME_METADATA_ATTRIBUTE), is(tableName));
assertThat(event.getMetadata().getAttribute(CHANGE_EVENT_TYPE_METADATA_ATTRIBUTE), is(eventType.toString()));
assertThat(event.getMetadata().getAttribute(BULK_ACTION_METADATA_ATTRIBUTE), is(bulkAction.toString()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ public void testHandleByteArrayData() {
final MySQLDataType columnType = MySQLDataType.BINARY;
final String columnName = "binaryColumn";
final String testData = UUID.randomUUID().toString();
final TableMetadata metadata = TableMetadata.builder().
withTableName(UUID.randomUUID().toString()).
withDatabaseName(UUID.randomUUID().toString()).
withColumnNames(List.of(columnName)).
withPrimaryKeys(List.of(columnName))
final TableMetadata metadata = TableMetadata.builder()
.withTableName(UUID.randomUUID().toString())
.withDatabaseName(UUID.randomUUID().toString())
.withColumnNames(List.of(columnName))
.withPrimaryKeys(List.of(columnName))
.build();
final Object result = handler.handle(columnType, columnName, testData.getBytes(), metadata);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ void test_run_success() throws Exception {
when(eventFactory.eventBuilder(any())).thenReturn(eventBuilder);
when(eventBuilder.withEventType(any()).withData(any()).build()).thenReturn(event);
when(event.toJsonString()).thenReturn(randomString);
when(recordConverter.convert(any(), any(), any(), any(), any(), anyLong(), anyLong(), any())).thenReturn(event);
when(recordConverter.convert(any(), any(), any(), any(), any(), any(), anyLong(), anyLong(), any())).thenReturn(event);

AvroParquetReader.Builder<GenericRecord> builder = mock(AvroParquetReader.Builder.class);
ParquetReader<GenericRecord> parquetReader = mock(ParquetReader.class);
Expand Down Expand Up @@ -184,7 +184,7 @@ void test_flush_failure_then_error_metric_updated() throws Exception {
when(eventBuilder.withEventType(any()).withData(any()).build()).thenReturn(event);
when(event.toJsonString()).thenReturn(randomString);

when(recordConverter.convert(any(), any(), any(), any(), any(), anyLong(), anyLong(), any())).thenReturn(event);
when(recordConverter.convert(any(), any(), any(), any(), any(), any(), anyLong(), anyLong(), any())).thenReturn(event);

ParquetReader<GenericRecord> parquetReader = mock(ParquetReader.class);
AvroParquetReader.Builder<GenericRecord> builder = mock(AvroParquetReader.Builder.class);
Expand Down
Loading

0 comments on commit 85127bc

Please sign in to comment.