Dylan commited on
Commit
350f8a0
·
1 Parent(s): 0c35e90

calling gemma

Browse files
Files changed (2) hide show
  1. agents.py +20 -9
  2. app.py +1 -0
agents.py CHANGED
@@ -31,6 +31,9 @@ def build_graph():
31
  workflow = StateGraph(State)
32
 
33
  # Add nodes
 
 
 
34
  workflow.add_node("caption_image", caption_image)
35
  workflow.add_node("describe_with_voice", describe_with_voice)
36
 
@@ -55,16 +58,21 @@ model = Gemma3ForConditionalGeneration.from_pretrained(
55
  )
56
 
57
 
58
- def describe_with_voice(state: State) -> State:
59
- state["description"] = "Dummy description"
 
 
60
  return state
61
 
62
 
63
- def caption_image(state: State) -> State:
64
- state["caption"] = "Dummy caption"
 
 
 
65
 
66
 
67
- def describe_with_voice2(state: State) -> State:
68
  caption = state["caption"]
69
  voice = state["voice"]
70
 
@@ -75,7 +83,10 @@ def describe_with_voice2(state: State) -> State:
75
  "sarcastic teenager": "You are a sarcastic and disinterested teenager.",
76
  }
77
  messages = [
78
- {"role": "system", "content": [voice_prompts.get(voice)]},
 
 
 
79
  {
80
  "role": "user",
81
  "content": [
@@ -93,7 +104,7 @@ def describe_with_voice2(state: State) -> State:
93
  input_len = inputs["input_ids"].shape[-1]
94
 
95
  with torch.inference_mode():
96
- generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
97
  generation = generation[0][input_len:]
98
 
99
  description = processor.decode(generation, skip_special_tokens=True)
@@ -103,7 +114,7 @@ def describe_with_voice2(state: State) -> State:
103
  return state
104
 
105
 
106
- def caption_image2(state: State) -> State:
107
  # image is PIL
108
  image = state["image"]
109
 
@@ -136,7 +147,7 @@ def caption_image2(state: State) -> State:
136
  input_len = inputs["input_ids"].shape[-1]
137
 
138
  with torch.inference_mode():
139
- generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
140
  generation = generation[0][input_len:]
141
 
142
  caption = processor.decode(generation, skip_special_tokens=True)
 
31
  workflow = StateGraph(State)
32
 
33
  # Add nodes
34
+ # workflow.add_node("caption_image", caption_image_dummy)
35
+ # workflow.add_node("describe_with_voice", describe_with_voice_dummy)
36
+
37
  workflow.add_node("caption_image", caption_image)
38
  workflow.add_node("describe_with_voice", describe_with_voice)
39
 
 
58
  )
59
 
60
 
61
+ def describe_with_voice_dummy(state: State) -> State:
62
+ print("Describe")
63
+ voice = state["voice"]
64
+ state["description"] = f"Dummy description from {voice}"
65
  return state
66
 
67
 
68
+ def caption_image_dummy(state: State) -> State:
69
+ print("Caption")
70
+ voice = state["voice"]
71
+ state["caption"] = f"Dummy caption from {voice}"
72
+ return state
73
 
74
 
75
+ def describe_with_voice(state: State) -> State:
76
  caption = state["caption"]
77
  voice = state["voice"]
78
 
 
83
  "sarcastic teenager": "You are a sarcastic and disinterested teenager.",
84
  }
85
  messages = [
86
+ {
87
+ "role": "system",
88
+ "content": [{"type": "text", "text": voice_prompts.get(voice)}],
89
+ },
90
  {
91
  "role": "user",
92
  "content": [
 
104
  input_len = inputs["input_ids"].shape[-1]
105
 
106
  with torch.inference_mode():
107
+ generation = model.generate(**inputs, max_new_tokens=1000, do_sample=False)
108
  generation = generation[0][input_len:]
109
 
110
  description = processor.decode(generation, skip_special_tokens=True)
 
114
  return state
115
 
116
 
117
+ def caption_image(state: State) -> State:
118
  # image is PIL
119
  image = state["image"]
120
 
 
147
  input_len = inputs["input_ids"].shape[-1]
148
 
149
  with torch.inference_mode():
150
+ generation = model.generate(**inputs, max_new_tokens=1000, do_sample=False)
151
  generation = generation[0][input_len:]
152
 
153
  caption = processor.decode(generation, skip_special_tokens=True)
app.py CHANGED
@@ -12,6 +12,7 @@ def process_and_display(image, voice):
12
 
13
  # Run the graph
14
  result = graph.invoke(state)
 
15
 
16
  # Return the caption and description
17
  return result["caption"], result["description"]
 
12
 
13
  # Run the graph
14
  result = graph.invoke(state)
15
+ print(result)
16
 
17
  # Return the caption and description
18
  return result["caption"], result["description"]