mtyrrell commited on
Commit
000787f
·
1 Parent(s): 8539509

cleanup and harmonization

Browse files
Files changed (8) hide show
  1. .gitignore +1 -0
  2. Dockerfile +2 -1
  3. README.md +8 -19
  4. app/main.py +2 -88
  5. app/prompt.py +0 -7
  6. app/utils.py +178 -0
  7. params.cfg +35 -0
  8. requirements.txt +19 -5
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
Dockerfile CHANGED
@@ -13,7 +13,8 @@ RUN pip install --no-cache-dir -r requirements.txt
13
 
14
  # -------- copy source --------
15
  COPY app ./app
16
- COPY model_params.cfg .
 
17
 
18
  # Ports:
19
  # • 7860 → Gradio UI (HF Spaces standard)
 
13
 
14
  # -------- copy source --------
15
  COPY app ./app
16
+ COPY params.cfg .
17
+ COPY .env* ./
18
 
19
  # Ports:
20
  # • 7860 → Gradio UI (HF Spaces standard)
README.md CHANGED
@@ -8,27 +8,16 @@ pinned: false
8
  license: mit
9
  ---
10
 
11
- # RAG Generation Service
12
 
13
- This is a Retrieval-Augmented Generation (RAG) service that answers questions based on provided context.
14
-
15
- ## How to use
16
-
17
- 1. Enter your question in the "Query" field
18
- 2. Paste relevant documents or context in the "Context" field
19
- 3. Click submit to get an AI-generated answer based on your context
20
-
21
- ## Features
22
-
23
- - Uses state-of-the-art language models via Hugging Face Inference API
24
- - Supports multiple model providers
25
- - Clean, intuitive interface
26
- - Example queries to get started
27
 
28
  ## Configuration
29
 
30
- This Space requires a `HF_TOKEN` environment variable to be set with your Hugging Face access token.
31
-
32
- ## Model Support
 
 
33
 
34
- By default, this uses `meta-llama/Meta-Llama-3-8B-Instruct`, but you can configure different models via environment variables.
 
8
  license: mit
9
  ---
10
 
11
+ # Generation Module
12
 
13
+ This is an LLM-based generation service designed to be deployed as a modular component of a broader RAG system. The service runs on a docker container and exposes a gradio UI on port 7860 as well as an MCP endpoint.
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  ## Configuration
16
 
17
+ 1. The module requires an API key (set as an environment variable) for an inference provider to run. Multiple inference providers are supported. Make sure to set the appropriate environment variables:
18
+ - OpenAI: `OPENAI_API_KEY`
19
+ - Anthropic: `ANTHROPIC_API_KEY`
20
+ - Cohere: `COHERE_API_KEY`
21
+ - HuggingFace: `HF_TOKEN`
22
 
23
+ 2. Inference provider and model settings are accessible via params.cfg
app/main.py CHANGED
@@ -1,81 +1,5 @@
1
- import os, asyncio, logging
2
  import gradio as gr
3
- from huggingface_hub import InferenceClient
4
- from .prompt import build_prompt
5
-
6
- # ---------------------------------------------------------------------
7
- # model / client initialisation
8
- # ---------------------------------------------------------------------
9
- HF_TOKEN = os.getenv("HF_TOKEN")
10
- MODEL_ID = os.getenv("MODEL_ID", "meta-llama/Meta-Llama-3-8B-Instruct")
11
- MAX_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512"))
12
- TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2"))
13
-
14
- if not HF_TOKEN:
15
- raise RuntimeError(
16
- "HF_TOKEN env-var missing. "
17
- )
18
-
19
- client = InferenceClient(model=MODEL_ID, token=HF_TOKEN)
20
-
21
- # ---------------------------------------------------------------------
22
- # Core generation function for both Gradio UI and MCP
23
- # ---------------------------------------------------------------------
24
- async def _call_llm(prompt: str) -> str:
25
- """
26
- Try text_generation first (for models/providers that still support it);
27
- fall back to chat_completion when the provider is chat-only (Novita, etc.).
28
- """
29
- try:
30
- # hf-inference
31
- return await asyncio.to_thread(
32
- client.text_generation,
33
- prompt,
34
- max_new_tokens=MAX_TOKENS,
35
- temperature=TEMPERATURE,
36
- )
37
- except ValueError as e:
38
- if "Supported task: conversational" not in str(e):
39
- raise # genuine error → bubble up
40
-
41
- # fallback for Novita
42
- messages = [{"role": "user", "content": prompt}]
43
- completion = await asyncio.to_thread(
44
- client.chat_completion,
45
- messages=messages,
46
- model=MODEL_ID,
47
- max_tokens=MAX_TOKENS,
48
- temperature=TEMPERATURE,
49
- )
50
- return completion.choices[0].message.content.strip()
51
-
52
- async def rag_generate(query: str, context: str) -> str:
53
- """
54
- Generate an answer to a query using provided context through RAG.
55
-
56
- This function takes a user query and relevant context, then uses a language model
57
- to generate a comprehensive answer based on the provided information.
58
-
59
- Args:
60
- query (str): The user's question or query
61
- context (str): The relevant context/documents to use for answering
62
-
63
- Returns:
64
- str: The generated answer based on the query and context
65
- """
66
- if not query.strip():
67
- return "Error: Query cannot be empty"
68
-
69
- if not context.strip():
70
- return "Error: Context cannot be empty"
71
-
72
- prompt = build_prompt(query, context)
73
- try:
74
- answer = await _call_llm(prompt)
75
- return answer
76
- except Exception as e:
77
- logging.exception("Generation failed")
78
- return f"Error: {str(e)}"
79
 
80
  # ---------------------------------------------------------------------
81
  # Gradio Interface with MCP support
@@ -102,17 +26,7 @@ ui = gr.Interface(
102
  show_copy_button=True
103
  ),
104
  title="RAG Generation Service",
105
- description="Ask questions and get answers based on your provided context. This service is also available as an MCP server for integration with AI applications.",
106
- examples=[
107
- [
108
- "What is the main benefit mentioned?",
109
- "Machine learning has revolutionized many industries. The main benefit is increased efficiency and accuracy in data processing."
110
- ],
111
- [
112
- "Who is the CEO?",
113
- "Company ABC was founded in 2020. The current CEO is Jane Smith, who has led the company to significant growth."
114
- ]
115
- ]
116
  )
117
 
118
  # Launch with MCP server enabled
 
 
1
  import gradio as gr
2
+ from .utils import rag_generate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # ---------------------------------------------------------------------
5
  # Gradio Interface with MCP support
 
26
  show_copy_button=True
27
  ),
28
  title="RAG Generation Service",
29
+ description="Ask questions based on provided context. Intended for use in RAG pipelines (i.e. context supplied by semantic retriever service) as an MCP server.",
 
 
 
 
 
 
 
 
 
 
30
  )
31
 
32
  # Launch with MCP server enabled
app/prompt.py DELETED
@@ -1,7 +0,0 @@
1
- def build_prompt(question: str, context: str) -> str:
2
- return (
3
- "You are an expert assistant. Answer the USER question using only the "
4
- "CONTEXT provided. If the context is insufficient say 'I don't know.'.\n\n"
5
- f"### CONTEXT\n{context}\n\n"
6
- f"### USER QUESTION\n{question}\n\n### ASSISTANT ANSWER\n"
7
- )
 
 
 
 
 
 
 
 
app/utils.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, asyncio, logging
2
+ import configparser
3
+ import logging
4
+ from dotenv import load_dotenv
5
+
6
+ # LangChain imports
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_anthropic import ChatAnthropic
9
+ from langchain_cohere import ChatCohere
10
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
11
+ from langchain_core.messages import SystemMessage, HumanMessage
12
+
13
+ # Local .env file
14
+ load_dotenv()
15
+
16
+ def getconfig(configfile_path: str):
17
+ """
18
+ Read the config file
19
+ Params
20
+ ----------------
21
+ configfile_path: file path of .cfg file
22
+ """
23
+ config = configparser.ConfigParser()
24
+ try:
25
+ config.read_file(open(configfile_path))
26
+ return config
27
+ except:
28
+ logging.warning("config file not found")
29
+
30
+ # ---------------------------------------------------------------------
31
+ # Provider-agnostic authentication and configuration
32
+ # ---------------------------------------------------------------------
33
+ def get_auth_config(provider: str) -> dict:
34
+ """Get authentication configuration for different providers"""
35
+ auth_configs = {
36
+ "openai": {"api_key": os.getenv("OPENAI_API_KEY")},
37
+ "huggingface": {"api_key": os.getenv("HF_TOKEN")},
38
+ "anthropic": {"api_key": os.getenv("ANTHROPIC_API_KEY")},
39
+ "cohere": {"api_key": os.getenv("COHERE_API_KEY")},
40
+ }
41
+
42
+ if provider not in auth_configs:
43
+ raise ValueError(f"Unsupported provider: {provider}")
44
+
45
+ auth_config = auth_configs[provider]
46
+ api_key = auth_config.get("api_key")
47
+
48
+ if not api_key:
49
+ raise RuntimeError(f"Missing API key for provider '{provider}'. Please set the appropriate environment variable.")
50
+
51
+ return auth_config
52
+
53
+ # ---------------------------------------------------------------------
54
+ # Model / client initialization
55
+ # ---------------------------------------------------------------------
56
+ config = getconfig("params.cfg")
57
+
58
+ PROVIDER = config.get("generator", "PROVIDER")
59
+ MODEL = config.get("generator", "MODEL")
60
+ MAX_TOKENS = int(config.get("generator", "MAX_TOKENS"))
61
+ TEMPERATURE = float(config.get("generator", "TEMPERATURE"))
62
+
63
+ # Set up authentication for the selected provider
64
+ auth_config = get_auth_config(PROVIDER)
65
+
66
+ def get_chat_model():
67
+ """Initialize the appropriate LangChain chat model based on provider"""
68
+ common_params = {
69
+ "temperature": TEMPERATURE,
70
+ "max_tokens": MAX_TOKENS,
71
+ }
72
+
73
+ if PROVIDER == "openai":
74
+ return ChatOpenAI(
75
+ model=MODEL,
76
+ openai_api_key=auth_config["api_key"],
77
+ **common_params
78
+ )
79
+ elif PROVIDER == "anthropic":
80
+ return ChatAnthropic(
81
+ model=MODEL,
82
+ anthropic_api_key=auth_config["api_key"],
83
+ **common_params
84
+ )
85
+ elif PROVIDER == "cohere":
86
+ return ChatCohere(
87
+ model=MODEL,
88
+ cohere_api_key=auth_config["api_key"],
89
+ **common_params
90
+ )
91
+ elif PROVIDER == "huggingface":
92
+ # Initialize HuggingFaceEndpoint with explicit parameters
93
+ llm = HuggingFaceEndpoint(
94
+ repo_id=MODEL,
95
+ huggingfacehub_api_token=auth_config["api_key"],
96
+ task="text-generation",
97
+ temperature=TEMPERATURE,
98
+ max_new_tokens=MAX_TOKENS
99
+ )
100
+ return ChatHuggingFace(llm=llm)
101
+ else:
102
+ raise ValueError(f"Unsupported provider: {PROVIDER}")
103
+
104
+ # Initialize provider-agnostic chat model
105
+ chat_model = get_chat_model()
106
+
107
+ # ---------------------------------------------------------------------
108
+ # Core generation function for both Gradio UI and MCP
109
+ # ---------------------------------------------------------------------
110
+ async def _call_llm(messages: list) -> str:
111
+ """
112
+ Provider-agnostic LLM call using LangChain.
113
+
114
+ Args:
115
+ messages: List of LangChain message objects
116
+
117
+ Returns:
118
+ Generated response content as string
119
+ """
120
+ try:
121
+ # Use async invoke for better performance
122
+ response = await chat_model.ainvoke(messages)
123
+ return response.content.strip()
124
+ except Exception as e:
125
+ logging.exception(f"LLM generation failed with provider '{PROVIDER}' and model '{MODEL}': {e}")
126
+ raise
127
+
128
+ def build_messages(question: str, context: str) -> list:
129
+ """
130
+ Build messages in LangChain format.
131
+
132
+ Args:
133
+ question: The user's question
134
+ context: The relevant context for answering
135
+
136
+ Returns:
137
+ List of LangChain message objects
138
+ """
139
+ system_content = (
140
+ "You are an expert assistant. Answer the USER question using only the "
141
+ "CONTEXT provided. If the context is insufficient say 'I don't know.'"
142
+ )
143
+
144
+ user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
145
+
146
+ return [
147
+ SystemMessage(content=system_content),
148
+ HumanMessage(content=user_content)
149
+ ]
150
+
151
+
152
+ async def rag_generate(query: str, context: str) -> str:
153
+ """
154
+ Generate an answer to a query using provided context through RAG.
155
+
156
+ This function takes a user query and relevant context, then uses a language model
157
+ to generate a comprehensive answer based on the provided information.
158
+
159
+ Args:
160
+ query (str): The user's question or query
161
+ context (str): The relevant context/documents to use for answering
162
+
163
+ Returns:
164
+ str: The generated answer based on the query and context
165
+ """
166
+ if not query.strip():
167
+ return "Error: Query cannot be empty"
168
+
169
+ if not context.strip():
170
+ return "Error: Context cannot be empty"
171
+
172
+ try:
173
+ messages = build_messages(query, context)
174
+ answer = await _call_llm(messages)
175
+ return answer
176
+ except Exception as e:
177
+ logging.exception("Generation failed")
178
+ return f"Error: {str(e)}"
params.cfg ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [generator]
2
+ PROVIDER = huggingface
3
+ MODEL = meta-llama/Meta-Llama-3-8B-Instruct
4
+ MAX_TOKENS = 512
5
+ TEMPERATURE = 0.2
6
+
7
+ ## OpenAI
8
+ # [generator]
9
+ # PROVIDER = openai
10
+ # MODEL = gpt-4o
11
+ # MAX_TOKENS = 512
12
+ # TEMPERATURE = 0.2
13
+
14
+ ## Anthropic
15
+ # [generator]
16
+ # PROVIDER = anthropic
17
+ # MODEL = claude-3-haiku-20240307
18
+ # MAX_TOKENS = 512
19
+ # TEMPERATURE = 0.2
20
+
21
+ ## Cohere
22
+ # [generator]
23
+ # PROVIDER = cohere
24
+ # MODEL = command
25
+ # MAX_TOKENS = 512
26
+ # TEMPERATURE = 0.2
27
+
28
+
29
+ ## Environment Variables Required
30
+
31
+ # Make sure to set the appropriate environment variables:
32
+ # - OpenAI: `OPENAI_API_KEY`
33
+ # - Anthropic: `ANTHROPIC_API_KEY`
34
+ # - Cohere: `COHERE_API_KEY`
35
+ # - HuggingFace: `HF_TOKEN`
requirements.txt CHANGED
@@ -1,5 +1,19 @@
1
- fastapi
2
- gradio[mcp]>=4.26.0
3
- huggingface_hub>=0.32.6
4
- pydantic>=2
5
- uvicorn[standard]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ gradio>=4.0.0
3
+ gradio[mcp]
4
+ python-dotenv>=1.0.0
5
+
6
+ # LangChain core
7
+ langchain-core>=0.1.0
8
+ langchain-community>=0.0.1
9
+
10
+ # Provider-specific LangChain packages
11
+ langchain-openai>=0.1.0
12
+ langchain-anthropic>=0.1.0
13
+ langchain-cohere>=0.1.0
14
+ langchain-together>=0.1.0
15
+ langchain-huggingface>=0.0.1
16
+
17
+ # Additional dependencies that might be needed
18
+ requests>=2.31.0
19
+ pydantic>=2.0.0