spandana30 commited on
Commit
2e1d572
·
verified ·
1 Parent(s): 8f974c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -10,7 +10,6 @@ from langgraph.graph import StateGraph, END
10
 
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
 
13
- # ✅ Corrected keys: base_id and adapter_id
14
  AGENT_MODEL_CONFIG = {
15
  "product_manager": {
16
  "base_id": "unsloth/gemma-3-1b-it",
@@ -63,17 +62,22 @@ class AgentState(TypedDict):
63
  def agent(prompt_template, state: AgentState, agent_key: str, timing_label: str):
64
  start = time.time()
65
  model, tokenizer = load_agent_model(**AGENT_MODEL_CONFIG[agent_key])
66
- prompt = prompt_template.format(**state)
 
 
 
 
67
  response = call_model(prompt, model, tokenizer)
 
68
  state["messages"].append({"role": agent_key, "content": response})
69
  state["timings"][timing_label] = time.time() - start
70
  gc.collect()
71
  return response
72
 
73
  PROMPTS = {
74
- "product_manager": "You're a Product Manager. Refine this user request:\n{messages[-1][content]}",
75
- "project_manager": "You're a Project Manager. Break down this refined request:\n{messages[-1][content]}",
76
- "software_engineer": "You're a Software Engineer. Generate HTML+CSS code for:\n{messages[-1][content]}",
77
  "qa_engineer": "You're a QA Engineer. Review this HTML:\n{html}\nGive feedback or reply APPROVED."
78
  }
79
 
 
10
 
11
  HF_TOKEN = os.getenv("HF_TOKEN")
12
 
 
13
  AGENT_MODEL_CONFIG = {
14
  "product_manager": {
15
  "base_id": "unsloth/gemma-3-1b-it",
 
62
  def agent(prompt_template, state: AgentState, agent_key: str, timing_label: str):
63
  start = time.time()
64
  model, tokenizer = load_agent_model(**AGENT_MODEL_CONFIG[agent_key])
65
+
66
+ latest_input = state["messages"][-1]["content"]
67
+ html_content = state.get("html", "")
68
+
69
+ prompt = prompt_template.format(user_input=latest_input, html=html_content)
70
  response = call_model(prompt, model, tokenizer)
71
+
72
  state["messages"].append({"role": agent_key, "content": response})
73
  state["timings"][timing_label] = time.time() - start
74
  gc.collect()
75
  return response
76
 
77
  PROMPTS = {
78
+ "product_manager": "You're a Product Manager. Refine this user request:\n{user_input}",
79
+ "project_manager": "You're a Project Manager. Break down this refined request:\n{user_input}",
80
+ "software_engineer": "You're a Software Engineer. Generate HTML+CSS code for:\n{user_input}",
81
  "qa_engineer": "You're a QA Engineer. Review this HTML:\n{html}\nGive feedback or reply APPROVED."
82
  }
83