Merge pull request #4 from dokutan/master
Print correct filenames when files are skipped
This commit is contained in:
@@ -33,7 +33,7 @@ def main(files, threshold, limit, bs, csv, input_file, group_tags, name_only, mo
|
||||
files = [click.open_file(filepath, "rb") for filepath in filepaths]
|
||||
predictions = autotagger.predict(files, threshold=threshold, limit=limit, bs=bs)
|
||||
|
||||
for filepath, tags in zip(filepaths, predictions):
|
||||
for filepath, tags in predictions:
|
||||
output_result(filepath, tags, csv, group_tags, name_only)
|
||||
|
||||
def output_result(filepath, tags, csv, group_tags, name_only):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastbook import *
|
||||
from pandas import DataFrame, read_csv
|
||||
import timm
|
||||
import sys
|
||||
|
||||
class Autotagger:
|
||||
def __init__(self, model_path="models/model.pth", data_path="test/tags.csv.gz", tags_path="data/tags.json"):
|
||||
@@ -32,13 +33,16 @@ class Autotagger:
|
||||
try:
|
||||
return PILImage.create(file)
|
||||
except:
|
||||
print("skipped file " + file.name, file=sys.stderr)
|
||||
return None
|
||||
images = list(filter(lambda i: i != None, [create_image(file) for file in files]))
|
||||
images = [create_image(file) for file in files]
|
||||
files = [files[i] for i in range(len(files)) if images[i] != None]
|
||||
images = [image for image in images if image != None]
|
||||
dl = self.learn.dls.test_dl(images, bs=bs)
|
||||
batch, _ = self.learn.get_preds(dl=dl)
|
||||
|
||||
for scores in batch:
|
||||
for scores, f in zip(batch, files):
|
||||
df = DataFrame({ "tag": self.learn.dls.vocab, "score": scores })
|
||||
df = df[df.score >= threshold].sort_values("score", ascending=False).head(limit)
|
||||
tags = dict(zip(df.tag, df.score))
|
||||
yield tags
|
||||
yield f.name, tags
|
||||
|
||||
Reference in New Issue
Block a user