Files
autotagger-win/autotagger/autotagger.py
T
evazion b975102633 Fix race condition in predict method.
Fix a race condition here:

  with self.learn.no_bar(), self.learn.no_logging():
    dl = self.learn.dls.test_dl(files, bs=bs)
    batch, _ = self.learn.get_preds(dl=dl)

The calls to `no_bar()` and maybe `no_logging()` were unsafe in a
multithreaded environment because they modified the shared `self.learn`
object. This led to random exceptions and deadlocks when gunicorn was
under load in production.
2022-06-30 00:13:35 -05:00

46 lines
1.7 KiB
Python

from fastbook import *
from pandas import DataFrame, read_csv
from fastai.imports import noop
from fastai.callback.progress import ProgressCallback
import timm
import sys
class Autotagger:
def __init__(self, model_path="models/model.pth", data_path="test/tags.csv.gz", tags_path="data/tags.json"):
self.model_path = model_path
self.learn = self.init_model(data_path=data_path, tags_path=tags_path, model_path=model_path)
def init_model(self, model_path="model/model.pth", data_path="test/tags.csv.gz", tags_path="data/tags.json"):
df = read_csv(data_path)
vocab = json.load(open(tags_path))
dblock = DataBlock(
blocks=(ImageBlock, MultiCategoryBlock(vocab=vocab)),
get_x = lambda df: Path("test") / df["filename"],
get_y = lambda df: df["tags"].split(" "),
item_tfms = Resize(224, method = ResizeMethod.Squish),
batch_tfms = [RandomErasing()]
)
dls = dblock.dataloaders(df)
learn = vision_learner(dls, "resnet152", pretrained=False)
model_file = open(model_path, "rb")
learn.load(model_file, with_opt=False)
learn.remove_cb(ProgressCallback)
learn.logger = noop
return learn
def predict(self, files, threshold=0.01, limit=50, bs=64):
if not files:
return
dl = self.learn.dls.test_dl(files, bs=bs)
batch, _ = self.learn.get_preds(dl=dl)
for scores in batch:
df = DataFrame({ "tag": self.learn.dls.vocab, "score": scores })
df = df[df.score >= threshold].sort_values("score", ascending=False).head(limit)
tags = dict(zip(df.tag, df.score))
yield tags