File size: 3,592 Bytes
1d91ffa
 
 
4d16da0
3fcfa56
a130567
3fcfa56
 
2e77d5f
 
1d91ffa
 
 
2e77d5f
 
 
 
 
 
 
 
 
 
 
fd9d58f
2e77d5f
 
 
 
 
3fcfa56
fd9d58f
2e77d5f
fd9d58f
2e77d5f
9ed9be5
a48a101
3fcfa56
1d91ffa
2e77d5f
 
 
 
 
9d0e814
 
 
 
 
 
 
 
 
 
2e77d5f
 
 
 
 
 
 
 
 
 
 
 
a48a101
2e77d5f
 
a48a101
2e77d5f
 
 
 
3fcfa56
a48a101
2e77d5f
1d91ffa
 
fd9d58f
1d91ffa
 
 
a48a101
2e77d5f
a48a101
 
 
 
 
 
3fcfa56
 
 
 
2e77d5f
 
1d91ffa
 
 
2e77d5f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import openai
from datasets import load_dataset
import logging
import time
from langchain_community.embeddings import HuggingFaceEmbeddings
import torch
import psutil

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize OpenAI API key
openai.api_key = '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'

# Initialize with E5 embedding model
model_name = 'intfloat/e5-base-v2'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model = HuggingFaceEmbeddings(model_name=model_name)
embedding_model.client.to(device)

# Load datasets
datasets = {}
dataset_names = ['covidqa', 'hotpotqa', 'pubmedqa']

for name in dataset_names:
    datasets[name] = load_dataset("rungalileo/ragbench", name, split='train')
    logger.info(f"Loaded {name}")

def get_system_metrics():
    return {
        'cpu_percent': psutil.cpu_percent(),
        'memory_percent': psutil.virtual_memory().percent
    }

def process_query(query, dataset_choice="all"):
    start_time = time.time()
    try:
        relevant_contexts = []
        search_datasets = [dataset_choice] if dataset_choice != "all" else datasets.keys()
        
        for dataset_name in search_datasets:
            if dataset_name in datasets:
                documents = datasets[dataset_name]['documents']
                for doc in documents:
                    # Handle both string and list document types
                    if isinstance(doc, list):
                        doc_text = ' '.join(doc)
                    else:
                        doc_text = str(doc)
                    
                    if any(keyword.lower() in doc_text.lower() for keyword in query.split()):
                        relevant_contexts.append((doc_text, dataset_name))
        
        context_info = f"From {relevant_contexts[0][1]}: {relevant_contexts[0][0]}" if relevant_contexts else "Searching across datasets..."
        
        response = openai.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "You are a knowledgeable expert using E5 embeddings for precise information retrieval."},
                {"role": "user", "content": f"Context: {context_info}\nQuestion: {query}"}
            ],
            max_tokens=300,
            temperature=0.7,
        )
        
        metrics = get_system_metrics()
        metrics['processing_time'] = time.time() - start_time
        
        metrics_display = f"""
        Processing Time: {metrics['processing_time']:.2f}s
        CPU Usage: {metrics['cpu_percent']}%
        Memory Usage: {metrics['memory_percent']}%
        """
        
        return response.choices[0].message.content.strip(), metrics_display
        
    except Exception as e:
        return str(e), "Performance metrics available on next query"

demo = gr.Interface(
    fn=process_query,
    inputs=[
        gr.Textbox(label="Question", placeholder="Ask your question here"),
        gr.Dropdown(
            choices=["all"] + dataset_names,
            label="Select Dataset",
            value="all"
        )
    ],
    outputs=[
        gr.Textbox(label="Response"),
        gr.Textbox(label="Performance Metrics")
    ],
    title="E5-Powered Knowledge Base",
    description="Search across RagBench datasets with performance monitoring"
)

if __name__ == "__main__":
    demo.queue()
    demo.launch(debug=True)