Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from langchain.agents import initialize_agent,AgentType | |
from langchain.chat_models import AzureChatOpenAI | |
from langchain.chains.conversation.memory import ConversationBufferWindowMemory | |
import torch | |
from transformers import BlipProcessor,BlipForConditionalGeneration | |
import requests | |
from PIL import Image | |
from langchain.tools import BaseTool | |
from langchain.chains import LLMChain | |
from langchain import PromptTemplate, FewShotPromptTemplate | |
OPENAI_API_KEY=os.getenv("OPENAI_API_KEY") | |
OPENAI_API_BASE=os.getenv("OPENAI_API_BASE") | |
DEP_NAME=os.getenv("deployment_name") | |
llm=AzureChatOpenAI(deployment_name=DEP_NAME,openai_api_base=OPENAI_API_BASE,openai_api_key=OPENAI_API_KEY,openai_api_version="2023-03-15-preview",model_name="gpt-3.5-turbo") | |
image_to_text_model="Salesforce/blip-image-captioning-large" | |
device= 'cuda' if torch.cuda.is_available() else 'cpu' | |
processor=BlipProcessor.from_pretrained(image_to_text_model) | |
model=BlipForConditionalGeneration.from_pretrained(image_to_text_model).to(device) | |
def descImage(image_url): | |
image_obj=Image.open(image_url).convert('RGB') | |
inputs=processor(image_obj,return_tensors='pt').to(device) | |
outputs=model.generate(**inputs) | |
return processor.decode(outputs[0],skip_special_tokens=True) | |
def toChinese(en:str): | |
pp="翻译下面语句到中文\n{en}" | |
prompt = PromptTemplate( | |
input_variables=["en"], | |
template=pp | |
) | |
llchain=LLMChain(llm=llm,prompt=prompt) | |
return llchain.run(en) | |
class DescTool(BaseTool): | |
name="Describe Image Tool" | |
description="use this tool to describe an image" | |
def _run(self,url:str): | |
description=descImage(url) | |
return description | |
def _arun( | |
self,query:str): | |
raise NotImplementedError('未实现') | |
tools=[DescTool()] | |
memory=ConversationBufferWindowMemory( | |
memory_key='chat_history', | |
k=5, | |
return_messages=True | |
) | |
agent=initialize_agent( | |
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION, | |
tools=tools, | |
llm=llm, | |
verbose=False, | |
max_iterations=3, | |
early_stopping_method='generate', | |
memory=memory | |
) | |
def reset_user_input(): | |
return gr.update(value='') | |
def reset_state(): | |
return [], [] | |
def predict(file,input, chatbot,history): | |
input1=f""+input+"\n"+file | |
out=agent(input1) | |
anws=toChinese(out['output']) | |
chatbot.append(input) | |
chatbot[-1] = (input, anws) | |
yield chatbot, history | |
return | |
with gr.Blocks(css=".chat-blocks{height:calc(100vh - 332px);} .mychat{flex:1} .mychat .block{min-height:100%} .mychat .block .wrap{max-height: calc(100vh - 330px);} .myinput{flex:initial !important;min-height:180px}") as demo: | |
title = '图像识别' | |
demo.title=title | |
with gr.Column(elem_classes="chat-blocks"): | |
with gr.Row(elem_classes="mychat"): | |
file = gr.Image(type="filepath") | |
chatbot = gr.Chatbot(label="图像识别", show_label=False) | |
with gr.Column(elem_classes="myinput"): | |
user_input = gr.Textbox(show_label=False, placeholder="请输入...", lines=1).style( | |
container=False) | |
submitBtn = gr.Button("提交", variant="primary", elem_classes="btn1") | |
emptyBtn = gr.Button("清除历史").style(container=False) | |
history = gr.State([]) | |
submitBtn.click(predict, [file,user_input, chatbot,history], [chatbot, history], | |
show_progress=True) | |
submitBtn.click(reset_user_input, [], [user_input]) | |
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) | |
demo.queue(api_open=False,concurrency_count=20).launch() |