diff --git a/algorithms/context_adaptive_interpolator/context_adaptive_interpolator.py b/algorithms/context_adaptive_interpolator/context_adaptive_interpolator.py index e6fec66..e9c7b6c 100644 --- a/algorithms/context_adaptive_interpolator/context_adaptive_interpolator.py +++ b/algorithms/context_adaptive_interpolator/context_adaptive_interpolator.py @@ -6,6 +6,7 @@ from wiener_filter import wienerFilter # Assume greyscale PIL image passed. # What about other color channels? See #11. +# `THRESHOLD` seems to have been designed to assume 256 range based images. def contextAdaptiveInterpolator(I, IImage, showProgress = False): rImage = Image.new('L', (IImage.size[0] - 2, IImage.size[1] - 2)) r = rImage.load() diff --git a/datasets/noise_free_test_images/estimate_prnu.py b/datasets/noise_free_test_images/estimate_prnu.py index f411977..41e3fb2 100644 --- a/datasets/noise_free_test_images/estimate_prnu.py +++ b/datasets/noise_free_test_images/estimate_prnu.py @@ -3,6 +3,7 @@ import os from PIL import Image import numpy as np +import matplotlib.pyplot as plt import sys sys.path.insert(0, '../../algorithms/image_utils/') @@ -13,34 +14,56 @@ sys.path.insert(0, '../../algorithms/context_adaptive_interpolator/') from context_adaptive_interpolator import contextAdaptiveInterpolator +from skimage.restoration import denoise_tv_chambolle + datasetPath = 'no_noise_images' # Note that contrarily to `datasets/fake/`, here we do not have images being Gaussian with `scale` `1` but actual images with pixel values between 0 and 255. # In addition to the range difference, note that the distribution in the first set of images was a Gaussian and here is very different and specific. -PRNU_FACTOR = 0.15 +PRNU_FACTORS = [0.1, 0.01] IMAGE_SIZE_SHAPE = (469, 704) np.random.seed(0) #prnuNpArray = 255 * randomGaussianImage(scale = PRNU_FACTOR, size = IMAGE_SIZE_SHAPE) prnuPil = Image.open('prnu.png').convert('F') -prnuNpArray = np.array(prnuPil) * PRNU_FACTOR +prnusNpArray = [np.array(prnuPil) * PRNU_FACTOR for PRNU_FACTOR in PRNU_FACTORS] def isIn256Range(x): return 0 <= x and x <= 255 +imagesPrnuEstimateNpArray = [] + for imageName in os.listdir(datasetPath): if imageName.endswith('.png'): imagePath = f'{datasetPath}/{imageName}' imageWithoutPrnuPil = Image.open(imagePath).convert('F') imageWithoutPrnuNpArray = np.array(imageWithoutPrnuPil) - #showImageWithMatplotlib(imageWithoutPrnuNpArray) - imageWithPrnuNpArray = imageWithoutPrnuNpArray + prnuNpArray - showImageWithMatplotlib(imageWithPrnuNpArray) - break - assert all([isIn256Range(extreme) for extreme in [imageWithPrnuNpArray.max(), imageWithPrnuNpArray.min()]]), 'Adding the PRNU resulted in out of 256 bounds image' - imageWithPrnuPil = toPilImage(imageWithPrnuNpArray) - imageWithPrnuCaiPil = contextAdaptiveInterpolator(imageWithPrnuPil.load(), imageWithPrnuPil) - imageWithPrnuCaiNpArray = np.array(imageWithPrnuCaiPil) - showImageWithMatplotlib(imageWithPrnuCaiNpArray) + + fig, axes = plt.subplots(3, 2) + fig.suptitle('Single PRNU estimation from an image with PRNU') + + axes[0][0].set_title('Actual PRNU') + axes[0][0].imshow(prnuNpArray) + + axes[0][1].axis('off') + + for prnuIndex, prnuNpArray in enumerate(prnusNpArray): + imageWithPrnuNpArray = imageWithoutPrnuNpArray + prnuNpArray + #assert all([isIn256Range(extreme) for extreme in [imageWithPrnuNpArray.max(), imageWithPrnuNpArray.min()]]), 'Adding the PRNU resulted in out of 256 bounds image' + imageWithPrnuPil = toPilImage(imageWithPrnuNpArray) + #imagePrnuEstimatePil = contextAdaptiveInterpolator(imageWithPrnuPil.load(), imageWithPrnuPil) + #imagePrnuEstimateNpArray = np.array(imagePrnuEstimatePil) + imagePrnuEstimateNpArray = imageWithPrnuNpArray - denoise_tv_chambolle(imageWithPrnuNpArray, weight=0.2, channel_axis=-1) + axis = axes[prnuIndex + 1] + + axis[0].set_title(f'Image with PRNU\nPRNU_FACTOR = {PRNU_FACTORS[prnuIndex]}') + axis[0].imshow(imageWithPrnuNpArray) + + axis[1].set_title('PRNU estimate') + axis[1].imshow(imagePrnuEstimateNpArray) break + imagesPrnuEstimateNpArray += [imagePrnuEstimateNpArray] + +plt.tight_layout(pad = 0) +plt.show()