ZeroTwo3 commited on
Commit
389b910
·
0 Parent(s):

Duplicate from ZeroTwo3/flask_test

Browse files
Files changed (5) hide show
  1. .gitattributes +27 -0
  2. README.md +14 -0
  3. app.py +86 -0
  4. requirements.txt +6 -0
  5. static/index.html +5 -0
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Demo_flask
3
+ emoji: 🏢
4
+ colorFrom: purple
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 2.8.13
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: ZeroTwo3/flask_test
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from flask import Flask, request, jsonify
4
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling, Trainer
5
+ import os
6
+
7
+ app = Flask(__name__)
8
+
9
+ # Load the fine-tuned model checkpoint if available; otherwise, load the pre-trained GPT-2 model
10
+ if os.path.exists("fine_tuned_checkpoint"):
11
+ model = GPT2LMHeadModel.from_pretrained("fine_tuned_checkpoint")
12
+ else:
13
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
14
+
15
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
16
+
17
+ # Function to fine-tune the model
18
+ def fine_tune_model(chat_history):
19
+ # Prepare training data for fine-tuning
20
+ input_texts = [item["message"] for item in chat_history]
21
+ with open("train.txt", "w") as f:
22
+ f.write("\n".join(input_texts))
23
+
24
+ # Load the dataset and create data collator
25
+ dataset = TextDataset(tokenizer=tokenizer, file_path="train.txt", block_size=128)
26
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
27
+
28
+ # Fine-tune the model
29
+ trainer = Trainer(model=model, data_collator=data_collator)
30
+ trainer.train("./training_directory")
31
+
32
+ # Save the fine-tuned model
33
+ model.save_pretrained("fine_tuned_model")
34
+
35
+ @app.route("/chat", methods=["POST"])
36
+ def chat_with_model():
37
+ request_data = request.get_json()
38
+ user_input = request_data["user_input"]
39
+ chat_history = request_data.get("chat_history", [])
40
+
41
+ # Append user message to the chat history
42
+ chat_history.append({"role": "user", "message": user_input})
43
+
44
+ # Generate response
45
+ response = generate_response(user_input, chat_history)
46
+
47
+ # Append bot message to the chat history
48
+ chat_history.append({"role": "bot", "message": response})
49
+
50
+ return jsonify({"bot_response": response, "chat_history": chat_history})
51
+
52
+ @app.route("/train", methods=["POST"])
53
+ def train_model():
54
+ chat_history = request.json["data"]
55
+
56
+ # Fine-tune the model with the provided data
57
+ fine_tune_model(chat_history)
58
+
59
+ return "Model trained and updated successfully."
60
+
61
+ def generate_response(user_input, chat_history):
62
+ # Set the maximum number of previous messages to consider
63
+ max_history = 3
64
+
65
+ # Use the last `max_history` messages from the chat history
66
+ inputs = [item["message"] for item in chat_history[-max_history:]]
67
+ input_text = "\n".join(inputs)
68
+
69
+ # Tokenize the input text
70
+ input_ids = tokenizer.encode(input_text, return_tensors="pt", add_special_tokens=True)
71
+
72
+ # Generate response
73
+ with torch.no_grad():
74
+ output = model.generate(input_ids, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
75
+
76
+ # Decode response and extract bot message
77
+ bot_response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
78
+
79
+ return bot_response
80
+
81
+ @app.route("/")
82
+ def index():
83
+ return jsonify({"status" : True})
84
+
85
+ if __name__ == "__main__":
86
+ app.run(host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Flask==2.0.2
2
+ python-dotenv==0.19.2
3
+ transformers>=4.0.0
4
+ torch>=1.8.0
5
+ nltk>=3.6.0
6
+ accelerate>=0.20.3
static/index.html ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ <h1>
2
+ This is flask test app
3
+ </h1>
4
+
5
+ <p>Hello flask ?</p>