0648b529d0
This is useful when performing batch inference on the Danbooru dataset. In this case the filename is the MD5, so this makes the output return only the MD5 instead of the full path to the image file.
46 lines
2.2 KiB
Python
Executable File
46 lines
2.2 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
import json
|
|
import click
|
|
import itertools
|
|
from autotagger import Autotagger
|
|
from pathlib import Path
|
|
from more_itertools import ichunked
|
|
|
|
@click.command(help="Automatically generate tags for a list of images.", context_settings=dict(max_content_width=140))
|
|
@click.option("-t", "--threshold", default=0.01, type=float, show_default=True, help="The minimum tag confidence level.")
|
|
@click.option("-n", "--limit", default=50, type=int, show_default=True, help="The maximum number of tags to return per image.")
|
|
@click.option("-b", "--batch", "bs", default=128, type=int, show_default=True, help="The number of images to process per batch.")
|
|
@click.option("-g/-f", "--group-tags/--flatten-tags", default=True, show_default=True, help="Output rows in {filename, tags} format or {filename, tag, score} format.")
|
|
@click.option("-N", "--name-only", is_flag=True, help="Output only the filename without the full path or extension.")
|
|
@click.option("-m", "--model", default="models/model.pth", type=click.Path(exists=True), show_default=True, help="The model to use.")
|
|
@click.argument("files", nargs=-1, type=click.Path(exists=True, allow_dash=True, path_type=Path), required=True)
|
|
def main(files, threshold, limit, bs, group_tags, name_only, model):
|
|
autotagger = Autotagger(model)
|
|
|
|
for filepaths in ichunked(get_filepaths(files), bs):
|
|
filepaths = list(filepaths)
|
|
files = [click.open_file(filepath, "rb") for filepath in filepaths]
|
|
predictions = autotagger.predict(files, threshold=threshold, limit=limit, bs=bs)
|
|
|
|
for filepath, tags in zip(filepaths, predictions):
|
|
name = filepath.stem if name_only else str(filepath)
|
|
|
|
if group_tags:
|
|
data = { "filename": name, "tags": tags }
|
|
click.echo(json.dumps(data))
|
|
else:
|
|
for tag, score in tags.items():
|
|
data = { "filename": name, "tag": tag, "score": score }
|
|
click.echo(json.dumps(data))
|
|
|
|
def get_filepaths(paths):
|
|
files = (recurse_dir(path) if path.is_dir() else iter([path]) for path in paths)
|
|
return itertools.chain(*files)
|
|
|
|
def recurse_dir(directory):
|
|
return (path for path in directory.glob("**/*") if not path.is_dir())
|
|
|
|
if __name__ == "__main__":
|
|
main()
|