diff --git a/datasets/raise/utils.py b/datasets/raise/utils.py index 98668c3..525c4a8 100644 --- a/datasets/raise/utils.py +++ b/datasets/raise/utils.py @@ -16,7 +16,7 @@ import imageio class Color(Enum): RED = auto() GREEN_RIGHT = auto() - GREEN_BOTTOM = auto() + GREEN_LEFT = auto() BLUE = auto() def __str__(self): @@ -117,29 +117,44 @@ RAW_IMAGE_FILE_EXTENSIONS = [ 'nef', ] +SUPPORTED_RAW_PATTERNS = [ + # RG + # GB + [[0, 1], + [3, 2]], + # GB + # RG + [[3, 2], + [0, 1]], +] + def getRawColorChannel(raw, color): colorDesc = raw.color_desc.decode('ascii') assert colorDesc == 'RGBG' - assert np.array_equal(raw.raw_pattern, np.array([[0, 1], [3, 2]], dtype = np.uint8)) - # RG - # GB + assert any(np.array_equal(raw.raw_pattern, np.array(supportedRawPattern, dtype = np.uint8)) for supportedRawPattern in SUPPORTED_RAW_PATTERNS) rawImageVisible = raw.raw_image_visible.copy() - redRawImageVisible = rawImageVisible[::2, ::2] - greenRightRawImageVisible = rawImageVisible[::2, 1::2] + topLeftRawImageVisible = rawImageVisible[::2, ::2] + topRightRawImageVisible = rawImageVisible[::2, 1::2] - greenBottomRawImageVisible = rawImageVisible[1::2, ::2] - blueRawImageVisible = rawImageVisible[1::2, 1::2] + bottomLeftRawImageVisible = rawImageVisible[1::2, ::2] + bottomRightRawImageVisible = rawImageVisible[1::2, 1::2] - match color: - case Color.RED: - imageNpArray = redRawImageVisible - case Color.GREEN_RIGHT: - imageNpArray = greenRightRawImageVisible - case Color.GREEN_BOTTOM: - imageNpArray = greenBottomRawImageVisible - case Color.BLUE: - imageNpArray = blueRawImageVisible + if np.array_equal(raw.raw_pattern, np.array(SUPPORTED_RAW_PATTERNS[0], dtype = np.uint8)): + imageNpArrays = { + Color.RED: topLeftRawImageVisible, + Color.GREEN_RIGHT: topRightRawImageVisible, + Color.GREEN_LEFT: bottomLeftRawImageVisible, + Color.BLUE: bottomRightRawImageVisible, + } + else: # elif np.array_equal(raw.raw_pattern, np.array(SUPPORTED_RAW_PATTERNS[1], dtype = np.uint8)): + imageNpArrays = { + Color.GREEN_LEFT: topLeftRawImageVisible, + Color.BLUE: topRightRawImageVisible, + Color.RED: bottomLeftRawImageVisible, + Color.GREEN_RIGHT: bottomRightRawImageVisible, + } + imageNpArray = imageNpArrays[color] return imageNpArray def isARawImage(imageFilePath): @@ -178,7 +193,7 @@ def mergeSingleColorChannelImagesAccordingToBayerFilter(singleColorChannelImages ''' multipleColorsImage[::2,::2] = singleColorChannelImages[Color.RED] multipleColorsImage[1::2,::2] = singleColorChannelImages[Color.GREEN_RIGHT] - multipleColorsImage[::2,1::2] = singleColorChannelImages[Color.GREEN_BOTTOM] + multipleColorsImage[::2,1::2] = singleColorChannelImages[Color.GREEN_LEFT] multipleColorsImage[1::2,1::2] = singleColorChannelImages[Color.BLUE] return multipleColorsImage