diff --git a/autotag b/autotag index a0165f3..1d9c93b 100755 --- a/autotag +++ b/autotag @@ -3,6 +3,8 @@ import json import click import itertools +import logging +import PIL from autotagger import Autotagger from pathlib import Path from more_itertools import ichunked @@ -29,12 +31,11 @@ def main(files, threshold, limit, bs, csv, input_file, group_tags, name_only, mo click.get_current_context().exit() for filepaths in ichunked(paths, bs): - filepaths = list(filepaths) - files = [click.open_file(filepath, "rb") for filepath in filepaths] + files = list(filter(None, [open_image(filepath) for filepath in filepaths])) predictions = autotagger.predict(files, threshold=threshold, limit=limit, bs=bs) - for filepath, tags in zip(filepaths, predictions): - output_result(filepath, tags, csv, group_tags, name_only) + for file, tags in zip(files, predictions): + output_result(Path(file.name), tags, csv, group_tags, name_only) def output_result(filepath, tags, csv, group_tags, name_only): name = filepath.stem if name_only else str(filepath) @@ -60,5 +61,20 @@ def get_filepaths(paths): def recurse_dir(directory): return (path for path in directory.glob("**/*") if not path.is_dir()) +def open_image(filepath): + try: + # Load the image to check that it's a valid file. Open the file twice because PIL closes the underlying file. + file = click.open_file(filepath, "rb") + image = PIL.Image.open(file) + image.load() + image.close() + return click.open_file(filepath, "rb") + except PIL.UnidentifiedImageError as err: + logging.warning(f"Skipped {filepath} (not an image)") + return None + except Exception as err: + logging.warning(f"Skipped {filepath} ({type(err).__name__}: {str(err)})") + return None + if __name__ == "__main__": main() diff --git a/autotagger/autotagger.py b/autotagger/autotagger.py index 76edc29..b2388e8 100644 --- a/autotagger/autotagger.py +++ b/autotagger/autotagger.py @@ -28,14 +28,11 @@ class Autotagger: return learn def predict(self, files, threshold=0.01, limit=50, bs=64): + if not files: + return + with self.learn.no_bar(), self.learn.no_logging(): - def create_image(file): - try: - return PILImage.create(file) - except: - print("skipped file " + file.name, file=sys.stderr) - return None - images = list(filter(lambda i: i != None, [create_image(file) for file in files])) + images = [PILImage.create(file) for file in files] dl = self.learn.dls.test_dl(images, bs=bs) batch, _ = self.learn.get_preds(dl=dl)