tkdehf2 commited on
Commit
8943a88
ยท
verified ยท
1 Parent(s): 45372dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -55
app.py CHANGED
@@ -1,59 +1,19 @@
 
1
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
2
 
3
- def generate_diary(emotion, num_samples=1, max_length=100, temperature=0.7):
4
- # ๊ฐ์ •์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์ผ๊ธฐ๋ฅผ ์ƒ์„ฑํ•  ํ† ํฌ๋‚˜์ด์ €์™€ ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
5
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
6
- model = GPT2LMHeadModel.from_pretrained("gpt2")
7
-
8
- # ๊ฐ์ •์— ๋”ฐ๋ผ prefix ๋ฌธ์žฅ ์ƒ์„ฑ
9
- if emotion == "happy":
10
- prefix = "์˜ค๋Š˜์€ ๊ธฐ๋ถ„์ด ์ข‹์•„์š”. "
11
- elif emotion == "sad":
12
- prefix = "์Šฌํ”ˆ ๊ธฐ๋ถ„์ด์—์š”. "
13
- elif emotion == "angry":
14
- prefix = "ํ™”๊ฐ€ ์น˜๋ฐ€์–ด ์˜ค๋ฅด๋Š” ๊ธฐ๋ถ„์ด์—์š”. "
15
- else:
16
- prefix = "์˜ค๋Š˜์€ ๊ธฐ๋ถ„์ด ์ด์ƒํ•ด์š”. "
17
-
18
- # prefix๋ฅผ ํ† ํฌ๋‚˜์ด์ง•ํ•˜์—ฌ ์ž…๋ ฅ ์‹œํ€€์Šค ์ƒ์„ฑ
19
- input_sequence = tokenizer.encode(prefix, return_tensors="pt")
20
-
21
- # ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ ์ƒ์„ฑ
22
- output = model.generate(
23
- input_sequence,
24
- max_length=max_length,
25
- num_return_sequences=num_samples,
26
- temperature=temperature,
27
- pad_token_id=tokenizer.eos_token_id
28
- )
29
 
30
- # ์ƒ์„ฑ๋œ ์ผ๊ธฐ ๋ฐ˜ํ™˜
31
- return [tokenizer.decode(output_sequence, skip_special_tokens=True) for output_sequence in output]
32
-
33
- def main():
34
- print("๊ฐ์ • ์ž…๋ ฅ์„ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค.")
35
- while True:
36
- # ์‚ฌ์šฉ์ž๋กœ๋ถ€ํ„ฐ ๊ฐ์ • ์ž…๋ ฅ ๋ฐ›๊ธฐ
37
- try:
38
- emotion = input("์˜ค๋Š˜์˜ ๊ฐ์ •์„ ์ž…๋ ฅํ•˜์„ธ์š” (happy, sad, angry ๋“ฑ): ").strip() # ์ž…๋ ฅ ๋ฌธ์ž์—ด์˜ ์•ž๋’ค ๊ณต๋ฐฑ ์ œ๊ฑฐ
39
- if emotion.lower() in ['happy', 'sad', 'angry']:
40
- break
41
- else:
42
- print("์ž…๋ ฅ๋œ ๊ฐ์ •์ด ์ž˜๋ชป๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์‹œ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”.")
43
- except EOFError:
44
- print("\n์‚ฌ์šฉ์ž ์ž…๋ ฅ์ด ์ข…๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ํ”„๋กœ๊ทธ๋žจ์„ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค.")
45
- return
46
-
47
- # ๋นˆ ๋ฌธ์ž์—ด์ด ์ž…๋ ฅ๋˜์—ˆ์„ ๋•Œ
48
- if not emotion:
49
- print("๊ฐ์ •์„ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”.")
50
-
51
- # ์ผ๊ธฐ ์ƒ์„ฑ
52
- diary_entries = generate_diary(emotion)
53
  # ์ƒ์„ฑ๋œ ์ผ๊ธฐ ์ถœ๋ ฅ
54
- print("์˜ค๋Š˜์˜ ์ผ๊ธฐ:")
55
- for i, entry in enumerate(diary_entries, start=1):
56
- print(f"{i}. {entry}")
57
-
58
- if __name__ == "__main__":
59
- main()
 
1
+ import streamlit as st
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
+ st.title("์ž๋™ ์ผ๊ธฐ ์ƒ์„ฑ๊ธฐ")
4
 
5
+ keywords = st.text_input("5๊ฐœ์˜ ํ‚ค์›Œ๋“œ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š” (์‰ผํ‘œ๋กœ ๊ตฌ๋ถ„)", "")
6
+ keyword_list = [kw.strip() for kw in keywords.split(",")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ if len(keyword_list) == 5 and st.button("์ผ๊ธฐ ์“ฐ๊ธฐ"):
9
+ # ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
10
+ model = GPT2LMHeadModel.from_pretrained("gpt2")
11
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
12
+
13
+ # ํ‚ค์›Œ๋“œ ๊ธฐ๋ฐ˜ fine-tuning
14
+ input_ids = tokenizer.encode(" ".join(keyword_list), return_tensors="pt")
15
+ output = model.generate(input_ids, max_length=500, num_return_sequences=1, do_sample=True, top_k=50, top_p=0.95, num_beams=5)
16
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # ์ƒ์„ฑ๋œ ์ผ๊ธฐ ์ถœ๋ ฅ
18
+ diary = tokenizer.decode(output[0], skip_special_tokens=True)
19
+ st.write(diary)