96b6a12924
Fix `autotag` so you can pass a filename of '-' to read a file from stdin. This way you can do this: docker run --rm -i ghcr.io/danbooru/autotagger autotag - < image.jpg ...to perform prediction on a single file outside of Docker.
43 lines
2.0 KiB
Python
Executable File
43 lines
2.0 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
import json
|
|
import click
|
|
import itertools
|
|
from autotagger import Autotagger
|
|
from pathlib import Path
|
|
from more_itertools import ichunked
|
|
|
|
@click.command(help="Automatically generate tags for a list of images.", context_settings=dict(max_content_width=140))
|
|
@click.option("-t", "--threshold", default=0.01, type=float, show_default=True, help="The minimum tag confidence level.")
|
|
@click.option("-n", "--limit", default=50, type=int, show_default=True, help="The maximum number of tags to return per image.")
|
|
@click.option("-b", "--batch", "bs", default=128, type=int, show_default=True, help="The number of images to process per batch.")
|
|
@click.option("--group-tags/--flatten-tags", default=True, show_default=True, help="Output rows in {filename, tags} format or {filename, tag, score} format.")
|
|
@click.option("-m", "--model", default="models/model.pth", type=click.Path(exists=True), show_default=True, help="The model to use.")
|
|
@click.argument("files", nargs=-1, type=click.Path(exists=True, allow_dash=True, path_type=Path), required=True)
|
|
def main(files, threshold, limit, bs, group_tags, model):
|
|
autotagger = Autotagger(model)
|
|
|
|
for filepaths in ichunked(get_filepaths(files), bs):
|
|
filepaths = list(filepaths)
|
|
files = [click.open_file(filepath, "rb") for filepath in filepaths]
|
|
predictions = autotagger.predict(files, threshold=threshold, limit=limit, bs=bs)
|
|
|
|
for filepath, tags in zip(filepaths, predictions):
|
|
if group_tags:
|
|
data = { "filename": filepath, "tags": tags }
|
|
click.echo(json.dumps(data))
|
|
else:
|
|
for tag, score in tags.items():
|
|
data = { "filename": filepath, "tag": tag, "score": score }
|
|
click.echo(json.dumps(data))
|
|
|
|
def get_filepaths(paths):
|
|
files = (recurse_dir(path) if path.is_dir() else iter([path]) for path in paths)
|
|
return itertools.chain(*files)
|
|
|
|
def recurse_dir(directory):
|
|
return (path for path in directory.glob("**/*") if not path.is_dir())
|
|
|
|
if __name__ == "__main__":
|
|
main()
|