diff --git a/autotag b/autotag index a0165f3..96b15f5 100755 --- a/autotag +++ b/autotag @@ -33,7 +33,7 @@ def main(files, threshold, limit, bs, csv, input_file, group_tags, name_only, mo files = [click.open_file(filepath, "rb") for filepath in filepaths] predictions = autotagger.predict(files, threshold=threshold, limit=limit, bs=bs) - for filepath, tags in zip(filepaths, predictions): + for filepath, tags in predictions: output_result(filepath, tags, csv, group_tags, name_only) def output_result(filepath, tags, csv, group_tags, name_only): diff --git a/autotagger/autotagger.py b/autotagger/autotagger.py index c2384ba..af974d8 100644 --- a/autotagger/autotagger.py +++ b/autotagger/autotagger.py @@ -1,6 +1,7 @@ from fastbook import * from pandas import DataFrame, read_csv import timm +import sys class Autotagger: def __init__(self, model_path="models/model.pth", data_path="test/tags.csv.gz", tags_path="data/tags.json"): @@ -32,13 +33,16 @@ class Autotagger: try: return PILImage.create(file) except: + print("skipped file " + file.name, file=sys.stderr) return None - images = list(filter(lambda i: i != None, [create_image(file) for file in files])) + images = [create_image(file) for file in files] + files = [files[i] for i in range(len(files)) if images[i] != None] + images = [image for image in images if image != None] dl = self.learn.dls.test_dl(images, bs=bs) batch, _ = self.learn.get_preds(dl=dl) - for scores in batch: + for scores, f in zip(batch, files): df = DataFrame({ "tag": self.learn.dls.vocab, "score": scores }) df = df[df.score >= threshold].sort_values("score", ascending=False).head(limit) tags = dict(zip(df.tag, df.score)) - yield tags + yield f.name, tags