BeveledCube commited on
Commit
93a2217
1 Parent(s): 987e371

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +16 -16
main.py CHANGED
@@ -1,8 +1,4 @@
1
- from fastapi.staticfiles import StaticFiles
2
- from fastapi.responses import FileResponse
3
- from pydantic import BaseModel
4
- from fastapi import FastAPI
5
-
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
 
8
  model_name = "facebook/blenderbot-1B-distill"
@@ -16,22 +12,24 @@ model_name = "facebook/blenderbot-1B-distill"
16
 
17
  # https://www.youtube.com/watch?v=irjYqV6EebU
18
 
19
- app = FastAPI()
20
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
22
 
23
- class req(BaseModel):
24
- prompt: str
25
-
26
  @app.get("/")
27
  def read_root():
28
- return FileResponse(path="templates/index.html", media_type="text/html")
29
 
30
- @app.post("/api")
31
- def read_root(data: req):
32
- print("Prompt:", data.prompt)
33
 
34
- input_text = data.prompt
 
 
 
 
 
35
 
36
  # Tokenize the input text
37
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
@@ -41,6 +39,8 @@ def read_root(data: req):
41
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
42
 
43
  answer_data = { "answer": generated_text }
44
- print("Answer:", generated_text)
45
 
46
- return answer_data
 
 
 
1
+ from flask import Flask, request, render_template, jsonify
 
 
 
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
  model_name = "facebook/blenderbot-1B-distill"
 
12
 
13
  # https://www.youtube.com/watch?v=irjYqV6EebU
14
 
15
+ app = Flask("AI API")
16
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
 
 
 
 
19
  @app.get("/")
20
  def read_root():
21
+ return render_template("index.html")
22
 
23
+ @app.route("/test")
24
+ def test_route():
25
+ return "This is a test route."
26
 
27
+ @app.route("/api", methods=["POST"])
28
+ def receive_data():
29
+ data = request.get_json()
30
+ print("Prompt:", data["prompt"])
31
+
32
+ input_text = data["prompt"]
33
 
34
  # Tokenize the input text
35
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
 
39
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
40
 
41
  answer_data = { "answer": generated_text }
42
+ print("Response:", generated_text)
43
 
44
+ return jsonify(answer_data)
45
+
46
+ app.run(host="0.0.0.0", port=25428, debug=False)