From 07b84d73fecdb886aea9793c0f34d4a42ee7cf7b Mon Sep 17 00:00:00 2001 From: evazion Date: Tue, 21 Jun 2022 20:08:14 -0500 Subject: [PATCH] Do batch prediction. Do batch prediction inside the `predict` method instead of calling `predict` once for each image. --- app.py | 13 ++++++++++--- autotag | 8 ++++++-- autotagger/autotagger.py | 17 ++++++++++------- templates/evaluate.html | 8 ++++---- 4 files changed, 30 insertions(+), 16 deletions(-) diff --git a/app.py b/app.py index 75459d6..88cc34d 100755 --- a/app.py +++ b/app.py @@ -5,6 +5,7 @@ from dotenv import load_dotenv from autotagger import Autotagger from base64 import b64encode from flask import Flask, request, render_template, jsonify +from fastai.vision.core import PILImage load_dotenv() model_path = getenv("MODEL_PATH", "models/model.pth") @@ -21,15 +22,21 @@ def index(): @app.route("/evaluate", methods=["POST"]) def evaluate(): files = request.files.getlist("file") + images = [PILImage.create(file) for file in files] threshold = float(request.form.get("threshold", 0.1)) output = request.form.get("format", "html") limit = int(request.form.get("limit", 50)) + predictions = autotagger.predict(images, threshold=threshold, limit=limit) + if output == "html": - predictions = [{ "data": b64encode(data).decode(), "tags": autotagger.predict(data, threshold=threshold, limit=limit) } for data in (file.stream.read() for file in files)] - return render_template("evaluate.html", predictions=predictions) + for file in files: + file.seek(0) + + base64data = [b64encode(file.read()).decode() for file in files] + return render_template("evaluate.html", predictions=zip(base64data, predictions)) elif output == "json": - predictions = [{ "filename": file.filename, "tags": autotagger.predict(file.read(), threshold=threshold, limit=limit) } for file in files] + predictions = [{ "filename": file.filename, "tags": tags } for file, tags in zip(files, predictions)] return jsonify(predictions) else: return 400 diff --git a/autotag b/autotag index 6bd5669..570fc9b 100755 --- a/autotag +++ b/autotag @@ -12,8 +12,12 @@ from fastai.vision.core import PILImage @click.argument("file", nargs=-1, type=click.File("rb"), required=True) def main(file, threshold, limit, model): autotagger = Autotagger(model) - predictions = [{ "filename": f.name, "tags": autotagger.predict(PILImage.create(f), threshold=threshold, limit=limit) } for f in file] - click.echo(json.dumps(predictions, indent=2)) + images = [PILImage.create(f) for f in file] + predictions = autotagger.predict(images, threshold=threshold, limit=limit) + + for i, tags in enumerate(predictions): + data = { "filename": file[i].name, "tags": tags } + click.echo(json.dumps(data)) if __name__ == "__main__": main() diff --git a/autotagger/autotagger.py b/autotagger/autotagger.py index 5678d87..f1e1f5c 100644 --- a/autotagger/autotagger.py +++ b/autotagger/autotagger.py @@ -1,5 +1,5 @@ from fastbook import * -from pandas import read_csv +from pandas import DataFrame, read_csv import timm class Autotagger: @@ -26,10 +26,13 @@ class Autotagger: return learn - def predict(self, path, threshold=0.01, limit=50): + def predict(self, images, threshold=0.01, limit=50, bs=64): with self.learn.no_bar(), self.learn.no_logging(): - pred = self.learn.predict(path) - scores = [score.item() for score in pred[2]] - results = { tag : score for (tag, score) in zip(self.learn.dls.vocab, scores) if score >= threshold } - results = sorted(results.items(), key = lambda x: -x[1]) - return dict(results[:limit]) + dl = self.learn.dls.test_dl(images, bs=bs) + batch, _ = self.learn.get_preds(dl=dl) + + for scores in batch: + df = DataFrame({ "tag": self.learn.dls.vocab, "score": scores }) + df = df[df.score >= threshold].sort_values("score", ascending=False).head(limit) + tags = dict(zip(df.tag, df.score)) + yield tags diff --git a/templates/evaluate.html b/templates/evaluate.html index c615667..4a41175 100644 --- a/templates/evaluate.html +++ b/templates/evaluate.html @@ -9,15 +9,15 @@ < Back
- {% for prediction in predictions %} + {% for base64data, tags in predictions %}
- +
- {% for tag, score in prediction["tags"].items() %} + {% for tag, score in tags.items() %}
? @@ -28,7 +28,7 @@ {% endfor %}
- +
{% endfor %}