Backup-bdg commited on
Commit
8203986
·
verified ·
1 Parent(s): a8279a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -11
app.py CHANGED
@@ -2,19 +2,25 @@ import gradio as gr
2
  import spaces
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
 
 
 
 
 
5
 
6
  # Model configuration
7
  CHECKPOINT = "bigcode/starcoder2-15b"
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
- # Load tokenizer and model (using bfloat16 for efficiency)
11
- @spaces.GPU(duration=120) # Set duration to 120s to handle model loading/generation
12
  def load_model_and_generate(prompt, max_length=256, temperature=0.2, top_p=0.95):
13
  try:
14
  # Initialize tokenizer
15
  tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
16
 
17
- # Initialize model with bfloat16 for lower memory usage
18
  model = AutoModelForCausalLM.from_pretrained(
19
  CHECKPOINT,
20
  torch_dtype=torch.bfloat16,
@@ -30,9 +36,12 @@ def load_model_and_generate(prompt, max_length=256, temperature=0.2, top_p=0.95)
30
  torch_dtype=torch.bfloat16
31
  )
32
 
 
 
 
33
  # Generate response
34
  result = pipe(
35
- prompt,
36
  max_length=max_length,
37
  temperature=temperature,
38
  top_p=top_p,
@@ -44,23 +53,44 @@ def load_model_and_generate(prompt, max_length=256, temperature=0.2, top_p=0.95)
44
  )
45
 
46
  generated_text = result[0]["generated_text"]
47
- return generated_text
 
 
48
  except Exception as e:
49
  return f"Error: {str(e)}"
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # Gradio interface setup
52
  with gr.Blocks() as demo:
53
- gr.Markdown("# StarCoder2-15B Code Generation")
54
- gr.Markdown("Enter a code prompt (e.g., 'def print_hello_world():') to generate code using bigcode/starcoder2-15b.")
55
 
56
  # Input components
57
- prompt = gr.Textbox(label="Code Prompt", placeholder="Enter your code prompt here...")
58
  max_length = gr.Slider(50, 512, value=256, label="Max Length", step=1)
59
  temperature = gr.Slider(0.1, 1.0, value=0.2, label="Temperature", step=0.1)
60
  top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top P", step=0.05)
61
 
62
  # Output component
63
- output = gr.Textbox(label="Generated Code")
64
 
65
  # Submit button
66
  submit_btn = gr.Button("Generate")
@@ -72,5 +102,9 @@ with gr.Blocks() as demo:
72
  outputs=output
73
  )
74
 
75
- # Launch the Gradio app
76
- demo.launch()
 
 
 
 
 
2
  import spaces
3
  import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
5
+ from fastapi import FastAPI, HTTPException
6
+ import uvicorn
7
+ import json
8
+
9
+ # Initialize FastAPI app
10
+ app = FastAPI()
11
 
12
  # Model configuration
13
  CHECKPOINT = "bigcode/starcoder2-15b"
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ # Load model and tokenizer with ZeroGPU
17
+ @spaces.GPU(duration=120)
18
  def load_model_and_generate(prompt, max_length=256, temperature=0.2, top_p=0.95):
19
  try:
20
  # Initialize tokenizer
21
  tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
22
 
23
+ # Initialize model
24
  model = AutoModelForCausalLM.from_pretrained(
25
  CHECKPOINT,
26
  torch_dtype=torch.bfloat16,
 
36
  torch_dtype=torch.bfloat16
37
  )
38
 
39
+ # Format prompt for chat-like interaction
40
+ chat_prompt = f"User: {prompt}\nAssistant: Let's interpret this as a coding request. Please provide a code-related prompt, or I'll generate a response based on code context.\n{prompt} ```python\n```"
41
+
42
  # Generate response
43
  result = pipe(
44
+ chat_prompt,
45
  max_length=max_length,
46
  temperature=temperature,
47
  top_p=top_p,
 
53
  )
54
 
55
  generated_text = result[0]["generated_text"]
56
+ # Extract response after the prompt
57
+ response = generated_text[len(chat_prompt):].strip() if generated_text.startswith(chat_prompt) else generated_text
58
+ return response
59
  except Exception as e:
60
  return f"Error: {str(e)}"
61
 
62
+ # FastAPI endpoint for backdoor-chat
63
+ @app.post("/backdoor-chat")
64
+ async def backdoor_chat(request: dict):
65
+ try:
66
+ # Validate input
67
+ if not isinstance(request, dict) or "message" not in request:
68
+ raise HTTPException(status_code=400, detail="Request must contain 'message' field")
69
+
70
+ prompt = request["message"]
71
+ max_length = request.get("max_length", 256)
72
+ temperature = request.get("temperature", 0.2)
73
+ top_p = request.get("top_p", 0.95)
74
+
75
+ # Generate response
76
+ response = load_model_and_generate(prompt, max_length, temperature, top_p)
77
+ return {"response": response}
78
+ except Exception as e:
79
+ raise HTTPException(status_code=500, detail=str(e))
80
+
81
  # Gradio interface setup
82
  with gr.Blocks() as demo:
83
+ gr.Markdown("# StarCoder2-15B Chat Interface")
84
+ gr.Markdown("Enter a prompt to generate code or simulate a chat. Use the API endpoint `/backdoor-chat` for programmatic access.")
85
 
86
  # Input components
87
+ prompt = gr.Textbox(label="Message", placeholder="Enter your message (e.g., 'Write a Python function')")
88
  max_length = gr.Slider(50, 512, value=256, label="Max Length", step=1)
89
  temperature = gr.Slider(0.1, 1.0, value=0.2, label="Temperature", step=0.1)
90
  top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top P", step=0.05)
91
 
92
  # Output component
93
+ output = gr.Textbox(label="Generated Response")
94
 
95
  # Submit button
96
  submit_btn = gr.Button("Generate")
 
102
  outputs=output
103
  )
104
 
105
+ # Mount Gradio app to FastAPI
106
+ app = gr.mount_gradio_app(app, demo, path="/")
107
+
108
+ # Run the app (for local testing; Hugging Face handles this in Spaces)
109
+ if __name__ == "__main__":
110
+ uvicorn.run(app, host="0.0.0.0", port=7860)