|
import gradio as gr |
|
from huggingface_hub import InferenceApi |
|
from duckduckgo_search import DDGS |
|
import requests |
|
import json |
|
from typing import List |
|
from pydantic import BaseModel, Field |
|
import os |
|
|
|
|
|
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN") |
|
|
|
|
|
def duckduckgo_search(query): |
|
with DDGS() as ddgs: |
|
results = ddgs.text(query, max_results=5) |
|
return results |
|
|
|
class CitingSources(BaseModel): |
|
sources: List[str] = Field( |
|
..., |
|
description="List of sources to cite. Should be an URL of the source." |
|
) |
|
|
|
def get_response_with_search(query): |
|
|
|
search_results = duckduckgo_search(query) |
|
|
|
|
|
context = "\n".join(f"{result['title']}\n{result['body']}\nSource: {result['href']}\n" |
|
for result in search_results if 'body' in result) |
|
|
|
|
|
prompt = f"""<s>[INST] Using the following context: |
|
{context} |
|
Write a detailed and complete research document that fulfills the following user request: '{query}' |
|
After writing the document, please provide a list of sources used in your response. [/INST]""" |
|
|
|
|
|
API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3" |
|
|
|
|
|
headers = {"Authorization": f"Bearer {huggingface_token}"} |
|
|
|
|
|
payload = { |
|
"inputs": prompt, |
|
"parameters": { |
|
"max_new_tokens": 1000, |
|
"temperature": 0.7, |
|
"top_p": 0.95, |
|
"top_k": 40, |
|
"repetition_penalty": 1.1 |
|
} |
|
} |
|
|
|
|
|
response = requests.post(API_URL, headers=headers, json=payload) |
|
|
|
if response.status_code == 200: |
|
result = response.json() |
|
if isinstance(result, list) and len(result) > 0: |
|
generated_text = result[0].get('generated_text', 'No text generated') |
|
|
|
|
|
content_start = generated_text.find("[/INST]") |
|
if content_start != -1: |
|
generated_text = generated_text[content_start + 7:].strip() |
|
|
|
|
|
parts = generated_text.split("Sources:", 1) |
|
main_content = parts[0].strip() |
|
sources = parts[1].strip() if len(parts) > 1 else "" |
|
|
|
return main_content, sources |
|
else: |
|
return f"Unexpected response format: {result}", "" |
|
else: |
|
return f"Error: API returned status code {response.status_code}", "" |
|
|
|
def chatbot_interface(message, history): |
|
main_content, sources = get_response_with_search(message) |
|
formatted_response = f"{main_content}\n\nSources:\n{sources}" |
|
return formatted_response |
|
|
|
|
|
iface = gr.ChatInterface( |
|
fn=chatbot_interface, |
|
title="AI-powered Web Search Assistant", |
|
description="Ask questions, and I'll search the web and provide answers using the Mistral-7B-Instruct model.", |
|
examples=[ |
|
["What are the latest developments in AI?"], |
|
["Tell me about recent updates on GitHub"], |
|
["What are the best hotels in Galapagos, Ecuador?"], |
|
["Summarize recent advancements in Python programming"], |
|
], |
|
retry_btn="Retry", |
|
undo_btn="Undo", |
|
clear_btn="Clear", |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch(share=True) |
|
|