Files
autotagger-win/autotag
T
evazion 0648b529d0 autotag: allow outputting filename without path or extension.
This is useful when performing batch inference on the Danbooru dataset.
In this case the filename is the MD5, so this makes the output return
only the MD5 instead of the full path to the image file.
2022-06-22 16:12:46 -05:00

46 lines
2.2 KiB
Python
Executable File

#!/usr/bin/env python
import json
import click
import itertools
from autotagger import Autotagger
from pathlib import Path
from more_itertools import ichunked
@click.command(help="Automatically generate tags for a list of images.", context_settings=dict(max_content_width=140))
@click.option("-t", "--threshold", default=0.01, type=float, show_default=True, help="The minimum tag confidence level.")
@click.option("-n", "--limit", default=50, type=int, show_default=True, help="The maximum number of tags to return per image.")
@click.option("-b", "--batch", "bs", default=128, type=int, show_default=True, help="The number of images to process per batch.")
@click.option("-g/-f", "--group-tags/--flatten-tags", default=True, show_default=True, help="Output rows in {filename, tags} format or {filename, tag, score} format.")
@click.option("-N", "--name-only", is_flag=True, help="Output only the filename without the full path or extension.")
@click.option("-m", "--model", default="models/model.pth", type=click.Path(exists=True), show_default=True, help="The model to use.")
@click.argument("files", nargs=-1, type=click.Path(exists=True, allow_dash=True, path_type=Path), required=True)
def main(files, threshold, limit, bs, group_tags, name_only, model):
autotagger = Autotagger(model)
for filepaths in ichunked(get_filepaths(files), bs):
filepaths = list(filepaths)
files = [click.open_file(filepath, "rb") for filepath in filepaths]
predictions = autotagger.predict(files, threshold=threshold, limit=limit, bs=bs)
for filepath, tags in zip(filepaths, predictions):
name = filepath.stem if name_only else str(filepath)
if group_tags:
data = { "filename": name, "tags": tags }
click.echo(json.dumps(data))
else:
for tag, score in tags.items():
data = { "filename": name, "tag": tag, "score": score }
click.echo(json.dumps(data))
def get_filepaths(paths):
files = (recurse_dir(path) if path.is_dir() else iter([path]) for path in paths)
return itertools.chain(*files)
def recurse_dir(directory):
return (path for path in directory.glob("**/*") if not path.is_dir())
if __name__ == "__main__":
main()