Fta98 commited on
Commit
8c97efe
·
1 Parent(s): 9213e75
Files changed (1) hide show
  1. app.py +36 -6
app.py CHANGED
@@ -1,10 +1,16 @@
 
 
1
  import streamlit as st
 
 
 
2
  from transformers import AutoModelForCausalLM, LlamaTokenizer
3
 
 
 
4
 
5
  @st.cache_resource
6
  def load():
7
- """
8
  base_model = AutoModelForCausalLM.from_pretrained(
9
  "stabilityai/japanese-stablelm-instruct-alpha-7b",
10
  device_map="auto",
@@ -18,7 +24,6 @@ def load():
18
  "lora_adapter",
19
  device_map="auto",
20
  )
21
- """
22
  model = None
23
  tokenizer = LlamaTokenizer.from_pretrained(
24
  "lora_adapter",
@@ -42,8 +47,26 @@ def get_input_token_length(user_query, system_prompt, messages=""):
42
  input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
43
  return input_ids.shape[-1]
44
 
45
- def generate():
46
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
  st.header(":dna: 遺伝カウンセリング対話AI")
@@ -90,8 +113,15 @@ if user_prompt := st.chat_input("質問を送信してください"):
90
  with st.chat_message("user"):
91
  st.text(user_prompt)
92
  st.session_state["messages"].append({"role": "user", "content": user_prompt})
93
- token_kength = get_input_token_length(user_query=user_prompt, system_prompt=st.session_state["options"]["system_prompt"], messages=st.session_state["messages"])
94
- response = f"{token_kength}: " + get_prompt(user_query=user_prompt, system_prompt=st.session_state["options"]["system_prompt"], messages=st.session_state["messages"])
 
 
 
 
 
 
 
95
  with st.chat_message("assistant"):
96
  st.text(response)
97
  st.session_state["messages"].append({"role": "assistant", "content": response})
 
1
+ import os
2
+
3
  import streamlit as st
4
+ import torch
5
+ from huggingface_hub import login
6
+ from peft import PeftModel
7
  from transformers import AutoModelForCausalLM, LlamaTokenizer
8
 
9
+ login(token=os.getenv("HUGGINGFACE_API_KEY"))
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
  @st.cache_resource
13
  def load():
 
14
  base_model = AutoModelForCausalLM.from_pretrained(
15
  "stabilityai/japanese-stablelm-instruct-alpha-7b",
16
  device_map="auto",
 
24
  "lora_adapter",
25
  device_map="auto",
26
  )
 
27
  model = None
28
  tokenizer = LlamaTokenizer.from_pretrained(
29
  "lora_adapter",
 
47
  input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
48
  return input_ids.shape[-1]
49
 
50
+ def generate_response(user_query: str, system_prompt: str, messages: str="", temperature: float=0, top_k: int=50, top_p: float=0.95, repetition_penalty: float=1.1):
51
+ prompt = get_prompt(user_query, system_prompt, messages)
52
+ inputs = tokenizer(
53
+ prompt,
54
+ add_special_tokens=False,
55
+ return_tensors="pt"
56
+ ).to(device)
57
+ max_new_tokens = 2048 - get_input_token_length(user_query, system_prompt, messages)
58
+ model.eval()
59
+ with torch.no_grad():
60
+ tokens = model.generate(
61
+ **inputs,
62
+ max_new_tokens=max_new_tokens,
63
+ temperature=temperature,
64
+ top_k=top_k,
65
+ top_p=top_p,
66
+ repetition_penalty=repetition_penalty,
67
+ )
68
+ response = tokenizer.decode(tokens[0][inputs.shape[1]:], skip_special_tokens=True).strip()
69
+ return response
70
 
71
 
72
  st.header(":dna: 遺伝カウンセリング対話AI")
 
113
  with st.chat_message("user"):
114
  st.text(user_prompt)
115
  st.session_state["messages"].append({"role": "user", "content": user_prompt})
116
+ response = generate_response(
117
+ user_prompt=user_prompt,
118
+ system_prompt=st.session_state["options"]["system_prompt"],
119
+ messages=st.session_state["messages"],
120
+ temperature=st.session_state["options"]["temperature"],
121
+ top_k=st.session_state["options"]["top_k"],
122
+ top_p=st.session_state["options"]["top_p"],
123
+ repetition_penalty=st.session_state["options"]["repetition_penalty"],
124
+ )
125
  with st.chat_message("assistant"):
126
  st.text(response)
127
  st.session_state["messages"].append({"role": "assistant", "content": response})