autotag: accept files from stdin with '-'.

Fix `autotag` so you can pass a filename of '-' to read a file from
stdin. This way you can do this:

  docker run --rm -i ghcr.io/danbooru/autotagger autotag - < image.jpg

...to perform prediction on a single file outside of Docker.
This commit is contained in:
evazion
2022-06-22 15:21:30 -05:00
parent e4344ab2d4
commit 96b6a12924
3 changed files with 11 additions and 11 deletions
+1 -3
View File
@@ -5,7 +5,6 @@ from dotenv import load_dotenv
from autotagger import Autotagger
from base64 import b64encode
from flask import Flask, request, render_template, jsonify
from fastai.vision.core import PILImage
load_dotenv()
model_path = getenv("MODEL_PATH", "models/model.pth")
@@ -22,12 +21,11 @@ def index():
@app.route("/evaluate", methods=["POST"])
def evaluate():
files = request.files.getlist("file")
images = [PILImage.create(file) for file in files]
threshold = float(request.form.get("threshold", 0.1))
output = request.form.get("format", "html")
limit = int(request.form.get("limit", 50))
predictions = autotagger.predict(images, threshold=threshold, limit=limit)
predictions = autotagger.predict(files, threshold=threshold, limit=limit)
if output == "html":
for file in files:
+8 -7
View File
@@ -7,19 +7,20 @@ from autotagger import Autotagger
from pathlib import Path
from more_itertools import ichunked
@click.command(help="Automatically generate tags for an image.", context_settings=dict(max_content_width=140))
@click.command(help="Automatically generate tags for a list of images.", context_settings=dict(max_content_width=140))
@click.option("-t", "--threshold", default=0.01, type=float, show_default=True, help="The minimum tag confidence level.")
@click.option("-n", "--limit", default=50, type=int, show_default=True, help="The maximum number of tags to return per image.")
@click.option("-b", "--batch", "bs", default=128, type=int, show_default=True, help="The number of images to process per batch.")
@click.option("--group-tags/--flatten-tags", default=True, show_default=True, help="Output rows in {filename, tags} format or {filename, tag, score} format.")
@click.option("-m", "--model", default="models/model.pth", type=click.Path(exists=True), show_default=True, help="The model to use.")
@click.argument("file", nargs=-1, type=click.Path(exists=True, allow_dash=True), required=True)
def main(file, threshold, limit, bs, group_tags, model):
@click.argument("files", nargs=-1, type=click.Path(exists=True, allow_dash=True, path_type=Path), required=True)
def main(files, threshold, limit, bs, group_tags, model):
autotagger = Autotagger(model)
for filepaths in ichunked(get_filepaths(file), bs):
for filepaths in ichunked(get_filepaths(files), bs):
filepaths = list(filepaths)
predictions = autotagger.predict(filepaths, threshold=threshold, limit=limit, bs=bs)
files = [click.open_file(filepath, "rb") for filepath in filepaths]
predictions = autotagger.predict(files, threshold=threshold, limit=limit, bs=bs)
for filepath, tags in zip(filepaths, predictions):
if group_tags:
@@ -31,11 +32,11 @@ def main(file, threshold, limit, bs, group_tags, model):
click.echo(json.dumps(data))
def get_filepaths(paths):
files = (recurse_dir(p) if Path(p).is_dir() else iter([p]) for p in paths)
files = (recurse_dir(path) if path.is_dir() else iter([path]) for path in paths)
return itertools.chain(*files)
def recurse_dir(directory):
return (path for path in Path(directory).glob("**/*") if not path.is_dir())
return (path for path in directory.glob("**/*") if not path.is_dir())
if __name__ == "__main__":
main()
+2 -1
View File
@@ -26,8 +26,9 @@ class Autotagger:
return learn
def predict(self, images, threshold=0.01, limit=50, bs=64):
def predict(self, files, threshold=0.01, limit=50, bs=64):
with self.learn.no_bar(), self.learn.no_logging():
images = [PILImage.create(file) for file in files]
dl = self.learn.dls.test_dl(images, bs=bs)
batch, _ = self.learn.get_preds(dl=dl)