dgnk007 commited on
Commit
ebb732a
·
1 Parent(s): f3a6ce3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -31
app.py CHANGED
@@ -1,37 +1,78 @@
1
- import gradio as gr
2
  import pip
3
  import os
4
  pip.main(['install', 'transformers'])
5
  pip.main(['install', 'torch'])
6
  pip.main(['install', 'pymongo'])
7
- import pymongo
8
  from transformers import pipeline
9
- from datetime import datetime
10
- model_name_or_path='dgnk007/eagle'
11
- generate=pipeline('text-generation',model=model_name_or_path)
12
-
13
- myclient = pymongo.MongoClient(os.environ['DB_URI'])
14
- mydb = myclient["eagle"]
15
- def store_at_db(prompt,response):
16
- rawusage = mydb["rawusage"]
17
- rawusage.insert_one({"prompt":prompt,"response":response,"current_time":datetime.now()})
18
-
19
- def generate_response(message):
20
- prompt_template=f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n {message}\n\n### Response:\n"
21
- response=generate(prompt_template,max_length=1024,return_full_text=False,eos_token_id=21017,pad_token_id=50256)
22
- store_at_db(message,response[0]['generated_text'])
23
- return response[0]['generated_text']
24
-
25
- examples = [
26
- ["Give three tips for staying healthy."],
27
- ["How can we reduce air pollution?"],
28
- ["Explain what an API is."],
29
- ]
30
- demo = gr.Interface(
31
- fn=generate_response,
32
- inputs="text",
33
- outputs="text",
34
- examples=examples,
35
- #flagging_callback=gr.HuggingFaceDatasetSaver(hf_token=os.environ['DatasetKey'],dataset_name=os.environ['DatasetName'],private=True,verbose=True )
36
- )
37
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pip
2
  import os
3
  pip.main(['install', 'transformers'])
4
  pip.main(['install', 'torch'])
5
  pip.main(['install', 'pymongo'])
6
+ import gradio as gr
7
  from transformers import pipeline
8
+ import pymongo
9
+
10
+ mongo_client = pymongo.MongoClient(os.environ['DB_URI'])
11
+ db = mongo_client["gradio_db"]
12
+ btn_disable=gr.Button.update(interactive=False)
13
+ btn_enable=gr.Button.update(interactive=True)
14
+ generator = pipeline("text-generation", model="dgnk007/eagle2")
15
+
16
+ def store_in_mongodb(collection_name, data):
17
+ collection = db[collection_name]
18
+ return collection.insert_one(data)
19
+
20
+
21
+ def generate_text(message,sequences):
22
+ prompt_template=f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n {message}\n\n### Response:\n\n"
23
+ generated_text = generator(prompt_template, max_length=1024,return_full_text=False,eos_token_id=21017,pad_token_id=50256, num_return_sequences=sequences)
24
+ return generated_text
25
+
26
+ def general_function(input_text):
27
+ output_text = generate_text(input_text,1)[0]['generated_text']
28
+ store_in_mongodb("general_collection", {"input": input_text, "output": output_text})
29
+ return output_text
30
+
31
+ def arena_function(input_text):
32
+ output_text = generate_text(input_text,2)
33
+ data_to_store = {
34
+ "input": input_text,
35
+ "r1": output_text[0]['generated_text'],
36
+ "r2": output_text[1]['generated_text'],
37
+ }
38
+ id=store_in_mongodb("arena_collection", data_to_store)
39
+ return output_text[0]['generated_text'], output_text[1]['generated_text'], id.inserted_id,btn_enable,btn_enable,btn_enable,btn_enable
40
+
41
+ general_interface = gr.Interface(fn=general_function, inputs=gr.Textbox(label="Enter your text here:", min_width=600), outputs="text")
42
+
43
+ def reward_click(id,reward):
44
+ db["arena_collection"].update_one(
45
+ {"_id": id},
46
+ {"$set": {"reward": reward}}
47
+ )
48
+ return btn_disable,btn_disable,btn_disable,btn_disable
49
+
50
+ with gr.Blocks() as arena_interface:
51
+ obid=gr.State([])
52
+ with gr.Row():
53
+ with gr.Column():
54
+ input_box = gr.Textbox(label="Enter your text here:", min_width=600)
55
+ prompt = gr.Button("Submit", variant="primary")
56
+ with gr.Row():
57
+ gr.Examples(['what is google?','what is youtube?'], input_box,)
58
+ with gr.Row():
59
+ output_block = [
60
+ gr.Textbox(label="Response 1", interactive=False),
61
+ gr.Textbox(label="Response 2", interactive=False),
62
+ obid
63
+ ]
64
+ with gr.Row():
65
+ tie=gr.Button(value="Tie",size='sm',interactive=False)
66
+ r1=gr.Button(value="Response 1 Wins",variant='primary',interactive=False)
67
+ r2=gr.Button(value="Response 2 Wins",variant='primary',interactive=False)
68
+ bad=gr.Button(value="Both are Bad",variant='secondary',interactive=False)
69
+ buttonGroup=[tie,r1,r2,bad]
70
+ prompt.click(fn=arena_function, inputs=input_box, outputs=output_block+buttonGroup)
71
+ tie.click(fn=reward_click,inputs=[obid,gr.State('tie')],outputs=buttonGroup)
72
+ r1.click(fn=reward_click,inputs=[obid,gr.State('r1')],outputs=buttonGroup)
73
+ r2.click(fn=reward_click,inputs=[obid,gr.State('r2')],outputs=buttonGroup)
74
+ bad.click(fn=reward_click,inputs=[obid,gr.State('bad')],outputs=buttonGroup)
75
+ demo = gr.TabbedInterface([general_interface, arena_interface], ["General", "Arena"])
76
+
77
+
78
+ demo.launch()