diff --git a/app.py b/app.py index 88cc34d..ea24e54 100755 --- a/app.py +++ b/app.py @@ -5,7 +5,6 @@ from dotenv import load_dotenv from autotagger import Autotagger from base64 import b64encode from flask import Flask, request, render_template, jsonify -from fastai.vision.core import PILImage load_dotenv() model_path = getenv("MODEL_PATH", "models/model.pth") @@ -22,12 +21,11 @@ def index(): @app.route("/evaluate", methods=["POST"]) def evaluate(): files = request.files.getlist("file") - images = [PILImage.create(file) for file in files] threshold = float(request.form.get("threshold", 0.1)) output = request.form.get("format", "html") limit = int(request.form.get("limit", 50)) - predictions = autotagger.predict(images, threshold=threshold, limit=limit) + predictions = autotagger.predict(files, threshold=threshold, limit=limit) if output == "html": for file in files: diff --git a/autotag b/autotag index 5e39fce..4b4bcf6 100755 --- a/autotag +++ b/autotag @@ -7,19 +7,20 @@ from autotagger import Autotagger from pathlib import Path from more_itertools import ichunked -@click.command(help="Automatically generate tags for an image.", context_settings=dict(max_content_width=140)) +@click.command(help="Automatically generate tags for a list of images.", context_settings=dict(max_content_width=140)) @click.option("-t", "--threshold", default=0.01, type=float, show_default=True, help="The minimum tag confidence level.") @click.option("-n", "--limit", default=50, type=int, show_default=True, help="The maximum number of tags to return per image.") @click.option("-b", "--batch", "bs", default=128, type=int, show_default=True, help="The number of images to process per batch.") @click.option("--group-tags/--flatten-tags", default=True, show_default=True, help="Output rows in {filename, tags} format or {filename, tag, score} format.") @click.option("-m", "--model", default="models/model.pth", type=click.Path(exists=True), show_default=True, help="The model to use.") -@click.argument("file", nargs=-1, type=click.Path(exists=True, allow_dash=True), required=True) -def main(file, threshold, limit, bs, group_tags, model): +@click.argument("files", nargs=-1, type=click.Path(exists=True, allow_dash=True, path_type=Path), required=True) +def main(files, threshold, limit, bs, group_tags, model): autotagger = Autotagger(model) - for filepaths in ichunked(get_filepaths(file), bs): + for filepaths in ichunked(get_filepaths(files), bs): filepaths = list(filepaths) - predictions = autotagger.predict(filepaths, threshold=threshold, limit=limit, bs=bs) + files = [click.open_file(filepath, "rb") for filepath in filepaths] + predictions = autotagger.predict(files, threshold=threshold, limit=limit, bs=bs) for filepath, tags in zip(filepaths, predictions): if group_tags: @@ -31,11 +32,11 @@ def main(file, threshold, limit, bs, group_tags, model): click.echo(json.dumps(data)) def get_filepaths(paths): - files = (recurse_dir(p) if Path(p).is_dir() else iter([p]) for p in paths) + files = (recurse_dir(path) if path.is_dir() else iter([path]) for path in paths) return itertools.chain(*files) def recurse_dir(directory): - return (path for path in Path(directory).glob("**/*") if not path.is_dir()) + return (path for path in directory.glob("**/*") if not path.is_dir()) if __name__ == "__main__": main() diff --git a/autotagger/autotagger.py b/autotagger/autotagger.py index f1e1f5c..d195e3a 100644 --- a/autotagger/autotagger.py +++ b/autotagger/autotagger.py @@ -26,8 +26,9 @@ class Autotagger: return learn - def predict(self, images, threshold=0.01, limit=50, bs=64): + def predict(self, files, threshold=0.01, limit=50, bs=64): with self.learn.no_bar(), self.learn.no_logging(): + images = [PILImage.create(file) for file in files] dl = self.learn.dls.test_dl(images, bs=bs) batch, _ = self.learn.get_preds(dl=dl)