Calvin commited on
Commit
8f08f12
·
1 Parent(s): 112fcc8
Files changed (1) hide show
  1. app.py +81 -0
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pprint import pprint
4
+
5
+ import bitsandbytes as bnb
6
+ import pandas as pd
7
+ import torch
8
+ import torch.nn as nn
9
+ import transformers
10
+ from datasets import load_dataset
11
+ from huggingface_hub import notebook_login
12
+ from peft import (
13
+ LoraConfig,
14
+ PeftConfig,
15
+ PeftModel,
16
+ get_peft_model,
17
+ prepare_model_for_kbit_training,
18
+ )
19
+ from transformers import (
20
+ AutoConfig,
21
+ AutoModelForCausalLM,
22
+ AutoTokenizer,
23
+ BitsAndBytesConfig,
24
+ )
25
+ import gradio as gr
26
+
27
+
28
+ bnb_config = BitsAndBytesConfig(
29
+ load_in_4bit=True,
30
+ bnb_4bit_use_double_quant=True,
31
+ bnb_4bit_quant_type="nf4",
32
+ bnb_4bit_compute_dtype=torch.bfloat16,
33
+ )
34
+
35
+ PEFT_MODEL = "cdy3870/Falcon-Fetch-Bot"
36
+
37
+ config = PeftConfig.from_pretrained(PEFT_MODEL)
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ config.base_model_name_or_path,
40
+ return_dict=True,
41
+ quantization_config=bnb_config,
42
+ device_map="auto",
43
+ trust_remote_code=True
44
+ )
45
+ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
46
+ tokenizer.pad_token = tokenizer.eos_token
47
+
48
+ model = PeftModel.from_pretrained(model, PEFT_MODEL)
49
+
50
+ generation_config = model.generation_config
51
+ generation_config.max_new_tokens = 200
52
+ generation_config.temperature = 0.7
53
+ generation_config.top_p = 0.7
54
+ generation_config.num_return_sequences = 1
55
+ generation_config.pad_token_id = tokenizer.eos_token_id
56
+ generation_config.eos_token_id = tokenizer.eos_token_id
57
+
58
+ pipeline = transformers.pipeline(
59
+ "text-generation",
60
+ model=model,
61
+ tokenizer=tokenizer,
62
+ )
63
+
64
+ def query_model(message, history):
65
+
66
+ prompt = f"""
67
+ <human>: {message}
68
+ <assistant>:
69
+ """.strip()
70
+
71
+ result = pipeline(
72
+ prompt,
73
+ generation_config=generation_config,
74
+ )
75
+
76
+ # parsed_result = result[0]["generated_text"].split("<assistant>:")[1][1:]
77
+
78
+ return result[0]["generated_text"]
79
+
80
+
81
+ gr.ChatInterface(query_model, textbox=gr.Textbox(placeholder="Ask anything about Fetch!", container=False, scale=7),).launch()