julylun commited on
Commit
5697c0f
·
1 Parent(s): 56daf7a

Create two input fields: query, document to rerank document with MonoT5

Browse files
Files changed (1) hide show
  1. app.py +35 -10
app.py CHANGED
@@ -1,17 +1,42 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
- # Tải hình MonoT5
5
- model_name = "castorini/monot5-large-msmarco"
6
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- def rerank(query, document):
10
- # Xử lý reranking giữa truy vấn và tài liệu
11
- inputs = tokenizer([query] * len(document), document, padding=True, truncation=True, return_tensors="pt")
12
- outputs = model.generate(inputs["input_ids"])
13
- decoded_output = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
14
- return decoded_output
 
 
 
 
 
15
 
16
- gr.Interface(fn=rerank, inputs=["text", "text"], outputs="text").launch()
 
 
17
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
4
+ # Load model and tokenizer
5
+ model_name = "castorini/monot5-small-msmarco-10k"
 
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
+
9
+ # Define reranking function
10
+ def rerank(query, documents):
11
+ documents = documents.split("\n") # Split documents by newlines
12
+ reranked_results = []
13
+
14
+ for doc in documents:
15
+ # Combine query and document into a single input
16
+ input_text = f"Query: {query} Document: {doc} Relevant:"
17
+ inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
18
+ outputs = model.generate(**inputs)
19
+ # Decode the output
20
+ relevance = tokenizer.decode(outputs[0], skip_special_tokens=True)
21
+ reranked_results.append((doc, relevance))
22
+
23
+ # Sort by relevance (assuming higher is better)
24
+ reranked_results.sort(key=lambda x: x[1], reverse=True)
25
+ return "\n".join([f"{doc} (Relevance: {rel})" for doc, rel in reranked_results])
26
 
27
+ # Create Gradio interface
28
+ interface = gr.Interface(
29
+ fn=rerank,
30
+ inputs=[
31
+ gr.Textbox(label="Query", placeholder="Enter your query"),
32
+ gr.Textbox(label="Documents (one per line)", lines=5, placeholder="Enter documents to rank"),
33
+ ],
34
+ outputs=gr.Textbox(label="Reranked Documents"),
35
+ title="MonoT5 Reranking",
36
+ description="Provide a query and a list of documents to rerank them using MonoT5."
37
+ )
38
 
39
+ # Launch the app
40
+ if __name__ == "__main__":
41
+ interface.launch()
42