Do batch prediction.

Do batch prediction inside the `predict` method instead of calling
`predict` once for each image.
This commit is contained in:
evazion
2022-06-21 20:08:14 -05:00
parent 5c90c7378d
commit 07b84d73fe
4 changed files with 30 additions and 16 deletions
+10 -7
View File
@@ -1,5 +1,5 @@
from fastbook import *
from pandas import read_csv
from pandas import DataFrame, read_csv
import timm
class Autotagger:
@@ -26,10 +26,13 @@ class Autotagger:
return learn
def predict(self, path, threshold=0.01, limit=50):
def predict(self, images, threshold=0.01, limit=50, bs=64):
with self.learn.no_bar(), self.learn.no_logging():
pred = self.learn.predict(path)
scores = [score.item() for score in pred[2]]
results = { tag : score for (tag, score) in zip(self.learn.dls.vocab, scores) if score >= threshold }
results = sorted(results.items(), key = lambda x: -x[1])
return dict(results[:limit])
dl = self.learn.dls.test_dl(images, bs=bs)
batch, _ = self.learn.get_preds(dl=dl)
for scores in batch:
df = DataFrame({ "tag": self.learn.dls.vocab, "score": scores })
df = df[df.score >= threshold].sort_values("score", ascending=False).head(limit)
tags = dict(zip(df.tag, df.score))
yield tags