Do batch prediction.
Do batch prediction inside the `predict` method instead of calling `predict` once for each image.
This commit is contained in:
@@ -5,6 +5,7 @@ from dotenv import load_dotenv
|
||||
from autotagger import Autotagger
|
||||
from base64 import b64encode
|
||||
from flask import Flask, request, render_template, jsonify
|
||||
from fastai.vision.core import PILImage
|
||||
|
||||
load_dotenv()
|
||||
model_path = getenv("MODEL_PATH", "models/model.pth")
|
||||
@@ -21,15 +22,21 @@ def index():
|
||||
@app.route("/evaluate", methods=["POST"])
|
||||
def evaluate():
|
||||
files = request.files.getlist("file")
|
||||
images = [PILImage.create(file) for file in files]
|
||||
threshold = float(request.form.get("threshold", 0.1))
|
||||
output = request.form.get("format", "html")
|
||||
limit = int(request.form.get("limit", 50))
|
||||
|
||||
predictions = autotagger.predict(images, threshold=threshold, limit=limit)
|
||||
|
||||
if output == "html":
|
||||
predictions = [{ "data": b64encode(data).decode(), "tags": autotagger.predict(data, threshold=threshold, limit=limit) } for data in (file.stream.read() for file in files)]
|
||||
return render_template("evaluate.html", predictions=predictions)
|
||||
for file in files:
|
||||
file.seek(0)
|
||||
|
||||
base64data = [b64encode(file.read()).decode() for file in files]
|
||||
return render_template("evaluate.html", predictions=zip(base64data, predictions))
|
||||
elif output == "json":
|
||||
predictions = [{ "filename": file.filename, "tags": autotagger.predict(file.read(), threshold=threshold, limit=limit) } for file in files]
|
||||
predictions = [{ "filename": file.filename, "tags": tags } for file, tags in zip(files, predictions)]
|
||||
return jsonify(predictions)
|
||||
else:
|
||||
return 400
|
||||
|
||||
@@ -12,8 +12,12 @@ from fastai.vision.core import PILImage
|
||||
@click.argument("file", nargs=-1, type=click.File("rb"), required=True)
|
||||
def main(file, threshold, limit, model):
|
||||
autotagger = Autotagger(model)
|
||||
predictions = [{ "filename": f.name, "tags": autotagger.predict(PILImage.create(f), threshold=threshold, limit=limit) } for f in file]
|
||||
click.echo(json.dumps(predictions, indent=2))
|
||||
images = [PILImage.create(f) for f in file]
|
||||
predictions = autotagger.predict(images, threshold=threshold, limit=limit)
|
||||
|
||||
for i, tags in enumerate(predictions):
|
||||
data = { "filename": file[i].name, "tags": tags }
|
||||
click.echo(json.dumps(data))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastbook import *
|
||||
from pandas import read_csv
|
||||
from pandas import DataFrame, read_csv
|
||||
import timm
|
||||
|
||||
class Autotagger:
|
||||
@@ -26,10 +26,13 @@ class Autotagger:
|
||||
|
||||
return learn
|
||||
|
||||
def predict(self, path, threshold=0.01, limit=50):
|
||||
def predict(self, images, threshold=0.01, limit=50, bs=64):
|
||||
with self.learn.no_bar(), self.learn.no_logging():
|
||||
pred = self.learn.predict(path)
|
||||
scores = [score.item() for score in pred[2]]
|
||||
results = { tag : score for (tag, score) in zip(self.learn.dls.vocab, scores) if score >= threshold }
|
||||
results = sorted(results.items(), key = lambda x: -x[1])
|
||||
return dict(results[:limit])
|
||||
dl = self.learn.dls.test_dl(images, bs=bs)
|
||||
batch, _ = self.learn.get_preds(dl=dl)
|
||||
|
||||
for scores in batch:
|
||||
df = DataFrame({ "tag": self.learn.dls.vocab, "score": scores })
|
||||
df = df[df.score >= threshold].sort_values("score", ascending=False).head(limit)
|
||||
tags = dict(zip(df.tag, df.score))
|
||||
yield tags
|
||||
|
||||
@@ -9,15 +9,15 @@
|
||||
<a class="text-xs text-sky-600 hover:text-sky-500 mr-4" href="/">< Back</a>
|
||||
|
||||
<div class="mt-4">
|
||||
{% for prediction in predictions %}
|
||||
{% for base64data, tags in predictions %}
|
||||
<div class="flex flex-col p-2 gap-2 border rounded md:flex-row md:max-h-[80vh]">
|
||||
<div class="flex-1 flex items-center justify-center">
|
||||
<img class="max-w-full max-h-full h-auto" src="data:image/jpg;base64,{{ prediction["data"] | safe }}">
|
||||
<img class="max-w-full max-h-full h-auto" src="data:image/jpg;base64,{{ base64data | safe }}">
|
||||
</div>
|
||||
|
||||
<div class="flex-0 overflow-scroll md:pr-2">
|
||||
<table class="w-full leading-4">
|
||||
{% for tag, score in prediction["tags"].items() %}
|
||||
{% for tag, score in tags.items() %}
|
||||
<tr>
|
||||
<td>
|
||||
<a class="text-sky-600 hover:text-sky-500" href="https://danbooru.donmai.us/wiki_pages/{{ tag | urlencode }}">?</a>
|
||||
@@ -28,7 +28,7 @@
|
||||
{% endfor %}
|
||||
</table>
|
||||
|
||||
<textarea class="w-full text-gray-500 mt-2" rows="4">{{ " ".join(prediction["tags"].keys()) }}</textarea>
|
||||
<textarea class="w-full text-gray-500 mt-2" rows="4">{{ " ".join(tags.keys()) }}</textarea>
|
||||
</div>
|
||||
</div>
|
||||
{% endfor %}
|
||||
|
||||
Reference in New Issue
Block a user