mike23415 commited on
Commit
25dc981
·
verified ·
1 Parent(s): e28964d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -395
app.py CHANGED
@@ -1,410 +1,41 @@
1
- import os
2
- import time
3
- import tempfile
4
- import jinja2
5
- import pdfkit
6
- import torch
7
- import logging
8
- import subprocess
9
- from threading import Thread
10
- from flask import Flask, request, send_file, jsonify
11
  from flask_cors import CORS
12
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
 
13
 
14
- # Configure cache directories
15
- os.environ['HF_HOME'] = '/app/.cache'
16
- os.environ['XDG_CACHE_HOME'] = '/app/.cache'
17
-
18
- # Configure logging
19
- logging.basicConfig(
20
- level=logging.INFO,
21
- format='%(asctime)s [%(levelname)s] %(message)s'
22
- )
23
-
24
- # Initialize Flask app
25
  app = Flask(__name__)
26
- CORS(app)
27
-
28
- # Global state tracking
29
- model_loaded = False
30
- load_error = None
31
- generator = None
32
-
33
- # Find wkhtmltopdf path
34
- WKHTMLTOPDF_PATH = '/usr/bin/wkhtmltopdf'
35
- if not os.path.exists(WKHTMLTOPDF_PATH):
36
- # Try to find it using which
37
- try:
38
- WKHTMLTOPDF_PATH = subprocess.check_output(['which', 'wkhtmltopdf']).decode().strip()
39
- except:
40
- app.logger.warning("Could not find wkhtmltopdf path. Using default.")
41
- WKHTMLTOPDF_PATH = 'wkhtmltopdf'
42
-
43
- # Configure wkhtmltopdf
44
- pdf_config = pdfkit.configuration(wkhtmltopdf=WKHTMLTOPDF_PATH)
45
 
46
- def load_model():
47
- global model_loaded, load_error, generator
48
- try:
49
- app.logger.info("Starting model loading process")
50
-
51
- # Detect device and dtype automatically
52
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
53
- device = "cuda" if torch.cuda.is_available() else "cpu"
54
- app.logger.info(f"Device set to use {device}")
55
-
56
- model = AutoModelForCausalLM.from_pretrained(
57
- "gpt2",
58
- use_safetensors=True,
59
- device_map="auto",
60
- torch_dtype=dtype,
61
- low_cpu_mem_usage=True,
62
- offload_folder="offload"
63
- )
64
-
65
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
66
-
67
- # Initialize pipeline without explicit device assignment
68
- generator = pipeline(
69
- 'text-generation',
70
- model=model,
71
- tokenizer=tokenizer,
72
- torch_dtype=dtype
73
- )
74
-
75
- model_loaded = True
76
- app.logger.info(f"Model loaded successfully on {model.device}")
77
-
78
- except Exception as e:
79
- load_error = str(e)
80
- app.logger.error(f"Model loading failed: {load_error}", exc_info=True)
81
-
82
- # Start model loading in background thread
83
- Thread(target=load_model).start()
84
-
85
- # --------------------------------------------------
86
- # IEEE Format Template
87
- # --------------------------------------------------
88
- IEEE_TEMPLATE = """
89
- <!DOCTYPE html>
90
- <html>
91
- <head>
92
- <meta charset="UTF-8">
93
- <title>{{ title }}</title>
94
- <style>
95
- @page { margin: 0.75in; }
96
- body {
97
- font-family: 'Times New Roman', Times, serif;
98
- font-size: 12pt;
99
- line-height: 1.5;
100
- }
101
- .header { text-align: center; margin-bottom: 24pt; }
102
- .two-column { column-count: 2; column-gap: 0.5in; }
103
- h1 { font-size: 14pt; margin: 12pt 0; }
104
- h2 { font-size: 12pt; margin: 12pt 0 6pt 0; }
105
- .abstract { margin-bottom: 24pt; }
106
- .keywords { font-weight: bold; margin: 12pt 0; }
107
- .references { margin-top: 24pt; }
108
- .reference-item { text-indent: -0.5in; padding-left: 0.5in; }
109
- </style>
110
- </head>
111
- <body>
112
- <div class="header">
113
- <h1>{{ title }}</h1>
114
- <div class="author-info">
115
- {% for author in authors %}
116
- {{ author.name }}<br>
117
- {% if author.institution %}{{ author.institution }}<br>{% endif %}
118
- {% if author.email %}Email: {{ author.email }}{% endif %}
119
- {% if not loop.last %}<br>{% endif %}
120
- {% endfor %}
121
- </div>
122
- </div>
123
 
124
- <div class="abstract">
125
- <h2>Abstract</h2>
126
- {{ abstract }}
127
- <div class="keywords">Keywords— {{ keywords }}</div>
128
- </div>
129
- <div class="two-column">
130
- {% for section in sections %}
131
- <h2>{{ section.title }}</h2>
132
- {{ section.content }}
133
- {% endfor %}
134
- </div>
135
- <div class="references">
136
- <h2>References</h2>
137
- {% for ref in references %}
138
- <div class="reference-item">[{{ loop.index }}] {{ ref }}</div>
139
- {% endfor %}
140
- </div>
141
- </body>
142
- </html>
143
- """
144
-
145
- # --------------------------------------------------
146
- # API Endpoints
147
- # --------------------------------------------------
148
- @app.route('/health', methods=['GET'])
149
- def health_check():
150
- return jsonify({
151
- "status": "ok",
152
- "model_loaded": model_loaded,
153
- "device": "cuda" if torch.cuda.is_available() else "cpu"
154
- }), 200
155
 
156
- app.logger.info(f"Health check returning status: {'ready' if model_loaded else 'loading'}, device: {device_info}")
157
- return jsonify({
158
- "status": "ready" if model_loaded else "loading",
159
- "model_loaded": model_loaded,
160
- "device": device_info
161
- }), status_code
162
-
163
- @app.route('/generate', methods=['POST'])
164
- def generate_pdf():
165
- # Check model status
166
- if not model_loaded:
167
- app.logger.error("PDF generation requested but model not loaded")
168
- return jsonify({
169
- "error": "Model not loaded yet",
170
- "status": "loading"
171
- }), 503
172
-
173
- try:
174
- app.logger.info("Processing PDF generation request")
175
-
176
- # Validate input
177
- data = request.json
178
- if not data:
179
- app.logger.error("No data provided in request")
180
- return jsonify({"error": "No data provided"}), 400
181
-
182
- required = ['title', 'authors', 'content']
183
- if missing := [field for field in required if field not in data]:
184
- app.logger.error(f"Missing required fields: {missing}")
185
- return jsonify({
186
- "error": f"Missing fields: {', '.join(missing)}"
187
- }), 400
188
-
189
- app.logger.info(f"Received request with title: {data['title']}")
190
-
191
- # Format content with model
192
- app.logger.info("Formatting content using the model")
193
- formatted = format_content(data['content'])
194
-
195
- app.logger.info("Creating HTML from template")
196
- # Generate HTML
197
- html = jinja2.Template(IEEE_TEMPLATE).render(
198
- title=data['title'],
199
- authors=data['authors'],
200
- abstract=formatted.get('abstract', ''),
201
- keywords=', '.join(formatted.get('keywords', [])),
202
- sections=formatted.get('sections', []),
203
- references=formatted.get('references', [])
204
- )
205
-
206
- # PDF options
207
- options = {
208
- 'page-size': 'Letter',
209
- 'margin-top': '0.75in',
210
- 'margin-right': '0.75in',
211
- 'margin-bottom': '0.75in',
212
- 'margin-left': '0.75in',
213
- 'encoding': 'UTF-8',
214
- 'quiet': ''
215
- }
216
-
217
- # Create temporary PDF
218
- app.logger.info("Generating PDF file")
219
- pdf_path = None
220
-
221
- try:
222
- with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as f:
223
- pdf_path = f.name
224
-
225
- # Generate PDF using xvfb-run as a separate process
226
- html_path = pdf_path + '.html'
227
- with open(html_path, 'w', encoding='utf-8') as f:
228
- f.write(html)
229
-
230
- command = ['xvfb-run', '-a', WKHTMLTOPDF_PATH] + \
231
- [f'--{k}={v}' for k, v in options.items() if v] + \
232
- [html_path, pdf_path]
233
-
234
- app.logger.info(f"Running command: {' '.join(command)}")
235
- result = subprocess.run(command, capture_output=True, text=True)
236
-
237
- if result.returncode != 0:
238
- app.logger.error(f"PDF generation command failed: {result.stderr}")
239
- # Fallback to direct pdfkit if available
240
- app.logger.info("Trying fallback PDF generation with pdfkit")
241
- pdfkit.from_string(html, pdf_path, options=options, configuration=pdf_config)
242
-
243
- # Clean up HTML file
244
- os.remove(html_path)
245
-
246
- app.logger.info(f"PDF generated successfully at {pdf_path}")
247
- return send_file(pdf_path, mimetype='application/pdf', as_attachment=True,
248
- download_name=f"{data['title'].replace(' ', '_')}.pdf")
249
-
250
- except Exception as e:
251
- app.logger.error(f"PDF generation failed: {str(e)}", exc_info=True)
252
- raise
253
-
254
- except Exception as e:
255
- app.logger.error(f"Request processing failed: {str(e)}", exc_info=True)
256
- return jsonify({"error": str(e)}), 500
257
-
258
- finally:
259
- # Clean up temporary file
260
- if 'pdf_path' in locals() and pdf_path:
261
- try:
262
- app.logger.info(f"Cleaning up temporary file {pdf_path}")
263
- os.remove(pdf_path)
264
- except Exception as e:
265
- app.logger.warning(f"Failed to remove temporary file: {str(e)}")
266
-
267
- # --------------------------------------------------
268
- # Content Formatting
269
- # --------------------------------------------------
270
- def parse_formatted_content(text):
271
- """Parse the generated text into structured sections"""
272
- app.logger.info("Parsing formatted content")
273
 
274
  try:
275
- lines = text.split('\n')
276
-
277
- # Default structure
278
- result = {
279
- 'abstract': '',
280
- 'keywords': ['IEEE', 'format', 'research', 'paper'],
281
- 'sections': [],
282
- 'references': []
283
- }
284
-
285
- # Extract abstract (simple approach - first paragraph after "Abstract")
286
- abstract_start = None
287
- for i, line in enumerate(lines):
288
- if line.strip().lower() == 'abstract':
289
- abstract_start = i + 1
290
- break
291
-
292
- if abstract_start:
293
- abstract_text = []
294
- i = abstract_start
295
- while i < len(lines) and not lines[i].strip().lower().startswith('keyword'):
296
- if lines[i].strip():
297
- abstract_text.append(lines[i].strip())
298
- i += 1
299
- result['abstract'] = ' '.join(abstract_text)
300
-
301
- # Extract keywords
302
- for line in lines:
303
- if line.strip().lower().startswith('keyword'):
304
- # Extract keywords from the line
305
- keyword_parts = line.split('—')
306
- if len(keyword_parts) > 1:
307
- keywords = keyword_parts[1].strip().split(',')
308
- result['keywords'] = [k.strip() for k in keywords if k.strip()]
309
- break
310
 
311
- # Extract sections
312
- current_section = None
313
- section_content = []
314
 
315
- # Skip lines until we find a section heading
316
- started = False
317
- for line in lines:
318
- # Very basic heuristic for Roman numerals section headings
319
- if line.strip() and (line.strip()[0].isupper() or line.strip()[0].isdigit()):
320
- started = True
321
- if not started:
322
- continue
323
-
324
- if line.strip() and (line.strip()[0].isupper() or line.strip()[0].isdigit()) and len(line.strip().split()) <= 6:
325
- # This is likely a section heading
326
- if current_section:
327
- # Save the previous section
328
- result['sections'].append({
329
- 'title': current_section,
330
- 'content': '\n'.join(section_content)
331
- })
332
- section_content = []
333
-
334
- current_section = line.strip()
335
- elif current_section and line.strip().lower() == 'references':
336
- # We've reached the references section
337
- if current_section:
338
- # Save the last section
339
- result['sections'].append({
340
- 'title': current_section,
341
- 'content': '\n'.join(section_content)
342
- })
343
- break
344
- elif current_section:
345
- # Add to current section content
346
- section_content.append(line)
347
-
348
- # Extract references
349
- in_references = False
350
- for line in lines:
351
- if line.strip().lower() == 'references':
352
- in_references = True
353
- continue
354
-
355
- if in_references and line.strip():
356
- result['references'].append(line.strip())
357
-
358
- app.logger.info(f"Content parsed into {len(result['sections'])} sections and {len(result['references'])} references")
359
- return result
360
-
361
- except Exception as e:
362
- app.logger.error(f"Error parsing formatted content: {str(e)}", exc_info=True)
363
- # Return a basic structure if parsing fails
364
- return {
365
- 'abstract': 'Error parsing content.',
366
- 'keywords': ['IEEE', 'format'],
367
- 'sections': [{'title': 'Content', 'content': text}],
368
- 'references': []
369
- }
370
-
371
- def format_content(content):
372
- """Format the content using the ML model"""
373
- try:
374
- app.logger.info("Formatting content with ML model")
375
- prompt = f"Format this research content to IEEE standards with sections, abstract, and references:\n\n{str(content)}"
376
-
377
- response = generator(
378
- prompt,
379
- max_new_tokens=1024, # Increased for more complete generation
380
- temperature=0.5, # More deterministic output
381
- do_sample=True,
382
- truncation=True,
383
- num_return_sequences=1
384
- )
385
-
386
- generated_text = response[0]['generated_text']
387
-
388
- # Remove the prompt from the generated text
389
- if prompt in generated_text:
390
- formatted_text = generated_text[len(prompt):].strip()
391
- else:
392
- formatted_text = generated_text
393
-
394
- app.logger.info("Content formatted successfully")
395
-
396
- # Parse the formatted text into structured sections
397
- return parse_formatted_content(formatted_text)
398
 
 
 
 
399
  except Exception as e:
400
- app.logger.error(f"Error formatting content: {str(e)}", exc_info=True)
401
- # Return the original content if formatting fails
402
- return {
403
- 'abstract': 'Content processing error.',
404
- 'keywords': ['IEEE', 'format'],
405
- 'sections': [{'title': 'Content', 'content': str(content)}],
406
- 'references': []
407
- }
408
 
409
  if __name__ == '__main__':
410
  app.run(host='0.0.0.0', port=5000)
 
1
+ from flask import Flask, request, send_file
 
 
 
 
 
 
 
 
 
2
  from flask_cors import CORS
3
+ from rembg import remove
4
+ from PIL import Image
5
+ import io
6
 
 
 
 
 
 
 
 
 
 
 
 
7
  app = Flask(__name__)
8
+ CORS(app) # Enable CORS for all routes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ @app.route('/remove_bg', methods=['POST'])
11
+ def remove_bg():
12
+ # Check if image file is present
13
+ if 'file' not in request.files:
14
+ return {'error': 'No file uploaded'}, 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ file = request.files['file']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Check if file is an image
19
+ if file.filename == '':
20
+ return {'error': 'No selected file'}, 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  try:
23
+ # Read image file
24
+ input_image = Image.open(file.stream)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Remove background
27
+ output_image = remove(input_image)
 
28
 
29
+ # Convert to bytes
30
+ img_byte_arr = io.BytesIO()
31
+ output_image.save(img_byte_arr, format='PNG')
32
+ img_byte_arr.seek(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Return result
35
+ return send_file(img_byte_arr, mimetype='image/png')
36
+
37
  except Exception as e:
38
+ return {'error': str(e)}, 500
 
 
 
 
 
 
 
39
 
40
  if __name__ == '__main__':
41
  app.run(host='0.0.0.0', port=5000)