diff --git a/image-suggestions/pyspark/src/commonswiki_file.py b/image-suggestions/pyspark/src/commonswiki_file.py index 4ade537..2fbdeef 100644 --- a/image-suggestions/pyspark/src/commonswiki_file.py +++ b/image-suggestions/pyspark/src/commonswiki_file.py @@ -25,6 +25,12 @@ def save_wikidata_data(wikidata_data, hive_db, snapshot): # pragma: no cover wikidata_data.withColumn('snapshot', F.lit(snapshot)), hive_db, shared.WIKIDATA_DATA ) + return ( + wikidata_data.sql_ctx.read + .table(hive_db + '.' + shared.WIKIDATA_DATA) + .where(F.col('snapshot') == F.lit(snapshot)) + .drop('snapshot') + ) def save_lead_image_data(lead_image_data, hive_db, snapshot): # pragma: no cover @@ -33,6 +39,13 @@ def save_lead_image_data(lead_image_data, hive_db, snapshot): # pragma: no cove hive_db, shared.LEAD_IMAGE_DATA ) + return ( + lead_image_data.sql_ctx.read + .table(hive_db + '.' + shared.LEAD_IMAGE_DATA) + .where(F.col('snapshot') == F.lit(snapshot)) + .drop('snapshot') + ) + def parse_args(): # pragma: no cover parser = argparse.ArgumentParser( @@ -116,9 +129,7 @@ def gather_wikidata_data( ) wikidata_data = commons_with_reverse_p18.union(commons_with_reverse_p373) - save_wikidata_data(wikidata_data, hive_db, snapshot) - - return wikidata_data + return save_wikidata_data(wikidata_data, hive_db, snapshot) def gather_commons_with_reverse_p18(commons_file_pages, wikidata_items_with_P18): @@ -287,9 +298,7 @@ def gather_lead_image_data(snapshot, hive_db): ) ) - save_lead_image_data(lead_image_data, hive_db, snapshot) - - return lead_image_data + return save_lead_image_data(lead_image_data, hive_db, snapshot) def get_commonswiki_file_data(wd_data, li_data): diff --git a/image-suggestions/pyspark/src/shared.py b/image-suggestions/pyspark/src/shared.py index 3cfa845..cd2be0e 100644 --- a/image-suggestions/pyspark/src/shared.py +++ b/image-suggestions/pyspark/src/shared.py @@ -88,6 +88,11 @@ def write_search_index_data(data, wiki, hive_db, snapshot, previous_snapshot, sp save_table( data.withColumn('snapshot', F.lit(snapshot)), hive_db, SEARCH_INDEX_FULL_TABLE_NAME ) + data = ( + data.sql_ctx.read + .table(hive_db + '.' + SEARCH_INDEX_FULL_TABLE_NAME) + .where(F.col('snapshot') == F.lit(snapshot)) + .drop('snapshot') + ) if previous_snapshot is None: delta = data