Image-to-Story / app.py
fffiloni's picture
update clip interrogator call
bc54443 verified
raw
history blame
4.6 kB
import gradio as gr
import re
import os
hf_token = os.environ.get('HF_TOKEN')
from gradio_client import Client, handle_file
clipi_client = Client("fffiloni/CLIP-Interrogator-2")
from transformers import AutoTokenizer, AutoModelForCausalLM
model_path = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, use_auth_token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_path, use_auth_token=hf_token).half().cuda()
#client = Client("https://fffiloni-test-llama-api-debug.hf.space/", hf_token=hf_token)
clipi_client = Client("https://fffiloni-clip-interrogator-2.hf.space/")
def llama_gen_story(prompt):
instruction = """[INST] <<SYS>>\nYou are a storyteller. You'll be given an image description and some keyword about the image.
For that given you'll be asked to generate a story that you think could fit very well with the image provided.
Always answer with a cool story, while being safe as possible. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n{} [/INST]"""
prompt = instruction.format(prompt)
generate_ids = model.generate(tokenizer(prompt, return_tensors='pt').input_ids.cuda(), max_new_tokens=4096)
output_text = tokenizer.decode(generate_ids[0], skip_special_tokens=True)
#print(generate_ids)
#print(output_text)
pattern = r'\[INST\].*?\[/INST\]'
cleaned_text = re.sub(pattern, '', output_text, flags=re.DOTALL)
return cleaned_text
def get_text_after_colon(input_text):
# Find the first occurrence of ":"
colon_index = input_text.find(":")
# Check if ":" exists in the input_text
if colon_index != -1:
# Extract the text after the colon
result_text = input_text[colon_index + 1:].strip()
return result_text
else:
# Return the original text if ":" is not found
return input_text
def infer(image_input, audience):
gr.Info('Calling CLIP Interrogator ...')
clipi_result = client.predict(
image=handle_file(image_input),
mode="best",
best_max_flavors=4,
api_name="/clipi2"
)
print(clipi_result)
llama_q = f"""
I'll give you a simple image caption, please provide a fictional story for a {audience} audience that would fit well with the image. Please be creative, do not worry and only generate a cool fictional story.
Here's the image description:
'{clipi_result[0]}'
"""
gr.Info('Calling Llama2 ...')
result = llama_gen_story(llama_q)
print(f"Llama2 result: {result}")
result = get_text_after_colon(result)
# Split the text into paragraphs based on actual line breaks
paragraphs = result.split('\n')
# Join the paragraphs back with an extra empty line between each paragraph
formatted_text = '\n\n'.join(paragraphs)
return formatted_text
css="""
#col-container {max-width: 910px; margin-left: auto; margin-right: auto;}
div#story textarea {
font-size: 1.5em;
line-height: 1.4em;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"""
<h1 style="text-align: center">Image to Story</h1>
<p style="text-align: center">Upload an image, get a story made by Llama2 !</p>
"""
)
with gr.Row():
with gr.Column():
image_in = gr.Image(label="Image input", type="filepath", elem_id="image-in")
audience = gr.Radio(label="Target Audience", choices=["Children", "Adult"], value="Children")
submit_btn = gr.Button('Tell me a story')
with gr.Column():
#caption = gr.Textbox(label="Generated Caption")
story = gr.Textbox(label="generated Story", elem_id="story")
gr.Examples(examples=[["./examples/crabby.png", "Children"],["./examples/hopper.jpeg", "Adult"]],
fn=infer,
inputs=[image_in, audience],
outputs=[story],
cache_examples=False
)
submit_btn.click(fn=infer, inputs=[image_in, audience], outputs=[story])
demo.queue(max_size=12).launch(ssr_mode=False)