From 0648b529d035da569992b0fab76782ead422910f Mon Sep 17 00:00:00 2001 From: evazion Date: Wed, 22 Jun 2022 15:34:35 -0500 Subject: [PATCH] autotag: allow outputting filename without path or extension. This is useful when performing batch inference on the Danbooru dataset. In this case the filename is the MD5, so this makes the output return only the MD5 instead of the full path to the image file. --- autotag | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/autotag b/autotag index 4b4bcf6..d5d8d81 100755 --- a/autotag +++ b/autotag @@ -11,10 +11,11 @@ from more_itertools import ichunked @click.option("-t", "--threshold", default=0.01, type=float, show_default=True, help="The minimum tag confidence level.") @click.option("-n", "--limit", default=50, type=int, show_default=True, help="The maximum number of tags to return per image.") @click.option("-b", "--batch", "bs", default=128, type=int, show_default=True, help="The number of images to process per batch.") -@click.option("--group-tags/--flatten-tags", default=True, show_default=True, help="Output rows in {filename, tags} format or {filename, tag, score} format.") +@click.option("-g/-f", "--group-tags/--flatten-tags", default=True, show_default=True, help="Output rows in {filename, tags} format or {filename, tag, score} format.") +@click.option("-N", "--name-only", is_flag=True, help="Output only the filename without the full path or extension.") @click.option("-m", "--model", default="models/model.pth", type=click.Path(exists=True), show_default=True, help="The model to use.") @click.argument("files", nargs=-1, type=click.Path(exists=True, allow_dash=True, path_type=Path), required=True) -def main(files, threshold, limit, bs, group_tags, model): +def main(files, threshold, limit, bs, group_tags, name_only, model): autotagger = Autotagger(model) for filepaths in ichunked(get_filepaths(files), bs): @@ -23,12 +24,14 @@ def main(files, threshold, limit, bs, group_tags, model): predictions = autotagger.predict(files, threshold=threshold, limit=limit, bs=bs) for filepath, tags in zip(filepaths, predictions): + name = filepath.stem if name_only else str(filepath) + if group_tags: - data = { "filename": filepath, "tags": tags } + data = { "filename": name, "tags": tags } click.echo(json.dumps(data)) else: for tag, score in tags.items(): - data = { "filename": filepath, "tag": tag, "score": score } + data = { "filename": name, "tag": tag, "score": score } click.echo(json.dumps(data)) def get_filepaths(paths):