#!/usr/bin/env python import json import click import itertools import logging import PIL from fastai.vision.core import PILImage from autotagger import Autotagger from pathlib import Path from more_itertools import ichunked @click.command(help="Automatically generate tags for a list of images.", context_settings=dict(max_content_width=140)) @click.option("-t", "--threshold", default=0.01, type=float, show_default=True, help="The minimum tag confidence level.") @click.option("-n", "--limit", default=50, type=int, show_default=True, help="The maximum number of tags to return per image.") @click.option("-b", "--batch", "bs", default=128, type=int, show_default=True, help="The number of images to process per batch.") @click.option("-c", "--csv", is_flag=True, help="Output CSV instead of JSON.") @click.option("-i", "--input-file", type=click.File(), help="Read list of image filenames from text file.") @click.option("-g/-f", "--group-tags/--flatten-tags", default=True, show_default=True, help="Output rows in {filename, tags} format or {filename, tag, score} format.") @click.option("-N", "--name-only", is_flag=True, help="Output only the filename without the full path or extension.") @click.option("-m", "--model", default="models/model.pth", type=click.Path(exists=True), show_default=True, help="The model to use.") @click.argument("files", nargs=-1, type=click.Path(exists=True, allow_dash=True, path_type=Path)) def main(files, threshold, limit, bs, csv, input_file, group_tags, name_only, model): autotagger = Autotagger(model) if input_file: paths = (Path(line.rstrip()) for line in input_file) elif len(files) > 0: paths = get_filepaths(files) else: click.echo(click.get_current_context().get_help()) click.get_current_context().exit() for filepaths in ichunked(paths, bs): paths_with_images = list(filter(None, [open_image(filepath) for filepath in filepaths])) filepaths = [x[0] for x in paths_with_images] images = [x[1] for x in paths_with_images] predictions = autotagger.predict(images, threshold=threshold, limit=limit, bs=bs) for file, tags in zip(filepaths, predictions): output_result(Path(file.name), tags, csv, group_tags, name_only) def output_result(filepath, tags, csv, group_tags, name_only): name = filepath.stem if name_only else str(filepath) if csv and group_tags: tag_names = " ".join(sorted(tags.keys())) click.echo(f"{name},{tag_names}") elif csv and not group_tags: for tag, score in tags.items(): click.echo(f"{name},{tag},{score}") elif group_tags: data = { "filename": name, "tags": tags } click.echo(json.dumps(data)) else: for tag, score in tags.items(): data = { "filename": name, "tag": tag, "score": score } click.echo(json.dumps(data)) def get_filepaths(paths): files = (recurse_dir(path) if path.is_dir() else iter([path]) for path in paths) return itertools.chain(*files) def recurse_dir(directory): return (path for path in directory.glob("**/*") if not path.is_dir()) def open_image(filepath): try: with click.open_file(filepath, "rb") as file: return (filepath, PILImage.create(file)) except PIL.UnidentifiedImageError as err: logging.warning(f"Skipped {filepath} (not an image)") return None except Exception as err: logging.warning(f"Skipped {filepath} ({type(err).__name__}: {str(err)})") return None if __name__ == "__main__": main()