vladyur commited on
Commit
4f590f0
·
1 Parent(s): 51b3391

Create new file

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ import tokenizers
4
+ import streamlit as st
5
+ import re
6
+
7
+ from PIL import Image
8
+
9
+
10
+ @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
11
+ def get_model(model_name, model_path):
12
+ tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name)
13
+ model = transformers.OPTForCasualLM.from_pretrained('big-kek/NeuroSkeptic', device_map='cpu')
14
+
15
+ model.eval()
16
+ return model, tokenizer
17
+
18
+
19
+ def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, length_of_generated=300):
20
+ text += '\n'
21
+ input_ids = tokenizer.encode(text, return_tensors="pt")
22
+ length_of_prompt = len(input_ids[0])
23
+ with torch.no_grad():
24
+ out = model.generate(input_ids,
25
+ do_sample=True,
26
+ num_beams=n_beams,
27
+ temperature=temperature,
28
+ top_p=top_p,
29
+ max_length=length_of_prompt + length_of_generated,
30
+ eos_token_id=tokenizer.eos_token_id
31
+ )
32
+
33
+ return generated = list(map(tokenizer.decode, out))[0]
34
+
35
+
36
+ model, tokenizer = get_model('facebook/opt-13b')
37
+
38
+ # st.title("NeuroKorzh")
39
+
40
+ # image = Image.open('korzh.jpg')
41
+ # st.image(image, caption='НейроКорж')
42
+
43
+ # option = st.selectbox('Выберите своего Коржа', ('Быстрый', 'Глубокий'))
44
+ craziness = st.slider(label='Craziness', min_value=0, max_value=100, value=50, step=5)
45
+ temperature = 2 + craziness / 50.
46
+
47
+ st.markdown("\n")
48
+
49
+ text = st.text_area(label='What are you interested in?', value='Covid - a worldwide conspiracy?', height=80)
50
+ button = st.button('Go')
51
+
52
+ if button:
53
+ try:
54
+ with st.spinner('Finding out the truth'):
55
+ result = predict(text, model, tokenizer, temperature=temperature)
56
+
57
+ st.text_area(label='', value=result, height=1000)
58
+
59
+ except Exception:
60
+ st.error("Ooooops, something went wrong. Please try again and report to me, tg: @vladyur")