autotag: allow outputting filename without path or extension.
This is useful when performing batch inference on the Danbooru dataset. In this case the filename is the MD5, so this makes the output return only the MD5 instead of the full path to the image file.
This commit is contained in:
@@ -11,10 +11,11 @@ from more_itertools import ichunked
|
||||
@click.option("-t", "--threshold", default=0.01, type=float, show_default=True, help="The minimum tag confidence level.")
|
||||
@click.option("-n", "--limit", default=50, type=int, show_default=True, help="The maximum number of tags to return per image.")
|
||||
@click.option("-b", "--batch", "bs", default=128, type=int, show_default=True, help="The number of images to process per batch.")
|
||||
@click.option("--group-tags/--flatten-tags", default=True, show_default=True, help="Output rows in {filename, tags} format or {filename, tag, score} format.")
|
||||
@click.option("-g/-f", "--group-tags/--flatten-tags", default=True, show_default=True, help="Output rows in {filename, tags} format or {filename, tag, score} format.")
|
||||
@click.option("-N", "--name-only", is_flag=True, help="Output only the filename without the full path or extension.")
|
||||
@click.option("-m", "--model", default="models/model.pth", type=click.Path(exists=True), show_default=True, help="The model to use.")
|
||||
@click.argument("files", nargs=-1, type=click.Path(exists=True, allow_dash=True, path_type=Path), required=True)
|
||||
def main(files, threshold, limit, bs, group_tags, model):
|
||||
def main(files, threshold, limit, bs, group_tags, name_only, model):
|
||||
autotagger = Autotagger(model)
|
||||
|
||||
for filepaths in ichunked(get_filepaths(files), bs):
|
||||
@@ -23,12 +24,14 @@ def main(files, threshold, limit, bs, group_tags, model):
|
||||
predictions = autotagger.predict(files, threshold=threshold, limit=limit, bs=bs)
|
||||
|
||||
for filepath, tags in zip(filepaths, predictions):
|
||||
name = filepath.stem if name_only else str(filepath)
|
||||
|
||||
if group_tags:
|
||||
data = { "filename": filepath, "tags": tags }
|
||||
data = { "filename": name, "tags": tags }
|
||||
click.echo(json.dumps(data))
|
||||
else:
|
||||
for tag, score in tags.items():
|
||||
data = { "filename": filepath, "tag": tag, "score": score }
|
||||
data = { "filename": name, "tag": tag, "score": score }
|
||||
click.echo(json.dumps(data))
|
||||
|
||||
def get_filepaths(paths):
|
||||
|
||||
Reference in New Issue
Block a user