""" Image quality classification using Inception trained on ImageNet 2012 Challenge data set and finetuned on quality data. This program runs inference on input JPEG images in a folder. Change the --image_dir argument to any jpg image folder to compute a classification images in that folder. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function try: import pyspark except ImportError: import findspark findspark.init() import pyspark import argparse from collections import defaultdict from concurrent import futures import csv import glob import json import os.path import random import re import sys import tarfile import time from os import listdir from os.path import isfile, join import numpy as np import pyspark.sql from pyspark.sql import types as T import requests from six.moves import urllib import tensorflow as tf FLAGS = None def proxies(): cluster = 'eqiad' if random.random() > 0.5 else 'codfw' return { 'http': 'http://webproxy.{}.wmnet:8080/'.format(cluster), 'https': 'https://webproxy.{}.wmnet:8080/'.format(cluster), } def create_graph(): """Creates a graph from saved GraphDef file and returns a saver.""" # Creates graph from saved graph_def.pb. local_model_path = pyspark.SparkFiles.get(FLAGS.model_path) with tf.gfile.FastGFile(local_model_path, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(graph_def, name='') def batch(x, n): batch = [] for item in x: batch.append(item) if len(batch) == n: yield batch batch = [] if batch: yield batch def fetch_url_batch(session, titles, retries=0): response = session.post('https://commons.wikimedia.org/w/api.php', timeout=5, data={ 'format': 'json', 'formatversion': 2, 'action': 'query', 'prop': 'imageinfo', 'iiprop': 'url', 'iiurlwidth': 600, 'titles': 'File:' + '|File:'.join(titles) }) try: res = response.json() except json.decoder.JSONDecodeError: raise Exception(response.content) if 'error' in res or 'query' not in res: if 'error' in res and res['error']['code'] == 'urlparamnormal': # We have a bad title. Try and extract it from the message and filter prefix = 'Could not normalize image parameters for ' bad_title = res['error']['info'][len(prefix):-1] try: del titles[titles.index(bad_title)] except ValueError: pass else: return fetch_url_batch(session, titles, retries + 1) # Some sort of api error, we should retry if retries < 3: return fetch_url_batch(session, titles, retries + 1) raise Exception(response.content) return res def fetch_urls(titles, batch_size=50): # Spread out the requests when spark first starts up time.sleep(5*random.random()) if batch_size > 50: raise Exception('Mediawiki api will only resize 50 images at a time') with requests.Session() as session: for batch_titles in batch(titles, batch_size): res = fetch_url_batch(session, batch_titles) try: normalized = {} if 'normalized' in res['query']: for norm in res['query']['normalized']: normalized[norm['to']] = norm['from'] for page in res['query']['pages']: try: title = page['title'] except KeyError: title = '***MISSING TITLE***' else: if title in normalized: title = normalized[title] if 'invalid' in page and page['invalid']: yield -1, title, None, page['invalidreason'] elif 'missing' in page and page['missing']: yield -1, title, None, 'missing' else: for info in page['imageinfo']: url = info['thumburl'] if 'thumburl' in info else info['url'] yield page['pageid'], title, url, None break except KeyError: raise Exception(res) def buffer_images(image_infos): from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError, RetryError from requests.packages.urllib3.util.retry import Retry from requests_futures.sessions import FuturesSession with FuturesSession(max_workers=10) as session: retries = defaultdict(int) def on_complete(future, page_id, title, url): try: res = future.result() if res.status_code == 200: yield page_id, title, res.content, error elif res.status_code == 429 and retries[future] < 3: # We can't really pause the in-progress requests be we can # at least stop adding new ones for a bit. # Sleep for 10, 20, 40 seconds time.sleep(10 * (2 ** (retries[future]))) next_future = session.get(url, timeout=120, proxies=proxies()) retries[next_future] = retries[future] + 1 fs[next_future] = (page_id, title, url) else: yield page_id, title, None, 'Received http status code {}'.format(res.status_code) except (ConnectionError, RetryError) as e: yield page_id, title, None, str(e.message) finally: if future in retries: del retries[future] fs = {} for page_id, title, url, error in image_infos: if error is not None: yield page_id, title, None, error continue future = session.get(url, timeout=120, proxies=proxies()) fs[future] = (page_id, title, url) while len(fs) >= 10: done_and_not_done = futures.wait(fs.keys(), return_when=futures.FIRST_COMPLETED) for future in done_and_not_done.done: image_info = fs[future] del fs[future] yield from on_complete(future, *image_info) for future in futures.as_completed(fs): yield from on_complete(future, *fs[future]) def run_inference_on_images(image_infos): session_conf = tf.ConfigProto(intra_op_parallelism_threads=10,inter_op_parallelism_threads=10) create_graph() count=0 with tf.Session(config=session_conf) as sess: # 'softmax:0': A tensor containing the normalized prediction across # 1000 labels. # 'pool_3:0': A tensor containing the next-to-last layer containing 2048 # float description of the image. # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG # encoding of the image. # Runs the softmax tensor by feeding the image_data as input to the graph. softmax_tensor = sess.graph.get_tensor_by_name('final_result:0') for page_id, title, image_data, error in buffer_images(image_infos): if error is not None: yield page_id, title, float('nan'), error continue if image_data is None: yield page_id, title, float('nan'), 'Failed to load image data from {}'.format(url) continue count += 1 try: predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data}) except Exception as e: yield page_id, title, float('nan'), '{}: {}'.format(type(e).__name__, e.message) else: #if by chance anything went well yield page_id, title, float(predictions[0][1]), None print(count) def main(_): conf = pyspark.SparkConf() sc = pyspark.SparkContext(appName="classify_image_quality") sc.addFile(FLAGS.model_path) for path in glob.glob(FLAGS.image_titles): with open(path, 'r') as f: titles = [t.strip() for t in f] rdd = sc.parallelize(titles, 200) \ .mapPartitions(fetch_urls) \ .mapPartitions(run_inference_on_images) with open(FLAGS.outfile, 'w') as f: writer = csv.writer(f) # toLocalIterator would make sense, but it does silly things # (running 200 tasks as 200 jobs end to end). Our # datasets are only ~100k items simply load it in memory. writer.writerows(rdd.collect()) if __name__ == '__main__': parser = argparse.ArgumentParser() # classify_image_graph_def.pb: # Binary representation of the GraphDef protocol buffer. # imagenet_synset_to_human_label_map.txt: # Map from synset ID to a human readable string. # imagenet_2012_challenge_label_map_proto.pbtxt: # Text representation of a protocol buffer mapping a label to synset ID. parser.add_argument( '--model_path', type=str, default='output_graph_new.pb', help="""\ Local path to output_graph_new.pb, """ ) parser.add_argument( '--image_titles', type=str, default='example_images.txt', help='local path to list of images' ) parser.add_argument( '--outfile', type=str, default='example_images_graded', help='HDFS path to write output to' ) FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)