vibrantGAN / app.py
datafreak's picture
improving the image quality and sharpness
1c850d7 verified
raw
history blame
2.29 kB
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
from tensorflow_examples.models.pix2pix import pix2pix
OUTPUT_CHANNELS = 3
generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_g.load_weights("vibrantGAN-generator-g-final.weights.h5")
def preprocess_single_image(image, target_height=256, target_width=256):
# Convert PIL image to tensorflow tensor
image = tf.convert_to_tensor(np.array(image))
# Ensure image has 3 channels (RGB)
if len(image.shape) == 2: # If grayscale
image = tf.stack([image, image, image], axis=-1)
elif image.shape[-1] == 4: # If RGBA
image = image[:, :, :3]
# Resize the image
image = tf.image.resize(image, [target_height, target_width])
# Normalize to [-1, 1]
image = tf.cast(image, tf.float32)
image = (image / 127.5) - 1
return image
def process_image(input_image, output_resolution=(512, 512)):
if input_image is None:
return None
# Get original input image size
original_size = input_image.size
# Preprocess the image
processed_input = preprocess_single_image(input_image)
# Add batch dimension
processed_input = tf.expand_dims(processed_input, 0)
# Generate prediction
prediction = generator_g(processed_input)
# Convert the prediction to displayable format
output_image = prediction[0] * 0.5 + 0.5 # Denormalize to [0, 1]
output_image = tf.clip_by_value(output_image, 0, 1)
# Resize the output image to the desired resolution
output_image = tf.image.resize(output_image, output_resolution, method='bilinear')
# Convert to numpy array and then to PIL Image
output_array = (output_image.numpy() * 255).astype(np.uint8)
output_pil = Image.fromarray(output_array)
enhancer = ImageEnhance.Sharpness(output_pil)
output_pil = enhancer.enhance(4) # Adjust the factor as needed
return output_pil
# Create Gradio interface
demo = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=gr.Image(type="pil", label="Generated Output"),
title="Image Processing Model",
description="Upload an image to see the model's output.",
)
# Launch the interface
demo.launch(debug=True)