Better fix for #2 (Crashes when encountering files that aren't images)

* Only ignore invalid files in the `autotag` script, not in the
  web service. In the web service we want to return an error if given an
  invalid or corrupt file.

* Use the logger to log a warning instead of printing directly to stderr.
This commit is contained in:
evazion
2022-06-29 18:06:11 -05:00
parent 6719ed3a40
commit b2160576fc
2 changed files with 24 additions and 11 deletions
+20 -4
View File
@@ -3,6 +3,8 @@
import json
import click
import itertools
import logging
import PIL
from autotagger import Autotagger
from pathlib import Path
from more_itertools import ichunked
@@ -29,12 +31,11 @@ def main(files, threshold, limit, bs, csv, input_file, group_tags, name_only, mo
click.get_current_context().exit()
for filepaths in ichunked(paths, bs):
filepaths = list(filepaths)
files = [click.open_file(filepath, "rb") for filepath in filepaths]
files = list(filter(None, [open_image(filepath) for filepath in filepaths]))
predictions = autotagger.predict(files, threshold=threshold, limit=limit, bs=bs)
for filepath, tags in zip(filepaths, predictions):
output_result(filepath, tags, csv, group_tags, name_only)
for file, tags in zip(files, predictions):
output_result(Path(file.name), tags, csv, group_tags, name_only)
def output_result(filepath, tags, csv, group_tags, name_only):
name = filepath.stem if name_only else str(filepath)
@@ -60,5 +61,20 @@ def get_filepaths(paths):
def recurse_dir(directory):
return (path for path in directory.glob("**/*") if not path.is_dir())
def open_image(filepath):
try:
# Load the image to check that it's a valid file. Open the file twice because PIL closes the underlying file.
file = click.open_file(filepath, "rb")
image = PIL.Image.open(file)
image.load()
image.close()
return click.open_file(filepath, "rb")
except PIL.UnidentifiedImageError as err:
logging.warning(f"Skipped {filepath} (not an image)")
return None
except Exception as err:
logging.warning(f"Skipped {filepath} ({type(err).__name__}: {str(err)})")
return None
if __name__ == "__main__":
main()
+4 -7
View File
@@ -28,14 +28,11 @@ class Autotagger:
return learn
def predict(self, files, threshold=0.01, limit=50, bs=64):
if not files:
return
with self.learn.no_bar(), self.learn.no_logging():
def create_image(file):
try:
return PILImage.create(file)
except:
print("skipped file " + file.name, file=sys.stderr)
return None
images = list(filter(lambda i: i != None, [create_image(file) for file in files]))
images = [PILImage.create(file) for file in files]
dl = self.learn.dls.test_dl(images, bs=bs)
batch, _ = self.learn.get_preds(dl=dl)