autotag: allow recursively processing directories.
* Make it so you can give the `autotag` script a mixed list of files and directories, and it will recursively process every file in each directory. * Allow choosing the output format: one (filename, tag, score) tuple per line, or one (filename, tags) tuple per line.
This commit is contained in:
@@ -2,22 +2,40 @@
|
||||
|
||||
import json
|
||||
import click
|
||||
import itertools
|
||||
from autotagger import Autotagger
|
||||
from fastai.vision.core import PILImage
|
||||
from pathlib import Path
|
||||
from more_itertools import ichunked
|
||||
|
||||
@click.command(help="Automatically generate tags for an image.")
|
||||
@click.command(help="Automatically generate tags for an image.", context_settings=dict(max_content_width=140))
|
||||
@click.option("-t", "--threshold", default=0.01, type=float, show_default=True, help="The minimum tag confidence level.")
|
||||
@click.option("-n", "--limit", default=50, type=int, show_default=True, help="The maximum number of tags to return.")
|
||||
@click.option("-m", "--model", default="models/model.pth", type=click.Path(exists=True), help="The model to use.")
|
||||
@click.argument("file", nargs=-1, type=click.File("rb"), required=True)
|
||||
def main(file, threshold, limit, model):
|
||||
@click.option("-n", "--limit", default=50, type=int, show_default=True, help="The maximum number of tags to return per image.")
|
||||
@click.option("-b", "--batch", "bs", default=128, type=int, show_default=True, help="The number of images to process per batch.")
|
||||
@click.option("--group-tags/--flatten-tags", default=True, show_default=True, help="Output rows in {filename, tags} format or {filename, tag, score} format.")
|
||||
@click.option("-m", "--model", default="models/model.pth", type=click.Path(exists=True), show_default=True, help="The model to use.")
|
||||
@click.argument("file", nargs=-1, type=click.Path(exists=True, allow_dash=True), required=True)
|
||||
def main(file, threshold, limit, bs, group_tags, model):
|
||||
autotagger = Autotagger(model)
|
||||
images = [PILImage.create(f) for f in file]
|
||||
predictions = autotagger.predict(images, threshold=threshold, limit=limit)
|
||||
|
||||
for i, tags in enumerate(predictions):
|
||||
data = { "filename": file[i].name, "tags": tags }
|
||||
click.echo(json.dumps(data))
|
||||
for filepaths in ichunked(get_filepaths(file), bs):
|
||||
filepaths = list(filepaths)
|
||||
predictions = autotagger.predict(filepaths, threshold=threshold, limit=limit, bs=bs)
|
||||
|
||||
for filepath, tags in zip(filepaths, predictions):
|
||||
if group_tags:
|
||||
data = { "filename": filepath.name, "tags": tags }
|
||||
click.echo(json.dumps(data))
|
||||
else:
|
||||
for tag, score in tags.items():
|
||||
data = { "filename": filepath.name, "tag": tag, "score": score }
|
||||
click.echo(json.dumps(data))
|
||||
|
||||
def get_filepaths(paths):
|
||||
files = (recurse_dir(p) if Path(p).is_dir() else iter([p]) for p in paths)
|
||||
return itertools.chain(*files)
|
||||
|
||||
def recurse_dir(directory):
|
||||
return (path for path in Path(directory).glob("**/*") if not path.is_dir())
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user