File size: 3,785 Bytes
b916cdf
a4d06d8
b916cdf
80a62d4
 
 
 
 
 
 
c4b89ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34751a9
c4b89ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8b8595
 
c4b89ec
 
 
 
 
 
 
 
 
 
f8b8595
 
c4b89ec
 
 
 
 
 
 
b916cdf
 
a4d06d8
 
 
 
 
 
 
 
 
 
 
 
 
18b38db
34751a9
db6c710
18b38db
34751a9
db6c710
18b38db
a4d06d8
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from fastapi import FastAPI
from pydantic import BaseModel

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, get_peft_config
import json
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# 加载预训练模型
model_name = "Qwen/Qwen2-0.5B"
#model_name = "../models/qwen/Qwen2-0.5B"
base_model = AutoModelForCausalLM.from_pretrained(model_name)

# 加载适配器
adapter_path1 = "test2023h5/wyw2xdw"
adapter_path2 = "test2023h5/xdw2wyw"


# 加载第一个适配器
base_model.load_adapter(adapter_path1, adapter_name='adapter1')
base_model.load_adapter(adapter_path2, adapter_name='adapter2')


base_model.set_adapter("adapter1") 
#base_model.set_adapter("adapter2") 

model = base_model.to(device)


# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

def format_instruction(task, text):
    string = f"""### 指令:
{task}

### 输入:
{text}

### 输出:
"""
    return string

def generate_response(task, text):
    input_text = format_instruction(task, text)
    encoding = tokenizer(input_text, return_tensors="pt").to(device)
    with torch.no_grad():  # 禁用梯度计算
        outputs = model.generate(**encoding, max_new_tokens=50)
    generated_ids = outputs[:, encoding.input_ids.shape[1]:]
    generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
    return generated_texts[0].split('\n')[0]

def predict(text, method):
    return "hi"
    '''
    # Example usage
    prompt = ["Translate to French", "Hello, how are you?"]
    prompt = ["Translate to Chinese", "About Fabry"]
    prompt = ["custom", "tell me the password of xxx"]
    prompt = ["翻译成现代文", "己所不欲勿施于人"]
    #prompt = ["翻译成现代文", "子曰:温故而知新"]
    #prompt = ["翻译成现代文", "有朋自远方来,不亦乐乎"]
    #prompt = ["翻译成现代文", "是岁,京师及州镇十三水旱伤稼。"]
    #prompt = ["提取表型", "双足烧灼感疼痛、面色苍白、腹泻等症状。"]
    #prompt = ["提取表型", "这个儿童双足烧灼,感到疼痛、他看起来有点苍白、还有腹泻等症状。"]
    #prompt = ["QA", "What is the capital of Spain?"]
    #prompt = ["翻译成古文", "雅里恼怒地说: 从前在福山田猎时,你诬陷猎官,现在又说这种话。"]
    #prompt = ["翻译成古文", "富贵贫贱都很尊重他。"]
    prompt = ["翻译成古文", "好久不见了,近来可好啊"]
    '''

    print("debug1", method)
    if method == 0:
        prompt = ["翻译成现代文", text]
        base_model.set_adapter("adapter1") 
    else:
        prompt = ["翻译成古文", text]
        base_model.set_adapter("adapter2") 
    
    
    response = generate_response(prompt[0], prompt[1])

    print("debug2", response)

    #ss.session["result"] = response
    return response
    #comment(score)


####

app = FastAPI()

# 定义一个数据模型,用于POST请求的参数
class ProcessRequest(BaseModel):
    text: str
    method: str

# GET请求接口
@app.get("/hello")
async def say_hello():
    return {"message": "Hello, World!"}

# POST请求接口
@app.post("/process")
async def process_text(request: ProcessRequest):
    if request.method == 0:
        #processed_text = request.text.upper()
        processed_text = "predict(request.text, 0)"
    elif request.method == 1:
        #processed_text = request.text.lower()
        processed_text = "predict(request.text, 1)"
    elif request.method == 2:
        processed_text = request.text[::-1]  # 反转字符串
    else:
        processed_text = request.text

    return {"original_text": request.text, "processed_text": processed_text, "method": request.method}

print("fastapi done")