LinkLinkWu commited on
Commit
8ac7c32
·
verified ·
1 Parent(s): 01ae9b7

Update func.py

Browse files
Files changed (1) hide show
  1. func.py +36 -38
func.py CHANGED
@@ -32,59 +32,57 @@ def img2text(img: Union[Image.Image, str, Path]) -> str:
32
  img = Image.open(img)
33
  return _get_captioner()(img)[0]["generated_text"]
34
 
35
- # Step2. Text Generation (Based on Caption)
 
36
  import torch
37
- from transformers import AutoTokenizer, AutoModelForCausalLM
38
 
39
- _MODEL_NAME = "aspis/gpt2-genre-story-generation"
40
- _PROMPT = (
41
  "Write a funny and warm children's story (50-100 words) for ages 3-10, "
42
  "fully and strictly based on this scene: {caption}\nStory:"
43
  )
44
 
45
- _tokenizer, _model = None, None
46
- def _load_story_model():
47
- """Lazy-load tokenizer / model once."""
48
- global _tokenizer, _model
49
- if _model is None:
50
- _tokenizer = AutoTokenizer.from_pretrained(_MODEL_NAME)
51
- _model = AutoModelForCausalLM.from_pretrained(_MODEL_NAME)
52
- if torch.cuda.is_available():
53
- _model = _model.to("cuda")
54
- return _tokenizer, _model
 
 
 
 
 
 
55
 
56
 
57
  def text2story(caption: str) -> str:
58
  """
59
- Generate a 50-100-word children’s story from an image caption.
60
 
61
  Args:
62
- caption: Scene description string.
63
 
64
  Returns:
65
- Story text (≤100 words).
66
  """
67
- tok, mdl = _load_story_model()
68
-
69
- prompt = _PROMPT.format(caption=caption)
70
- inputs = tok(prompt, return_tensors="pt", add_special_tokens=False)
71
- if mdl.device.type == "cuda":
72
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
73
-
74
- gen_ids = mdl.generate(
75
- **inputs,
76
- max_new_tokens=150,
77
- do_sample=True,
78
- top_p=0.9,
79
- temperature=0.8,
80
- pad_token_id=tok.eos_token_id,
81
- repetition_penalty=1.1
82
- )[0]
83
-
84
- # drop prompt, decode, keep ≤100 words, end at last period
85
- story_ids = gen_ids[inputs["input_ids"].shape[-1]:]
86
- story = tok.decode(story_ids, skip_special_tokens=True).strip()
87
- story = story[: story.rfind(".") + 1] if "." in story else story
88
  return " ".join(story.split()[:100])
89
 
90
  # Step3. Text to Audio
 
32
  img = Image.open(img)
33
  return _get_captioner()(img)[0]["generated_text"]
34
 
35
+ # Step 2. Caption ➜ Children’s story (DeepSeek-R1 1.5 B)
36
+ # -------------------------------------------------------------------
37
  import torch
38
+ from transformers import pipeline
39
 
40
+ _GEN_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
41
+ _PROMPT_TMPL = (
42
  "Write a funny and warm children's story (50-100 words) for ages 3-10, "
43
  "fully and strictly based on this scene: {caption}\nStory:"
44
  )
45
 
46
+ _generator = None
47
+ def _get_generator():
48
+ """Lazy-load DeepSeek generator once (GPU if available)."""
49
+ global _generator
50
+ if _generator is None:
51
+ _generator = pipeline(
52
+ "text-generation",
53
+ model=_GEN_MODEL,
54
+ device=0 if torch.cuda.is_available() else -1,
55
+ # common decoding params – can still be overridden in the call
56
+ max_new_tokens=150,
57
+ do_sample=True,
58
+ top_p=0.9,
59
+ temperature=0.8,
60
+ )
61
+ return _generator
62
 
63
 
64
  def text2story(caption: str) -> str:
65
  """
66
+ Generate a 100-word children’s story from the image caption.
67
 
68
  Args:
69
+ caption: scene description string.
70
 
71
  Returns:
72
+ Story text (plain string, trimmed to ≤100 words).
73
  """
74
+ prompt = _PROMPT_TMPL.format(caption=caption)
75
+ gen = _get_generator()(
76
+ prompt,
77
+ return_full_text=False # only the completion, not the prompt
78
+ )[0]["generated_text"]
79
+
80
+ # ensure last sentence is closed
81
+ story = gen.strip()
82
+ if "." in story:
83
+ story = story[: story.rfind(".") + 1]
84
+
85
+ # hard cap at 100 words
 
 
 
 
 
 
 
 
 
86
  return " ".join(story.split()[:100])
87
 
88
  # Step3. Text to Audio