"""Library for parsing image inputs before feeding them into the model."""
import tensorflow as tf
from tensorflow.keras.layers.experimental.preprocessing import Rescaling
from typing import List
[docs]def parse_image(file_id: str, img_lookup, img_w: int, img_h: int, root: str = './'):
"""Read image file and resize it to `(img_w, img_h, 3)`.
:param file_id: ID of the image to read.
:type file_id: str
:param img_lookup: Lookup table assigning file ID to file name.
:type img_lookup: tf.lookup.StaticHashTable
:param img_w: Image width after resizing.
:type img_w: int
:param img_h: Image height after resizing.
:type img_h: int
:param root: Root path where the image is stored. Defaults to current folder `./`.
:type root: str
:return: Decoded image.
:rtype: Tensor
"""
# Retrieve file name from file ID, read image and resize
filename = img_lookup.lookup(file_id)
img = tf.io.read_file(root + filename)
img = tf.io.decode_jpeg(img, channels = 3)
img = tf.image.resize(img, [img_w, img_h])
return img
[docs]def parse_image_and_mask(file_id: str, img_lookup, img_w: int, img_h: int, root_img: str = './', root_mask: str = './masks/'):
"""Read image file, resize it to `(img_w, img_h, 3)` and read object mask and resize it to `(img_w, img_h, 1)`.
:param file_id: ID of the image to read.
:type file_id: str
:param img_lookup: Lookup table assigning file ID to file name.
:type img_lookup: tf.lookup.StaticHashTable
:param img_w: Image width after resizing.
:type img_w: int
:param img_h: Image height after resizing.
:type img_h: int
:param root_img: Root path where the image is stored. Defaults to current folder `./`.
:type root_img: str
:param root_mask: Root path where the mask is stored. Defaults to current folder `./masks/`.
:type root_mask: str
:return: Decoded image and binary mask
:rtype: Tuple[Tensor, Tensor]
"""
# Retrieve file name from file ID, read image and resize
filename = img_lookup.lookup(file_id)
img = tf.io.read_file(root_img + filename)
img = tf.io.decode_jpeg(img, channels=3)
img = tf.image.resize(img, [img_w, img_h])
# Read binary mask and resize it accordingly
mask = tf.io.read_file(root_mask + filename)
mask = tf.io.decode_jpeg(mask, channels=1)
mask = tf.image.resize(mask, [img_w, img_h], method = 'nearest') # use nearest interpolation to preserve edge sharpness
mask = tf.where(mask > 200, 1.0, 0.0) # remove resizing artifacts and normalize to 0.0-1.0
return img, mask
[docs]def parse_to_dataset(file_ids: List[str], labels: List[int], dataset_map):
"""Retrieves file names from **file_ids**, decodes and stores them along with **labels** in a `tf.Dataset`.
:param file_ids: File IDs to parse images from (will be keys of a dictionary mapping to the file names).
:type file_ids: list of strings
:param labels: Class labels corresponding to the images.
:type labels: list of integers
:param dataset_map: Function that maps a tuple of filenames and labels to a tuple of parsed images and labels.
:return: Dataset of decoded images and labels.
:rtype: tf.Dataset
"""
file_ids = tf.constant(file_ids)
dataset = tf.data.Dataset.from_tensor_slices((file_ids, labels))
dataset = dataset.map(dataset_map)
return dataset