kristianfischerai12345 commited on
Commit
9334142
Β·
verified Β·
1 Parent(s): cba04c3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -0
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
+ import time
5
+
6
+ # Load model and tokenizer
7
+ model_id = "kristianfischerai12345/fischgpt-sft"
8
+ print("Loading FischGPT model...")
9
+ model = GPT2LMHeadModel.from_pretrained(model_id)
10
+ tokenizer = GPT2Tokenizer.from_pretrained(model_id)
11
+
12
+ if tokenizer.pad_token is None:
13
+ tokenizer.pad_token = tokenizer.eos_token
14
+
15
+ model.eval()
16
+ print("Model loaded successfully!")
17
+
18
+ def generate_api(user_message, temperature=0.8, max_length=150, top_p=0.9):
19
+ """
20
+ API endpoint for FischGPT generation.
21
+
22
+ Args:
23
+ user_message (str): The user's input message
24
+ temperature (float): Sampling temperature (0.1-2.0)
25
+ max_length (int): Maximum response length (50-300)
26
+ top_p (float): Top-p sampling (0.1-1.0)
27
+
28
+ Returns:
29
+ dict: Response with generated text and metadata
30
+ """
31
+
32
+ if not user_message or not user_message.strip():
33
+ return {
34
+ "error": "Empty message",
35
+ "response": None,
36
+ "metadata": None
37
+ }
38
+
39
+ try:
40
+ # Format as conversation
41
+ prompt = f"<|user|>{user_message.strip()}<|assistant|>"
42
+
43
+ # Tokenize
44
+ inputs = tokenizer.encode(prompt, return_tensors='pt')
45
+
46
+ # Generate
47
+ start_time = time.time()
48
+ with torch.no_grad():
49
+ outputs = model.generate(
50
+ inputs,
51
+ max_length=max_length,
52
+ temperature=float(temperature),
53
+ top_p=float(top_p),
54
+ do_sample=True,
55
+ pad_token_id=tokenizer.eos_token_id,
56
+ attention_mask=torch.ones_like(inputs)
57
+ )
58
+
59
+ generation_time = time.time() - start_time
60
+
61
+ # Decode and extract response
62
+ full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
63
+ response = full_text.split("<|assistant|>", 1)[1].strip()
64
+
65
+ # Calculate metrics
66
+ input_tokens = len(inputs[0])
67
+ output_tokens = len(outputs[0])
68
+ new_tokens = output_tokens - input_tokens
69
+ tokens_per_sec = new_tokens / generation_time if generation_time > 0 else 0
70
+
71
+ # Return structured response
72
+ return {
73
+ "error": None,
74
+ "response": response,
75
+ "metadata": {
76
+ "input_tokens": input_tokens,
77
+ "output_tokens": output_tokens,
78
+ "new_tokens": new_tokens,
79
+ "generation_time": round(generation_time, 3),
80
+ "tokens_per_second": round(tokens_per_sec, 1),
81
+ "model": "FischGPT-SFT",
82
+ "parameters": {
83
+ "temperature": temperature,
84
+ "max_length": max_length,
85
+ "top_p": top_p
86
+ }
87
+ }
88
+ }
89
+
90
+ except Exception as e:
91
+ return {
92
+ "error": str(e),
93
+ "response": None,
94
+ "metadata": None
95
+ }
96
+
97
+ # Create minimal Gradio interface for API
98
+ with gr.Blocks(title="FischGPT API") as demo:
99
+
100
+ gr.HTML("""
101
+ <div style="text-align: center; padding: 15px; background: #f0f0f0; border-radius: 10px; margin-bottom: 20px;">
102
+ <h2>πŸš€ FischGPT API Backend</h2>
103
+ <p>Minimal interface for API testing. Use the API endpoint for your custom frontend.</p>
104
+ <p><strong>API Endpoint:</strong> <code>/api/predict</code></p>
105
+ </div>
106
+ """)
107
+
108
+ gr.Markdown("""
109
+ ## πŸ”Œ API Usage
110
+
111
+ **Python Example:**
112
+ ```python
113
+ import requests
114
+
115
+ response = requests.post(
116
+ "https://kristianfischerai12345-fischgpt-api.hf.space/api/predict",
117
+ json={
118
+ "data": [
119
+ "Explain machine learning", # user_message
120
+ 0.8, # temperature
121
+ 150, # max_length
122
+ 0.9 # top_p
123
+ ]
124
+ }
125
+ )
126
+
127
+ result = response.json()
128
+ print(result["data"][0]["response"])
129
+ ```
130
+
131
+ **JavaScript/React Example:**
132
+ ```javascript
133
+ const response = await fetch("https://kristianfischerai12345-fischgpt-api.hf.space/api/predict", {
134
+ method: "POST",
135
+ headers: { "Content-Type": "application/json" },
136
+ body: JSON.stringify({
137
+ data: [
138
+ "Explain machine learning", // user_message
139
+ 0.8, // temperature
140
+ 150, // max_length
141
+ 0.9 // top_p
142
+ ]
143
+ })
144
+ });
145
+
146
+ const result = await response.json();
147
+ console.log(result.data[0].response);
148
+ ```
149
+ """)
150
+
151
+ # Simple test interface
152
+ gr.Markdown("### Quick Test Interface")
153
+
154
+ with gr.Row():
155
+ user_input = gr.Textbox(label="Test Message", value="Hello, how are you?", scale=2)
156
+ test_btn = gr.Button("Test API", variant="primary")
157
+
158
+ with gr.Row():
159
+ temperature = gr.Slider(0.1, 2.0, 0.8, label="Temperature")
160
+ max_length = gr.Slider(50, 300, 150, label="Max Length")
161
+ top_p = gr.Slider(0.1, 1.0, 0.9, label="Top-p")
162
+
163
+ output = gr.JSON(label="API Response")
164
+
165
+ # Connect the test interface
166
+ test_btn.click(
167
+ fn=generate_api,
168
+ inputs=[user_input, temperature, max_length, top_p],
169
+ outputs=output
170
+ )
171
+
172
+ # Create the main API interface (this creates the /api/predict endpoint)
173
+ api_interface = gr.Interface(
174
+ fn=generate_api,
175
+ inputs=[
176
+ gr.Textbox(label="User Message"),
177
+ gr.Slider(0.1, 2.0, 0.8, label="Temperature"),
178
+ gr.Slider(50, 300, 150, label="Max Length"),
179
+ gr.Slider(0.1, 1.0, 0.9, label="Top-p")
180
+ ],
181
+ outputs=gr.JSON(label="Response"),
182
+ api_name="predict"
183
+ )
184
+
185
+ if __name__ == "__main__":
186
+ demo.launch()