Gradio Interface

Gradio 是一个用于构建机器学习模型交互式界面的 Python 库。它的主要目标是简化机器学习模型的部署和展示,使非技术用户能够轻松地与模型进行交互,而无需编写任何代码。Gradio 中的 Interface 类是 Gradio 库的核心组件之一,用于创建交互式机器学习模型界面。

pip install gradio==3.44.3

1. Hello World

接下来,我们使用上面两种方式分别构建一个包含文本输入文本框,输出文本框,以及提交按钮的页面。

import gradio as gr


def process_function(num1, operator, num2):
    ret = None
    if operator == '+':
        ret = num1 + num2
    if operator == '-':
        ret = num1 - num2
    if operator == '*':
        ret = num1 * num2
    if operator == '/':
        ret = num1 / num2
    return ret


def test():

    app = gr.Interface(fn=process_function,
                       inputs=['number',
                               gr.Radio(choices=['+', '-', '*', '/'], value='+'),
                               'number'],
                       outputs='text',
                       # 页面标题
                       title='Interface Demo',
                       # 页面控件上方的文字描述
                       description='1. 页面控件上方的文字描述',
                       # 页面控件下方的文字描述
                       article='2. 页面控件下方的文字描述')
    app.launch()


if __name__ == '__main__':
    test()

浏览器访问:http://127.0.0.1:7860/,将会显示如下界面

2. Example Inputs

Examples 是 Interface 中的一个参数,用于为模型提供输入示例,参数是单个示例输入或示例输入的列表。

import gradio as gr


def process_function(num1, num2):

    return num1 + num2


def test():

    app = gr.Interface(fn=process_function,
                       inputs=['number',
                               'number'],
                       outputs='number',
                       examples=[[10, 20], [30, 40], [50, 60]])
    app.launch()


if __name__ == '__main__':
    test()

点击 Example 会将值自动输入到模型输入框中。

3. Interface State

使用

3.1 Global State

记录所有用户的数据。

import gradio as gr

# 用于记录所有用户的输入记录
state_list = []

def process_function(user_input):
    # 记录输入记录
    state_list.append(user_input)
    return '输出是: ' + user_input, ','.join(state_list)

def test():

    user_input = gr.Textbox()
    user_output = gr.Textbox()
    user_states = gr.Textbox()

    app = gr.Interface(fn=process_function, inputs=user_input,
                       outputs=[user_output, user_states])
    app.launch()


if __name__ == '__main__':
    test()

多个客户端共享所有用户的历史输入记录。

3.2 Session State

只记录当前用户的数据。

import gradio as gr

def process_function(user_input, user_state):
    return '输出结果: ' + user_input


def keep_state(user_input, user_state):

    if user_state == '':
        user_states = user_input
    else:
        user_states = user_state + ',' + user_input

    return user_input, user_states


def test():

    with gr.Blocks() as demo:
        user_input = gr.Textbox(label='user_input')
        user_output = gr.Textbox(label='user_output')
        user_state = gr.Textbox(label='input history', visible=True)

        # 首先,将输入和输入历史发送到 keep_state 函数中,将历史拼接到输入历史中
        # 然后,再将输入和新的历史送入到 process_function 函数中进行处理
        user_input.submit(keep_state, [user_input, user_state], [user_input, user_state]).then(process_function, [user_input, user_state], user_output)

    demo.launch()

if __name__ == '__main__':
    test()

每个客户端独享历史输入记录。

4. Iterative Outputs

输出的结果并不是单一结果,而是一个流序列,可以进行如下步骤操作:

  1. 将 fn 处理函数修改为 yield 生成器
  2. 将 app 设置为 queue 模式
import gradio as gr
import time


def process_function(user_input):
    for number in range(int(user_input)):
        time.sleep(1)
        yield number


def test():

    user_input = gr.Number(label='user_input')
    user_output = gr.Number(label='user_output')

    app = gr.Interface(fn=process_function, inputs=user_input, outputs=user_output)
    # 如果处理函数为生成器,则需要调用 queue 函数
    app.queue()
    app.launch()


if __name__ == '__main__':
    test()

5. Progress Bars

使用进度条要注意两点:

  1. 在 fn 函数输入参数右侧增加一个 gr.Progress() 默认参数,不允许创建为局部变量;
  2. app 需要设置为 queue 模式。
import gradio as gr
import time

def process_function(number, progress=gr.Progress()):

    # 提示
    progress(0, desc="正在准备")
    time.sleep(1)
    result = 0.0

    # 迭代
    for num in progress.tqdm(range(int(number) + 1), desc="正在计算"):
        time.sleep(0.1)
        result += num

    return result

def test():
    demo = gr.Interface(process_function, inputs='number', outputs='number')
    demo.queue()
    demo.launch()

if __name__ == "__main__":
    test()

6. Flagging

Flagging 的目的是将用户看到的输入、输出结果提交服务器端。这在有些场景下比较有用,例如:用户可以反馈模型的预测结果。

import gradio as gr
import time

def process_function(number):
    return 'hello'

def test():
    # allow_flagging: never 禁用 auto 每次自动标记 manual 每次手动标记
    # flagging_dir: 默认 flagged,表示标记的记录存储路径
    demo = gr.Interface(process_function,
                        inputs='text',
                        outputs='text',
                        allow_flagging='manual',
                        flagging_dir='flagged', 
                        flagging_options=['错误', '违法'])
    demo.launch()

if __name__ == "__main__":
    test()

7. Alerts

import gradio as gr
import time

def process_function(text):

    # 不显示弹窗
    if text == 'e' or text == 'E':
        gr.Error('这是错误消息')

    if text == 'i' or  text == 'I':
        gr.Info('这是普通消息')

    if text == 'w' or text == 'W':
        gr.Warning('这是警告消息')

    return 'hello'


def test():
    demo = gr.Interface(process_function, inputs='text', outputs='text')
    demo.queue().launch()

if __name__ == "__main__":
    test()

未经允许不得转载:一亩三分地 » Gradio Interface
评论 (0)

7 + 1 =