sonyps1928 commited on
Commit
7fe97c0
·
1 Parent(s): 107fb80

update app8

Browse files
Files changed (3) hide show
  1. app.py +234 -209
  2. requirements.txt +3 -4
  3. requirements1.txt +4 -0
app.py CHANGED
@@ -1,18 +1,20 @@
1
- import json
2
- from http.server import HTTPServer, BaseHTTPRequestHandler
3
- from urllib.parse import urlparse, parse_qs
4
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
  import torch
6
  import logging
 
7
 
8
  # Set up logging
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
 
 
 
12
  # Load model and tokenizer globally
13
  logger.info("Loading GPT-2 model and tokenizer...")
14
  model_name = "gpt2"
15
- tokenizer = GPT2LMHeadModel.from_pretrained(model_name)
16
  model = GPT2LMHeadModel.from_pretrained(model_name)
17
  tokenizer.pad_token = tokenizer.eos_token
18
  logger.info("Model loaded successfully!")
@@ -48,220 +50,243 @@ def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
48
  return f"Error: {str(e)}"
49
 
50
 
51
- class GPT2Handler(BaseHTTPRequestHandler):
52
- def _set_headers(self, content_type='application/json'):
53
- self.send_response(200)
54
- self.send_header('Content-type', content_type)
55
- self.send_header('Access-Control-Allow-Origin', '*')
56
- self.send_header('Access-Control-Allow-Methods', 'GET, POST, OPTIONS')
57
- self.send_header('Access-Control-Allow-Headers', 'Content-Type')
58
- self.end_headers()
59
-
60
- def _send_error(self, code, message):
61
- self.send_response(code)
62
- self.send_header('Content-type', 'application/json')
63
- self.end_headers()
64
- response = {'error': message}
65
- self.wfile.write(json.dumps(response).encode())
66
-
67
- def do_OPTIONS(self):
68
- self._set_headers()
69
-
70
- def do_GET(self):
71
- parsed_path = urlparse(self.path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- if parsed_path.path == '/':
74
- # Serve a simple HTML interface
75
- self._set_headers('text/html')
76
- html = '''
77
- <!DOCTYPE html>
78
- <html>
79
- <head>
80
- <title>GPT-2 Text Generator</title>
81
- <style>
82
- body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
83
- .container { margin: 20px 0; }
84
- textarea, input, button { margin: 5px 0; padding: 8px; }
85
- textarea { width: 100%; height: 100px; }
86
- button { background: #007bff; color: white; border: none; padding: 10px 20px; cursor: pointer; }
87
- button:hover { background: #0056b3; }
88
- .output { background: #f8f9fa; padding: 15px; border-radius: 5px; min-height: 100px; }
89
- .controls { display: grid; grid-template-columns: 1fr 1fr; gap: 10px; }
90
- label { font-weight: bold; }
91
- </style>
92
- </head>
93
- <body>
94
- <h1>GPT-2 Text Generator</h1>
95
- <p>Enter a prompt and generate text using GPT-2</p>
96
-
97
- <div class="container">
98
- <label for="prompt">Prompt:</label>
99
- <textarea id="prompt" placeholder="Enter your text prompt here...">Once upon a time in a distant galaxy,</textarea>
100
- </div>
101
-
102
- <div class="controls">
103
- <div>
104
- <label for="maxLength">Max Length: <span id="maxLengthValue">100</span></label>
105
- <input type="range" id="maxLength" min="10" max="200" value="100" step="10">
106
- </div>
107
- <div>
108
- <label for="temperature">Temperature: <span id="temperatureValue">0.7</span></label>
109
- <input type="range" id="temperature" min="0.1" max="2.0" value="0.7" step="0.1">
110
- </div>
111
- <div>
112
- <label for="topP">Top-p: <span id="topPValue">0.9</span></label>
113
- <input type="range" id="topP" min="0.1" max="1.0" value="0.9" step="0.1">
114
- </div>
115
- <div>
116
- <label for="topK">Top-k: <span id="topKValue">50</span></label>
117
- <input type="range" id="topK" min="1" max="100" value="50" step="1">
118
- </div>
119
- </div>
120
-
121
- <div class="container">
122
- <button onclick="generateText()" id="generateBtn">Generate Text</button>
123
- </div>
124
-
125
- <div class="container">
126
- <label>Generated Text:</label>
127
- <div id="output" class="output">Generated text will appear here...</div>
128
- </div>
129
-
130
- <script>
131
- // Update slider value displays
132
- document.getElementById('maxLength').oninput = function() {
133
- document.getElementById('maxLengthValue').textContent = this.value;
134
- }
135
- document.getElementById('temperature').oninput = function() {
136
- document.getElementById('temperatureValue').textContent = this.value;
137
- }
138
- document.getElementById('topP').oninput = function() {
139
- document.getElementById('topPValue').textContent = this.value;
140
- }
141
- document.getElementById('topK').oninput = function() {
142
- document.getElementById('topKValue').textContent = this.value;
143
- }
144
-
145
- async function generateText() {
146
- const btn = document.getElementById('generateBtn');
147
- const output = document.getElementById('output');
148
-
149
- btn.disabled = true;
150
- btn.textContent = 'Generating...';
151
- output.textContent = 'Generating text...';
152
-
153
- const data = {
154
- prompt: document.getElementById('prompt').value,
155
- max_length: parseInt(document.getElementById('maxLength').value),
156
- temperature: parseFloat(document.getElementById('temperature').value),
157
- top_p: parseFloat(document.getElementById('topP').value),
158
- top_k: parseInt(document.getElementById('topK').value)
159
- };
160
-
161
- try {
162
- const response = await fetch('/generate', {
163
- method: 'POST',
164
- headers: {'Content-Type': 'application/json'},
165
- body: JSON.stringify(data)
166
- });
167
-
168
- const result = await response.json();
169
-
170
- if (result.error) {
171
- output.textContent = 'Error: ' + result.error;
172
- } else {
173
- output.textContent = result.generated_text;
174
- }
175
- } catch (error) {
176
- output.textContent = 'Error: ' + error.message;
177
- }
178
-
179
- btn.disabled = false;
180
- btn.textContent = 'Generate Text';
181
- }
182
- </script>
183
- </body>
184
- </html>
185
- '''
186
- self.wfile.write(html.encode())
187
 
188
- elif parsed_path.path == '/health':
189
- # Health check endpoint
190
- self._set_headers()
191
- response = {'status': 'healthy', 'model': model_name}
192
- self.wfile.write(json.dumps(response).encode())
193
 
194
- else:
195
- self._send_error(404, 'Not found')
196
-
197
- def do_POST(self):
198
- if self.path == '/generate':
199
- try:
200
- # Get request body
201
- content_length = int(self.headers['Content-Length'])
202
- post_data = self.rfile.read(content_length)
203
- data = json.loads(post_data.decode())
204
-
205
- # Extract parameters
206
- prompt = data.get('prompt', '')
207
- max_length = data.get('max_length', 100)
208
- temperature = data.get('temperature', 0.7)
209
- top_p = data.get('top_p', 0.9)
210
- top_k = data.get('top_k', 50)
211
-
212
- if not prompt:
213
- self._send_error(400, 'Prompt is required')
214
- return
215
-
216
- # Generate text
217
- logger.info(f"Generating text for prompt: {prompt[:50]}...")
218
- generated_text = generate_text(prompt, max_length, temperature, top_p, top_k)
219
 
220
- # Send response
221
- self._set_headers()
222
- response = {'generated_text': generated_text}
223
- self.wfile.write(json.dumps(response).encode())
224
 
225
- except json.JSONDecodeError:
226
- self._send_error(400, 'Invalid JSON')
227
- except Exception as e:
228
- logger.error(f"Error in POST /generate: {str(e)}")
229
- self._send_error(500, str(e))
230
- else:
231
- self._send_error(404, 'Not found')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
- def log_message(self, format, *args):
234
- # Override to use our logger
235
- logger.info(f"{self.address_string()} - {format % args}")
236
 
 
 
 
 
237
 
238
- def run_server(host='localhost', port=8000):
239
- """Start the HTTP server"""
240
- server_address = (host, port)
241
- httpd = HTTPServer(server_address, GPT2Handler)
242
-
243
- logger.info(f"Starting GPT-2 server on http://{host}:{port}")
244
- logger.info(f"Web interface: http://{host}:{port}")
245
- logger.info(f"API endpoint: http://{host}:{port}/generate")
246
- logger.info(f"Health check: http://{host}:{port}/health")
247
-
 
 
 
 
248
  try:
249
- httpd.serve_forever()
250
- except KeyboardInterrupt:
251
- logger.info("Shutting down server...")
252
- httpd.shutdown()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
 
255
  if __name__ == "__main__":
256
- import sys
257
-
258
- # Parse command line arguments
259
- host = 'localhost'
260
- port = 8000
261
-
262
- if len(sys.argv) > 1:
263
- port = int(sys.argv[1])
264
- if len(sys.argv) > 2:
265
- host = sys.argv[2]
266
-
267
- run_server(host, port)
 
1
+ from flask import Flask, request, jsonify, render_template_string
 
 
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
  import torch
4
  import logging
5
+ import os
6
 
7
  # Set up logging
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
+ # Initialize Flask app
12
+ app = Flask(__name__)
13
+
14
  # Load model and tokenizer globally
15
  logger.info("Loading GPT-2 model and tokenizer...")
16
  model_name = "gpt2"
17
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
18
  model = GPT2LMHeadModel.from_pretrained(model_name)
19
  tokenizer.pad_token = tokenizer.eos_token
20
  logger.info("Model loaded successfully!")
 
50
  return f"Error: {str(e)}"
51
 
52
 
53
+ HTML_TEMPLATE = '''
54
+ <!DOCTYPE html>
55
+ <html>
56
+ <head>
57
+ <title>GPT-2 Text Generator</title>
58
+ <style>
59
+ body {
60
+ font-family: Arial, sans-serif;
61
+ max-width: 800px;
62
+ margin: 0 auto;
63
+ padding: 20px;
64
+ background: #f5f5f5;
65
+ }
66
+ .container {
67
+ background: white;
68
+ padding: 20px;
69
+ border-radius: 10px;
70
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
71
+ margin: 20px 0;
72
+ }
73
+ textarea, input, button { margin: 5px 0; padding: 8px; border-radius: 5px; border: 1px solid #ddd; }
74
+ textarea { width: 100%; height: 100px; font-family: monospace; }
75
+ button {
76
+ background: #007bff;
77
+ color: white;
78
+ border: none;
79
+ padding: 10px 20px;
80
+ cursor: pointer;
81
+ border-radius: 5px;
82
+ }
83
+ button:hover { background: #0056b3; }
84
+ button:disabled { background: #ccc; cursor: not-allowed; }
85
+ .output {
86
+ background: #f8f9fa;
87
+ padding: 15px;
88
+ border-radius: 5px;
89
+ min-height: 100px;
90
+ border: 1px solid #e9ecef;
91
+ font-family: monospace;
92
+ white-space: pre-wrap;
93
+ }
94
+ .controls { display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin: 15px 0; }
95
+ label { font-weight: bold; display: block; margin-bottom: 5px; }
96
+ .slider-container { margin: 10px 0; }
97
+ input[type="range"] { width: 100%; }
98
+ .value-display { font-weight: normal; color: #666; }
99
+ h1 { color: #333; text-align: center; }
100
+ .description { text-align: center; color: #666; margin-bottom: 30px; }
101
+ .examples { margin: 20px 0; }
102
+ .example-btn {
103
+ background: #28a745;
104
+ margin: 5px;
105
+ padding: 5px 10px;
106
+ font-size: 12px;
107
+ }
108
+ .example-btn:hover { background: #218838; }
109
+ </style>
110
+ </head>
111
+ <body>
112
+ <div class="container">
113
+ <h1>🤖 GPT-2 Text Generator</h1>
114
+ <p class="description">Enter a prompt and generate text using GPT-2. Powered by Hugging Face Transformers!</p>
115
 
116
+ <div>
117
+ <label for="prompt">Prompt:</label>
118
+ <textarea id="prompt" placeholder="Enter your text prompt here...">Once upon a time in a distant galaxy,</textarea>
119
+ </div>
120
+
121
+ <div class="examples">
122
+ <label>Quick Examples:</label><br>
123
+ <button class="example-btn" onclick="setPrompt('Once upon a time in a distant galaxy,')">Sci-Fi Story</button>
124
+ <button class="example-btn" onclick="setPrompt('The future of artificial intelligence is')">AI Future</button>
125
+ <button class="example-btn" onclick="setPrompt('In the heart of the ancient forest,')">Fantasy</button>
126
+ <button class="example-btn" onclick="setPrompt('The detective walked into the room and noticed')">Mystery</button>
127
+ </div>
128
+
129
+ <div class="controls">
130
+ <div class="slider-container">
131
+ <label for="maxLength">Max Length: <span id="maxLengthValue" class="value-display">100</span></label>
132
+ <input type="range" id="maxLength" min="10" max="200" value="100" step="10">
133
+ </div>
134
+ <div class="slider-container">
135
+ <label for="temperature">Temperature: <span id="temperatureValue" class="value-display">0.7</span></label>
136
+ <input type="range" id="temperature" min="0.1" max="2.0" value="0.7" step="0.1">
137
+ </div>
138
+ <div class="slider-container">
139
+ <label for="topP">Top-p: <span id="topPValue" class="value-display">0.9</span></label>
140
+ <input type="range" id="topP" min="0.1" max="1.0" value="0.9" step="0.1">
141
+ </div>
142
+ <div class="slider-container">
143
+ <label for="topK">Top-k: <span id="topKValue" class="value-display">50</span></label>
144
+ <input type="range" id="topK" min="1" max="100" value="50" step="1">
145
+ </div>
146
+ </div>
147
+
148
+ <div style="text-align: center;">
149
+ <button onclick="generateText()" id="generateBtn">🚀 Generate Text</button>
150
+ </div>
151
+ </div>
152
+
153
+ <div class="container">
154
+ <label>Generated Text:</label>
155
+ <div id="output" class="output">Generated text will appear here...</div>
156
+ </div>
157
+
158
+ <script>
159
+ // Update slider value displays
160
+ document.getElementById('maxLength').oninput = function() {
161
+ document.getElementById('maxLengthValue').textContent = this.value;
162
+ }
163
+ document.getElementById('temperature').oninput = function() {
164
+ document.getElementById('temperatureValue').textContent = this.value;
165
+ }
166
+ document.getElementById('topP').oninput = function() {
167
+ document.getElementById('topPValue').textContent = this.value;
168
+ }
169
+ document.getElementById('topK').oninput = function() {
170
+ document.getElementById('topKValue').textContent = this.value;
171
+ }
172
+
173
+ function setPrompt(text) {
174
+ document.getElementById('prompt').value = text;
175
+ }
176
+
177
+ async function generateText() {
178
+ const btn = document.getElementById('generateBtn');
179
+ const output = document.getElementById('output');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
+ btn.disabled = true;
182
+ btn.textContent = '⏳ Generating...';
183
+ output.textContent = 'Generating text, please wait...';
 
 
184
 
185
+ const data = {
186
+ prompt: document.getElementById('prompt').value,
187
+ max_length: parseInt(document.getElementById('maxLength').value),
188
+ temperature: parseFloat(document.getElementById('temperature').value),
189
+ top_p: parseFloat(document.getElementById('topP').value),
190
+ top_k: parseInt(document.getElementById('topK').value)
191
+ };
192
+
193
+ try {
194
+ const response = await fetch('/generate', {
195
+ method: 'POST',
196
+ headers: {'Content-Type': 'application/json'},
197
+ body: JSON.stringify(data)
198
+ });
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ const result = await response.json();
 
 
 
201
 
202
+ if (result.error) {
203
+ output.textContent = ' Error: ' + result.error;
204
+ } else {
205
+ output.textContent = result.generated_text || 'No text generated';
206
+ }
207
+ } catch (error) {
208
+ output.textContent = ' Network Error: ' + error.message;
209
+ }
210
+
211
+ btn.disabled = false;
212
+ btn.textContent = '🚀 Generate Text';
213
+ }
214
+
215
+ // Allow Enter key to generate (Ctrl+Enter for textarea)
216
+ document.getElementById('prompt').addEventListener('keydown', function(e) {
217
+ if (e.ctrlKey && e.key === 'Enter') {
218
+ generateText();
219
+ }
220
+ });
221
+ </script>
222
+ </body>
223
+ </html>
224
+ '''
225
 
 
 
 
226
 
227
+ @app.route('/')
228
+ def home():
229
+ """Serve the web interface"""
230
+ return render_template_string(HTML_TEMPLATE)
231
 
232
+
233
+ @app.route('/health')
234
+ def health():
235
+ """Health check endpoint"""
236
+ return jsonify({
237
+ 'status': 'healthy',
238
+ 'model': model_name,
239
+ 'framework': 'flask'
240
+ })
241
+
242
+
243
+ @app.route('/generate', methods=['POST'])
244
+ def generate():
245
+ """Text generation API endpoint"""
246
  try:
247
+ data = request.get_json()
248
+
249
+ if not data:
250
+ return jsonify({'error': 'No JSON data provided'}), 400
251
+
252
+ # Extract parameters
253
+ prompt = data.get('prompt', '')
254
+ max_length = data.get('max_length', 100)
255
+ temperature = data.get('temperature', 0.7)
256
+ top_p = data.get('top_p', 0.9)
257
+ top_k = data.get('top_k', 50)
258
+
259
+ if not prompt:
260
+ return jsonify({'error': 'Prompt is required'}), 400
261
+
262
+ # Validate parameters
263
+ max_length = max(10, min(200, int(max_length)))
264
+ temperature = max(0.1, min(2.0, float(temperature)))
265
+ top_p = max(0.1, min(1.0, float(top_p)))
266
+ top_k = max(1, min(100, int(top_k)))
267
+
268
+ # Generate text
269
+ logger.info(f"Generating text for prompt: {prompt[:50]}...")
270
+ generated_text = generate_text(prompt, max_length, temperature, top_p, top_k)
271
+
272
+ return jsonify({'generated_text': generated_text})
273
+
274
+ except Exception as e:
275
+ logger.error(f"Error in /generate: {str(e)}")
276
+ return jsonify({'error': str(e)}), 500
277
+
278
+
279
+ @app.errorhandler(404)
280
+ def not_found(error):
281
+ return jsonify({'error': 'Not found'}), 404
282
+
283
+
284
+ @app.errorhandler(500)
285
+ def internal_error(error):
286
+ return jsonify({'error': 'Internal server error'}), 500
287
 
288
 
289
  if __name__ == "__main__":
290
+ # For Hugging Face Spaces
291
+ port = int(os.environ.get("PORT", 7860))
292
+ app.run(host="0.0.0.0", port=port, debug=False)
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- gradio>=3.50.0
2
- transformers>=4.30.0
3
- torch>=2.0.0
4
- tokenizers>=0.13.0
 
1
+ flask==2.3.3
2
+ transformers==4.35.0
3
+ torch==2.1.0
 
requirements1.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio>=3.50.0
2
+ transformers>=4.30.0
3
+ torch>=2.0.0
4
+ tokenizers>=0.13.0