min24ss commited on
Commit
86ff856
ยท
verified ยท
1 Parent(s): efafceb

Upload r-story-test.py

Browse files
Files changed (1) hide show
  1. r-story-test.py +225 -0
r-story-test.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # ## 1. tsv full data load
5
+
6
+ # In[1]:
7
+
8
+
9
+ import pandas as pd
10
+
11
+
12
+ df = pd.read_csv("sl_webtoon_full_data_sequential.tsv", sep="\t")
13
+
14
+
15
+ print(df.head())
16
+ print("์ „์ฒด ๋ฌธ์žฅ ์ˆ˜:", len(df))
17
+ print("์ปฌ๋Ÿผ ๋ชฉ๋ก:", df.columns.tolist())
18
+
19
+ # 549
20
+ #์ปฌ๋Ÿผ ๋ชฉ๋ก: ['์—ํ”ผ์†Œ๋“œ', 'scene_text', 'type']
21
+
22
+
23
+ # In[2]:
24
+
25
+
26
+ import pandas as pd
27
+
28
+ df = pd.read_csv("sl_webtoon_full_data_sequential.tsv", sep="\t")
29
+ print(df.head(3))
30
+ print("์ปฌ๋Ÿผ:", df.columns.tolist(), "์ „์ฒด ํ–‰:", len(df))
31
+
32
+
33
+ # In[3]:
34
+
35
+
36
+ df['row_id'] = df.index #์ธ๋ฑ์Šค ์ปฌ๋Ÿผ ์ถ”๊ฐ€ <- ์›๋ณธ ์ถ”์ ์šฉ
37
+
38
+ df['text'] = df.apply(
39
+ lambda x: f"[{x['์—ํ”ผ์†Œ๋“œ']}] #{x['row_id']} {x['type']} {x['scene_text']}", #rag ๋ฌธ์žฅ ์ปฌ๋Ÿผ ์ƒ์„ฑ
40
+ axis=1
41
+ )
42
+
43
+ print(df['text'].head(3).tolist())
44
+
45
+
46
+ # In[4]:
47
+
48
+
49
+ texts = df['text'].tolist()
50
+ print("์ตœ์ข… ๋ฌธ์žฅ ์ˆ˜:", len(texts))
51
+ # 549
52
+
53
+
54
+ # ## 2. Rag ๋ฌธ์žฅ ์ƒ์„ฑ
55
+
56
+ # In[5]:
57
+
58
+
59
+ # 2๋‹จ๊ณ„: ์ตœ์ข… RAG ๋ฌธ์žฅ ์ƒ์„ฑ
60
+ df['row_id'] = df.index # ์›๋ณธ ์ถ”์ ์šฉ ์ธ๋ฑ์Šค
61
+ df['text'] = df.apply(
62
+ lambda x: f"[{x['์—ํ”ผ์†Œ๋“œ']}] #{x['row_id']} {x['type']} {x['scene_text']}",
63
+ axis=1
64
+ )
65
+
66
+ print("์˜ˆ์‹œ 5๊ฐœ:")
67
+ for t in df['text'].head(5).tolist():
68
+ print("-", t)
69
+
70
+ texts = df['text'].tolist()
71
+ print("\n์ตœ์ข… ๋ฌธ์žฅ ์ˆ˜:", len(texts))
72
+ #549
73
+
74
+
75
+ # ## 3. ํ•œ๊ตญ์–ด ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ, ๋ฒกํ„ฐ db - solo_leveling_faiss_ko
76
+ #
77
+ #
78
+
79
+ # In[6]:
80
+
81
+
82
+ from langchain.vectorstores import FAISS
83
+ from langchain.embeddings import HuggingFaceEmbeddings
84
+
85
+ embedding_model = HuggingFaceEmbeddings(model_name='jhgan/ko-sroberta-multitask')
86
+
87
+ db = FAISS.from_texts(texts, embedding_model)
88
+ print(" ๋ฒกํ„ฐDB ์ƒ์„ฑ ์™„๋ฃŒ. ์ด ๋ฌธ์žฅ ์ˆ˜:", len(texts))
89
+
90
+ db.save_local("solo_leveling_faiss_ko")
91
+ print(" 'solo_leveling_faiss_ko' ํด๋”์— ์ €์žฅ")
92
+
93
+
94
+ # In[7]:
95
+
96
+
97
+ db = FAISS.load_local("solo_leveling_faiss_ko", embedding_model, allow_dangerous_deserialization=True)
98
+
99
+
100
+ query = "๋งˆ๋‚˜์„์ด ๋ญ์ง€?"
101
+ docs = db.similarity_search(query, k=5)
102
+
103
+ for i, doc in enumerate(docs, 1):
104
+ print(f"[{i}] {doc.page_content}")
105
+
106
+
107
+ # In[8]:
108
+
109
+
110
+ ## rag ํ™•์ธ
111
+
112
+
113
+ # In[9]:
114
+
115
+
116
+ from transformers import pipeline
117
+
118
+ generator = pipeline(
119
+ "text-generation",
120
+ model="kakaocorp/kanana-nano-2.1b-instruct",
121
+ device=0
122
+ )
123
+
124
+
125
+
126
+ # In[10]:
127
+
128
+
129
+ from langchain.chains import RetrievalQA
130
+ from langchain.vectorstores import FAISS
131
+ from langchain.prompts import PromptTemplate
132
+ from langchain_community.llms import HuggingFacePipeline
133
+ from langchain.embeddings import HuggingFaceEmbeddings
134
+ import torch
135
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
136
+
137
+ embedding_model = HuggingFaceEmbeddings(model_name='jhgan/ko-sroberta-multitask')
138
+ vectorstore = FAISS.load_local("solo_leveling_faiss_ko", embedding_model, allow_dangerous_deserialization=True)
139
+
140
+ model_name = "kakaocorp/kanana-nano-2.1b-instruct"
141
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
142
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to("cuda")
143
+
144
+ llm_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256)
145
+ llm = HuggingFacePipeline(pipeline=llm_pipeline)
146
+
147
+ custom_prompt = PromptTemplate(
148
+ input_variables=["context", "question"],
149
+ template="๋‹ค์Œ ๋ฌธ๋งฅ์„ ์ฐธ๊ณ ํ•˜์—ฌ ์งˆ๋ฌธ์— ๋‹ตํ•˜์„ธ์š”.\n\n๋ฌธ๋งฅ:\n{context}\n\n์งˆ๋ฌธ:\n{question}\n\n๋‹ต๋ณ€:"
150
+ )
151
+
152
+ qa_chain = RetrievalQA.from_chain_type(
153
+ llm=llm,
154
+ retriever=vectorstore.as_retriever(search_kwargs={"k": 5}),
155
+ chain_type="stuff",
156
+ return_source_documents=True,
157
+ chain_type_kwargs={
158
+ "prompt": custom_prompt }
159
+ )
160
+
161
+ #์งˆ๋ฌธ
162
+ query = "์„ฑ์ง„์šฐ๋Š” ๋ช‡ ๊ธ‰ ํ—Œํ„ฐ์ง€?"
163
+ result = qa_chain({"query": query})
164
+
165
+ print("๋‹ต๋ณ€:", result["result"])
166
+ print("\n์ฐธ์กฐ ๋ฌธ์„œ:")
167
+ for doc in result["source_documents"]:
168
+ print(doc.page_content)
169
+
170
+
171
+ # ## 4. ํ™ฉ๋™์„ ์—ํ”ผ์†Œ๋“œ
172
+
173
+ # In[13]:
174
+
175
+
176
+ choices = [
177
+ "1: ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.",
178
+ "2: ์ง„ํ˜ธ๋ฅผ ํฌํ•จํ•œ ํ™ฉ๋™์„ ๋ฌด๋ฆฌ๋ฅผ ๋ชจ๋‘ ์ฒ˜์น˜ํ•œ๋‹ค.",
179
+ "3: ์ „๋ถ€ ๊ธฐ์ ˆ ์‹œํ‚ค๊ณ  ์‚ด๋ ค๋‘”๋‹ค.",
180
+ "4: ์‹œ์Šคํ…œ์„ ๊ฑฐ๋ถ€ํ•˜๊ณ  ๊ทธ๋ƒฅ ๋„๋ง์นœ๋‹ค."
181
+ ]
182
+
183
+ print("\n[์„ ํƒ์ง€]")
184
+ for idx, choice in enumerate(choices, start=1):
185
+ print(f"{idx}. {choice}")
186
+
187
+ user_idx = int(input("\n์„ ํƒ ๋ฒˆํ˜ธ ์ž…๋ ฅ: ")) - 1
188
+ user_choice = choices[user_idx]
189
+ print(f"\n[์‚ฌ์šฉ์ž ์„ ํƒ]: {user_choice}")
190
+
191
+ result = qa_chain({"query": user_choice})
192
+
193
+ retrieved_context = "\n".join([doc.page_content for doc in result["source_documents"]])
194
+ print("\n[๊ฒ€์ƒ‰๋œ ๊ทผ๊ฑฐ ๋ฌธ์„œ ์˜ˆ์‹œ]")
195
+ print(retrieved_context[:600], "...")
196
+
197
+ prompt = f"""
198
+ ๋‹น์‹ ์€ ์›นํˆฐ '๋‚˜ ํ˜ผ์ž๋งŒ ๋ ˆ๋ฒจ์—…'์˜ ์„ฑ์ง„์šฐ์ž…๋‹ˆ๋‹ค.
199
+ ํ˜„์žฌ ์ƒํ™ฉ:
200
+ {retrieved_context}
201
+
202
+ ์‚ฌ์šฉ์ž ์„ ํƒ: {user_choice}
203
+
204
+ ์„ฑ์ง„์šฐ์˜ ๋งํˆฌ๋กœ ๊ฐ„๊ฒฐํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šด ๋Œ€์‚ฌ๋ฅผ 1~2๋ฌธ์žฅ ์ƒ์„ฑํ•˜์„ธ์š”.
205
+ ์ค‘๋ณต๋œ ๋‚ด์šฉ์ด๋‚˜ ๋น„์Šทํ•œ ๋ฌธ์žฅ์€ ๋งŒ๋“ค์ง€ ๋งˆ์„ธ์š”.
206
+ """
207
+
208
+ response = generator(prompt,
209
+ max_new_tokens=200,
210
+ do_sample=True,
211
+ temperature=0.6,
212
+ top_p = 0.9,
213
+ return_full_text=False
214
+ )[0]["generated_text"]
215
+ print("\n[์„ฑ์ง„์šฐ ์‘๋‹ต]")
216
+ print(response)
217
+
218
+
219
+ # In[ ]:
220
+
221
+
222
+
223
+
224
+
225
+ # ##