arcsu1 commited on
Commit
137ee4f
·
1 Parent(s): fd25dfc

add text gen

Browse files
main.py CHANGED
@@ -6,7 +6,9 @@ import re
6
 
7
  from fastapi import FastAPI
8
  from fastapi.middleware.cors import CORSMiddleware
 
9
  import chatbot
 
10
 
11
 
12
  app = FastAPI()
@@ -31,15 +33,9 @@ class HistoryRequest(BaseModel):
31
  user: list[str]
32
  ai: list[str]
33
 
34
- @app.post("/generate")
35
  async def generate_response(history: HistoryRequest):
36
- # try:
37
- # model
38
- # print(12321323)
39
- # except:
40
- # global model
41
- # # model = chatbot.ChatBot()
42
-
43
  global model
44
 
45
  try:
@@ -63,4 +59,32 @@ async def generate_response(history: HistoryRequest):
63
  response = model.generate_response(history)
64
  return response
65
 
66
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  from fastapi import FastAPI
8
  from fastapi.middleware.cors import CORSMiddleware
9
+
10
  import chatbot
11
+ import textgen
12
 
13
 
14
  app = FastAPI()
 
33
  user: list[str]
34
  ai: list[str]
35
 
36
+ @app.post("/chatbot")
37
  async def generate_response(history: HistoryRequest):
38
+ print("Chatbot request")
 
 
 
 
 
 
39
  global model
40
 
41
  try:
 
59
  response = model.generate_response(history)
60
  return response
61
 
62
+
63
+ class TextGenInput(BaseModel):
64
+ user: str
65
+
66
+ @app.post("/text-gen")
67
+ async def generate_text(input: TextGenInput):
68
+ print("Generating text request")
69
+ global model
70
+ directory = 'models/fine-tuned-gpt2-textgen'
71
+
72
+ try:
73
+ # check if model is already loaded
74
+ if not isinstance(model, textgen.TextGen):
75
+ model = textgen.TextGen(
76
+ GPT2Tokenizer.from_pretrained(directory),
77
+ GPT2LMHeadModel.from_pretrained(directory),
78
+ torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+ )
80
+ except NameError:
81
+ # if model is not defined, load
82
+ model = textgen.TextGen(
83
+ GPT2Tokenizer.from_pretrained(directory),
84
+ GPT2LMHeadModel.from_pretrained(directory),
85
+ torch.device("cuda" if torch.cuda.is_available() else "cpu")
86
+ )
87
+
88
+ response = model.generate_text(input.user)
89
+
90
+ return response
models/fine-tuned-gpt2-textgen/config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "gpt2",
3
+ "activation_function": "gelu_new",
4
+ "architectures": [
5
+ "GPT2LMHeadModel"
6
+ ],
7
+ "attn_pdrop": 0.1,
8
+ "bos_token_id": 50256,
9
+ "embd_pdrop": 0.1,
10
+ "eos_token_id": 50256,
11
+ "initializer_range": 0.02,
12
+ "layer_norm_epsilon": 1e-05,
13
+ "model_type": "gpt2",
14
+ "n_ctx": 1024,
15
+ "n_embd": 768,
16
+ "n_head": 12,
17
+ "n_inner": null,
18
+ "n_layer": 12,
19
+ "n_positions": 1024,
20
+ "reorder_and_upcast_attn": false,
21
+ "resid_pdrop": 0.1,
22
+ "scale_attn_by_inverse_layer_idx": false,
23
+ "scale_attn_weights": true,
24
+ "summary_activation": null,
25
+ "summary_first_dropout": 0.1,
26
+ "summary_proj_to_labels": true,
27
+ "summary_type": "cls_index",
28
+ "summary_use_proj": true,
29
+ "task_specific_params": {
30
+ "text-generation": {
31
+ "do_sample": true,
32
+ "max_length": 50
33
+ }
34
+ },
35
+ "torch_dtype": "float32",
36
+ "transformers_version": "4.42.4",
37
+ "use_cache": true,
38
+ "vocab_size": 50257
39
+ }
models/fine-tuned-gpt2-textgen/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50256,
4
+ "eos_token_id": 50256,
5
+ "transformers_version": "4.42.4"
6
+ }
models/fine-tuned-gpt2-textgen/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
models/fine-tuned-gpt2-textgen/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a712602a48d8498e3836a5fe746ebdbc5aeefa1e7ee3175fee71ae21cad5b8f5
3
+ size 497774208
models/fine-tuned-gpt2-textgen/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
models/fine-tuned-gpt2-textgen/tokenizer_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "50256": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ }
13
+ },
14
+ "bos_token": "<|endoftext|>",
15
+ "clean_up_tokenization_spaces": true,
16
+ "eos_token": "<|endoftext|>",
17
+ "errors": "replace",
18
+ "model_max_length": 1024,
19
+ "pad_token": "<|endoftext|>",
20
+ "tokenizer_class": "GPT2Tokenizer",
21
+ "unk_token": "<|endoftext|>"
22
+ }
models/fine-tuned-gpt2-textgen/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
textgen.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class TextGen:
4
+ def __init__(self,tokenizer,model,device):
5
+ self.tokenizer = tokenizer
6
+ self.model = model
7
+ self.device = device
8
+ self.model.to(self.device)
9
+
10
+
11
+ def generate_text(self, user_input):
12
+ inputs = self.tokenizer.encode(user_input, return_tensors="pt").to(self.device)
13
+
14
+ # generate text
15
+ attention_mask = torch.ones(inputs.shape, device=self.device)
16
+ output = self.model.generate(
17
+ inputs,
18
+ attention_mask=attention_mask,
19
+ num_return_sequences=1,
20
+ max_length=50,
21
+ max_new_tokens=100,
22
+ temperature=0.5,
23
+ repetition_penalty=1.2,
24
+ pad_token_id=self.tokenizer.eos_token_id,
25
+ )
26
+
27
+ generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
28
+
29
+
30
+ return generated_text