Initial commit.
This commit is contained in:
@@ -0,0 +1 @@
|
||||
from .autotagger import Autotagger
|
||||
@@ -0,0 +1,35 @@
|
||||
from fastbook import *
|
||||
from pandas import read_csv
|
||||
import timm
|
||||
|
||||
class Autotagger:
|
||||
def __init__(self, model_path="models/model.pth", data_path="test/tags.csv.gz", tags_path="data/tags.json"):
|
||||
self.model_path = model_path
|
||||
self.learn = self.init_model(data_path=data_path, tags_path=tags_path, model_path=model_path)
|
||||
|
||||
def init_model(self, model_path="model/model.pth", data_path="test/tags.csv.gz", tags_path="data/tags.json"):
|
||||
df = read_csv(data_path)
|
||||
vocab = json.load(open(tags_path))
|
||||
|
||||
dblock = DataBlock(
|
||||
blocks=(ImageBlock, MultiCategoryBlock(vocab=vocab)),
|
||||
get_x = lambda df: Path("test") / df["filename"],
|
||||
get_y = lambda df: df["tags"].split(" "),
|
||||
item_tfms = Resize(224, method = ResizeMethod.Squish),
|
||||
batch_tfms = [RandomErasing()]
|
||||
)
|
||||
|
||||
dls = dblock.dataloaders(df)
|
||||
learn = vision_learner(dls, "resnet152", pretrained=False)
|
||||
model_file = open(model_path, "rb")
|
||||
learn.load(model_file, with_opt=False)
|
||||
|
||||
return learn
|
||||
|
||||
def predict(self, path, threshold=0.01, limit=50):
|
||||
with self.learn.no_bar(), self.learn.no_logging():
|
||||
pred = self.learn.predict(path)
|
||||
scores = [score.item() for score in pred[2]]
|
||||
results = { tag : score for (tag, score) in zip(self.learn.dls.vocab, scores) if score >= threshold }
|
||||
results = sorted(results.items(), key = lambda x: -x[1])
|
||||
return dict(results[:limit])
|
||||
Reference in New Issue
Block a user