sonyps1928 commited on
Commit
760431c
·
1 Parent(s): 6ad91fc

update app

Browse files
Files changed (2) hide show
  1. app.py +81 -155
  2. requirements.txt +4 -3
app.py CHANGED
@@ -1,23 +1,16 @@
1
- from flask import Flask, request, jsonify
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
  import torch
4
- import logging
5
- import os
6
 
7
- # Set up logging
8
- logging.basicConfig(level=logging.INFO)
9
- logger = logging.getLogger(__name__)
10
 
11
- # Initialize Flask app
12
- app = Flask(__name__)
13
-
14
- # Load model and tokenizer globally
15
- logger.info("Loading GPT-2 model and tokenizer...")
16
- model_name = "gpt2"
17
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
18
  model = GPT2LMHeadModel.from_pretrained(model_name)
 
 
 
19
  tokenizer.pad_token = tokenizer.eos_token
20
- logger.info("Model loaded successfully!")
21
 
22
 
23
  def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
@@ -30,7 +23,7 @@ def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
30
  with torch.no_grad():
31
  outputs = model.generate(
32
  inputs,
33
- max_length=min(max_length + len(inputs[0]), 512),
34
  temperature=temperature,
35
  top_p=top_p,
36
  top_k=top_k,
@@ -46,150 +39,83 @@ def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
46
  return generated_text[len(prompt):].strip()
47
 
48
  except Exception as e:
49
- logger.error(f"Error generating text: {str(e)}")
50
- return f"Error: {str(e)}"
51
-
52
-
53
- @app.route('/')
54
- def root():
55
- """API information endpoint"""
56
- return jsonify({
57
- "message": "GPT-2 Text Generation API",
58
- "model": model_name,
59
- "endpoints": {
60
- "/": "API information",
61
- "/health": "Health check",
62
- "/generate": "POST - Generate text"
63
- },
64
- "example_request": {
65
- "url": "/generate",
66
- "method": "POST",
67
- "headers": {"Content-Type": "application/json"},
68
- "body": {
69
- "prompt": "Once upon a time",
70
- "max_length": 100,
71
- "temperature": 0.7,
72
- "top_p": 0.9,
73
- "top_k": 50
74
- }
75
- }
76
- })
77
 
78
 
79
- @app.route('/health')
80
- def health():
81
- """Health check endpoint"""
82
- return jsonify({
83
- 'status': 'healthy',
84
- 'model': model_name,
85
- 'framework': 'flask',
86
- 'endpoints_available': ['/health', '/generate', '/']
87
- })
88
-
89
-
90
- @app.route('/generate', methods=['POST'])
91
- def generate():
92
- """Text generation API endpoint"""
93
- try:
94
- # Log the request
95
- logger.info(f"Received generate request from {request.remote_addr}")
96
-
97
- data = request.get_json()
98
-
99
- if not data:
100
- logger.warning("No JSON data provided")
101
- return jsonify({'error': 'No JSON data provided', 'received_content_type': request.content_type}), 400
102
-
103
- # Extract parameters with defaults
104
- prompt = data.get('prompt', '')
105
- max_length = data.get('max_length', 100)
106
- temperature = data.get('temperature', 0.7)
107
- top_p = data.get('top_p', 0.9)
108
- top_k = data.get('top_k', 50)
109
-
110
- if not prompt:
111
- logger.warning("Empty prompt provided")
112
- return jsonify({'error': 'Prompt is required and cannot be empty'}), 400
113
-
114
- # Validate and clamp parameters
115
- max_length = max(10, min(200, int(max_length)))
116
- temperature = max(0.1, min(2.0, float(temperature)))
117
- top_p = max(0.1, min(1.0, float(top_p)))
118
- top_k = max(1, min(100, int(top_k)))
119
-
120
- logger.info(f"Generating text for prompt: '{prompt[:50]}...' with params: max_length={max_length}, temperature={temperature}")
121
-
122
- # Generate text
123
- generated_text = generate_text(prompt, max_length, temperature, top_p, top_k)
124
-
125
- result = {
126
- 'generated_text': generated_text,
127
- 'prompt': prompt,
128
- 'parameters': {
129
- 'max_length': max_length,
130
- 'temperature': temperature,
131
- 'top_p': top_p,
132
- 'top_k': top_k
133
- }
134
- }
135
-
136
- logger.info("Text generation successful")
137
- return jsonify(result)
138
 
139
- except ValueError as e:
140
- logger.error(f"Parameter validation error: {str(e)}")
141
- return jsonify({'error': f'Invalid parameter: {str(e)}'}), 400
142
- except Exception as e:
143
- logger.error(f"Error in /generate: {str(e)}")
144
- return jsonify({'error': f'Internal server error: {str(e)}'}), 500
145
-
146
-
147
- @app.route('/generate', methods=['GET'])
148
- def generate_get():
149
- """GET endpoint for /generate with usage information"""
150
- return jsonify({
151
- 'error': 'Method not allowed',
152
- 'message': 'This endpoint only accepts POST requests',
153
- 'usage': 'Send a POST request with JSON body containing "prompt" field',
154
- 'example': {
155
- 'method': 'POST',
156
- 'headers': {'Content-Type': 'application/json'},
157
- 'body': {'prompt': 'Once upon a time', 'max_length': 100}
158
- }
159
- }), 405
160
-
161
-
162
- @app.errorhandler(404)
163
- def not_found(error):
164
- return jsonify({
165
- 'error': 'Not found',
166
- 'available_endpoints': ['/', '/health', '/generate'],
167
- 'message': 'Check the available endpoints above'
168
- }), 404
169
-
170
-
171
- @app.errorhandler(405)
172
- def method_not_allowed(error):
173
- return jsonify({
174
- 'error': 'Method not allowed',
175
- 'message': 'Check the allowed methods for this endpoint'
176
- }), 405
177
-
178
-
179
- @app.errorhandler(500)
180
- def internal_error(error):
181
- return jsonify({'error': 'Internal server error'}), 500
182
 
183
 
 
184
  if __name__ == "__main__":
185
- # For Hugging Face Spaces
186
- port = int(os.environ.get("PORT", 7860))
187
- host = "0.0.0.0"
188
-
189
- logger.info(f"Starting GPT-2 API server on {host}:{port}")
190
- logger.info("Available endpoints:")
191
- logger.info(" GET / - API information")
192
- logger.info(" GET /health - Health check")
193
- logger.info(" POST /generate - Text generation")
194
-
195
- app.run(host=host, port=port, debug=False)
 
1
+ import gradio as gr
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
  import torch
 
 
4
 
 
 
 
5
 
6
+ # Load model and tokenizer (using smaller GPT-2 for free tier)
7
+ model_name = "gpt2" # You can also use "gpt2-medium" if it fits in memory
 
 
 
 
8
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
9
  model = GPT2LMHeadModel.from_pretrained(model_name)
10
+
11
+
12
+ # Set pad token
13
  tokenizer.pad_token = tokenizer.eos_token
 
14
 
15
 
16
  def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
 
23
  with torch.no_grad():
24
  outputs = model.generate(
25
  inputs,
26
+ max_length=min(max_length + len(inputs[0]), 512), # Limit total length
27
  temperature=temperature,
28
  top_p=top_p,
29
  top_k=top_k,
 
39
  return generated_text[len(prompt):].strip()
40
 
41
  except Exception as e:
42
+ return f"Error generating text: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
+ # Create Gradio interface
46
+ with gr.Blocks(title="GPT-2 Text Generator") as demo:
47
+ gr.Markdown("# GPT-2 Text Generation Server")
48
+ gr.Markdown("Enter a prompt and generate text using GPT-2. Free tier optimized!")
49
+
50
+ with gr.Row():
51
+ with gr.Column():
52
+ prompt_input = gr.Textbox(
53
+ label="Prompt",
54
+ placeholder="Enter your text prompt here...",
55
+ lines=3
56
+ )
57
+
58
+ with gr.Row():
59
+ max_length = gr.Slider(
60
+ minimum=10,
61
+ maximum=200,
62
+ value=100,
63
+ step=10,
64
+ label="Max Length"
65
+ )
66
+ temperature = gr.Slider(
67
+ minimum=0.1,
68
+ maximum=2.0,
69
+ value=0.7,
70
+ step=0.1,
71
+ label="Temperature"
72
+ )
73
+
74
+ with gr.Row():
75
+ top_p = gr.Slider(
76
+ minimum=0.1,
77
+ maximum=1.0,
78
+ value=0.9,
79
+ step=0.1,
80
+ label="Top-p"
81
+ )
82
+ top_k = gr.Slider(
83
+ minimum=1,
84
+ maximum=100,
85
+ value=50,
86
+ step=1,
87
+ label="Top-k"
88
+ )
89
+
90
+ generate_btn = gr.Button("Generate Text", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ with gr.Column():
93
+ output_text = gr.Textbox(
94
+ label="Generated Text",
95
+ lines=10,
96
+ placeholder="Generated text will appear here..."
97
+ )
98
+
99
+ # Examples
100
+ gr.Examples(
101
+ examples=[
102
+ ["Once upon a time in a distant galaxy,"],
103
+ ["The future of artificial intelligence is"],
104
+ ["In the heart of the ancient forest,"],
105
+ ["The detective walked into the room and noticed"],
106
+ ],
107
+ inputs=prompt_input
108
+ )
109
+
110
+ # Connect the function with explicit API endpoint name
111
+ generate_btn.click(
112
+ fn=generate_text,
113
+ inputs=[prompt_input, max_length, temperature, top_p, top_k],
114
+ outputs=output_text,
115
+ api_name="/predict" # Explicit API endpoint for external calls
116
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
+ # Launch the app
120
  if __name__ == "__main__":
121
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
- flask==2.3.3
2
- transformers==4.35.0
3
- torch==2.1.0
 
 
1
+ gradio>=3.50.0
2
+ transformers>=4.30.0
3
+ torch>=2.0.0
4
+ tokenizers>=0.13.0