File size: 2,953 Bytes
ced2e1f
eb014c4
 
 
e72b795
eb014c4
62a137c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb014c4
62a137c
 
 
 
 
 
eb014c4
62a137c
 
 
e72b795
62a137c
eb014c4
62a137c
 
 
 
e951325
 
 
 
eb014c4
62a137c
 
e72b795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62a137c
eb014c4
62a137c
eb014c4
62a137c
e72b795
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
from transformers import pipeline
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate
from PIL import Image

class StoryGenerator:
    def __init__(self, image_model="Salesforce/blip-image-captioning-base"):
        self.image_model = image_model
        self.image_to_text = pipeline("image-to-text", model=self.image_model)
        self.text_models = {
            "Mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2",
            "FLAN-T5": "google/flan-t5-large",
            "MPT-7B": "mosaicml/mpt-7b-instruct",
            "Falcon-7B": "tiiuae/falcon-7b-instruct"
        }
        self.prompt_template = PromptTemplate.from_template("""
        You are a kids story writer. Provide a coherent story for kids
        using this simple instruction: {scenario}. The story should have a clear
        beginning, middle, and end. The story should be interesting and engaging for
        kids. The story should be maximum 200 words long. Do not include
        any adult or polemic content.
        Story:
        """)

    def get_llm(self, model_name):
        return HuggingFaceEndpoint(
            repo_id=self.text_models[model_name],
            temperature=0.5,
            streaming=True
        )

    def img2txt(self, image_path):
        """Convert image to text using Hugging Face pipeline."""
        text = self.image_to_text(image_path)[0]["generated_text"]
        print(f"Image caption: {text}")
        return text

    def generate_story(self, scenario, model_name):
        """Generate a story using image captioning and language model."""
        llm = self.get_llm(model_name)
        story = self.prompt_template | llm
        generated_story = story.invoke(
            input={"scenario": scenario}
        ).strip().rstrip('</s>').strip()
        return generated_story

    def generate_story_from_image(self, image, model_name):
        """Generate a story from an image."""
        print(f"Received image: {image}")
        print(f"Image type: {type(image)}")
        
        if isinstance(image, str):  # If it's a file path
            temp_image_path = image
        else:  # If it's a PIL Image object
            temp_image_path = "temp_image.jpg"
            image.save(temp_image_path)
        
        try:
            scenario = self.img2txt(temp_image_path)
            story = self.generate_story(scenario, model_name)
        finally:
            if temp_image_path != image and os.path.exists(temp_image_path):
                os.remove(temp_image_path)
        
        return story

# Example usage
if __name__ == "__main__":
    generator = StoryGenerator()
    example_image_path = os.path.join("assets", "image.jpg")
    if os.path.exists(example_image_path):
        story = generator.generate_story_from_image(example_image_path, "Mistral-7B")
        print(story)
    else:
        print(f"Example image not found at {example_image_path}")