autotag: add CSV output mode.
This commit is contained in:
@@ -11,11 +11,12 @@ from more_itertools import ichunked
|
||||
@click.option("-t", "--threshold", default=0.01, type=float, show_default=True, help="The minimum tag confidence level.")
|
||||
@click.option("-n", "--limit", default=50, type=int, show_default=True, help="The maximum number of tags to return per image.")
|
||||
@click.option("-b", "--batch", "bs", default=128, type=int, show_default=True, help="The number of images to process per batch.")
|
||||
@click.option("-c", "--csv", is_flag=True, help="Output CSV instead of JSON.")
|
||||
@click.option("-g/-f", "--group-tags/--flatten-tags", default=True, show_default=True, help="Output rows in {filename, tags} format or {filename, tag, score} format.")
|
||||
@click.option("-N", "--name-only", is_flag=True, help="Output only the filename without the full path or extension.")
|
||||
@click.option("-m", "--model", default="models/model.pth", type=click.Path(exists=True), show_default=True, help="The model to use.")
|
||||
@click.argument("files", nargs=-1, type=click.Path(exists=True, allow_dash=True, path_type=Path), required=True)
|
||||
def main(files, threshold, limit, bs, group_tags, name_only, model):
|
||||
def main(files, threshold, limit, bs, csv, group_tags, name_only, model):
|
||||
autotagger = Autotagger(model)
|
||||
|
||||
for filepaths in ichunked(get_filepaths(files), bs):
|
||||
@@ -24,15 +25,24 @@ def main(files, threshold, limit, bs, group_tags, name_only, model):
|
||||
predictions = autotagger.predict(files, threshold=threshold, limit=limit, bs=bs)
|
||||
|
||||
for filepath, tags in zip(filepaths, predictions):
|
||||
name = filepath.stem if name_only else str(filepath)
|
||||
output_result(filepath, tags, csv, group_tags, name_only)
|
||||
|
||||
if group_tags:
|
||||
data = { "filename": name, "tags": tags }
|
||||
click.echo(json.dumps(data))
|
||||
else:
|
||||
for tag, score in tags.items():
|
||||
data = { "filename": name, "tag": tag, "score": score }
|
||||
click.echo(json.dumps(data))
|
||||
def output_result(filepath, tags, csv, group_tags, name_only):
|
||||
name = filepath.stem if name_only else str(filepath)
|
||||
|
||||
if csv and group_tags:
|
||||
tag_names = " ".join(sorted(tags.keys()))
|
||||
click.echo(f"{name},{tag_names}")
|
||||
elif csv and not group_tags:
|
||||
for tag, score in tags.items():
|
||||
click.echo(f"{name},{tag},{score}")
|
||||
elif group_tags:
|
||||
data = { "filename": name, "tags": tags }
|
||||
click.echo(json.dumps(data))
|
||||
else:
|
||||
for tag, score in tags.items():
|
||||
data = { "filename": name, "tag": tag, "score": score }
|
||||
click.echo(json.dumps(data))
|
||||
|
||||
def get_filepaths(paths):
|
||||
files = (recurse_dir(path) if path.is_dir() else iter([path]) for path in paths)
|
||||
|
||||
Reference in New Issue
Block a user