From aaaae88e776a1c24337bfb982a1bb4423e55718c Mon Sep 17 00:00:00 2001 From: evazion Date: Wed, 29 Jun 2022 20:02:18 -0500 Subject: [PATCH] autotag: fix bug when reading file from stdin. Another attempt at fixing #2. b216057 had a bug where you couldn't pass a file through stdin because we tried to read the file twice, which we can't do if the input is a pipe. --- app.py | 4 +++- autotag | 18 +++++++++--------- autotagger/autotagger.py | 3 +-- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/app.py b/app.py index 0d1f4b2..820264b 100755 --- a/app.py +++ b/app.py @@ -4,6 +4,7 @@ from os import getenv from dotenv import load_dotenv from autotagger import Autotagger from base64 import b64encode +from fastai.vision.core import PILImage from flask import Flask, request, render_template, jsonify, abort from werkzeug.exceptions import HTTPException @@ -26,7 +27,8 @@ def evaluate(): output = request.values.get("format", "html") limit = int(request.values.get("limit", 50)) - predictions = autotagger.predict(files, threshold=threshold, limit=limit) + images = [PILImage.create(file) for file in files] + predictions = autotagger.predict(images, threshold=threshold, limit=limit) if output == "html": for file in files: diff --git a/autotag b/autotag index 1d9c93b..ab16621 100755 --- a/autotag +++ b/autotag @@ -5,6 +5,7 @@ import click import itertools import logging import PIL +from fastai.vision.core import PILImage from autotagger import Autotagger from pathlib import Path from more_itertools import ichunked @@ -31,10 +32,13 @@ def main(files, threshold, limit, bs, csv, input_file, group_tags, name_only, mo click.get_current_context().exit() for filepaths in ichunked(paths, bs): - files = list(filter(None, [open_image(filepath) for filepath in filepaths])) - predictions = autotagger.predict(files, threshold=threshold, limit=limit, bs=bs) + paths_with_images = list(filter(None, [open_image(filepath) for filepath in filepaths])) + filepaths = [x[0] for x in paths_with_images] + images = [x[1] for x in paths_with_images] - for file, tags in zip(files, predictions): + predictions = autotagger.predict(images, threshold=threshold, limit=limit, bs=bs) + + for file, tags in zip(filepaths, predictions): output_result(Path(file.name), tags, csv, group_tags, name_only) def output_result(filepath, tags, csv, group_tags, name_only): @@ -63,12 +67,8 @@ def recurse_dir(directory): def open_image(filepath): try: - # Load the image to check that it's a valid file. Open the file twice because PIL closes the underlying file. - file = click.open_file(filepath, "rb") - image = PIL.Image.open(file) - image.load() - image.close() - return click.open_file(filepath, "rb") + with click.open_file(filepath, "rb") as file: + return (filepath, PILImage.create(file)) except PIL.UnidentifiedImageError as err: logging.warning(f"Skipped {filepath} (not an image)") return None diff --git a/autotagger/autotagger.py b/autotagger/autotagger.py index b2388e8..5299e9f 100644 --- a/autotagger/autotagger.py +++ b/autotagger/autotagger.py @@ -32,8 +32,7 @@ class Autotagger: return with self.learn.no_bar(), self.learn.no_logging(): - images = [PILImage.create(file) for file in files] - dl = self.learn.dls.test_dl(images, bs=bs) + dl = self.learn.dls.test_dl(files, bs=bs) batch, _ = self.learn.get_preds(dl=dl) for scores in batch: