Gradio 是一个用于构建机器学习模型交互式界面的 Python 库。它的主要目标是简化机器学习模型的部署和展示,使非技术用户能够轻松地与模型进行交互,而无需编写任何代码。Gradio 中的 Interface 类是 Gradio 库的核心组件之一,用于创建交互式机器学习模型界面。
pip install gradio==3.44.3
1. Hello World
接下来,我们使用上面两种方式分别构建一个包含文本输入文本框,输出文本框,以及提交按钮的页面。
import gradio as gr def process_function(num1, operator, num2): ret = None if operator == '+': ret = num1 + num2 if operator == '-': ret = num1 - num2 if operator == '*': ret = num1 * num2 if operator == '/': ret = num1 / num2 return ret def test(): app = gr.Interface(fn=process_function, inputs=['number', gr.Radio(choices=['+', '-', '*', '/'], value='+'), 'number'], outputs='text', # 页面标题 title='Interface Demo', # 页面控件上方的文字描述 description='1. 页面控件上方的文字描述', # 页面控件下方的文字描述 article='2. 页面控件下方的文字描述') app.launch() if __name__ == '__main__': test()
浏览器访问:http://127.0.0.1:7860/,将会显示如下界面
2. Example Inputs
Examples 是 Interface 中的一个参数,用于为模型提供输入示例,参数是单个示例输入或示例输入的列表。
import gradio as gr def process_function(num1, num2): return num1 + num2 def test(): app = gr.Interface(fn=process_function, inputs=['number', 'number'], outputs='number', examples=[[10, 20], [30, 40], [50, 60]]) app.launch() if __name__ == '__main__': test()
点击 Example 会将值自动输入到模型输入框中。
3. Interface State
使用
3.1 Global State
记录所有用户的数据。
import gradio as gr # 用于记录所有用户的输入记录 state_list = [] def process_function(user_input): # 记录输入记录 state_list.append(user_input) return '输出是: ' + user_input, ','.join(state_list) def test(): user_input = gr.Textbox() user_output = gr.Textbox() user_states = gr.Textbox() app = gr.Interface(fn=process_function, inputs=user_input, outputs=[user_output, user_states]) app.launch() if __name__ == '__main__': test()
多个客户端共享所有用户的历史输入记录。
3.2 Session State
只记录当前用户的数据。
import gradio as gr def process_function(user_input, user_state): return '输出结果: ' + user_input def keep_state(user_input, user_state): if user_state == '': user_states = user_input else: user_states = user_state + ',' + user_input return user_input, user_states def test(): with gr.Blocks() as demo: user_input = gr.Textbox(label='user_input') user_output = gr.Textbox(label='user_output') user_state = gr.Textbox(label='input history', visible=True) # 首先,将输入和输入历史发送到 keep_state 函数中,将历史拼接到输入历史中 # 然后,再将输入和新的历史送入到 process_function 函数中进行处理 user_input.submit(keep_state, [user_input, user_state], [user_input, user_state]).then(process_function, [user_input, user_state], user_output) demo.launch() if __name__ == '__main__': test()
每个客户端独享历史输入记录。
4. Iterative Outputs
输出的结果并不是单一结果,而是一个流序列,可以进行如下步骤操作:
- 将 fn 处理函数修改为 yield 生成器
- 将 app 设置为 queue 模式
import gradio as gr import time def process_function(user_input): for number in range(int(user_input)): time.sleep(1) yield number def test(): user_input = gr.Number(label='user_input') user_output = gr.Number(label='user_output') app = gr.Interface(fn=process_function, inputs=user_input, outputs=user_output) # 如果处理函数为生成器,则需要调用 queue 函数 app.queue() app.launch() if __name__ == '__main__': test()
5. Progress Bars
使用进度条要注意两点:
- 在 fn 函数输入参数右侧增加一个 gr.Progress() 默认参数,不允许创建为局部变量;
- app 需要设置为 queue 模式。
import gradio as gr import time def process_function(number, progress=gr.Progress()): # 提示 progress(0, desc="正在准备") time.sleep(1) result = 0.0 # 迭代 for num in progress.tqdm(range(int(number) + 1), desc="正在计算"): time.sleep(0.1) result += num return result def test(): demo = gr.Interface(process_function, inputs='number', outputs='number') demo.queue() demo.launch() if __name__ == "__main__": test()
6. Flagging
Flagging 的目的是将用户看到的输入、输出结果提交服务器端。这在有些场景下比较有用,例如:用户可以反馈模型的预测结果。
import gradio as gr import time def process_function(number): return 'hello' def test(): # allow_flagging: never 禁用 auto 每次自动标记 manual 每次手动标记 # flagging_dir: 默认 flagged,表示标记的记录存储路径 demo = gr.Interface(process_function, inputs='text', outputs='text', allow_flagging='manual', flagging_dir='flagged', flagging_options=['错误', '违法']) demo.launch() if __name__ == "__main__": test()
7. Alerts
import gradio as gr import time def process_function(text): # 不显示弹窗 if text == 'e' or text == 'E': gr.Error('这是错误消息') if text == 'i' or text == 'I': gr.Info('这是普通消息') if text == 'w' or text == 'W': gr.Warning('这是警告消息') return 'hello' def test(): demo = gr.Interface(process_function, inputs='text', outputs='text') demo.queue().launch() if __name__ == "__main__": test()