Print correct filenames when files are skipped

This commit is contained in:
dokutan
2022-06-28 19:01:43 +02:00
parent 9fd7697c80
commit b0b93eb471
2 changed files with 6 additions and 4 deletions
+5 -3
View File
@@ -35,12 +35,14 @@ class Autotagger:
except:
print("skipped file " + file.name, file=sys.stderr)
return None
images = list(filter(lambda i: i != None, [create_image(file) for file in files]))
images = [create_image(file) for file in files]
files = [files[i] for i in range(len(files)) if images[i] != None]
images = [image for image in images if image != None]
dl = self.learn.dls.test_dl(images, bs=bs)
batch, _ = self.learn.get_preds(dl=dl)
for scores in batch:
for scores, f in zip(batch, files):
df = DataFrame({ "tag": self.learn.dls.vocab, "score": scores })
df = df[df.score >= threshold].sort_values("score", ascending=False).head(limit)
tags = dict(zip(df.tag, df.score))
yield tags
yield f.name, tags