soiz commited on
Commit
5f32ce4
·
verified ·
1 Parent(s): e5cd3b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -38
app.py CHANGED
@@ -1,43 +1,28 @@
1
- import os
2
- os.system("pip install flask")
 
3
 
4
- from flask import Flask, request, send_file, abort
5
- from io import BytesIO
6
- from PIL import Image
7
- import random
8
 
9
- app = Flask(__name__)
 
 
 
 
 
 
10
 
11
- # モデルのロードをシミュレートするためのスタブ
12
- def load_model(model_name):
13
- # モデルをロードするロジックをここに追加
14
- # ここでは、PILで空の画像を生成する簡易的な例を示します
15
- def model(prompt):
16
- img = Image.new('RGB', (256, 256), color = (73, 109, 137))
17
- return img
18
- return model
19
 
20
- # モデルのロード
21
- models_load = {f'model_{i}': load_model(f'model_{i}') for i in range(1, 7)}
 
 
 
 
 
 
22
 
23
- @app.route('/generate_image', methods=['GET'])
24
- def generate_image():
25
- prompt = request.args.get('prompt')
26
- model_str = request.args.get('model')
27
-
28
- if model_str not in models_load:
29
- abort(404, description="Model not found")
30
-
31
- model = models_load[model_str]
32
- noise = str(random.randint(0, 99999999999))
33
- image = model(f'{prompt} {noise}')
34
-
35
- # 画像をバイナリで保存し、レスポンスとして返す
36
- img_bytes = BytesIO()
37
- image.save(img_bytes, format='PNG')
38
- img_bytes.seek(0)
39
-
40
- return send_file(img_bytes, mimetype='image/png')
41
-
42
- if __name__ == '__main__':
43
- app.run(host='0.0.0.0', port=5000)
 
1
+ from fastapi import FastAPI, Query
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
+ app = FastAPI()
 
 
 
6
 
7
+ # モデルとトークナイザーのロード
8
+ def load_prompter():
9
+ prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
10
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
11
+ tokenizer.pad_token = tokenizer.eos_token
12
+ tokenizer.padding_side = "left"
13
+ return prompter_model, tokenizer
14
 
15
+ prompter_model, prompter_tokenizer = load_prompter()
 
 
 
 
 
 
 
16
 
17
+ @app.get("/generate")
18
+ async def generate(text: str = Query(..., description="Input text to be processed by the model")):
19
+ input_ids = prompter_tokenizer(text.strip() + " Rephrase:", return_tensors="pt").input_ids
20
+ eos_id = prompter_tokenizer.eos_token_id
21
+ outputs = prompter_model.generate(input_ids, do_sample=False, max_new_tokens=75, num_beams=8, num_return_sequences=1, eos_token_id=eos_id, pad_token_id=eos_id, length_penalty=-1.0)
22
+ output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
23
+ res = output_texts[0].replace(text + " Rephrase:", "").strip()
24
+ return {"result": res}
25
 
26
+ if __name__ == "__main__":
27
+ import uvicorn
28
+ uvicorn.run(app, host="0.0.0.0", port=8000)